TheVera commited on
Commit
8e98672
·
verified ·
1 Parent(s): f9faf91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -127
app.py CHANGED
@@ -1,143 +1,76 @@
1
  import os
2
- import re
3
- import sys
4
- import time
5
- import json
6
- from typing import Any, Dict
7
- import requests
8
  from flask import Flask, request, jsonify
 
9
 
10
  app = Flask(__name__)
11
 
 
12
  MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
13
- API_KEY = os.getenv("API_KEY")
14
- HF_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
15
- TIMEOUT = 25 # seconds
16
-
17
- ALLOWED = {
18
- "health_wellness",
19
- "spiritual_guidance",
20
- "generate_image",
21
- "realtime_query",
22
- "other_query",
23
- }
24
-
25
- _token_re = re.compile(
26
- r"\b(health_wellness|spiritual_guidance|generate_image|realtime_query|other_query)\b",
27
- re.I
28
- )
29
-
30
- def format_prompt(user_text: str, instructions: str) -> str:
31
- # One [INST] block with a <<SYS>> section is reliable for Mixtral
32
- system = (
33
- f"<<SYS>>{instructions}\n"
34
- f"Return EXACTLY one token from the list above. No quotes, no punctuation, no extra words."
35
- f"<<SYS>>"
36
- )
37
- return f"[INST] {system}\nUser: {user_text}\nAssistant: [/INST]"
38
-
39
- def extract_category(text: str) -> str:
40
- m = _token_re.search((text or "").strip().lower())
41
- return m.group(1) if m else "other_query"
42
 
43
- def hf_conversational_call(prompt: str) -> str:
44
- """
45
- Call HF Inference API using the 'conversational' task payload.
46
- Handles model warmup (503) and different response shapes.
47
- """
48
- headers = {
49
- "Authorization": f"Bearer {API_KEY}",
50
- "Accept": "application/json",
51
- "Content-Type": "application/json",
52
- }
53
 
