TheVera's picture
Update app.py
c696661 verified
import os
import sys
from flask import Flask, request, jsonify
from huggingface_hub import InferenceClient
app = Flask(__name__)
API_KEY = (os.getenv("API_KEY") or "").strip()
# Multilingual zero-shot model (handles Hindi + English well)
ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip()
LABELS = [
"health_wellness",
"spiritual_guidance",
"generate_image",
"realtime_query",
"other_query",
]
ALLOWED = set(LABELS)
def log(msg, **kv):
line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()])
print(line, file=sys.stderr, flush=True)
# Init HF client once
client = InferenceClient(token=API_KEY) if API_KEY else None
@app.get("/")
def root():
return jsonify({"ok": True, "model": ZSL_MODEL_ID})
@app.post("/generate_text")
def generate_text():
if not API_KEY:
log("DECISION_ERR", reason="missing_api_key")
return jsonify({"error": "Missing API_KEY"}), 400
if client is None:
log("DECISION_ERR", reason="client_not_initialized")
return jsonify({"error": "Client not initialized"}), 500
data = request.get_json(silent=True) or {}
prompt = (data.get("prompt") or "").strip()
instructions = (data.get("instructions") or "").strip() # not required here
if not prompt:
log("DECISION_BAD_REQ", has_prompt=False)
return jsonify({"error": "Missing required fields"}), 400
# Fast-path: explicit image command
if prompt.startswith("/image "):
log("DECISION_FAST", token="generate_image")
return jsonify({"response": "generate_image"}), 200
try:
log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
zs = client.zero_shot_classification(
prompt,
LABELS,
model=ZSL_MODEL_ID,
hypothesis_template="This text is about {}.",
multi_label=False, # single best label
)
# Normalize shapes:
# - Newer hub often returns a dict
# - Some providers return a list[dict] (one per input)
if isinstance(zs, list):
zs = zs[0] if zs else {}
if not isinstance(zs, dict):
raise ValueError(f"Unexpected ZSL response type: {type(zs)}")
labels = zs.get("labels") or zs.get("candidate_labels") or []
scores = zs.get("scores") or []
if not labels and "label" in zs:
labels = [zs["label"]]
scores = [zs.get("score", 0.0)]
best = labels[0] if labels else "other_query"
score = float(scores[0]) if scores else 0.0
token = best if best in ALLOWED else "other_query"
log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
return jsonify({"response": token}), 200
except Exception as e:
log("DECISION_FAIL", error=str(e))
return jsonify({"response": "other_query", "error": str(e)}), 200
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
app.run(host="0.0.0.0", port=port)