File size: 3,040 Bytes
dc3747b
1f15859
30ca71a
fa0e902
 
a8c3b23
dc3747b
20e34ca
dc3747b
 
 
fa0e902
 
 
dc3747b
 
 
 
 
fa0e902
 
 
 
 
 
 
 
 
20e34ca
fa0e902
 
20e34ca
dc3747b
20e34ca
dc3747b
 
 
 
 
 
 
 
20e34ca
 
 
dc3747b
1f15859
 
dc3747b
 
20e34ca
a8c3b23
fa0e902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20e34ca
 
fa0e902
 
 
 
20e34ca
fa0e902
 
 
 
 
 
20e34ca
fa0e902
a8c3b23
dc3747b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import random

# Load model and tokenizer
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load dataset
dataset = load_dataset("openlifescienceai/medmcqa")

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

def get_random_question():
    """Get a random question from the dataset"""
    index = random.randint(0, len(dataset['train']) - 1)
    question_data = dataset['train'][index]
    return (
        question_data['question'],
        question_data['opa'],
        question_data['opb'],
        question_data['opc'],
        question_data['opd']
    )

def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
    # Format the prompt
    prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
    
    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            temperature=0.7,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Get prediction
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return prediction

# Create Gradio interface with Blocks for more control
with gr.Blocks(title="Medical MCQ Predictor") as demo:
    gr.Markdown("# Medical MCQ Predictor")
    gr.Markdown("Get a random medical question or enter your own question and options.")
    
    with gr.Row():
        with gr.Column():
            # Input fields
            question = gr.Textbox(label="Question", lines=3, interactive=True)
            option_a = gr.Textbox(label="Option A", interactive=True)
            option_b = gr.Textbox(label="Option B", interactive=True)
            option_c = gr.Textbox(label="Option C", interactive=True)
            option_d = gr.Textbox(label="Option D", interactive=True)
            
            # Buttons
            with gr.Row():
                predict_btn = gr.Button("Predict", variant="primary")
                random_btn = gr.Button("Get Random Question", variant="secondary")
            
            # Output
            output = gr.Textbox(label="Model's Answer", lines=5)
    
    # Set up button actions
    predict_btn.click(
        fn=predict,
        inputs=[question, option_a, option_b, option_c, option_d],
        outputs=output
    )
    
    random_btn.click(
        fn=get_random_question,
        inputs=[],
        outputs=[question, option_a, option_b, option_c, option_d]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()