54
- payload: Dict[str, Any] = {
55
- "inputs": {
56
- # Conversational schema — we pass a single-turn prompt
57
- "past_user_inputs": [],
58
- "generated_responses": [],
59
- "text": prompt,
60
- },
61
- "parameters": {
62
- "max_new_tokens": 3, # just enough to emit one category token
63
- "temperature": 0.0,
64
- "top_p": 1.0,
65
- "repetition_penalty": 1.0,
66
- "stop": ["\n"], # cut at first newline if it tries to add more
67
- # Some backends use 'return_full_text'; harmless if ignored
68
- "return_full_text": False,
69
- },
70
- "options": {
71
- "use_cache": True,
72
- "wait_for_model": True, # block until the model is loaded
73
- },
74
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # minimal retry on cold start
77
- for attempt in range(3):
78
- r = requests.post(HF_URL, headers=headers, data=json.dumps(payload), timeout=TIMEOUT)
79
- if r.status_code == 503:
80
- # model loading — wait and retry
81
- time.sleep(2 + attempt)
82
- continue
83
- r.raise_for_status()
84
- data = r.json()
85
-
86
- # Response can be a dict or a list (legacy). Try common shapes:
87
- # 1) dict with 'generated_text'
88
- if isinstance(data, dict) and "generated_text" in data:
89
- return str(data["generated_text"]).strip()
90
- # 2) dict with 'choices' (some backends)
91
- if isinstance(data, dict) and "choices" in data and data["choices"]:
92
- # choices[i].text or .message?.content
93
- ch = data["choices"][0]
94
- txt = ch.get("text") or ch.get("message", {}).get("content")
95
- if txt:
96
- return str(txt).strip()
97
- # 3) list with first item having 'generated_text'
98
- if isinstance(data, list) and data and isinstance(data[0], dict):
99
- if "generated_text" in data[0]:
100
- return str(data[0]["generated_text"]).strip()
101
- # sometimes 'conversation' shaped
102
- if "generated_responses" in data[0]:
103
- gresps = data[0]["generated_responses"]
104
- if isinstance(gresps, list) and gresps:
105
- return str(gresps[-1]).strip()
106
-
107
- # fallback: stringify
108
- return str(data).strip()
109
-
110
- # If all retries hit 503, give up gracefully
111
- return ""
112
-
113
- @app.post("/generate_text")
114
  def generate_text():
115
- data = request.get_json(silent=True) or {}
116
- prompt = (data.get("prompt") or "").strip()
117
- instructions = (data.get("instructions") or "").strip()
 
118
 
119
- if not API_KEY:
120
- return jsonify({"error": "Missing API_KEY"}), 400
121
- if not prompt or not instructions:
122
  return jsonify({"error": "Missing required fields"}), 400
123
 
124
- try:
125
- formatted = format_prompt(prompt, instructions)
126
- raw = hf_conversational_call(formatted)
127
- token = extract_category(raw)
128
-
129
- print("RAW_DECISION:", repr(raw), "->", token, file=sys.stderr, flush=True)
130
-
131
- if token not in ALLOWED:
132
- token = "other_query"
133
- return jsonify({"response": token})
134
-
135
- except requests.HTTPError as he:
136
- print("DECISION_HTTP_ERROR:", repr(he), file=sys.stderr, flush=True)
137
- return jsonify({"response": "other_query", "error": str(he)}), 200
138
- except Exception as e:
139
- print("DECISION_ERROR:", repr(e), file=sys.stderr, flush=True)
140
- return jsonify({"response": "other_query", "error": str(e)}), 200
141
 
142
  if __name__ == "__main__":
143
- app.run(host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
 
1
  import os
2
+ import uvicorn
 
 
 
 
 
3
  from flask import Flask, request, jsonify
4
+ from huggingface_hub import InferenceClient
5
 
6
  app = Flask(__name__)
7
 
8
+ # Fixed API URL
9
  MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Use environment variable for the API key
12
+ API_KEY = os.getenv("API_KEY")
 
 
 
 
 
 
 
 
13
 
14
+ def format_prompt(message, custom_instructions=None):
15
+ prompt = ""
16
+ if custom_instructions:
17
+ prompt += f"[INST] {custom_instructions} [/INST]"
18
+ prompt += f"[INST] {message} [/INST]"
19
+ return prompt
20
+
21
+ def normalize_text(text):
22
+ # Normalize text to handle different spellings
23
+ replacements = {
24
+ 'summarise': 'Summarize',
25
+ 'colour': 'Color',
26
+ 'favour': 'Favor',
27
+ 'centre': 'Center',
28
+ # Add more replacements as needed
 
 
 
 
 
29
  }
30
+
31
+ text = text.lower()
32
+ for british, american in replacements.items():
33
+ text = text.replace(british, american)
34
+ text = text.capitalize()
35
+
36
+ return text
37
+
38
+ def Mistral7B(prompt, instructions, api_key, temperature=0.2, max_new_tokens=18, top_p=0.9, repetition_penalty=1.0):
39
+ try:
40
+ temperature = max(float(temperature), 1e-2)
41
+ top_p = float(top_p)
42
+
43
+ generate_kwargs = dict(
44
+ temperature=temperature,
45
+ max_new_tokens=max_new_tokens,
46
+ top_p=top_p,
47
+ repetition_penalty=repetition_penalty,
48
+ do_sample=True,
49
+ seed=69,
50
+ )
51
+ custom_instructions = instructions
52
+ formatted_prompt = format_prompt(prompt, custom_instructions)
53
+
54
+ client = InferenceClient(token=api_key)
55
+ response = client.text_generation(formatted_prompt, model=MODEL_ID, **generate_kwargs)
56
+ return response
57
+ except Exception as e:
58
+ return str(e)
59
 
60
+ @app.route("/generate_text", methods=["POST"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def generate_text():
62
+ data = request.json
63
+ prompt = data.get("prompt")
64
+ instructions = data.get("instructions")
65
+ api_key = API_KEY # Use the API key from the environment variable
66
 
67
+ if not prompt or not instructions or not api_key:
 
 
68
  return jsonify({"error": "Missing required fields"}), 400
69
 
70
+ normalized_prompt = normalize_text(prompt)
71
+ response = Mistral7B(normalized_prompt, instructions, api_key)
72
+ print(response)
73
+ return jsonify({"response": response}), 200
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))