TheVera commited on
Commit
16db54b
·
verified ·
1 Parent(s): d2b7fba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -10,29 +10,29 @@ from huggingface_hub import InferenceClient
10
 
11
  app = Flask(__name__)
12
 
13
- # --- Config ---
14
- MODEL_ID = os.getenv("MODEL_ID", "mistralai/Mixtral-8x7B-Instruct-v0.1").strip()
15
  API_KEY = os.getenv("API_KEY", "").strip()
16
- # If you created a private Inference Endpoint, put its full URL here:
17
- # e.g. https://xxxxxxxxx-abcdefg.hf.space or https://xxxx-yyy.endpoints.huggingface.cloud
18
- DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip() # optional but recommended
19
  TIMEOUT = 25
20
 
21
  ALLOWED = {
22
- "health_wellness","spiritual_guidance","generate_image","realtime_query","other_query"
23
  }
24
 
25
  def log(msg, **kv):
26
- parts = [msg] + [f"{k}={v}" for k,v in kv.items()]
27
- print(" | ".join(parts), file=sys.stderr, flush=True)
28
 
29
- # --- Prompt formatting ---
30
  def format_prompt(user_message: str, instructions: str = "") -> str:
31
  sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
32
  return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
33
 
34
- # --- Extractor ---
35
- _token_re = re.compile(r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b", re.I)
 
 
36
  def extract_category(text: str) -> str:
37
  raw = (text or "").strip().lower()
38
  m = _token_re.search(raw)
@@ -40,8 +40,8 @@ def extract_category(text: str) -> str:
40
  first = raw.split()[0].strip(",.;:|") if raw else ""
41
  return first if first in ALLOWED else "other_query"
42
 
43
- # --- Conversational REST call (works even if client lacks .conversational) ---
44
  def hf_conversational(prompt: str) -> str:
 
45
  url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
46
  headers = {
47
  "Authorization": f"Bearer {API_KEY}",
@@ -62,10 +62,7 @@ def hf_conversational(prompt: str) -> str:
62
  "stop": ["\n"],
63
  "return_full_text": False,
64
  },
65
- "options": {
66
- "use_cache": True,
67
- "wait_for_model": True
68
- }
69
  }
70
  for attempt in range(3):
71
  r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
@@ -74,8 +71,6 @@ def hf_conversational(prompt: str) -> str:
74
  continue
75
  r.raise_for_status()
76
  data = r.json()
77
-
78
- # Common shapes:
79
  if isinstance(data, dict) and "generated_text" in data:
80
  return str(data["generated_text"]).strip()
81
  if isinstance(data, dict) and "choices" in data and data["choices"]:
@@ -89,11 +84,9 @@ def hf_conversational(prompt: str) -> str:
89
  gresps = data[0]["generated_responses"]
90
  if isinstance(gresps, list) and gresps:
91
  return str(gresps[-1]).strip()
92
-
93
  return str(data).strip()
94
  return ""
95
 
96
- # --- Text-generation (use if provider supports it) ---
97
  def try_text_generation(client: InferenceClient, formatted: str) -> str:
98
  return client.text_generation(
99
  formatted,
@@ -107,6 +100,10 @@ def try_text_generation(client: InferenceClient, formatted: str) -> str:
107
  details=False
108
  )
109
 
 
 
 
 
110
  @app.post("/generate_text")
111
  def generate_text():
112
  if not API_KEY:
@@ -123,14 +120,11 @@ def generate_text():
123
  formatted = format_prompt(prompt, instructions)
124
  raw = ""
125
  try:
126
- # First, try text-generation path
127
  client = InferenceClient(token=API_KEY)
128
  log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
129
  raw = try_text_generation(client, formatted)
130
  except Exception as e:
131
- # If provider says "Supported task: conversational", fallback to REST conversational
132
- msg = str(e)
133
- log("DECISION_TG_FAIL", error=msg)
134
  log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
135
  try:
136
  raw = hf_conversational(formatted)
 
10
 
11
  app = Flask(__name__)
12
 
13
+ # ---------- Config ----------
14
+ MODEL_ID = os.getenv("MODEL_ID", "HuggingFaceH4/zephyr-7b-beta").strip()
15
  API_KEY = os.getenv("API_KEY", "").strip()
16
+ # Optional: your private Inference Endpoint URL (recommended for Mixtral)
17
+ DECISION_ENDPOINT = os.getenv("DECISION_ENDPOINT", "").strip()
 
18
  TIMEOUT = 25
19
 
20
  ALLOWED = {
21
+ "health_wellness", "spiritual_guidance", "generate_image", "realtime_query", "other_query"
22
  }
23
 
24
  def log(msg, **kv):
25
+ print(" | ".join([msg] + [f"{k}={v}" for k, v in kv.items()]),
26
+ file=sys.stderr, flush=True)
27
 
 
28
  def format_prompt(user_message: str, instructions: str = "") -> str:
29
  sys_block = f"<<SYS>>{instructions}\nReturn EXACTLY one token from the list above. No quotes, no punctuation, no extra words.<<SYS>>" if instructions else ""
30
  return f"[INST] {sys_block}\nUser: {user_message}\nAssistant: [/INST]"
31
 
32
+ _token_re = re.compile(
33
+ r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b",
34
+ re.I
35
+ )
36
  def extract_category(text: str) -> str:
37
  raw = (text or "").strip().lower()
38
  m = _token_re.search(raw)
 
40
  first = raw.split()[0].strip(",.;:|") if raw else ""
41
  return first if first in ALLOWED else "other_query"
42
 
 
43
  def hf_conversational(prompt: str) -> str:
44
+ """Call conversational endpoint (public API or your private DECISION_ENDPOINT)."""
45
  url = DECISION_ENDPOINT or f"https://api-inference.huggingface.co/models/{MODEL_ID}"
46
  headers = {
47
  "Authorization": f"Bearer {API_KEY}",
 
62
  "stop": ["\n"],
63
  "return_full_text": False,
64
  },
65
+ "options": {"use_cache": True, "wait_for_model": True},
 
 
 
66
  }
67
  for attempt in range(3):
68
  r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
 
71
  continue
72
  r.raise_for_status()
73
  data = r.json()
 
 
74
  if isinstance(data, dict) and "generated_text" in data:
75
  return str(data["generated_text"]).strip()
76
  if isinstance(data, dict) and "choices" in data and data["choices"]:
 
84
  gresps = data[0]["generated_responses"]
85
  if isinstance(gresps, list) and gresps:
86
  return str(gresps[-1]).strip()
 
87
  return str(data).strip()
88
  return ""
89
 
 
90
  def try_text_generation(client: InferenceClient, formatted: str) -> str:
91
  return client.text_generation(
92
  formatted,
 
100
  details=False
101
  )
102
 
103
+ @app.get("/")
104
+ def root():
105
+ return jsonify({"ok": True, "model": MODEL_ID})
106
+
107
  @app.post("/generate_text")
108
  def generate_text():
109
  if not API_KEY:
 
120
  formatted = format_prompt(prompt, instructions)
121
  raw = ""
122
  try:
 
123
  client = InferenceClient(token=API_KEY)
124
  log("DECISION_CALL_TG", model=MODEL_ID, endpoint="hf_hub_text_generation")
125
  raw = try_text_generation(client, formatted)
126
  except Exception as e:
127
+ log("DECISION_TG_FAIL", error=str(e))
 
 
128
  log("DECISION_CALL_CONV", model=MODEL_ID, endpoint=(DECISION_ENDPOINT or "api-inference"))
129
  try:
130
  raw = hf_conversational(formatted)