Spaces:
Running
Running
import os | |
import sys | |
from flask import Flask, request, jsonify | |
from huggingface_hub import InferenceClient | |
app = Flask(__name__) | |
API_KEY = (os.getenv("API_KEY") or "").strip() | |
# Multilingual zero-shot model (handles Hindi + English well) | |
ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip() | |
LABELS = [ | |
"health_wellness", | |
"spiritual_guidance", | |
"generate_image", | |
"realtime_query", | |
"other_query", | |
] | |
ALLOWED = set(LABELS) | |
def log(msg, **kv): | |
line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]) | |
print(line, file=sys.stderr, flush=True) | |
# Init HF client once | |
client = InferenceClient(token=API_KEY) if API_KEY else None | |
def root(): | |
return jsonify({"ok": True, "model": ZSL_MODEL_ID}) | |
def generate_text(): | |
if not API_KEY: | |
log("DECISION_ERR", reason="missing_api_key") | |
return jsonify({"error": "Missing API_KEY"}), 400 | |
if client is None: | |
log("DECISION_ERR", reason="client_not_initialized") | |
return jsonify({"error": "Client not initialized"}), 500 | |
data = request.get_json(silent=True) or {} | |
prompt = (data.get("prompt") or "").strip() | |
instructions = (data.get("instructions") or "").strip() # not required here | |
if not prompt: | |
log("DECISION_BAD_REQ", has_prompt=False) | |
return jsonify({"error": "Missing required fields"}), 400 | |
# Fast-path: explicit image command | |
if prompt.startswith("/image "): | |
log("DECISION_FAST", token="generate_image") | |
return jsonify({"response": "generate_image"}), 200 | |
try: | |
log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt)) | |
zs = client.zero_shot_classification( | |
prompt, | |
LABELS, | |
model=ZSL_MODEL_ID, | |
hypothesis_template="This text is about {}.", | |
multi_label=False, # single best label | |
) | |
# Normalize shapes: | |
# - Newer hub often returns a dict | |
# - Some providers return a list[dict] (one per input) | |
if isinstance(zs, list): | |
zs = zs[0] if zs else {} | |
if not isinstance(zs, dict): | |
raise ValueError(f"Unexpected ZSL response type: {type(zs)}") | |
labels = zs.get("labels") or zs.get("candidate_labels") or [] | |
scores = zs.get("scores") or [] | |
if not labels and "label" in zs: | |
labels = [zs["label"]] | |
scores = [zs.get("score", 0.0)] | |
best = labels[0] if labels else "other_query" | |
score = float(scores[0]) if scores else 0.0 | |
token = best if best in ALLOWED else "other_query" | |
log("DECISION_OK", token=token, top_label=best, score=round(score, 4)) | |
return jsonify({"response": token}), 200 | |
except Exception as e: | |
log("DECISION_FAIL", error=str(e)) | |
return jsonify({"response": "other_query", "error": str(e)}), 200 | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY)) | |
app.run(host="0.0.0.0", port=port) | |