import os import sys from flask import Flask, request, jsonify from huggingface_hub import InferenceClient app = Flask(__name__) API_KEY = (os.getenv("API_KEY") or "").strip() # Multilingual zero-shot model (handles Hindi + English well) ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip() LABELS = [ "health_wellness", "spiritual_guidance", "generate_image", "realtime_query", "other_query", ] ALLOWED = set(LABELS) def log(msg, **kv): line = " | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]) print(line, file=sys.stderr, flush=True) # Init HF client once client = InferenceClient(token=API_KEY) if API_KEY else None @app.get("/") def root(): return jsonify({"ok": True, "model": ZSL_MODEL_ID}) @app.post("/generate_text") def generate_text(): if not API_KEY: log("DECISION_ERR", reason="missing_api_key") return jsonify({"error": "Missing API_KEY"}), 400 if client is None: log("DECISION_ERR", reason="client_not_initialized") return jsonify({"error": "Client not initialized"}), 500 data = request.get_json(silent=True) or {} prompt = (data.get("prompt") or "").strip() instructions = (data.get("instructions") or "").strip() # not required here if not prompt: log("DECISION_BAD_REQ", has_prompt=False) return jsonify({"error": "Missing required fields"}), 400 # Fast-path: explicit image command if prompt.startswith("/image "): log("DECISION_FAST", token="generate_image") return jsonify({"response": "generate_image"}), 200 try: log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt)) zs = client.zero_shot_classification( prompt, LABELS, model=ZSL_MODEL_ID, hypothesis_template="This text is about {}.", multi_label=False, # single best label ) # Normalize shapes: # - Newer hub often returns a dict # - Some providers return a list[dict] (one per input) if isinstance(zs, list): zs = zs[0] if zs else {} if not isinstance(zs, dict): raise ValueError(f"Unexpected ZSL response type: {type(zs)}") labels = zs.get("labels") or zs.get("candidate_labels") or [] scores = zs.get("scores") or [] if not labels and "label" in zs: labels = [zs["label"]] scores = [zs.get("score", 0.0)] best = labels[0] if labels else "other_query" score = float(scores[0]) if scores else 0.0 token = best if best in ALLOWED else "other_query" log("DECISION_OK", token=token, top_label=best, score=round(score, 4)) return jsonify({"response": token}), 200 except Exception as e: log("DECISION_FAIL", error=str(e)) return jsonify({"response": "other_query", "error": str(e)}), 200 if __name__ == "__main__": port = int(os.getenv("PORT", 7860)) log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY)) app.run(host="0.0.0.0", port=port)