TheVera commited on
Commit
369b62f
·
verified ·
1 Parent(s): 6f4563e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -112
app.py CHANGED
@@ -1,143 +1,81 @@
1
  import os
2
  import sys
3
- import json
4
- import re
5
- import time
6
- import traceback
7
- import requests
8
  from flask import Flask, request, jsonify
9
  from huggingface_hub import InferenceClient
10
 
11
  app = Flask(__name__)
12
 
13
- # ---------- Config ----------
14
- MODEL_ID = os.getenv("MODEL_ID", "HuggingFaceH4/zephyr-7b-beta").strip()
15
- API_KEY = os.getenv("API_KEY", "").strip()
16
- # Optional: your private Inference Endpoint URL (recommended for Mixtral)
17
- DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip()
18
- TIMEOUT = 25
19
 
20
- ALLOWED = {
21
- "health_wellness", "spiritual_guidance", "generate_image", "realtime_query", "other_query"
22
- }
 
 
 
 
 
23
 
24
  def log(msg, **kv):
25
- print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]),
26
- file=sys.stderr, flush=True)
27
 
28
- def format_prompt(user_message: str, instructions: str = "") -> str:
29
- sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
30
- return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
31
-
32
- _token_re = re.compile(
33
- r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b",
34
- re.I
35
- )
36
- def extract_category(text: str) -> str:
37
- raw = (text or "").strip().lower()
38
- m = _token_re.search(raw)
39
- if m: return m.group(1)
40
- first = raw.split()[0].strip(",.;:|") if raw else ""
41
- return first if first in ALLOWED else "other_query"
42
-
43
- def hf_conversational(prompt: str) -> str:
44
- """Call conversational endpoint (public API or your private DECISION_ENDPOINT)."""
45
- url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
46
- headers = {
47
- "Authorization": f"Bearer {API_KEY}",
48
- "Accept": "application/json",
49
- "Content-Type": "application/json",
50
- }
51
- payload = {
52
- "inputs": {
53
- "past_user_inputs": [],
54
- "generated_responses": [],
55
- "text": prompt,
56
- },
57
- "parameters": {
58
- "max_new_tokens": 3,
59
- "temperature": 0.0,
60
- "top_p": 1.0,
61
- "repetition_penalty": 1.0,
62
- "stop": ["\n"],
63
- "return_full_text": False,
64
- },
65
- "options": {"use_cache": True, "wait_for_model": True},
66
- }
67
- for attempt in range(3):
68
- r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
69
- if r.status_code == 503:
70
- time.sleep(2 + attempt)
71
- continue
72
- r.raise_for_status()
73
- data = r.json()
74
- if isinstance(data, dict) and "generated_text" in data:
75
- return str(data["generated_text"]).strip()
76
- if isinstance(data, dict) and "choices" in data and data["choices"]:
77
- ch = data["choices"][0]
78
- txt = ch.get("text") or ch.get("message", {}).get("content")
79
- if txt: return str(txt).strip()
80
- if isinstance(data, list) and data and isinstance(data[0], dict):
81
- if "generated_text" in data[0]:
82
- return str(data[0]["generated_text"]).strip()
83
- if "generated_responses" in data[0]:
84
- gresps = data[0]["generated_responses"]
85
- if isinstance(gresps, list) and gresps:
86
- return str(gresps[-1]).strip()
87
- return str(data).strip()
88
- return ""
89
-
90
- def try_text_generation(client: InferenceClient, formatted: str) -> str:
91
- return client.text_generation(
92
- formatted,
93
- model=MODEL_ID,
94
- temperature=0.0,
95
- max_new_tokens=3,
96
- top_p=1.0,
97
- repetition_penalty=1.0,
98
- do_sample=False,
99
- stop=["\n"],
100
- details=False
101
- )
102
 
103
  @app.get("/")
104
  def root():
105
- return jsonify({"ok": True, "model": MODEL_ID})
106
 
107
  @app.post("/generate_text")
108
  def generate_text():
109
  if not API_KEY:
110
  log("DECISION_ERR", reason="missing_api_key")
111
  return jsonify({"error": "Missing API_KEY"}), 400
 
 
 
112
 
113
  data = request.get_json(silent=True) or {}
114
  prompt = (data.get("prompt") or "").strip()
115
- instructions = (data.get("instructions") or "").strip()
116
- if not prompt or not instructions:
117
- log("DECISION_BAD_REQ", has_prompt=bool(prompt), has_instructions=bool(instructions))
 
118
  return jsonify({"error": "Missing required fields"}), 400
119
 
120
- formatted = format_prompt(prompt, instructions)
121
- raw = ""
 
 
 
