from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoModelForMultipleChoice, AutoTokenizer import os from datasets import load_dataset import random from typing import Optional, List import gradio as gr app = FastAPI() # Add CORS middleware for Gradio app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define input models class QuestionRequest(BaseModel): question: str options: list[str] # List of 4 options class DatasetQuestion(BaseModel): question: str opa: str opb: str opc: str opd: str cop: Optional[int] = None # Correct option (0-3) exp: Optional[str] = None # Explanation if available # Global variables model = None tokenizer = None dataset = None def load_model(): global model, tokenizer, dataset try: # Load your fine-tuned model and tokenizer model_name = os.getenv("BioXP-0.5b", "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B") model = AutoModelForMultipleChoice.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load MedMCQA dataset dataset = load_dataset("openlifescienceai/medmcqa") # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() except Exception as e: raise Exception(f"Error loading model: {str(e)}") def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str): """Gradio interface prediction function""" try: options = [option_a, option_b, option_c, option_d] inputs = [] for option in options: text = f"{question} {option}" inputs.append(text) encodings = tokenizer( inputs, padding=True, truncation=True, max_length=512, return_tensors="pt" ) device = next(model.parameters()).device encodings = {k: v.to(device) for k, v in encodings.items()} with torch.no_grad(): outputs = model(**encodings) logits = outputs.logits probabilities = torch.softmax(logits, dim=1)[0].tolist() predicted_class = torch.argmax(logits, dim=1).item() # Format the output for Gradio result = f"Predicted Answer: {options[predicted_class]}\n\n" result += "Confidence Scores:\n" for i, (opt, prob) in enumerate(zip(options, probabilities)): result += f"{opt}: {prob:.2%}\n" return result except Exception as e: return f"Error: {str(e)}" def get_random_question(): """Get a random question for Gradio interface""" if dataset is None: return "Error: Dataset not loaded", "", "", "", "" index = random.randint(0, len(dataset['train']) - 1) question_data = dataset['train'][index] return ( question_data['question'], question_data['opa'], question_data['opb'], question_data['opc'], question_data['opd'] ) # Create Gradio interface with gr.Blocks(title="Medical MCQ Predictor") as demo: gr.Markdown("# Medical MCQ Predictor") gr.Markdown("Enter a medical question and its options, or get a random question from MedMCQA dataset.") with gr.Row(): with gr.Column(): question = gr.Textbox(label="Question", lines=3) option_a = gr.Textbox(label="Option A") option_b = gr.Textbox(label="Option B") option_c = gr.Textbox(label="Option C") option_d = gr.Textbox(label="Option D") with gr.Row(): predict_btn = gr.Button("Predict") random_btn = gr.Button("Get Random Question") output = gr.Textbox(label="Prediction", lines=5) predict_btn.click( fn=predict_gradio, inputs=[question, option_a, option_b, option_c, option_d], outputs=output ) random_btn.click( fn=get_random_question, inputs=[], outputs=[question, option_a, option_b, option_c, option_d] ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, demo, path="/") @app.on_event("startup") async def startup_event(): load_model() @app.get("/dataset/question") async def get_dataset_question(index: Optional[int] = None, random_question: bool = False): """Get a question from the MedMCQA dataset""" try: if dataset is None: raise HTTPException(status_code=500, detail="Dataset not loaded") if random_question: index = random.randint(0, len(dataset['train']) - 1) elif index is None: raise HTTPException(status_code=400, detail="Either index or random_question must be provided") question_data = dataset['train'][index] question = DatasetQuestion( question=question_data['question'], opa=question_data['opa'], opb=question_data['opb'], opc=question_data['opc'], opd=question_data['opd'], cop=question_data['cop'] if 'cop' in question_data else None, exp=question_data['exp'] if 'exp' in question_data else None ) return question except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict") async def predict(request: QuestionRequest): if len(request.options) != 4: raise HTTPException(status_code=400, detail="Exactly 4 options are required") try: inputs = [] for option in request.options: text = f"{request.question} {option}" inputs.append(text) encodings = tokenizer( inputs, padding=True, truncation=True, max_length=512, return_tensors="pt" ) device = next(model.parameters()).device encodings = {k: v.to(device) for k, v in encodings.items()} with torch.no_grad(): outputs = model(**encodings) logits = outputs.logits probabilities = torch.softmax(logits, dim=1)[0].tolist() predicted_class = torch.argmax(logits, dim=1).item() response = { "predicted_option": request.options[predicted_class], "option_index": predicted_class, "confidence": probabilities[predicted_class], "probabilities": { f"option_{i}": prob for i, prob in enumerate(probabilities) } } return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return { "status": "healthy", "model_loaded": model is not None, "dataset_loaded": dataset is not None }