abaryan commited on
Commit
03b5e22
·
verified ·
1 Parent(s): 6d30915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -34
app.py CHANGED
@@ -5,11 +5,8 @@ from datasets import load_dataset
5
  import random
6
  import re
7
 
8
- # Load model and tokenizer
9
- model_name = "abaryan/BioXP-0.5B-MedMCQA"
10
-
11
  SYSTEM_PROMPT = """
12
- You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 200 words without repeating.
13
  Respond in the following format:
14
  <answer>
15
  [correct answer]
@@ -19,10 +16,9 @@ Respond in the following format:
19
  </reasoning>
20
  """
21
 
 
22
  model = AutoModelForCausalLM.from_pretrained(model_name)
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
-
25
- # Load dataset
26
  dataset = load_dataset("openlifescienceai/medmcqa")
27
 
28
  # Move model to GPU if available
@@ -49,11 +45,9 @@ def predict(question: str, option_a: str = "", option_b: str = "", option_c: str
49
  temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
50
 
51
  # Determine if this is an MCQ by checking if any option is provided
52
- # Only treat as MCQ if at least one option is non-empty
53
  is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d])
54
 
55
  if is_mcq:
56
- # Format MCQ question with only non-empty options
57
  options = []
58
  if option_a.strip(): options.append(f"A. {option_a}")
59
  if option_b.strip(): options.append(f"B. {option_b}")
