import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset import random import re # Load model and tokenizer # model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B" model_name = "rgb2gbr/BioXP-0.5B-MedMCQA" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load dataset dataset = load_dataset("openlifescienceai/medmcqa") # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() def get_random_question(): """Get a random question from the dataset""" index = random.randint(0, len(dataset['train']) - 1) question_data = dataset['train'][index] return ( question_data['question'], question_data['opa'], question_data['opb'], question_data['opc'], question_data['opd'], question_data.get('cop', None), # Correct option (0-3) question_data.get('exp', None) # Explanation ) def extract_answer(prediction: str) -> tuple: """Extract answer and reasoning from model output""" # Try to find the answer part answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE) answer = answer_match.group(1).upper() if answer_match else "Not found" # Try to find reasoning part reasoning = "" if "Reasoning:" in prediction: reasoning = prediction.split("Reasoning:")[-1].strip() elif "Explanation:" in prediction: reasoning = prediction.split("Explanation:")[-1].strip() return answer, reasoning def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, correct_option: int = None, explanation: str = None, temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 20): # Format the prompt prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:" # Tokenize and generate inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, # pad_token_id=tokenizer.eos_token_id ) # Get prediction prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) model_answer, model_reasoning = extract_answer(prediction) # Format output with evaluation if available output = prediction if correct_option is not None: correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D is_correct = model_answer == correct_letter output += f"\n\n---\nEvaluation:\n" output += f"Correct Answer: {correct_letter}\n" output += f"Model's Answer: {model_answer}\n" output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" if explanation: output += f"\nExpert Explanation:\n{explanation}" return output # Create Gradio interface with Blocks for more control with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo: gr.Markdown("# Medical-QA (MedMCQA) Predictor") gr.Markdown("Get a random medical question or enter your own question and options.") with gr.Row(): with gr.Column(): # Input fields question = gr.Textbox(label="Question", lines=3, interactive=True) option_a = gr.Textbox(label="Option A", interactive=True) option_b = gr.Textbox(label="Option B", interactive=True) option_c = gr.Textbox(label="Option C", interactive=True) option_d = gr.Textbox(label="Option D", interactive=True) # Generation parameters with gr.Accordion("Generation Parameters", open=False): temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Temperature", info="Higher values make output more random, lower values more focused" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P", info="Higher values allow more diverse tokens, lower values more focused" ) max_tokens = gr.Slider( minimum=10, maximum=512, value=20, step=32, label="Max Tokens", info="Maximum length of the generated response" ) # Hidden fields for correct answer and explanation correct_option = gr.Number(visible=False) expert_explanation = gr.Textbox(visible=False) # Buttons with gr.Row(): predict_btn = gr.Button("Predict", variant="primary") random_btn = gr.Button("Get Random Question", variant="secondary") # Output output = gr.Textbox(label="Model's Answer", lines=10) # Set up button actions predict_btn.click( fn=predict, inputs=[ question, option_a, option_b, option_c, option_d, correct_option, expert_explanation, temperature, top_p, max_tokens ], outputs=output ) random_btn.click( fn=get_random_question, inputs=[], outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] ) # Launch the app if __name__ == "__main__": demo.launch()