Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
from datasets import load_dataset | |
import random | |
from typing import Optional, List, Tuple, Union | |
import gradio as gr | |
from contextlib import asynccontextmanager | |
# Global variables | |
model = None | |
tokenizer = None | |
dataset = None | |
async def lifespan(app: FastAPI): | |
# Startup: Load the model | |
global model, tokenizer, dataset | |
try: | |
# Load your fine-tuned model and tokenizer | |
model_name = os.getenv("MODEL_NAME", "rgb2gbr/BioXP-0.5B-MedMCQA") | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load MedMCQA dataset | |
dataset = load_dataset("openlifescienceai/medmcqa") | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise e | |
yield # This is where FastAPI serves the application | |
# Shutdown: Clean up resources if needed | |
if model is not None: | |
del model | |
if tokenizer is not None: | |
del tokenizer | |
if dataset is not None: | |
del dataset | |
torch.cuda.empty_cache() | |
app = FastAPI(lifespan=lifespan) | |
# Add CORS middleware for Gradio | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define input models | |
class QuestionRequest(BaseModel): | |
question: str | |
options: list[str] # List of 4 options | |
class DatasetQuestion(BaseModel): | |
question: str | |
opa: str | |
opb: str | |
opc: str | |
opd: str | |
cop: Optional[int] = None # Correct option (0-3) | |
exp: Optional[str] = None # Explanation if available | |
def format_prompt(question: str, options: List[str]) -> str: | |
"""Format the prompt for the model""" | |
prompt = f"Question: {question}\n\nOptions:\n" | |
for i, opt in enumerate(options): | |
prompt += f"{chr(65+i)}. {opt}\n" | |
prompt += "\nAnswer:" | |
return prompt | |
def get_question(index: Optional[int] = None, random_question: bool = False, format: str = "api") -> Union[DatasetQuestion, Tuple[str, str, str, str, str]]: | |
""" | |
Get a question from the dataset. | |
Args: | |
index: Optional question index | |
random_question: Whether to get a random question | |
format: 'api' for DatasetQuestion object, 'gradio' for tuple | |
""" | |
if dataset is None: | |
raise Exception("Dataset not loaded") | |
if random_question: | |
index = random.randint(0, len(dataset['train']) - 1) | |
elif index is None: | |
raise ValueError("Either index or random_question must be provided") | |
question_data = dataset['train'][index] | |
if format == "gradio": | |
return ( | |
question_data['question'], | |
question_data['opa'], | |
question_data['opb'], | |
question_data['opc'], | |
question_data['opd'] | |
) | |
return DatasetQuestion( | |
question=question_data['question'], | |
opa=question_data['opa'], | |
opb=question_data['opb'], | |
opc=question_data['opc'], | |
opd=question_data['opd'], | |
cop=question_data['cop'] if 'cop' in question_data else None, | |
exp=question_data['exp'] if 'exp' in question_data else None | |
) | |
def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str): | |
"""Gradio interface prediction function""" | |
try: | |
options = [option_a, option_b, option_c, option_d] | |
# Format the prompt | |
prompt = format_prompt(question, options) | |
# Tokenize the input | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Generate prediction | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=10, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the output | |
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract the answer from the prediction | |
answer = prediction.split("Answer:")[-1].strip() | |
# Format the output for Gradio | |
result = f"Model Output:\n{prediction}\n\n" | |
result += f"Extracted Answer: {answer}" | |
return result | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="Medical MCQ Predictor") as demo: | |
gr.Markdown("# Medical MCQ Predictor") | |
gr.Markdown("Enter a medical question and its options, or get a random question from MedMCQA dataset.") | |
with gr.Row(): | |
with gr.Column(): | |
question = gr.Textbox(label="Question", lines=3) | |
option_a = gr.Textbox(label="Option A") | |
option_b = gr.Textbox(label="Option B") | |
option_c = gr.Textbox(label="Option C") | |
option_d = gr.Textbox(label="Option D") | |
with gr.Row(): | |
predict_btn = gr.Button("Predict") | |
random_btn = gr.Button("Get Random Question") | |
output = gr.Textbox(label="Prediction", lines=5) | |
predict_btn.click( | |
fn=predict_gradio, | |
inputs=[question, option_a, option_b, option_c, option_d], | |
outputs=output | |
) | |
random_btn.click( | |
fn=lambda: get_question(random_question=True, format="gradio"), | |
inputs=[], | |
outputs=[question, option_a, option_b, option_c, option_d] | |
) | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
async def get_dataset_question(index: Optional[int] = None, random_question: bool = False): | |
"""Get a question from the MedMCQA dataset""" | |
try: | |
return get_question(index=index, random_question=random_question) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def predict(request: QuestionRequest): | |
if len(request.options) != 4: | |
raise HTTPException(status_code=400, detail="Exactly 4 options are required") | |
try: | |
# Format the prompt | |
prompt = format_prompt(request.question, request.options) | |
# Tokenize the input | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
) | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Generate prediction | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=10, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the output | |
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract the answer from the prediction | |
answer = prediction.split("Answer:")[-1].strip() | |
response = { | |
"model_output": prediction, | |
"extracted_answer": answer, | |
"full_response": prediction | |
} | |
return response | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"dataset_loaded": dataset is not None | |
} |