TheVera commited on
Commit
c696661
·
verified ·
1 Parent(s): 369b62f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -19,8 +19,9 @@ LABELS = [
19
  ALLOWED = set(LABELS)
20
 
21
  def log(msg, **kv):
22
- print(" | ".join([msg] + [f"{k}={v}"]) for k, v in kv.items())
23
- print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]), file=sys.stderr, flush=True)
 
24
 
25
  # Init HF client once
26
  client = InferenceClient(token=API_KEY) if API_KEY else None
@@ -52,29 +53,42 @@ def generate_text():
52
  return jsonify({"response": "generate_image"}), 200
53
 
54
  try:
55
- # Zero-shot classification
56
  log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
57
  zs = client.zero_shot_classification(
58
  prompt,
59
  LABELS,
60
  model=ZSL_MODEL_ID,
61
  hypothesis_template="This text is about {}.",
62
- multi_label=False, # pick the best single label
63
  )
64
- # Response shape: {'labels': [...], 'scores': [...], 'sequence': '...'}
65
- labels = zs.get("labels") or []
 
 
 
 
 
 
 
 
 
66
  scores = zs.get("scores") or []
 
 
 
 
67
  best = labels[0] if labels else "other_query"
68
  score = float(scores[0]) if scores else 0.0
69
-
70
  token = best if best in ALLOWED else "other_query"
 
71
  log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
72
  return jsonify({"response": token}), 200
73
-
74
  except Exception as e:
75
  log("DECISION_FAIL", error=str(e))
76
  return jsonify({"response": "other_query", "error": str(e)}), 200
77
 
 
78
  if __name__ == "__main__":
79
  port = int(os.getenv("PORT", 7860))
80
  log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
 
19
  ALLOWED = set(LABELS)
20
 
21
  def log(msg, **kv):
22
+ line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()])
23
+ print(line, file=sys.stderr, flush=True)
24
+
25
 
26
  # Init HF client once
27
  client = InferenceClient(token=API_KEY) if API_KEY else None
 
53
  return jsonify({"response": "generate_image"}), 200
54
 
55
  try:
 
56
  log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
57
  zs = client.zero_shot_classification(
58
  prompt,
59
  LABELS,
60
  model=ZSL_MODEL_ID,
61
  hypothesis_template="This text is about {}.",
62
+ multi_label=False, # single best label
63
  )
64
+
65
+ # Normalize shapes:
66
+ # - Newer hub often returns a dict
67
+ # - Some providers return a list[dict] (one per input)
68
+ if isinstance(zs, list):
69
+ zs = zs[0] if zs else {}
70
+
71
+ if not isinstance(zs, dict):
72
+ raise ValueError(f"Unexpected ZSL response type: {type(zs)}")
73
+
74
+ labels = zs.get("labels") or zs.get("candidate_labels") or []
75
  scores = zs.get("scores") or []
76
+ if not labels and "label" in zs:
77
+ labels = [zs["label"]]
78
+ scores = [zs.get("score", 0.0)]
79
+
80
  best = labels[0] if labels else "other_query"
81
  score = float(scores[0]) if scores else 0.0
 
82
  token = best if best in ALLOWED else "other_query"
83
+
84
  log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
85
  return jsonify({"response": token}), 200
86
+
87
  except Exception as e:
88
  log("DECISION_FAIL", error=str(e))
89
  return jsonify({"response": "other_query", "error": str(e)}), 200
90
 
91
+
92
  if __name__ == "__main__":
93
  port = int(os.getenv("PORT", 7860))
94
  log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))