Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoModelForMultipleChoice, AutoTokenizer | |
import os | |
from datasets import load_dataset | |
import random | |
from typing import Optional, List | |
import gradio as gr | |
app = FastAPI() | |
# Add CORS middleware for Gradio | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define input models | |
class QuestionRequest(BaseModel): | |
question: str | |
options: list[str] # List of 4 options | |
class DatasetQuestion(BaseModel): | |
question: str | |
opa: str | |
opb: str | |
opc: str | |
opd: str | |
cop: Optional[int] = None # Correct option (0-3) | |
exp: Optional[str] = None # Explanation if available | |
# Global variables | |
model = None | |
tokenizer = None | |
dataset = None | |
def load_model(): | |
global model, tokenizer, dataset | |
try: | |
# Load your fine-tuned model and tokenizer | |
model_name = os.getenv("BioXP-0.5b", "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B") | |
model = AutoModelForMultipleChoice.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load MedMCQA dataset | |
dataset = load_dataset("openlifescienceai/medmcqa") | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
except Exception as e: | |
raise Exception(f"Error loading model: {str(e)}") | |
def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str): | |
"""Gradio interface prediction function""" | |
try: | |
options = [option_a, option_b, option_c, option_d] | |
inputs = [] | |
for option in options: | |
text = f"{question} {option}" | |
inputs.append(text) | |
encodings = tokenizer( | |
inputs, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
) | |
device = next(model.parameters()).device | |
encodings = {k: v.to(device) for k, v in encodings.items()} | |
with torch.no_grad(): | |
outputs = model(**encodings) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1)[0].tolist() | |
predicted_class = torch.argmax(logits, dim=1).item() | |
# Format the output for Gradio | |
result = f"Predicted Answer: {options[predicted_class]}\n\n" | |
result += "Confidence Scores:\n" | |
for i, (opt, prob) in enumerate(zip(options, probabilities)): | |
result += f"{opt}: {prob:.2%}\n" | |
return result | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def get_random_question(): | |
"""Get a random question for Gradio interface""" | |
if dataset is None: | |
return "Error: Dataset not loaded", "", "", "", "" | |
index = random.randint(0, len(dataset['train']) - 1) | |
question_data = dataset['train'][index] | |
return ( | |
question_data['question'], | |
question_data['opa'], | |
question_data['opb'], | |
question_data['opc'], | |
question_data['opd'] | |
) | |
# Create Gradio interface | |
with gr.Blocks(title="Medical MCQ Predictor") as demo: | |
gr.Markdown("# Medical MCQ Predictor") | |
gr.Markdown("Enter a medical question and its options, or get a random question from MedMCQA dataset.") | |
with gr.Row(): | |
with gr.Column(): | |
question = gr.Textbox(label="Question", lines=3) | |
option_a = gr.Textbox(label="Option A") | |
option_b = gr.Textbox(label="Option B") | |
option_c = gr.Textbox(label="Option C") | |
option_d = gr.Textbox(label="Option D") | |
with gr.Row(): | |
predict_btn = gr.Button("Predict") | |
random_btn = gr.Button("Get Random Question") | |
output = gr.Textbox(label="Prediction", lines=5) | |
predict_btn.click( | |
fn=predict_gradio, | |
inputs=[question, option_a, option_b, option_c, option_d], | |
outputs=output | |
) | |
random_btn.click( | |
fn=get_random_question, | |
inputs=[], | |
outputs=[question, option_a, option_b, option_c, option_d] | |
) | |
# Mount Gradio app to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
async def startup_event(): | |
load_model() | |
async def get_dataset_question(index: Optional[int] = None, random_question: bool = False): | |
"""Get a question from the MedMCQA dataset""" | |
try: | |
if dataset is None: | |
raise HTTPException(status_code=500, detail="Dataset not loaded") | |
if random_question: | |
index = random.randint(0, len(dataset['train']) - 1) | |
elif index is None: | |
raise HTTPException(status_code=400, detail="Either index or random_question must be provided") | |
question_data = dataset['train'][index] | |
question = DatasetQuestion( | |
question=question_data['question'], | |
opa=question_data['opa'], | |
opb=question_data['opb'], | |
opc=question_data['opc'], | |
opd=question_data['opd'], | |
cop=question_data['cop'] if 'cop' in question_data else None, | |
exp=question_data['exp'] if 'exp' in question_data else None | |
) | |
return question | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def predict(request: QuestionRequest): | |
if len(request.options) != 4: | |
raise HTTPException(status_code=400, detail="Exactly 4 options are required") | |
try: | |
inputs = [] | |
for option in request.options: | |
text = f"{request.question} {option}" | |
inputs.append(text) | |
encodings = tokenizer( | |
inputs, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
) | |
device = next(model.parameters()).device | |
encodings = {k: v.to(device) for k, v in encodings.items()} | |
with torch.no_grad(): | |
outputs = model(**encodings) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1)[0].tolist() | |
predicted_class = torch.argmax(logits, dim=1).item() | |
response = { | |
"predicted_option": request.options[predicted_class], | |
"option_index": predicted_class, | |
"confidence": probabilities[predicted_class], | |
"probabilities": { | |
f"option_{i}": prob for i, prob in enumerate(probabilities) | |
} | |
} | |
return response | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"dataset_loaded": dataset is not None | |
} |