File size: 2,818 Bytes
1314214
37becb1
f9faf91
8e98672
1314214
 
 
369b62f
 
 
37becb1
369b62f
 
 
 
 
 
 
 
37becb1
 
369b62f
 
37becb1
369b62f
 
d2b7fba
16db54b
 
369b62f
16db54b
d2b7fba
1314214
37becb1
 
 
369b62f
 
 
5c2c28b
37becb1
 
369b62f
 
 
 
1314214
 
369b62f
 
 
 
 
37becb1
369b62f
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b7fba
369b62f
 
 
 
 
 
 
1314214
 
37becb1
369b62f
37becb1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import sys
from flask import Flask, request, jsonify
from huggingface_hub import InferenceClient

app = Flask(__name__)

API_KEY = (os.getenv("API_KEY") or "").strip()
# Multilingual zero-shot model (handles Hindi + English well)
ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip()

LABELS = [
    "health_wellness",
    "spiritual_guidance",
    "generate_image",
    "realtime_query",
    "other_query",
]
ALLOWED = set(LABELS)

def log(msg, **kv):
    print(" | ".join([msg] + [f"{k}={v}"]) for k, v in kv.items())
    print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]), file=sys.stderr, flush=True)

# Init HF client once
client = InferenceClient(token=API_KEY) if API_KEY else None

@app.get("/")
def root():
    return jsonify({"ok": True, "model": ZSL_MODEL_ID})

@app.post("/generate_text")
def generate_text():
    if not API_KEY:
        log("DECISION_ERR", reason="missing_api_key")
        return jsonify({"error": "Missing API_KEY"}), 400
    if client is None:
        log("DECISION_ERR", reason="client_not_initialized")
        return jsonify({"error": "Client not initialized"}), 500

    data = request.get_json(silent=True) or {}
    prompt = (data.get("prompt") or "").strip()
    instructions = (data.get("instructions") or "").strip()  # not required here

    if not prompt:
        log("DECISION_BAD_REQ", has_prompt=False)
        return jsonify({"error": "Missing required fields"}), 400

    # Fast-path: explicit image command
    if prompt.startswith("/image "):
        log("DECISION_FAST", token="generate_image")
        return jsonify({"response": "generate_image"}), 200

    try:
        # Zero-shot classification
        log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
        zs = client.zero_shot_classification(
            prompt,
            LABELS,
            model=ZSL_MODEL_ID,
            hypothesis_template="This text is about {}.",
            multi_label=False,          # pick the best single label
        )
        # Response shape: {'labels': [...], 'scores': [...], 'sequence': '...'}
        labels = zs.get("labels") or []
        scores = zs.get("scores") or []
        best = labels[0] if labels else "other_query"
        score = float(scores[0]) if scores else 0.0

        token = best if best in ALLOWED else "other_query"
        log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
        return jsonify({"response": token}), 200

    except Exception as e:
        log("DECISION_FAIL", error=str(e))
        return jsonify({"response": "other_query", "error": str(e)}), 200

if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860))
    log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
    app.run(host="0.0.0.0", port=port)