TheVera commited on
Commit
d2b7fba
·
verified ·
1 Parent(s): 278f2bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -98
app.py CHANGED
@@ -1,138 +1,149 @@
1
  import os
2
  import sys
3
  import json
 
 
4
  import traceback
 
5
  from flask import Flask, request, jsonify
6
  from huggingface_hub import InferenceClient
7
 
8
  app = Flask(__name__)
9
 
10
- # Prefer a model that supports text-generation
11
- MODEL_ID = os.getenv("MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.3").strip()
12
- API_KEY = os.getenv("API_KEY")
13
-
14
- # Init client once
15
- client = InferenceClient(token=API_KEY) if API_KEY else None
 
16
 
17
  ALLOWED = {
18
- "health_wellness",
19
- "spiritual_guidance",
20
- "generate_image",
21
- "realtime_query",
22
- "other_query",
23
  }
24
 
25
  def log(msg, **kv):
26
- # compact console logging
27
- parts = [msg] + [f"{k}={v}" for k, v in kv.items()]
28
  print(" | ".join(parts), file=sys.stderr, flush=True)
29
 
30
- def format_prompt(user_message: str, custom_instructions: str = "") -> str:
31
- """
32
- Single [INST] with a <<SYS>> system section works best for instruction-following.
33
- """
34
- sys_block = f"<<SYS>>{custom_instructions}<<SYS>>" if custom_instructions else ""
35
  return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
36
 
37
- def normalize_text(text: str) -> str:
38
- if not text:
39
- return ""
40
- # lowercase then normalize a few British/var variants → American/neutral
41
- repl = {
42
- "summarise": "summarize",
43
- "colour": "color",
44
- "favour": "favor",
45
- "centre": "center",
46
- }
47
- t = text.lower()
48
- for k, v in repl.items():
49
- t = t.replace(k, v)
50
- return t
51
-
52
- def call_model(prompt: str,
53
- temperature: float = 0.0,
54
- max_new_tokens: int = 3,
55
- top_p: float = 1.0,
56
- repetition_penalty: float = 1.0,
57
- stop=None) -> str:
58
- """
59
- Use text_generation for models that support it.
60
- We keep decoding deterministic and tiny (single-token classification).
61
- """
62
- if stop is None:
63
- stop = ["\n"]
64
- out = client.text_generation(
65
- prompt, # positional first arg
66
- model=MODEL_ID,
67
- temperature=temperature,
68
- max_new_tokens=max_new_tokens,
69
- top_p=top_p,
70
- repetition_penalty=repetition_penalty,
71
- do_sample=False, # deterministic
72
- stop=stop,
73
- details=False # return plain str
74
- )
75
- return (out or "").strip()
76
-
77
  def extract_category(text: str) -> str:
78
- if not text:
79
- return "other_query"
80
- raw = text.strip().lower()
81
- # Grab the first allowed token if it appears anywhere
82
- for token in ALLOWED:
83
- if token in raw:
84
- return token
85
- # or just the first word (some models emit "category: health_wellness")
86
  first = raw.split()[0].strip(",.;:|") if raw else ""
87
  return first if first in ALLOWED else "other_query"
88
 
