Spaces:
Sleeping
Sleeping
update from kimi2
Browse files
app.py
CHANGED
@@ -1,85 +1,95 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
import os
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
DB_HOST = "gateway01.ap-southeast-1.prod.aws.tidbcloud.com"
|
13 |
-
DB_PORT = 4000
|
14 |
-
DB_USER = "4V44XYoMA7okY9v.root"
|
15 |
-
DB_PASS = "aW2CrSwcTgjFhNAb"
|
16 |
-
DB_NAME = "final_project"
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
classifier = pipeline(
|
32 |
-
"zero-shot-classification",
|
33 |
-
model="MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
|
34 |
-
)
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
"Computer Science", "Data Science", "Artificial Intelligence",
|
47 |
-
"Robotics", "Biotechnology", "Engineering",
|
48 |
-
"Fine Arts", "Music", "Dance", "Theater",
|
49 |
-
"Business Studies", "Accountancy", "Entrepreneurship",
|
50 |
-
"Physical Education", "Health Science"
|
51 |
-
]
|
52 |
|
53 |
-
#
|
54 |
-
class
|
55 |
student_id: str
|
56 |
text: str
|
57 |
|
|
|
58 |
@app.post("/predict")
|
59 |
-
def
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
)
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
#
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
cursor.execute(
|
73 |
-
"""
|
74 |
-
INSERT INTO log_table (student_id, input_sample, subject)
|
75 |
-
VALUES (%s, %s, %s)
|
76 |
-
""",
|
77 |
-
(data.student_id, sample_text, predicted_subject)
|
78 |
-
)
|
79 |
-
conn.commit()
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
"sample_text": sample_text
|
85 |
-
}
|
|
|
1 |
+
"""
|
2 |
+
FastAPI server inside Hugging Face Space
|
3 |
+
POST /predict -> zero-shot subject prediction + save to TiDB
|
4 |
+
"""
|
5 |
import os
|
6 |
+
import time
|
7 |
+
from contextlib import asynccontextmanager
|
8 |
|
9 |
+
import mysql.connector
|
10 |
+
import torch
|
11 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
12 |
+
from fastapi import FastAPI, HTTPException
|
13 |
+
from pydantic import BaseModel
|
14 |
|
15 |
+
# ---------- load model ONCE ----------
|
16 |
+
MODEL_NAME = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
|
17 |
+
LABELS = [
|
18 |
+
"Mathematics", "Physics", "Chemistry", "Biology",
|
19 |
+
"History", "Geography", "Literature", "Computer-Science"
|
20 |
+
]
|
21 |
|
22 |
+
ml_models = {}
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
@asynccontextmanager
|
25 |
+
async def lifespan(app: FastAPI):
|
26 |
+
# load at startup
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
28 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
|
29 |
+
model.eval()
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
model.cuda()
|
32 |
+
ml_models["tokenizer"] = tokenizer
|
33 |
+
ml_models["model"] = model
|
34 |
+
yield
|
35 |
+
# shutdown
|
36 |
+
ml_models.clear()
|
37 |
|
38 |
+
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
# ---------- DB helper ----------
|
41 |
+
def get_conn():
|
42 |
+
return mysql.connector.connect(
|
43 |
+
host=os.getenv("DB_HOST"),
|
44 |
+
port=int(os.getenv("DB_PORT", 4000)),
|
45 |
+
user=os.getenv("DB_USER"),
|
46 |
+
password=os.getenv("DB_PASS"),
|
47 |
+
database=os.getenv("DB_NAME"),
|
48 |
+
ssl_ca=os.getenv("DB_SSL_CA_PATH") or None
|
49 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
# ---------- request schema ----------
|
52 |
+
class PredictRequest(BaseModel):
|
53 |
student_id: str
|
54 |
text: str
|
55 |
|
56 |
+
# ---------- API endpoint ----------
|
57 |
@app.post("/predict")
|
58 |
+
def predict(req: PredictRequest):
|
59 |
+
if not req.text.strip():
|
60 |
+
raise HTTPException(400, "Empty text")
|
61 |
+
tok = ml_models["tokenizer"](
|
62 |
+
req.text,
|
63 |
+
padding=True,
|
64 |
+
truncation=True,
|
65 |
+
return_tensors="pt"
|
66 |
)
|
67 |
+
if torch.cuda.is_available():
|
68 |
+
tok = {k: v.cuda() for k, v in tok.items()}
|
69 |
+
with torch.no_grad():
|
70 |
+
logits = ml_models["model"](**tok).logits
|
71 |
+
probs = torch.softmax(logits, dim=-1)[0]
|
72 |
+
idx = int(torch.argmax(probs))
|
73 |
+
subject = LABELS[idx]
|
74 |
|
75 |
+
# save to DB
|
76 |
+
try:
|
77 |
+
conn = get_conn()
|
78 |
+
cur = conn.cursor()
|
79 |
+
cur.execute(
|
80 |
+
"INSERT INTO predictions (student_id, text, subject, created_at) "
|
81 |
+
"VALUES (%s, %s, %s, %s)",
|
82 |
+
(req.student_id, req.text, subject, time.strftime('%Y-%m-%d %H:%M:%S'))
|
83 |
+
)
|
84 |
+
conn.commit()
|
85 |
+
cur.close()
|
86 |
+
conn.close()
|
87 |
+
except Exception as e:
|
88 |
+
print("DB error:", e)
|
89 |
+
raise HTTPException(500, "DB write failed")
|
90 |
|
91 |
+
return {"subject": subject}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
@app.get("/")
|
94 |
+
def root():
|
95 |
+
return {"message": "Subject predictor is running"}
|
|
|
|