122
  try:
123
- client = InferenceClient(token=API_KEY)
124
- log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
125
- raw = try_text_generation(client, formatted)
126
- except Exception as e:
127
- log("DECISION_TG_FAIL", error=str(e))
128
- log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
129
- try:
130
- raw = hf_conversational(formatted)
131
- except Exception as e2:
132
- trace = traceback.format_exc().replace("\n", "\\n")
133
- log("DECISION_CONV_FAIL", error=str(e2), trace=trace)
134
- return jsonify({"response": "other_query", "error": str(e2)}), 200
 
 
135
 
136
- token = extract_category(raw)
137
- log("DECISION_OK", raw=raw.replace("\n", "\\n"), token=token)
138
- return jsonify({"response": token}), 200
 
 
 
 
139
 
140
  if __name__ == "__main__":
141
  port = int(os.getenv("PORT", 7860))
142
- log("BOOT", model=MODEL_ID, port=port, endpoint=DECISION_ENDPOINT or "api-inference")
143
  app.run(host="0.0.0.0", port=port)
 
1
  import os
2
  import sys
 
 
 
 
 
3
  from flask import Flask, request, jsonify
4
  from huggingface_hub import InferenceClient
5
 
6
  app = Flask(__name__)
7
 
8
+ API_KEY = (os.getenv("API_KEY") or "").strip()
9
+ # Multilingual zero-shot model (handles Hindi + English well)
10
+ ZSL_MODEL_ID = os.getenv("ZSL_MODEL_ID", "joeddav/xlm-roberta-large-xnli").strip()
 
 
 
11
 
12
+ LABELS = [
13
+ "health_wellness",
14
+ "spiritual_guidance",
15
+ "generate_image",
16
+ "realtime_query",
17
+ "other_query",
18
+ ]
19
+ ALLOWED = set(LABELS)
20
 
21
  def log(msg, **kv):
22
+ print(" | ".join([msg] + [f"{k}={v}"]) for k, v in kv.items())
23
+ print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]), file=sys.stderr, flush=True)
24
 
25
+ # Init HF client once
26
+ client = InferenceClient(token=API_KEY) if API_KEY else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @app.get("/")
29
  def root():
30
+ return jsonify({"ok": True, "model": ZSL_MODEL_ID})
31
 
32
  @app.post("/generate_text")
33
  def generate_text():
34
  if not API_KEY:
35
  log("DECISION_ERR", reason="missing_api_key")
36
  return jsonify({"error": "Missing API_KEY"}), 400
37
+ if client is None:
38
+ log("DECISION_ERR", reason="client_not_initialized")
39
+ return jsonify({"error": "Client not initialized"}), 500
40
 
41
  data = request.get_json(silent=True) or {}
42
  prompt = (data.get("prompt") or "").strip()
43
+ instructions = (data.get("instructions") or "").strip() # not required here
44
+
45
+ if not prompt:
46
+ log("DECISION_BAD_REQ", has_prompt=False)
47
  return jsonify({"error": "Missing required fields"}), 400
48
 
49
+ # Fast-path: explicit image command
50
+ if prompt.startswith("/image "):
51
+ log("DECISION_FAST", token="generate_image")
52
+ return jsonify({"response": "generate_image"}), 200
53
+
54
  try:
55
+ # Zero-shot classification
56
+ log("DECISION_CALL_ZSL", model=ZSL_MODEL_ID, prompt_len=len(prompt))
57
+ zs = client.zero_shot_classification(
58
+ prompt,
59
+ LABELS,
60
+ model=ZSL_MODEL_ID,
61
+ hypothesis_template="This text is about {}.",
62
+ multi_label=False, # pick the best single label
63
+ )
64
+ # Response shape: {'labels': [...], 'scores': [...], 'sequence': '...'}
65
+ labels = zs.get("labels") or []
66
+ scores = zs.get("scores") or []
67
+ best = labels[0] if labels else "other_query"
68
+ score = float(scores[0]) if scores else 0.0
69
 
70
+ token = best if best in ALLOWED else "other_query"
71
+ log("DECISION_OK", token=token, top_label=best, score=round(score, 4))
72
+ return jsonify({"response": token}), 200
73
+
74
+ except Exception as e:
75
+ log("DECISION_FAIL", error=str(e))
76
+ return jsonify({"response": "other_query", "error": str(e)}), 200
77
 
78
  if __name__ == "__main__":
79
  port = int(os.getenv("PORT", 7860))
80
+ log("BOOT", port=port, zsl_model=ZSL_MODEL_ID, api_key_set=bool(API_KEY))
81
  app.run(host="0.0.0.0", port=port)