Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -19,8 +19,9 @@ LABELS = [
|
|
19 |
ALLOWED = set(LABELS)
|
20 |
|
21 |
def log(msg, **kv):
|
22 |
-
|
23 |
-
print(
|
|
|
24 |
|
25 |
# Init HF client once
|
26 |
client = InferenceClient(token=API_KEY) if API_KEY else None
|
@@ -52,29 +53,42 @@ def generate_text():
|
|
52 |
return jsonify({"response": "generate_image"}), 200
|
53 |
|
54 |
try:
|
55 |
-
# Zero-shot classification
|
56 |
log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
|
57 |
zs = client.zero_shot_classification(
|
58 |
prompt,
|
59 |
LABELS,
|
60 |
model=ZSL_MODEL_ID,
|
61 |
hypothesis_template="This text is about {}.",
|
62 |
-
multi_label=False,
|
63 |
)
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
scores = zs.get("scores") or []
|
|
|
|
|
|
|
|
|
67 |
best = labels[0] if labels else "other_query"
|
68 |
score = float(scores[0]) if scores else 0.0
|
69 |
-
|
70 |
token = best if best in ALLOWED else "other_query"
|
|
|
71 |
log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
|
72 |
return jsonify({"response": token}), 200
|
73 |
-
|
74 |
except Exception as e:
|
75 |
log("DECISION_FAIL", error=str(e))
|
76 |
return jsonify({"response": "other_query", "error": str(e)}), 200
|
77 |
|
|
|
78 |
if __name__ == "__main__":
|
79 |
port = int(os.getenv("PORT", 7860))
|
80 |
log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
|
|
|
19 |
ALLOWED = set(LABELS)
|
20 |
|
21 |
def log(msg, **kv):
|
22 |
+
line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()])
|
23 |
+
print(line, file=sys.stderr, flush=True)
|
24 |
+
|
25 |
|
26 |
# Init HF client once
|
27 |
client = InferenceClient(token=API_KEY) if API_KEY else None
|
|
|
53 |
return jsonify({"response": "generate_image"}), 200
|
54 |
|
55 |
try:
|
|
|
56 |
log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
|
57 |
zs = client.zero_shot_classification(
|
58 |
prompt,
|
59 |
LABELS,
|
60 |
model=ZSL_MODEL_ID,
|
61 |
hypothesis_template="This text is about {}.",
|
62 |
+
multi_label=False, # single best label
|
63 |
)
|
64 |
+
|
65 |
+
# Normalize shapes:
|
66 |
+
# - Newer hub often returns a dict
|
67 |
+
# - Some providers return a list[dict] (one per input)
|
68 |
+
if isinstance(zs, list):
|
69 |
+
zs = zs[0] if zs else {}
|
70 |
+
|
71 |
+
if not isinstance(zs, dict):
|
72 |
+
raise ValueError(f"Unexpected ZSL response type: {type(zs)}")
|
73 |
+
|
74 |
+
labels = zs.get("labels") or zs.get("candidate_labels") or []
|
75 |
scores = zs.get("scores") or []
|
76 |
+
if not labels and "label" in zs:
|
77 |
+
labels = [zs["label"]]
|
78 |
+
scores = [zs.get("score", 0.0)]
|
79 |
+
|
80 |
best = labels[0] if labels else "other_query"
|
81 |
score = float(scores[0]) if scores else 0.0
|
|
|
82 |
token = best if best in ALLOWED else "other_query"
|
83 |
+
|
84 |
log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
|
85 |
return jsonify({"response": token}), 200
|
86 |
+
|
87 |
except Exception as e:
|
88 |
log("DECISION_FAIL", error=str(e))
|
89 |
return jsonify({"response": "other_query", "error": str(e)}), 200
|
90 |
|
91 |
+
|
92 |
if __name__ == "__main__":
|
93 |
port = int(os.getenv("PORT", 7860))
|
94 |
log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
|