Abaryan commited on
Commit
fa0e902
·
verified ·
1 Parent(s): dc3747b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -14
app.py CHANGED
@@ -1,17 +1,34 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
 
5
  # Load model and tokenizer
6
  model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
 
 
 
10
  # Move model to GPU if available
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model = model.to(device)
13
  model.eval()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
16
  # Format the prompt
17
  prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
@@ -33,20 +50,40 @@ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d
33
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
  return prediction
35
 
36
- # Create Gradio interface
37
- demo = gr.Interface(
38
- fn=predict,
39
- inputs=[
40
- gr.Textbox(label="Question", lines=3),
41
- gr.Textbox(label="Option A"),
42
- gr.Textbox(label="Option B"),
43
- gr.Textbox(label="Option C"),
44
- gr.Textbox(label="Option D")
45
- ],
46
- outputs=gr.Textbox(label="Model's Answer", lines=5),
47
- title="Medical MCQ Predictor",
48
- description="Enter a medical question and its options to get the model's prediction."
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Launch the app
52
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from datasets import load_dataset
5
+ import random
6
 
7
  # Load model and tokenizer
8
  model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
12
+ # Load dataset
13
+ dataset = load_dataset("openlifescienceai/medmcqa")
14
+
15
  # Move model to GPU if available
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model = model.to(device)
18
  model.eval()
19
 
20
+ def get_random_question():
21
+ """Get a random question from the dataset"""
22
+ index = random.randint(0, len(dataset['train']) - 1)
23
+ question_data = dataset['train'][index]
24
+ return (
25
+ question_data['question'],
26
+ question_data['opa'],
27
+ question_data['opb'],
28
+ question_data['opc'],
29
+ question_data['opd']
30
+ )
31
+
32
  def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
33
  # Format the prompt
34
  prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
 
50
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
  return prediction
52
 
53
+ # Create Gradio interface with Blocks for more control
54
+ with gr.Blocks(title="Medical MCQ Predictor") as demo:
55
+ gr.Markdown("# Medical MCQ Predictor")
56
+ gr.Markdown("Get a random medical question or enter your own question and options.")
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ # Input fields
61
+ question = gr.Textbox(label="Question", lines=3, interactive=True)
62
+ option_a = gr.Textbox(label="Option A", interactive=True)
63
+ option_b = gr.Textbox(label="Option B", interactive=True)
64
+ option_c = gr.Textbox(label="Option C", interactive=True)
65
+ option_d = gr.Textbox(label="Option D", interactive=True)
66
+
67
+ # Buttons
68
+ with gr.Row():
69
+ predict_btn = gr.Button("Predict", variant="primary")
70
+ random_btn = gr.Button("Get Random Question", variant="secondary")
71
+
72
+ # Output
73
+ output = gr.Textbox(label="Model's Answer", lines=5)
74
+
75
+ # Set up button actions
76
+ predict_btn.click(
77
+ fn=predict,
78
+ inputs=[question, option_a, option_b, option_c, option_d],
79
+ outputs=output
80
+ )
81
+
82
+ random_btn.click(
83
+ fn=get_random_question,
84
+ inputs=[],
85
+ outputs=[question, option_a, option_b, option_c, option_d]
86
+ )
87
 
88
  # Launch the app
89
  if __name__ == "__main__":