Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uvicorn
|
3 |
+
from flask import Flask, request, jsonify
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
|
6 |
+
app = Flask(__name__)
|
7 |
+
|
8 |
+
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
9 |
+
API_KEY = os.getenv("API_KEY")
|
10 |
+
client = InferenceClient(token=API_KEY)
|
11 |
+
|
12 |
+
ALLOWED = {
|
13 |
+
"health_wellness",
|
14 |
+
"spiritual_guidance",
|
15 |
+
"generate_image",
|
16 |
+
"realtime_query",
|
17 |
+
"other_query",
|
18 |
+
}
|
19 |
+
|
20 |
+
def format_prompt(message, custom_instructions=None):
|
21 |
+
p = ""
|
22 |
+
if custom_instructions:
|
23 |
+
p += f"[INST] {custom_instructions} [/INST]"
|
24 |
+
p += f"[INST] {message} [/INST]"
|
25 |
+
return p
|
26 |
+
|
27 |
+
@app.post("/generate_text")
|
28 |
+
def generate_text():
|
29 |
+
data = request.json or {}
|
30 |
+
prompt = data.get("prompt", "")
|
31 |
+
instructions = data.get("instructions", "")
|
32 |
+
if not (prompt and instructions and API_KEY):
|
33 |
+
return jsonify({"error": "Missing required fields"}), 400
|
34 |
+
|
35 |
+
try:
|
36 |
+
formatted = format_prompt(prompt, instructions)
|
37 |
+
out = client.text_generation(
|
38 |
+
formatted, model=MODEL_ID,
|
39 |
+
temperature=0.2, max_new_tokens=8, top_p=0.9, repetition_penalty=1.05,
|
40 |
+
do_sample=True, seed=69
|
41 |
+
)
|
42 |
+
token = (out or "").strip().lower()
|
43 |
+
# Keep only first word-like token
|
44 |
+
token = token.split()[0].strip(",.;:|") if token else ""
|
45 |
+
if token not in ALLOWED:
|
46 |
+
token = "other_query"
|
47 |
+
return jsonify({"response": token})
|
48 |
+
except Exception as e:
|
49 |
+
return jsonify({"response": "other_query"}), 200
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|