Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,29 +10,29 @@ from huggingface_hub import InferenceClient
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
|
13 |
-
#
|
14 |
-
MODEL_ID = os.getenv("MODEL_ID", "
|
15 |
API_KEY = os.getenv("API_KEY", "").strip()
|
16 |
-
#
|
17 |
-
|
18 |
-
DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip() # optional but recommended
|
19 |
TIMEOUT = 25
|
20 |
|
21 |
ALLOWED = {
|
22 |
-
"health_wellness","spiritual_guidance","generate_image","realtime_query","other_query"
|
23 |
}
|
24 |
|
25 |
def log(msg, **kv):
|
26 |
-
|
27 |
-
|
28 |
|
29 |
-
# --- Prompt formatting ---
|
30 |
def format_prompt(user_message: str, instructions: str = "") -> str:
|
31 |
sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
|
32 |
return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
def extract_category(text: str) -> str:
|
37 |
raw = (text or "").strip().lower()
|
38 |
m = _token_re.search(raw)
|
@@ -40,8 +40,8 @@ def extract_category(text: str) -> str:
|
|
40 |
first = raw.split()[0].strip(",.;:|") if raw else ""
|
41 |
return first if first in ALLOWED else "other_query"
|
42 |
|
43 |
-
# --- Conversational REST call (works even if client lacks .conversational) ---
|
44 |
def hf_conversational(prompt: str) -> str:
|
|
|
45 |
url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
|
46 |
headers = {
|
47 |
"Authorization": f"Bearer {API_KEY}",
|
@@ -62,10 +62,7 @@ def hf_conversational(prompt: str) -> str:
|
|
62 |
"stop": ["\n"],
|
63 |
"return_full_text": False,
|
64 |
},
|
65 |
-
"options": {
|
66 |
-
"use_cache": True,
|
67 |
-
"wait_for_model": True
|
68 |
-
}
|
69 |
}
|
70 |
for attempt in range(3):
|
71 |
r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
|
@@ -74,8 +71,6 @@ def hf_conversational(prompt: str) -> str:
|
|
74 |
continue
|
75 |
r.raise_for_status()
|
76 |
data = r.json()
|
77 |
-
|
78 |
-
# Common shapes:
|
79 |
if isinstance(data, dict) and "generated_text" in data:
|
80 |
return str(data["generated_text"]).strip()
|
81 |
if isinstance(data, dict) and "choices" in data and data["choices"]:
|
@@ -89,11 +84,9 @@ def hf_conversational(prompt: str) -> str:
|
|
89 |
gresps = data[0]["generated_responses"]
|
90 |
if isinstance(gresps, list) and gresps:
|
91 |
return str(gresps[-1]).strip()
|
92 |
-
|
93 |
return str(data).strip()
|
94 |
return ""
|
95 |
|
96 |
-
# --- Text-generation (use if provider supports it) ---
|
97 |
def try_text_generation(client: InferenceClient, formatted: str) -> str:
|
98 |
return client.text_generation(
|
99 |
formatted,
|
@@ -107,6 +100,10 @@ def try_text_generation(client: InferenceClient, formatted: str) -> str:
|
|
107 |
details=False
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
110 |
@app.post("/generate_text")
|
111 |
def generate_text():
|
112 |
if not API_KEY:
|
@@ -123,14 +120,11 @@ def generate_text():
|
|
123 |
formatted = format_prompt(prompt, instructions)
|
124 |
raw = ""
|
125 |
try:
|
126 |
-
# First, try text-generation path
|
127 |
client = InferenceClient(token=API_KEY)
|
128 |
log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
|
129 |
raw = try_text_generation(client, formatted)
|
130 |
except Exception as e:
|
131 |
-
|
132 |
-
msg = str(e)
|
133 |
-
log("DECISION_TG_FAIL", error=msg)
|
134 |
log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
|
135 |
try:
|
136 |
raw = hf_conversational(formatted)
|
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
|
13 |
+
# ---------- Config ----------
|
14 |
+
MODEL_ID = os.getenv("MODEL_ID", "HuggingFaceH4/zephyr-7b-beta").strip()
|
15 |
API_KEY = os.getenv("API_KEY", "").strip()
|
16 |
+
# Optional: your private Inference Endpoint URL (recommended for Mixtral)
|
17 |
+
DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip()
|
|
|
18 |
TIMEOUT = 25
|
19 |
|
20 |
ALLOWED = {
|
21 |
+
"health_wellness", "spiritual_guidance", "generate_image", "realtime_query", "other_query"
|
22 |
}
|
23 |
|
24 |
def log(msg, **kv):
|
25 |
+
print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]),
|
26 |
+
file=sys.stderr, flush=True)
|
27 |
|
|
|
28 |
def format_prompt(user_message: str, instructions: str = "") -> str:
|
29 |
sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
|
30 |
return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
|
31 |
|
32 |
+
_token_re = re.compile(
|
33 |
+
r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b",
|
34 |
+
re.I
|
35 |
+
)
|
36 |
def extract_category(text: str) -> str:
|
37 |
raw = (text or "").strip().lower()
|
38 |
m = _token_re.search(raw)
|
|
|
40 |
first = raw.split()[0].strip(",.;:|") if raw else ""
|
41 |
return first if first in ALLOWED else "other_query"
|
42 |
|
|
|
43 |
def hf_conversational(prompt: str) -> str:
|
44 |
+
"""Call conversational endpoint (public API or your private DECISION_ENDPOINT)."""
|
45 |
url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
|
46 |
headers = {
|
47 |
"Authorization": f"Bearer {API_KEY}",
|
|
|
62 |
"stop": ["\n"],
|
63 |
"return_full_text": False,
|
64 |
},
|
65 |
+
"options": {"use_cache": True, "wait_for_model": True},
|
|
|
|
|
|
|
66 |
}
|
67 |
for attempt in range(3):
|
68 |
r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
|
|
|
71 |
continue
|
72 |
r.raise_for_status()
|
73 |
data = r.json()
|
|
|
|
|
74 |
if isinstance(data, dict) and "generated_text" in data:
|
75 |
return str(data["generated_text"]).strip()
|
76 |
if isinstance(data, dict) and "choices" in data and data["choices"]:
|
|
|
84 |
gresps = data[0]["generated_responses"]
|
85 |
if isinstance(gresps, list) and gresps:
|
86 |
return str(gresps[-1]).strip()
|
|
|
87 |
return str(data).strip()
|
88 |
return ""
|
89 |
|
|
|
90 |
def try_text_generation(client: InferenceClient, formatted: str) -> str:
|
91 |
return client.text_generation(
|
92 |
formatted,
|
|
|
100 |
details=False
|
101 |
)
|
102 |
|
103 |
+
@app.get("/")
|
104 |
+
def root():
|
105 |
+
return jsonify({"ok": True, "model": MODEL_ID})
|
106 |
+
|
107 |
@app.post("/generate_text")
|
108 |
def generate_text():
|
109 |
if not API_KEY:
|
|
|
120 |
formatted = format_prompt(prompt, instructions)
|
121 |
raw = ""
|
122 |
try:
|
|
|
123 |
client = InferenceClient(token=API_KEY)
|
124 |
log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
|
125 |
raw = try_text_generation(client, formatted)
|
126 |
except Exception as e:
|
127 |
+
log("DECISION_TG_FAIL", error=str(e))
|
|
|
|
|
128 |
log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
|
129 |
try:
|
130 |
raw = hf_conversational(formatted)
|