@@ -67,16 +61,12 @@ def predict(question: str, option_a: str = "", option_b: str = "", option_c: str
67
  formatted_question = f"Question: {question}"
68
  system_prompt = SYSTEM_PROMPT
69
 
70
- # Create chat-style prompt
71
  prompt = [
72
  {'role': 'system', 'content': system_prompt},
73
  {'role': 'user', 'content': formatted_question}
74
  ]
75
 
76
- # Use apply_chat_template for better formatting
77
- text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
78
-
79
- # Tokenize and generate
80
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
81
 
82
  with torch.inference_mode():
@@ -87,7 +77,6 @@ def predict(question: str, option_a: str = "", option_b: str = "", option_c: str
87
  top_p=top_p,
88
  )
89
 
90
- # Get only the generated response
91
  generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
92
  model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
93
 
@@ -99,22 +88,21 @@ def predict(question: str, option_a: str = "", option_b: str = "", option_c: str
99
  # Format output with evaluation if available (only for MCQs)
100
  output = cleaned_response
101
 
102
- if is_mcq and correct_option is not None:
103
- correct_letter = chr(65 + correct_option)
104
- answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE)
105
- model_answer = answer_match.group(1).upper() if answer_match else "Not found"
106
 
107
- is_correct = model_answer == correct_letter
108
- output += f"\n\n---\nEvaluation:\n"
109
- output += f"Correct Answer: {correct_letter}\n"
110
- output += f"Model's Answer: {model_answer}\n"
111
- output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
112
- if explanation:
113
- output += f"\nExpert Explanation:\n{explanation}"
114
 
115
  return output
116
 
117
- # Create Gradio interface with mobile-optimized design
118
  with gr.Blocks(
119
  title="BioXP Medical MCQ Assistant",
120
  theme=gr.themes.Soft(
@@ -132,7 +120,6 @@ with gr.Blocks(
132
 
133
  with gr.Row():
134
  with gr.Column(scale=1):
135
- # Input fields with mobile-friendly spacing
136
  question = gr.Textbox(
137
  label="Medical Question",
138
  placeholder="Enter your medical question here...",
@@ -141,7 +128,6 @@ with gr.Blocks(
141
  elem_classes=["mobile-input"]
142
  )
143
 
144
- # Options in a mobile-friendly accordion
145
  with gr.Accordion("Options", open=True):
146
  option_a = gr.Textbox(
147
  label="Option A",
@@ -168,7 +154,6 @@ with gr.Blocks(
168
  elem_classes=["mobile-input"]
169
  )
170
 
171
- # Generation parameters in a collapsible section
172
  with gr.Accordion("Advanced Settings", open=False):
173
  with gr.Row():
174
  with gr.Column(scale=1):
@@ -202,13 +187,11 @@ with gr.Blocks(
202
  correct_option = gr.Number(visible=False)
203
  expert_explanation = gr.Textbox(visible=False)
204
 
205
- # Buttons with mobile-friendly spacing
206
  with gr.Row():
207
  predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"])
208
  random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"])
209
 
210
  with gr.Column(scale=1):
211
- # Output with mobile-friendly styling
212
  output = gr.Textbox(
213
  label="Model's Response",
214
  lines=12,
@@ -232,10 +215,8 @@ with gr.Blocks(
232
  outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
233
  )
234
 
235
- # Add mobile-optimized CSS
236
  gr.HTML("""
237
  <style>
238
- /* Mobile-friendly base styles */
239
  .container {
240
  max-width: 100%;
241
  padding: 0.5rem;
@@ -258,7 +239,6 @@ with gr.Blocks(
258
  font-weight: 500;
259
  }
260
 
261
- /* Response box styling */
262
  .response-box {
263
  font-family: 'Inter', sans-serif;
264
  line-height: 1.6;
 
5
  import random
6
  import re
7
 
 
 
 
8
  SYSTEM_PROMPT = """
9
+ You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 2 sentences without repeating.
10
  Respond in the following format:
11
  <answer>
12
  [correct answer]
 
16
  </reasoning>
17
  """
18
 
19
+ model_name = "abaryan/BioXP-0.5B-MedMCQA"
20
  model = AutoModelForCausalLM.from_pretrained(model_name)
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
22
  dataset = load_dataset("openlifescienceai/medmcqa")
23
 
24
  # Move model to GPU if available
 
45
  temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
46
 
47
  # Determine if this is an MCQ by checking if any option is provided
 
48
  is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d])
49
 
50
  if is_mcq:
 
51
  options = []
52
  if option_a.strip(): options.append(f"A. {option_a}")
53
  if option_b.strip(): options.append(f"B. {option_b}")
 
61
  formatted_question = f"Question: {question}"
62
  system_prompt = SYSTEM_PROMPT
63
 
 
64
  prompt = [
65
  {'role': 'system', 'content': system_prompt},
66
  {'role': 'user', 'content': formatted_question}
67
  ]
68
 
69
+ text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
 
 
 
70
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
71
 
72
  with torch.inference_mode():
 
77
  top_p=top_p,
78
  )
79
 
 
80
  generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
81
  model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
82
 
 
88
  # Format output with evaluation if available (only for MCQs)
89
  output = cleaned_response
90
 
91
+ # if is_mcq and correct_option is not None:
92
+ # correct_letter = chr(65 + correct_option)
93
+ # answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE)
94
+ # model_answer = answer_match.group(1).upper() if answer_match else "Not found"
95
 
96
+ # is_correct = model_answer == correct_letter
97
+ # output += f"\n\n---\nEvaluation:\n"
98
+ # output += f"Correct Answer: {correct_letter}\n"
99
+ # output += f"Model's Answer: {model_answer}\n"
100
+ # output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
101
+ # if explanation:
102
+ # output += f"\nExpert Explanation:\n{explanation}"
103
 
104
  return output
105
 
 
106
  with gr.Blocks(
107
  title="BioXP Medical MCQ Assistant",
108
  theme=gr.themes.Soft(
 
120
 
121
  with gr.Row():
122
  with gr.Column(scale=1):
 
123
  question = gr.Textbox(
124
  label="Medical Question",
125
  placeholder="Enter your medical question here...",
 
128
  elem_classes=["mobile-input"]
129
  )
130
 
 
131
  with gr.Accordion("Options", open=True):
132
  option_a = gr.Textbox(
133
  label="Option A",
 
154
  elem_classes=["mobile-input"]
155
  )
156
 
 
157
  with gr.Accordion("Advanced Settings", open=False):
158
  with gr.Row():
159
  with gr.Column(scale=1):
 
187
  correct_option = gr.Number(visible=False)
188
  expert_explanation = gr.Textbox(visible=False)
189
 
 
190
  with gr.Row():
191
  predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"])
192
  random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"])
193
 
194
  with gr.Column(scale=1):
 
195
  output = gr.Textbox(
196
  label="Model's Response",
197
  lines=12,
 
215
  outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
216
  )
217
 
 
218
  gr.HTML("""
219
  <style>
 
220
  .container {
221
  max-width: 100%;
222
  padding: 0.5rem;
 
239
  font-weight: 500;
240
  }
241
 
 
242
  .response-box {
243
  font-family: 'Inter', sans-serif;
244
  line-height: 1.6;