Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,143 +1,81 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import json
|
4 |
-
import re
|
5 |
-
import time
|
6 |
-
import traceback
|
7 |
-
import requests
|
8 |
from flask import Flask, request, jsonify
|
9 |
from huggingface_hub import InferenceClient
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
# Optional: your private Inference Endpoint URL (recommended for Mixtral)
|
17 |
-
DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip()
|
18 |
-
TIMEOUT = 25
|
19 |
|
20 |
-
|
21 |
-
"health_wellness",
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def log(msg, **kv):
|
25 |
-
print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()
|
26 |
-
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
|
31 |
-
|
32 |
-
_token_re = re.compile(
|
33 |
-
r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b",
|
34 |
-
re.I
|
35 |
-
)
|
36 |
-
def extract_category(text: str) -> str:
|
37 |
-
raw = (text or "").strip().lower()
|
38 |
-
m = _token_re.search(raw)
|
39 |
-
if m: return m.group(1)
|
40 |
-
first = raw.split()[0].strip(",.;:|") if raw else ""
|
41 |
-
return first if first in ALLOWED else "other_query"
|
42 |
-
|
43 |
-
def hf_conversational(prompt: str) -> str:
|
44 |
-
"""Call conversational endpoint (public API or your private DECISION_ENDPOINT)."""
|
45 |
-
url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
|
46 |
-
headers = {
|
47 |
-
"Authorization": f"Bearer {API_KEY}",
|
48 |
-
"Accept": "application/json",
|
49 |
-
"Content-Type": "application/json",
|
50 |
-
}
|
51 |
-
payload = {
|
52 |
-
"inputs": {
|
53 |
-
"past_user_inputs": [],
|
54 |
-
"generated_responses": [],
|
55 |
-
"text": prompt,
|
56 |
-
},
|
57 |
-
"parameters": {
|
58 |
-
"max_new_tokens": 3,
|
59 |
-
"temperature": 0.0,
|
60 |
-
"top_p": 1.0,
|
61 |
-
"repetition_penalty": 1.0,
|
62 |
-
"stop": ["\n"],
|
63 |
-
"return_full_text": False,
|
64 |
-
},
|
65 |
-
"options": {"use_cache": True, "wait_for_model": True},
|
66 |
-
}
|
67 |
-
for attempt in range(3):
|
68 |
-
r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
|
69 |
-
if r.status_code == 503:
|
70 |
-
time.sleep(2 + attempt)
|
71 |
-
continue
|
72 |
-
r.raise_for_status()
|
73 |
-
data = r.json()
|
74 |
-
if isinstance(data, dict) and "generated_text" in data:
|
75 |
-
return str(data["generated_text"]).strip()
|
76 |
-
if isinstance(data, dict) and "choices" in data and data["choices"]:
|
77 |
-
ch = data["choices"][0]
|
78 |
-
txt = ch.get("text") or ch.get("message", {}).get("content")
|
79 |
-
if txt: return str(txt).strip()
|
80 |
-
if isinstance(data, list) and data and isinstance(data[0], dict):
|
81 |
-
if "generated_text" in data[0]:
|
82 |
-
return str(data[0]["generated_text"]).strip()
|
83 |
-
if "generated_responses" in data[0]:
|
84 |
-
gresps = data[0]["generated_responses"]
|
85 |
-
if isinstance(gresps, list) and gresps:
|
86 |
-
return str(gresps[-1]).strip()
|
87 |
-
return str(data).strip()
|
88 |
-
return ""
|
89 |
-
|
90 |
-
def try_text_generation(client: InferenceClient, formatted: str) -> str:
|
91 |
-
return client.text_generation(
|
92 |
-
formatted,
|
93 |
-
model=MODEL_ID,
|
94 |
-
temperature=0.0,
|
95 |
-
max_new_tokens=3,
|
96 |
-
top_p=1.0,
|
97 |
-
repetition_penalty=1.0,
|
98 |
-
do_sample=False,
|
99 |
-
stop=["\n"],
|
100 |
-
details=False
|
101 |
-
)
|
102 |
|
103 |
@app.get("/")
|
104 |
def root():
|
105 |
-
return jsonify({"ok": True, "model":
|
106 |
|
107 |
@app.post("/generate_text")
|
108 |
def generate_text():
|
109 |
if not API_KEY:
|
110 |
log("DECISION_ERR", reason="missing_api_key")
|
111 |
return jsonify({"error": "Missing API_KEY"}), 400
|
|
|
|
|
|
|
112 |
|
113 |
data = request.get_json(silent=True) or {}
|
114 |
prompt = (data.get("prompt") or "").strip()
|
115 |
-
instructions = (data.get("instructions") or "").strip()
|
116 |
-
|
117 |
-
|
|
|
118 |
return jsonify({"error": "Missing required fields"}), 400
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
122 |
try:
|
123 |
-
|
124 |
-
log("
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
139 |
|
140 |
if __name__ == "__main__":
|
141 |
port = int(os.getenv("PORT", 7860))
|
142 |
-
log("BOOT",
|
143 |
app.run(host="0.0.0.0", port=port)
|
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
3 |
from flask import Flask, request, jsonify
|
4 |
from huggingface_hub import InferenceClient
|
5 |
|
6 |
app = Flask(__name__)
|
7 |
|
8 |
+
API_KEY = (os.getenv("API_KEY") or "").strip()
|
9 |
+
# Multilingual zero-shot model (handles Hindi + English well)
|
10 |
+
ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip()
|
|
|
|
|
|
|
11 |
|
12 |
+
LABELS = [
|
13 |
+
"health_wellness",
|
14 |
+
"spiritual_guidance",
|
15 |
+
"generate_image",
|
16 |
+
"realtime_query",
|
17 |
+
"other_query",
|
18 |
+
]
|
19 |
+
ALLOWED = set(LABELS)
|
20 |
|
21 |
def log(msg, **kv):
|
22 |
+
print(" | ".join([msg] + [f"{k}={v}"]) for k, v in kv.items())
|
23 |
+
print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]), file=sys.stderr, flush=True)
|
24 |
|
25 |
+
# Init HF client once
|
26 |
+
client = InferenceClient(token=API_KEY) if API_KEY else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
@app.get("/")
|
29 |
def root():
|
30 |
+
return jsonify({"ok": True, "model": ZSL_MODEL_ID})
|
31 |
|
32 |
@app.post("/generate_text")
|
33 |
def generate_text():
|
34 |
if not API_KEY:
|
35 |
log("DECISION_ERR", reason="missing_api_key")
|
36 |
return jsonify({"error": "Missing API_KEY"}), 400
|
37 |
+
if client is None:
|
38 |
+
log("DECISION_ERR", reason="client_not_initialized")
|
39 |
+
return jsonify({"error": "Client not initialized"}), 500
|
40 |
|
41 |
data = request.get_json(silent=True) or {}
|
42 |
prompt = (data.get("prompt") or "").strip()
|
43 |
+
instructions = (data.get("instructions") or "").strip() # not required here
|
44 |
+
|
45 |
+
if not prompt:
|
46 |
+
log("DECISION_BAD_REQ", has_prompt=False)
|
47 |
return jsonify({"error": "Missing required fields"}), 400
|
48 |
|
49 |
+
# Fast-path: explicit image command
|
50 |
+
if prompt.startswith("/image "):
|
51 |
+
log("DECISION_FAST", token="generate_image")
|
52 |
+
return jsonify({"response": "generate_image"}), 200
|
53 |
+
|
54 |
try:
|
55 |
+
# Zero-shot classification
|
56 |
+
log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
|
57 |
+
zs = client.zero_shot_classification(
|
58 |
+
prompt,
|
59 |
+
LABELS,
|
60 |
+
model=ZSL_MODEL_ID,
|
61 |
+
hypothesis_template="This text is about {}.",
|
62 |
+
multi_label=False, # pick the best single label
|
63 |
+
)
|
64 |
+
# Response shape: {'labels': [...], 'scores': [...], 'sequence': '...'}
|
65 |
+
labels = zs.get("labels") or []
|
66 |
+
scores = zs.get("scores") or []
|
67 |
+
best = labels[0] if labels else "other_query"
|
68 |
+
score = float(scores[0]) if scores else 0.0
|
69 |
|
70 |
+
token = best if best in ALLOWED else "other_query"
|
71 |
+
log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
|
72 |
+
return jsonify({"response": token}), 200
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
log("DECISION_FAIL", error=str(e))
|
76 |
+
return jsonify({"response": "other_query", "error": str(e)}), 200
|
77 |
|
78 |
if __name__ == "__main__":
|
79 |
port = int(os.getenv("PORT", 7860))
|
80 |
+
log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
|
81 |
app.run(host="0.0.0.0", port=port)
|