Adchay commited on
Commit
54bcfd7
·
verified ·
1 Parent(s): cd8e991

update from kimi2

Browse files
Files changed (1) hide show
  1. app.py +79 -69
app.py CHANGED
@@ -1,85 +1,95 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import pipeline
4
- import mysql.connector
5
  import os
 
 
6
 
7
- os.environ['HF_HOME'] = '/tmp/huggingface'
 
 
 
 
8
 
9
- app = FastAPI()
 
 
 
 
 
10
 
11
- # Database connection settings
12
- DB_HOST = "gateway01.ap-southeast-1.prod.aws.tidbcloud.com"
13
- DB_PORT = 4000
14
- DB_USER = "4V44XYoMA7okY9v.root"
15
- DB_PASS = "aW2CrSwcTgjFhNAb"
16
- DB_NAME = "final_project"
17
 
18
- # Create MySQL connection
19
- conn = mysql.connector.connect(
20
- host=DB_HOST,
21
- port=DB_PORT,
22
- user=DB_USER,
23
- password=DB_PASS,
24
- database=DB_NAME,
25
- ssl_verify_cert=True,
26
- ssl_verify_identity=True
27
- )
28
- cursor = conn.cursor()
 
 
29
 
30
- # Load model
31
- classifier = pipeline(
32
- "zero-shot-classification",
33
- model="MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
34
- )
35
 
36
- # Labels
37
- subject_labels = [
38
- "Physics", "Chemistry", "Biology", "Astronomy",
39
- "Earth Science", "Environmental Science",
40
- "Algebra", "Geometry", "Calculus", "Statistics",
41
- "Probability", "Number Theory",
42
- "English Language", "English Literature",
43
- "Tamil Language", "Tamil Literature",
44
- "History", "Geography", "Political Science", "Economics",
45
- "Sociology", "Psychology", "Philosophy",
46
- "Computer Science", "Data Science", "Artificial Intelligence",
47
- "Robotics", "Biotechnology", "Engineering",
48
- "Fine Arts", "Music", "Dance", "Theater",
49
- "Business Studies", "Accountancy", "Entrepreneurship",
50
- "Physical Education", "Health Science"
51
- ]
52
 
53
- # Request model
54
- class TextInput(BaseModel):
55
  student_id: str
56
  text: str
57
 
 
58
  @app.post("/predict")
59
- def predict_topic(data: TextInput):
60
- # Predict subject
61
- result = classifier(
62
- data.text,
63
- candidate_labels=subject_labels,
64
- hypothesis_template="This text is about {}."
 
 
65
  )
66
- predicted_subject = result["labels"][0]
 
 
 
 
 
 
67
 
68
- # Get first 100 characters of the text
69
- sample_text = data.text[:100]
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # Save to DB
72
- cursor.execute(
73
- """
74
- INSERT INTO log_table (student_id, input_sample, subject)
75
- VALUES (%s, %s, %s)
76
- """,
77
- (data.student_id, sample_text, predicted_subject)
78
- )
79
- conn.commit()
80
 
81
- return {
82
- "student_id": data.student_id,
83
- "predicted_subject": predicted_subject,
84
- "sample_text": sample_text
85
- }
 
1
+ """
2
+ FastAPI server inside Hugging Face Space
3
+ POST /predict -> zero-shot subject prediction + save to TiDB
4
+ """
5
  import os
6
+ import time
7
+ from contextlib import asynccontextmanager
8
 
9
+ import mysql.connector
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ from fastapi import FastAPI, HTTPException
13
+ from pydantic import BaseModel
14
 
15
+ # ---------- load model ONCE ----------
16
+ MODEL_NAME = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
17
+ LABELS = [
18
+ "Mathematics", "Physics", "Chemistry", "Biology",
19
+ "History", "Geography", "Literature", "Computer-Science"
20
+ ]
21
 
22
+ ml_models = {}
 
 
 
 
 
23
 
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ # load at startup
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
29
+ model.eval()
30
+ if torch.cuda.is_available():
31
+ model.cuda()
32
+ ml_models["tokenizer"] = tokenizer
33
+ ml_models["model"] = model
34
+ yield
35
+ # shutdown
36
+ ml_models.clear()
37
 
38
+ app = FastAPI(lifespan=lifespan)
 
 
 
 
39
 
40
+ # ---------- DB helper ----------
41
+ def get_conn():
42
+ return mysql.connector.connect(
43
+ host=os.getenv("DB_HOST"),
44
+ port=int(os.getenv("DB_PORT", 4000)),
45
+ user=os.getenv("DB_USER"),
46
+ password=os.getenv("DB_PASS"),
47
+ database=os.getenv("DB_NAME"),
48
+ ssl_ca=os.getenv("DB_SSL_CA_PATH") or None
49
+ )
 
 
 
 
 
 
50
 
51
+ # ---------- request schema ----------
52
+ class PredictRequest(BaseModel):
53
  student_id: str
54
  text: str
55
 
56
+ # ---------- API endpoint ----------
57
  @app.post("/predict")
58
+ def predict(req: PredictRequest):
59
+ if not req.text.strip():
60
+ raise HTTPException(400, "Empty text")
61
+ tok = ml_models["tokenizer"](
62
+ req.text,
63
+ padding=True,
64
+ truncation=True,
65
+ return_tensors="pt"
66
  )
67
+ if torch.cuda.is_available():
68
+ tok = {k: v.cuda() for k, v in tok.items()}
69
+ with torch.no_grad():
70
+ logits = ml_models["model"](**tok).logits
71
+ probs = torch.softmax(logits, dim=-1)[0]
72
+ idx = int(torch.argmax(probs))
73
+ subject = LABELS[idx]
74
 
75
+ # save to DB
76
+ try:
77
+ conn = get_conn()
78
+ cur = conn.cursor()
79
+ cur.execute(
80
+ "INSERT INTO predictions (student_id, text, subject, created_at) "
81
+ "VALUES (%s, %s, %s, %s)",
82
+ (req.student_id, req.text, subject, time.strftime('%Y-%m-%d %H:%M:%S'))
83
+ )
84
+ conn.commit()
85
+ cur.close()
86
+ conn.close()
87
+ except Exception as e:
88
+ print("DB error:", e)
89
+ raise HTTPException(500, "DB write failed")
90
 
91
+ return {"subject": subject}
 
 
 
 
 
 
 
 
92
 
93
+ @app.get("/")
94
+ def root():
95
+ return {"message": "Subject predictor is running"}