Abaryan commited on
Commit
dc3747b
·
verified ·
1 Parent(s): 30ca71a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -255
app.py CHANGED
@@ -1,263 +1,53 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- import os
7
- from datasets import load_dataset
8
- import random
9
- from typing import Optional, List, Tuple, Union
10
- import gradio as gr
11
- from contextlib import asynccontextmanager
12
-
13
- # Global variables
14
- model = None
15
- tokenizer = None
16
- dataset = None
17
-
18
- @asynccontextmanager
19
- async def lifespan(app: FastAPI):
20
- # Startup: Load the model
21
- global model, tokenizer, dataset
22
- try:
23
- # Load your fine-tuned model and tokenizer
24
- model_name = os.getenv("MODEL_NAME", "rgb2gbr/BioXP-0.5B-MedMCQA")
25
- model = AutoModelForCausalLM.from_pretrained(model_name)
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
27
-
28
- # Load MedMCQA dataset
29
- dataset = load_dataset("openlifescienceai/medmcqa")
30
-
31
- # Move model to GPU if available
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- model = model.to(device)
34
- model.eval()
35
- except Exception as e:
36
- print(f"Error loading model: {str(e)}")
37
- raise e
38
-
39
- yield # This is where FastAPI serves the application
40
-
41
- # Shutdown: Clean up resources if needed
42
- if model is not None:
43
- del model
44
- if tokenizer is not None:
45
- del tokenizer
46
- if dataset is not None:
47
- del dataset
48
- torch.cuda.empty_cache()
49
-
50
- app = FastAPI(lifespan=lifespan)
51
-
52
- # Add CORS middleware for Gradio
53
- app.add_middleware(
54
- CORSMiddleware,
55
- allow_origins=["*"],
56
- allow_credentials=True,
57
- allow_methods=["*"],
58
- allow_headers=["*"],
59
- )
60
-
61
- # Define input models
62
- class QuestionRequest(BaseModel):
63
- question: str
64
- options: list[str] # List of 4 options
65
-
66
- class DatasetQuestion(BaseModel):
67
- question: str
68
- opa: str
69
- opb: str
70
- opc: str
71
- opd: str
72
- cop: Optional[int] = None # Correct option (0-3)
73
- exp: Optional[str] = None # Explanation if available
74
-
75
- def format_prompt(question: str, options: List[str]) -> str:
76
- """Format the prompt for the model"""
77
- prompt = f"Question: {question}\n\nOptions:\n"
78
- for i, opt in enumerate(options):
79
- prompt += f"{chr(65+i)}. {opt}\n"
80
- prompt += "\nAnswer:"
81
- return prompt
82
-
83
- def get_question(index: Optional[int] = None, random_question: bool = False, format: str = "api") -> Union[DatasetQuestion, Tuple[str, str, str, str, str]]:
84
- """
85
- Get a question from the dataset.
86
- Args:
87
- index: Optional question index
88
- random_question: Whether to get a random question
89
- format: 'api' for DatasetQuestion object, 'gradio' for tuple
90
- """
91
- if dataset is None:
92
- raise Exception("Dataset not loaded")
93
-
94
- if random_question:
95
- index = random.randint(0, len(dataset['train']) - 1)
96
- elif index is None:
97
- raise ValueError("Either index or random_question must be provided")
98
-
99
- question_data = dataset['train'][index]
100
-
101
- if format == "gradio":
102
- return (
103
- question_data['question'],
104
- question_data['opa'],
105
- question_data['opb'],
106
- question_data['opc'],
107
- question_data['opd']
108
- )
109
-
110
- return DatasetQuestion(
111
- question=question_data['question'],
112
- opa=question_data['opa'],
113
- opb=question_data['opb'],
114
- opc=question_data['opc'],
115
- opd=question_data['opd'],
116
- cop=question_data['cop'] if 'cop' in question_data else None,
117
- exp=question_data['exp'] if 'exp' in question_data else None
118
- )
119
 
120
- def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
121
- """Gradio interface prediction function"""
122
- try:
123
- options = [option_a, option_b, option_c, option_d]
124
-
125
- # Format the prompt
126
- prompt = format_prompt(question, options)
127
-
128
- # Tokenize the input
129
- inputs = tokenizer(
130
- prompt,
131
- return_tensors="pt",
132
- padding=True,
133
- truncation=True,
134
- max_length=512
 
 
 
 
 
 
 
 
 
 
135
  )
136
-
137
- device = next(model.parameters()).device
138
- inputs = {k: v.to(device) for k, v in inputs.items()}
139
-
140
- # Generate prediction
141
- with torch.no_grad():
142
- outputs = model.generate(
143
- **inputs,
144
- max_new_tokens=10,
145
- num_return_sequences=1,
146
- temperature=0.7,
147
- do_sample=False,
148
- pad_token_id=tokenizer.eos_token_id
149
- )
150
-
151
- # Decode the output
152
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
153
-
154
- # Extract the answer from the prediction
155
- answer = prediction.split("Answer:")[-1].strip()
156
-
157
- # Format the output for Gradio
158
- result = f"Model Output:\n{prediction}\n\n"
159
- result += f"Extracted Answer: {answer}"
160
-
161
- return result
162
 
163
- except Exception as e:
164
- return f"Error: {str(e)}"
 
165
 
166
  # Create Gradio interface
167
- with gr.Blocks(title="Medical MCQ Predictor") as demo:
168
- gr.Markdown("# Medical MCQ Predictor")
169
- gr.Markdown("Enter a medical question and its options, or get a random question from MedMCQA dataset.")
170
-
171
- with gr.Row():
172
- with gr.Column():
173
- question = gr.Textbox(label="Question", lines=3)
174
- option_a = gr.Textbox(label="Option A")
175
- option_b = gr.Textbox(label="Option B")
176
- option_c = gr.Textbox(label="Option C")
177
- option_d = gr.Textbox(label="Option D")
178
-
179
- with gr.Row():
180
- predict_btn = gr.Button("Predict")
181
- random_btn = gr.Button("Get Random Question")
182
-
183
- output = gr.Textbox(label="Prediction", lines=5)
184
-
185
- predict_btn.click(
186
- fn=predict_gradio,
187
- inputs=[question, option_a, option_b, option_c, option_d],
188
- outputs=output
189
- )
190
-
191
- random_btn.click(
192
- fn=lambda: get_question(random_question=True, format="gradio"),
193
- inputs=[],
194
- outputs=[question, option_a, option_b, option_c, option_d]
195
- )
196
-
197
- # Mount Gradio app to FastAPI
198
- app = gr.mount_gradio_app(app, demo, path="/")
199
-
200
- @app.get("/dataset/question")
201
- async def get_dataset_question(index: Optional[int] = None, random_question: bool = False):
202
- """Get a question from the MedMCQA dataset"""
203
- try:
204
- return get_question(index=index, random_question=random_question)
205
- except Exception as e:
206
- raise HTTPException(status_code=500, detail=str(e))
207
-
208
- @app.post("/predict")
209
- async def predict(request: QuestionRequest):
210
- if len(request.options) != 4:
211
- raise HTTPException(status_code=400, detail="Exactly 4 options are required")
212
-
213
- try:
214
- # Format the prompt
215
- prompt = format_prompt(request.question, request.options)
216
-
217
- # Tokenize the input
218
- inputs = tokenizer(
219
- prompt,
220
- return_tensors="pt",
221
- padding=True,
222
- truncation=True,
223
- max_length=512
224
- )
225
-
226
- device = next(model.parameters()).device
227
- inputs = {k: v.to(device) for k, v in inputs.items()}
228
-
229
- # Generate prediction
230
- with torch.no_grad():
231
- outputs = model.generate(
232
- **inputs,
233
- max_new_tokens=10,
234
- num_return_sequences=1,
235
- temperature=0.7,
236
- do_sample=False,
237
- pad_token_id=tokenizer.eos_token_id
238
- )
239
-
240
- # Decode the output
241
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
242
-
243
- # Extract the answer from the prediction
244
- answer = prediction.split("Answer:")[-1].strip()
245
-
246
- response = {
247
- "model_output": prediction,
248
- "extracted_answer": answer,
249
- "full_response": prediction
250
- }
251
-
252
- return response
253
-
254
- except Exception as e:
255
- raise HTTPException(status_code=500, detail=str(e))
256
 
257
- @app.get("/health")
258
- async def health_check():
259
- return {
260
- "status": "healthy",
261
- "model_loaded": model is not None,
262
- "dataset_loaded": dataset is not None
263
- }
 
1
+ import gradio as gr
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Load model and tokenizer
6
+ model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
7
+ model = AutoModelForCausalLM.from_pretrained(model_name)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+
10
+ # Move model to GPU if available
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = model.to(device)
13
+ model.eval()
14
+
15
+ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
16
+ # Format the prompt
17
+ prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
18
+
19
+ # Tokenize and generate
20
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
21
+ inputs = {k: v.to(device) for k, v in inputs.items()}
22
+
23
+ with torch.no_grad():
24
+ outputs = model.generate(
25
+ **inputs,
26
+ max_new_tokens=10,
27
+ temperature=0.7,
28
+ do_sample=False,
29
+ pad_token_id=tokenizer.eos_token_id
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Get prediction
33
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ return prediction
35
 
36
  # Create Gradio interface
37
+ demo = gr.Interface(
38
+ fn=predict,
39
+ inputs=[
40
+ gr.Textbox(label="Question", lines=3),
41
+ gr.Textbox(label="Option A"),
42
+ gr.Textbox(label="Option B"),
43
+ gr.Textbox(label="Option C"),
44
+ gr.Textbox(label="Option D")
45
+ ],
46
+ outputs=gr.Textbox(label="Model's Answer", lines=5),
47
+ title="Medical MCQ Predictor",
48
+ description="Enter a medical question and its options to get the model's prediction."
49
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Launch the app
52
+ if __name__ == "__main__":
53
+ demo.launch()