import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset import random import re # Load model and tokenizer model_name = "abaryan/BioXP-0.5B-MedMCQA" SYSTEM_PROMPT = """ You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 200 words without repeating. Respond in the following format: [correct answer] [explain why the selected option is correct] """ model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load dataset dataset = load_dataset("openlifescienceai/medmcqa") # Move model to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() def get_random_question(): """Get a random question from the dataset""" index = random.randint(0, len(dataset['validation']) - 1) question_data = dataset['validation'][index] return ( question_data['question'], question_data['opa'], question_data['opb'], question_data['opc'], question_data['opd'], question_data.get('cop', None), # Correct option (0-3) question_data.get('exp', None) # Explanation ) def predict(question: str, option_a: str = "", option_b: str = "", option_c: str = "", option_d: str = "", correct_option: int = None, explanation: str = None, temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256): # Determine if this is an MCQ by checking if any option is provided # Only treat as MCQ if at least one option is non-empty is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d]) if is_mcq: # Format MCQ question with only non-empty options options = [] if option_a.strip(): options.append(f"A. {option_a}") if option_b.strip(): options.append(f"B. {option_b}") if option_c.strip(): options.append(f"C. {option_c}") if option_d.strip(): options.append(f"D. {option_d}") formatted_question = f"Question: {question}\n\nOptions:\n" + "\n".join(options) system_prompt = SYSTEM_PROMPT else: # Format regular question formatted_question = f"Question: {question}" system_prompt = SYSTEM_PROMPT # Create chat-style prompt prompt = [ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': formatted_question} ] # Use apply_chat_template for better formatting text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) # Tokenize and generate model_inputs = tokenizer([text], return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **model_inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, ) # Get only the generated response generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:] model_response = tokenizer.decode(generated_ids, skip_special_tokens=True) # Clean up the response by removing tags and formatting cleaned_response = model_response cleaned_response = re.sub(r'\s*([A-D])\s*', r'Answer: \1', cleaned_response, flags=re.IGNORECASE) cleaned_response = re.sub(r'\s*(.*?)\s*', r'Reasoning:\n\1', cleaned_response, flags=re.IGNORECASE | re.DOTALL) # Format output with evaluation if available (only for MCQs) output = cleaned_response if is_mcq and correct_option is not None: correct_letter = chr(65 + correct_option) answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE) model_answer = answer_match.group(1).upper() if answer_match else "Not found" is_correct = model_answer == correct_letter output += f"\n\n---\nEvaluation:\n" output += f"Correct Answer: {correct_letter}\n" output += f"Model's Answer: {model_answer}\n" output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" if explanation: output += f"\nExpert Explanation:\n{explanation}" return output # Create Gradio interface with mobile-optimized design with gr.Blocks( title="BioXP Medical MCQ Assistant", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="blue", neutral_hue="slate", radius_size="md", font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], ) ) as demo: gr.Markdown(""" # BioXP Medical MCQ Assistant A specialized AI assistant for medical multiple-choice questions. """) with gr.Row(): with gr.Column(scale=1): # Input fields with mobile-friendly spacing question = gr.Textbox( label="Medical Question", placeholder="Enter your medical question here...", lines=3, interactive=True, elem_classes=["mobile-input"] ) # Options in a mobile-friendly accordion with gr.Accordion("Options", open=True): option_a = gr.Textbox( label="Option A", placeholder="Enter option A...", interactive=True, elem_classes=["mobile-input"] ) option_b = gr.Textbox( label="Option B", placeholder="Enter option B...", interactive=True, elem_classes=["mobile-input"] ) option_c = gr.Textbox( label="Option C", placeholder="Enter option C...", interactive=True, elem_classes=["mobile-input"] ) option_d = gr.Textbox( label="Option D", placeholder="Enter option D...", interactive=True, elem_classes=["mobile-input"] ) # Generation parameters in a collapsible section with gr.Accordion("Advanced Settings", open=False): with gr.Row(): with gr.Column(scale=1): temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.6, step=0.1, label="Temperature", info="Higher = more creative, Lower = more focused" ) with gr.Column(scale=1): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P", info="Controls response diversity" ) max_tokens = gr.Slider( minimum=50, maximum=512, value=256, step=32, label="Max Response Length", info="Maximum length of the response" ) # Hidden fields correct_option = gr.Number(visible=False) expert_explanation = gr.Textbox(visible=False) # Buttons with mobile-friendly spacing with gr.Row(): predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"]) random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"]) with gr.Column(scale=1): # Output with mobile-friendly styling output = gr.Textbox( label="Model's Response", lines=12, elem_classes=["response-box", "mobile-output"] ) # Set up button actions predict_btn.click( fn=predict, inputs=[ question, option_a, option_b, option_c, option_d, correct_option, expert_explanation, temperature, top_p, max_tokens ], outputs=output ) random_btn.click( fn=get_random_question, inputs=[], outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] ) # Add mobile-optimized CSS gr.HTML(""" """) # Launch the app if __name__ == "__main__": demo.launch(share=False)