Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from datasets import load_dataset | |
import random | |
# Load model and tokenizer | |
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load dataset | |
dataset = load_dataset("openlifescienceai/medmcqa") | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
def get_random_question(): | |
"""Get a random question from the dataset""" | |
index = random.randint(0, len(dataset['train']) - 1) | |
question_data = dataset['train'][index] | |
return ( | |
question_data['question'], | |
question_data['opa'], | |
question_data['opb'], | |
question_data['opc'], | |
question_data['opd'] | |
) | |
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str): | |
# Format the prompt | |
prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:" | |
# Tokenize and generate | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=10, | |
temperature=0.7, | |
do_sample=False, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Get prediction | |
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return prediction | |
# Create Gradio interface with Blocks for more control | |
with gr.Blocks(title="Medical MCQ Predictor") as demo: | |
gr.Markdown("# Medical MCQ Predictor") | |
gr.Markdown("Get a random medical question or enter your own question and options.") | |
with gr.Row(): | |
with gr.Column(): | |
# Input fields | |
question = gr.Textbox(label="Question", lines=3, interactive=True) | |
option_a = gr.Textbox(label="Option A", interactive=True) | |
option_b = gr.Textbox(label="Option B", interactive=True) | |
option_c = gr.Textbox(label="Option C", interactive=True) | |
option_d = gr.Textbox(label="Option D", interactive=True) | |
# Buttons | |
with gr.Row(): | |
predict_btn = gr.Button("Predict", variant="primary") | |
random_btn = gr.Button("Get Random Question", variant="secondary") | |
# Output | |
output = gr.Textbox(label="Model's Answer", lines=5) | |
# Set up button actions | |
predict_btn.click( | |
fn=predict, | |
inputs=[question, option_a, option_b, option_c, option_d], | |
outputs=output | |
) | |
random_btn.click( | |
fn=get_random_question, | |
inputs=[], | |
outputs=[question, option_a, option_b, option_c, option_d] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |