Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,138 +1,149 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
import json
|
|
|
|
|
4 |
import traceback
|
|
|
5 |
from flask import Flask, request, jsonify
|
6 |
from huggingface_hub import InferenceClient
|
7 |
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
-
#
|
11 |
-
MODEL_ID = os.getenv("MODEL_ID", "mistralai/
|
12 |
-
API_KEY = os.getenv("API_KEY")
|
13 |
-
|
14 |
-
#
|
15 |
-
|
|
|
16 |
|
17 |
ALLOWED = {
|
18 |
-
"health_wellness",
|
19 |
-
"spiritual_guidance",
|
20 |
-
"generate_image",
|
21 |
-
"realtime_query",
|
22 |
-
"other_query",
|
23 |
}
|
24 |
|
25 |
def log(msg, **kv):
|
26 |
-
|
27 |
-
parts = [msg] + [f"{k}={v}" for k, v in kv.items()]
|
28 |
print(" | ".join(parts), file=sys.stderr, flush=True)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
"""
|
34 |
-
sys_block = f"<<SYS>>{custom_instructions}<<SYS>>" if custom_instructions else ""
|
35 |
return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
return ""
|
40 |
-
# lowercase then normalize a few British/var variants → American/neutral
|
41 |
-
repl = {
|
42 |
-
"summarise": "summarize",
|
43 |
-
"colour": "color",
|
44 |
-
"favour": "favor",
|
45 |
-
"centre": "center",
|
46 |
-
}
|
47 |
-
t = text.lower()
|
48 |
-
for k, v in repl.items():
|
49 |
-
t = t.replace(k, v)
|
50 |
-
return t
|
51 |
-
|
52 |
-
def call_model(prompt: str,
|
53 |
-
temperature: float = 0.0,
|
54 |
-
max_new_tokens: int = 3,
|
55 |
-
top_p: float = 1.0,
|
56 |
-
repetition_penalty: float = 1.0,
|
57 |
-
stop=None) -> str:
|
58 |
-
"""
|
59 |
-
Use text_generation for models that support it.
|
60 |
-
We keep decoding deterministic and tiny (single-token classification).
|
61 |
-
"""
|
62 |
-
if stop is None:
|
63 |
-
stop = ["\n"]
|
64 |
-
out = client.text_generation(
|
65 |
-
prompt, # positional first arg
|
66 |
-
model=MODEL_ID,
|
67 |
-
temperature=temperature,
|
68 |
-
max_new_tokens=max_new_tokens,
|
69 |
-
top_p=top_p,
|
70 |
-
repetition_penalty=repetition_penalty,
|
71 |
-
do_sample=False, # deterministic
|
72 |
-
stop=stop,
|
73 |
-
details=False # return plain str
|
74 |
-
)
|
75 |
-
return (out or "").strip()
|
76 |
-
|
77 |
def extract_category(text: str) -> str:
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
# Grab the first allowed token if it appears anywhere
|
82 |
-
for token in ALLOWED:
|
83 |
-
if token in raw:
|
84 |
-
return token
|
85 |
-
# or just the first word (some models emit "category: health_wellness")
|
86 |
first = raw.split()[0].strip(",.;:|") if raw else ""
|
87 |
return first if first in ALLOWED else "other_query"
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def generate_text():
|
91 |
if not API_KEY:
|
92 |
log("DECISION_ERR", reason="missing_api_key")
|
93 |
return jsonify({"error": "Missing API_KEY"}), 400
|
94 |
-
if client is None:
|
95 |
-
log("DECISION_ERR", reason="client_not_initialized")
|
96 |
-
return jsonify({"error": "Client not initialized"}), 500
|
97 |
|
98 |
data = request.get_json(silent=True) or {}
|
99 |
prompt = (data.get("prompt") or "").strip()
|
100 |
instructions = (data.get("instructions") or "").strip()
|
101 |
-
|
102 |
if not prompt or not instructions:
|
103 |
log("DECISION_BAD_REQ", has_prompt=bool(prompt), has_instructions=bool(instructions))
|
104 |
return jsonify({"error": "Missing required fields"}), 400
|
105 |
|
|
|
|
|
106 |
try:
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
model=MODEL_ID,
|
112 |
-
prompt_len=len(norm_prompt),
|
113 |
-
instr_len=len(instructions))
|
114 |
-
|
115 |
-
raw = call_model(
|
116 |
-
formatted,
|
117 |
-
temperature=0.0,
|
118 |
-
max_new_tokens=3, # single token is enough
|
119 |
-
top_p=1.0,
|
120 |
-
repetition_penalty=1.0,
|
121 |
-
stop=["\n"]
|
122 |
-
)
|
123 |
-
|
124 |
-
token = extract_category(raw)
|
125 |
-
|
126 |
-
log("DECISION_OK", raw=raw.replace("\n", "\\n"), token=token)
|
127 |
-
return jsonify({"response": token}), 200
|
128 |
-
|
129 |
except Exception as e:
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
if __name__ == "__main__":
|
135 |
-
# Use Flask’s dev server; in prod use gunicorn/uvicorn with a WSGI/ASGI wrapper as appropriate.
|
136 |
port = int(os.getenv("PORT", 7860))
|
137 |
-
log("BOOT", model=MODEL_ID, port=port,
|
138 |
app.run(host="0.0.0.0", port=port)
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
import json
|
4 |
+
import re
|
5 |
+
import time
|
6 |
import traceback
|
7 |
+
import requests
|
8 |
from flask import Flask, request, jsonify
|
9 |
from huggingface_hub import InferenceClient
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
|
13 |
+
# --- Config ---
|
14 |
+
MODEL_ID = os.getenv("MODEL_ID", "mistralai/Mixtral-8x7B-Instruct-v0.1").strip()
|
15 |
+
API_KEY = os.getenv("API_KEY", "").strip()
|
16 |
+
# If you created a private Inference Endpoint, put its full URL here:
|
17 |
+
# e.g. https://xxxxxxxxx-abcdefg.hf.space or https://xxxx-yyy.endpoints.huggingface.cloud
|
18 |
+
DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip() # optional but recommended
|
19 |
+
TIMEOUT = 25
|
20 |
|
21 |
ALLOWED = {
|
22 |
+
"health_wellness","spiritual_guidance","generate_image","realtime_query","other_query"
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
|
25 |
def log(msg, **kv):
|
26 |
+
parts = [msg] + [f"{k}={v}" for k,v in kv.items()]
|
|
|
27 |
print(" | ".join(parts), file=sys.stderr, flush=True)
|
28 |
|
29 |
+
# --- Prompt formatting ---
|
30 |
+
def format_prompt(user_message: str, instructions: str = "") -> str:
|
31 |
+
sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
|
|
|
|
|
32 |
return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
|
33 |
|
34 |
+
# --- Extractor ---
|
35 |
+
_token_re = re.compile(r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b", re.I)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def extract_category(text: str) -> str:
|
37 |
+
raw = (text or "").strip().lower()
|
38 |
+
m = _token_re.search(raw)
|
39 |
+
if m: return m.group(1)
|
|
|
|
|
|
|
|
|
|
|
40 |
first = raw.split()[0].strip(",.;:|") if raw else ""
|
41 |
return first if first in ALLOWED else "other_query"
|
42 |
|
43 |
+
# --- Conversational REST call (works even if client lacks .conversational) ---
|
44 |
+
def hf_conversational(prompt: str) -> str:
|
45 |
+
url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
|
46 |
+
headers = {
|
47 |
+
"Authorization": f"Bearer {API_KEY}",
|
48 |
+
"Accept": "application/json",
|
49 |
+
"Content-Type": "application/json",
|
50 |
+
}
|
51 |
+
payload = {
|
52 |
+
"inputs": {
|
53 |
+
"past_user_inputs": [],
|
54 |
+
"generated_responses": [],
|
55 |
+
"text": prompt,
|
56 |
+
},
|
57 |
+
"parameters": {
|
58 |
+
"max_new_tokens": 3,
|
59 |
+
"temperature": 0.0,
|
60 |
+
"top_p": 1.0,
|
61 |
+
"repetition_penalty": 1.0,
|
62 |
+
"stop": ["\n"],
|
63 |
+
"return_full_text": False,
|
64 |
+
},
|
65 |
+
"options": {
|
66 |
+
"use_cache": True,
|
67 |
+
"wait_for_model": True
|
68 |
+
}
|
69 |
+
}
|
70 |
+
for attempt in range(3):
|
71 |
+
r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
|
72 |
+
if r.status_code == 503:
|
73 |
+
time.sleep(2 + attempt)
|
74 |
+
continue
|
75 |
+
r.raise_for_status()
|
76 |
+
data = r.json()
|
77 |
+
|
78 |
+
# Common shapes:
|
79 |
+
if isinstance(data, dict) and "generated_text" in data:
|
80 |
+
return str(data["generated_text"]).strip()
|
81 |
+
if isinstance(data, dict) and "choices" in data and data["choices"]:
|
82 |
+
ch = data["choices"][0]
|
83 |
+
txt = ch.get("text") or ch.get("message", {}).get("content")
|
84 |
+
if txt: return str(txt).strip()
|
85 |
+
if isinstance(data, list) and data and isinstance(data[0], dict):
|
86 |
+
if "generated_text" in data[0]:
|
87 |
+
return str(data[0]["generated_text"]).strip()
|
88 |
+
if "generated_responses" in data[0]:
|
89 |
+
gresps = data[0]["generated_responses"]
|
90 |
+
if isinstance(gresps, list) and gresps:
|
91 |
+
return str(gresps[-1]).strip()
|
92 |
+
|
93 |
+
return str(data).strip()
|
94 |
+
return ""
|
95 |
+
|
96 |
+
# --- Text-generation (use if provider supports it) ---
|
97 |
+
def try_text_generation(client: InferenceClient, formatted: str) -> str:
|
98 |
+
return client.text_generation(
|
99 |
+
formatted,
|
100 |
+
model=MODEL_ID,
|
101 |
+
temperature=0.0,
|
102 |
+
max_new_tokens=3,
|
103 |
+
top_p=1.0,
|
104 |
+
repetition_penalty=1.0,
|
105 |
+
do_sample=False,
|
106 |
+
stop=["\n"],
|
107 |
+
details=False
|
108 |
+
)
|
109 |
+
|
110 |
+
@app.post("/generate_text")
|
111 |
def generate_text():
|
112 |
if not API_KEY:
|
113 |
log("DECISION_ERR", reason="missing_api_key")
|
114 |
return jsonify({"error": "Missing API_KEY"}), 400
|
|
|
|
|
|
|
115 |
|
116 |
data = request.get_json(silent=True) or {}
|
117 |
prompt = (data.get("prompt") or "").strip()
|
118 |
instructions = (data.get("instructions") or "").strip()
|
|
|
119 |
if not prompt or not instructions:
|
120 |
log("DECISION_BAD_REQ", has_prompt=bool(prompt), has_instructions=bool(instructions))
|
121 |
return jsonify({"error": "Missing required fields"}), 400
|
122 |
|
123 |
+
formatted = format_prompt(prompt, instructions)
|
124 |
+
raw = ""
|
125 |
try:
|
126 |
+
# First, try text-generation path
|
127 |
+
client = InferenceClient(token=API_KEY)
|
128 |
+
log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
|
129 |
+
raw = try_text_generation(client, formatted)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
except Exception as e:
|
131 |
+
# If provider says "Supported task: conversational", fallback to REST conversational
|
132 |
+
msg = str(e)
|
133 |
+
log("DECISION_TG_FAIL", error=msg)
|
134 |
+
log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
|
135 |
+
try:
|
136 |
+
raw = hf_conversational(formatted)
|
137 |
+
except Exception as e2:
|
138 |
+
trace = traceback.format_exc().replace("\n", "\\n")
|
139 |
+
log("DECISION_CONV_FAIL", error=str(e2), trace=trace)
|
140 |
+
return jsonify({"response": "other_query", "error": str(e2)}), 200
|
141 |
+
|
142 |
+
token = extract_category(raw)
|
143 |
+
log("DECISION_OK", raw=raw.replace("\n", "\\n"), token=token)
|
144 |
+
return jsonify({"response": token}), 200
|
145 |
|
146 |
if __name__ == "__main__":
|
|
|
147 |
port = int(os.getenv("PORT", 7860))
|
148 |
+
log("BOOT", model=MODEL_ID, port=port, endpoint=DECISION_ENDPOINT or "api-inference")
|
149 |
app.run(host="0.0.0.0", port=port)
|