BioXP-0.5b-v2 / app.py
Abaryan
Update app.py
20e34ca verified
raw
history blame
3.04 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import random
# Load model and tokenizer
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load dataset
dataset = load_dataset("openlifescienceai/medmcqa")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
def get_random_question():
"""Get a random question from the dataset"""
index = random.randint(0, len(dataset['train']) - 1)
question_data = dataset['train'][index]
return (
question_data['question'],
question_data['opa'],
question_data['opb'],
question_data['opc'],
question_data['opd']
)
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
# Format the prompt
prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=10,
temperature=0.7,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
# Get prediction
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
return prediction
# Create Gradio interface with Blocks for more control
with gr.Blocks(title="Medical MCQ Predictor") as demo:
gr.Markdown("# Medical MCQ Predictor")
gr.Markdown("Get a random medical question or enter your own question and options.")
with gr.Row():
with gr.Column():
# Input fields
question = gr.Textbox(label="Question", lines=3, interactive=True)
option_a = gr.Textbox(label="Option A", interactive=True)
option_b = gr.Textbox(label="Option B", interactive=True)
option_c = gr.Textbox(label="Option C", interactive=True)
option_d = gr.Textbox(label="Option D", interactive=True)
# Buttons
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
random_btn = gr.Button("Get Random Question", variant="secondary")
# Output
output = gr.Textbox(label="Model's Answer", lines=5)
# Set up button actions
predict_btn.click(
fn=predict,
inputs=[question, option_a, option_b, option_c, option_d],
outputs=output
)
random_btn.click(
fn=get_random_question,
inputs=[],
outputs=[question, option_a, option_b, option_c, option_d]
)
# Launch the app
if __name__ == "__main__":
demo.launch()