Abaryan commited on
Commit
ecef291
·
verified ·
1 Parent(s): 238b4d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -8
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datasets import load_dataset
5
  import random
 
6
 
7
  # Load model and tokenizer
8
  model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
@@ -26,12 +27,38 @@ def get_random_question():
26
  question_data['opa'],
27
  question_data['opb'],
28
  question_data['opc'],
29
- question_data['opd']
 
 
30
  )
31
 
32
- def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Format the prompt
34
- prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
 
 
 
 
 
 
 
 
 
35
 
36
  # Tokenize and generate
37
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
@@ -49,7 +76,32 @@ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d
49
 
50
  # Get prediction
51
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Create Gradio interface with Blocks for more control
55
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
@@ -65,25 +117,29 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
65
  option_c = gr.Textbox(label="Option C", interactive=True)
66
  option_d = gr.Textbox(label="Option D", interactive=True)
67
 
 
 
 
 
68
  # Buttons
69
  with gr.Row():
70
  predict_btn = gr.Button("Predict", variant="primary")
71
  random_btn = gr.Button("Get Random Question", variant="secondary")
72
 
73
- # Output
74
- output = gr.Textbox(label="Model's Answer", lines=5)
75
 
76
  # Set up button actions
77
  predict_btn.click(
78
  fn=predict,
79
- inputs=[question, option_a, option_b, option_c, option_d],
80
  outputs=output
81
  )
82
 
83
  random_btn.click(
84
  fn=get_random_question,
85
  inputs=[],
86
- outputs=[question, option_a, option_b, option_c, option_d]
87
  )
88
 
89
  # Launch the app
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datasets import load_dataset
5
  import random
6
+ import re
7
 
8
  # Load model and tokenizer
9
  model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
 
27
  question_data['opa'],
28
  question_data['opb'],
29
  question_data['opc'],
30
+ question_data['opd'],
31
+ question_data.get('cop', None), # Correct option (0-3)
32
+ question_data.get('exp', None) # Explanation
33
  )
34
 
35
+ def extract_answer(prediction: str) -> tuple:
36
+ """Extract answer and reasoning from model output"""
37
+ # Try to find the answer part
38
+ answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
39
+ answer = answer_match.group(1).upper() if answer_match else "Not found"
40
+
41
+ # Try to find reasoning part
42
+ reasoning = ""
43
+ if "Reasoning:" in prediction:
44
+ reasoning = prediction.split("Reasoning:")[-1].strip()
45
+ elif "Explanation:" in prediction:
46
+ reasoning = prediction.split("Explanation:")[-1].strip()
47
+
48
+ return answer, reasoning
49
+
50
+ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, correct_option: int = None, explanation: str = None):
51
  # Format the prompt
52
+ prompt = f"""Question: {question}
53
+
54
+ Options:
55
+ A. {option_a}
56
+ B. {option_b}
57
+ C. {option_c}
58
+ D. {option_d}
59
+
60
+ Please provide your answer and reasoning.
61
+ Answer:"""
62
 
63
  # Tokenize and generate
64
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
76
 
77
  # Get prediction
78
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
+ model_answer, model_reasoning = extract_answer(prediction)
80
+
81
+ # Format the output
82
+ output = f"""## Model's Response
83
+
84
+ ### Answer
85
+ {model_answer}
86
+
87
+ ### Reasoning
88
+ {model_reasoning if model_reasoning else "No reasoning provided"}
89
+
90
+ ### Evaluation
91
+ """
92
+
93
+ # Add evaluation if correct answer is available
94
+ if correct_option is not None:
95
+ correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
96
+ is_correct = model_answer == correct_letter
97
+ output += f"- Correct Answer: {correct_letter}\n"
98
+ output += f"- Model's Answer: {model_answer}\n"
99
+ output += f"- Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
100
+
101
+ if explanation:
102
+ output += f"\n### Expert Explanation\n{explanation}"
103
+
104
+ return output
105
 
106
  # Create Gradio interface with Blocks for more control
107
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
 
117
  option_c = gr.Textbox(label="Option C", interactive=True)
118
  option_d = gr.Textbox(label="Option D", interactive=True)
119
 
120
+ # Hidden fields for correct answer and explanation
121
+ correct_option = gr.Number(visible=False)
122
+ expert_explanation = gr.Textbox(visible=False)
123
+
124
  # Buttons
125
  with gr.Row():
126
  predict_btn = gr.Button("Predict", variant="primary")
127
  random_btn = gr.Button("Get Random Question", variant="secondary")
128
 
129
+ # Output with markdown support
130
+ output = gr.Markdown(label="Model's Answer")
131
 
132
  # Set up button actions
133
  predict_btn.click(
134
  fn=predict,
135
+ inputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation],
136
  outputs=output
137
  )
138
 
139
  random_btn.click(
140
  fn=get_random_question,
141
  inputs=[],
142
+ outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
143
  )
144
 
145
  # Launch the app