89
- @app.route("/generate_text", methods=["POST"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def generate_text():
91
  if not API_KEY:
92
  log("DECISION_ERR", reason="missing_api_key")
93
  return jsonify({"error": "Missing API_KEY"}), 400
94
- if client is None:
95
- log("DECISION_ERR", reason="client_not_initialized")
96
- return jsonify({"error": "Client not initialized"}), 500
97
 
98
  data = request.get_json(silent=True) or {}
99
  prompt = (data.get("prompt") or "").strip()
100
  instructions = (data.get("instructions") or "").strip()
101
-
102
  if not prompt or not instructions:
103
  log("DECISION_BAD_REQ", has_prompt=bool(prompt), has_instructions=bool(instructions))
104
  return jsonify({"error": "Missing required fields"}), 400
105
 
 
 
106
  try:
107
- norm_prompt = normalize_text(prompt)
108
- formatted = format_prompt(norm_prompt, instructions)
109
-
110
- log("DECISION_CALL",
111
- model=MODEL_ID,
112
- prompt_len=len(norm_prompt),
113
- instr_len=len(instructions))
114
-
115
- raw = call_model(
116
- formatted,
117
- temperature=0.0,
118
- max_new_tokens=3, # single token is enough
119
- top_p=1.0,
120
- repetition_penalty=1.0,
121
- stop=["\n"]
122
- )
123
-
124
- token = extract_category(raw)
125
-
126
- log("DECISION_OK", raw=raw.replace("\n", "\\n"), token=token)
127
- return jsonify({"response": token}), 200
128
-
129
  except Exception as e:
130
- tb = traceback.format_exc(limit=2)
131
- log("DECISION_FAIL", error=str(e), trace=tb.replace("\n", "\\n"))
132
- return jsonify({"response": "other_query", "error": str(e)}), 200
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if __name__ == "__main__":
135
- # Use Flask’s dev server; in prod use gunicorn/uvicorn with a WSGI/ASGI wrapper as appropriate.
136
  port = int(os.getenv("PORT", 7860))
137
- log("BOOT", model=MODEL_ID, port=port, api_key_set=bool(API_KEY))
138
  app.run(host="0.0.0.0", port=port)
 
1
  import os
2
  import sys
3
  import json
4
+ import re
5
+ import time
6
  import traceback
7
+ import requests
8
  from flask import Flask, request, jsonify
9
  from huggingface_hub import InferenceClient
10
 
11
  app = Flask(__name__)
12
 
13
+ # --- Config ---
14
+ MODEL_ID = os.getenv("MODEL_ID", "mistralai/Mixtral-8x7B-Instruct-v0.1").strip()
15
+ API_KEY = os.getenv("API_KEY", "").strip()
16
+ # If you created a private Inference Endpoint, put its full URL here:
17
+ # e.g. https://xxxxxxxxx-abcdefg.hf.space or https://xxxx-yyy.endpoints.huggingface.cloud
18
+ DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip() # optional but recommended
19
+ TIMEOUT = 25
20
 
21
  ALLOWED = {
22
+ "health_wellness","spiritual_guidance","generate_image","realtime_query","other_query"
 
 
 
 
23
  }
24
 
25
  def log(msg, **kv):
26
+ parts = [msg] + [f"{k}={v}" for k,v in kv.items()]
 
27
  print(" | ".join(parts), file=sys.stderr, flush=True)
28
 
29
+ # --- Prompt formatting ---
30
+ def format_prompt(user_message: str, instructions: str = "") -> str:
31
+ sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
 
 
32
  return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
33
 
34
+ # --- Extractor ---
35
+ _token_re = re.compile(r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b", re.I)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def extract_category(text: str) -> str:
37
+ raw = (text or "").strip().lower()
38
+ m = _token_re.search(raw)
39
+ if m: return m.group(1)
 
 
 
 
 
40
  first = raw.split()[0].strip(",.;:|") if raw else ""
41
  return first if first in ALLOWED else "other_query"
42
 
43
+ # --- Conversational REST call (works even if client lacks .conversational) ---
44
+ def hf_conversational(prompt: str) -> str:
45
+ url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
46
+ headers = {
47
+ "Authorization": f"Bearer {API_KEY}",
48
+ "Accept": "application/json",
49
+ "Content-Type": "application/json",
50
+ }
51
+ payload = {
52
+ "inputs": {
53
+ "past_user_inputs": [],
54
+ "generated_responses": [],
55
+ "text": prompt,
56
+ },
57
+ "parameters": {
58
+ "max_new_tokens": 3,
59
+ "temperature": 0.0,
60
+ "top_p": 1.0,
61
+ "repetition_penalty": 1.0,
62
+ "stop": ["\n"],
63
+ "return_full_text": False,
64
+ },
65
+ "options": {
66
+ "use_cache": True,
67
+ "wait_for_model": True
68
+ }
69
+ }
70
+ for attempt in range(3):
71
+ r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
72
+ if r.status_code == 503:
73
+ time.sleep(2 + attempt)
74
+ continue
75
+ r.raise_for_status()
76
+ data = r.json()
77
+
78
+ # Common shapes:
79
+ if isinstance(data, dict) and "generated_text" in data:
80
+ return str(data["generated_text"]).strip()
81
+ if isinstance(data, dict) and "choices" in data and data["choices"]:
82
+ ch = data["choices"][0]
83
+ txt = ch.get("text") or ch.get("message", {}).get("content")
84
+ if txt: return str(txt).strip()
85
+ if isinstance(data, list) and data and isinstance(data[0], dict):
86
+ if "generated_text" in data[0]:
87
+ return str(data[0]["generated_text"]).strip()
88
+ if "generated_responses" in data[0]:
89
+ gresps = data[0]["generated_responses"]
90
+ if isinstance(gresps, list) and gresps:
91
+ return str(gresps[-1]).strip()
92
+
93
+ return str(data).strip()
94
+ return ""
95
+
96
+ # --- Text-generation (use if provider supports it) ---
97
+ def try_text_generation(client: InferenceClient, formatted: str) -> str:
98
+ return client.text_generation(
99
+ formatted,
100
+ model=MODEL_ID,
101
+ temperature=0.0,
102
+ max_new_tokens=3,
103
+ top_p=1.0,
104
+ repetition_penalty=1.0,
105
+ do_sample=False,
106
+ stop=["\n"],
107
+ details=False
108
+ )
109
+
110
+ @app.post("/generate_text")
111
  def generate_text():
112
  if not API_KEY:
113
  log("DECISION_ERR", reason="missing_api_key")
114
  return jsonify({"error": "Missing API_KEY"}), 400
 
 
 
115
 
116
  data = request.get_json(silent=True) or {}
117
  prompt = (data.get("prompt") or "").strip()
118
  instructions = (data.get("instructions") or "").strip()
 
119
  if not prompt or not instructions:
120
  log("DECISION_BAD_REQ", has_prompt=bool(prompt), has_instructions=bool(instructions))
121
  return jsonify({"error": "Missing required fields"}), 400
122
 
123
+ formatted = format_prompt(prompt, instructions)
124
+ raw = ""
125
  try:
126
+ # First, try text-generation path
127
+ client = InferenceClient(token=API_KEY)
128
+ log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
129
+ raw = try_text_generation(client, formatted)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  except Exception as e:
131
+ # If provider says "Supported task: conversational", fallback to REST conversational
132
+ msg = str(e)
133
+ log("DECISION_TG_FAIL", error=msg)
134
+ log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
135
+ try:
136
+ raw = hf_conversational(formatted)
137
+ except Exception as e2:
138
+ trace = traceback.format_exc().replace("\n", "\\n")
139
+ log("DECISION_CONV_FAIL", error=str(e2), trace=trace)
140
+ return jsonify({"response": "other_query", "error": str(e2)}), 200
141
+
142
+ token = extract_category(raw)
143
+ log("DECISION_OK", raw=raw.replace("\n", "\\n"), token=token)
144
+ return jsonify({"response": token}), 200
145
 
146
  if __name__ == "__main__":
 
147
  port = int(os.getenv("PORT", 7860))
148
+ log("BOOT", model=MODEL_ID, port=port, endpoint=DECISION_ENDPOINT or "api-inference")
149
  app.run(host="0.0.0.0", port=port)