Abaryan commited on
Commit
dee81c5
·
verified ·
1 Parent(s): e4a59fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -8,6 +8,18 @@ import re
8
  # Load model and tokenizer
9
  # model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
10
  model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
 
 
 
 
 
 
 
 
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
@@ -21,8 +33,8 @@ model.eval()
21
 
22
  def get_random_question():
23
  """Get a random question from the dataset"""
24
- index = random.randint(0, len(dataset['train']) - 1)
25
- question_data = dataset['train'][index]
26
  return (
27
  question_data['question'],
28
  question_data['opa'],
@@ -33,49 +45,46 @@ def get_random_question():
33
  question_data.get('exp', None) # Explanation
34
  )
35
 
36
- def extract_answer(prediction: str) -> tuple:
37
- """Extract answer and reasoning from model output"""
38
- # Try to find the answer part
39
- answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
40
- answer = answer_match.group(1).upper() if answer_match else "Not found"
41
-
42
- # Try to find reasoning part
43
- reasoning = ""
44
- if "Reasoning:" in prediction:
45
- reasoning = prediction.split("Reasoning:")[-1].strip()
46
- elif "Explanation:" in prediction:
47
- reasoning = prediction.split("Explanation:")[-1].strip()
48
-
49
- return answer, reasoning
50
-
51
  def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
52
  correct_option: int = None, explanation: str = None,
53
- temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 20):
54
- # Format the prompt
55
- prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
 
 
 
 
 
 
 
 
 
56
 
57
  # Tokenize and generate
58
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
59
- inputs = {k: v.to(device) for k, v in inputs.items()}
60
 
61
- with torch.no_grad():
62
- outputs = model.generate(
63
- **inputs,
64
  max_new_tokens=max_tokens,
65
  temperature=temperature,
66
  top_p=top_p,
67
- do_sample=True,
68
- # pad_token_id=tokenizer.eos_token_id
69
  )
70
 
71
- # Get prediction
72
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
- model_answer, model_reasoning = extract_answer(prediction)
74
 
75
  # Format output with evaluation if available
76
- output = prediction
 
77
  if correct_option is not None:
78
  correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
 
 
 
 
79
  is_correct = model_answer == correct_letter
80
  output += f"\n\n---\nEvaluation:\n"
81
  output += f"Correct Answer: {correct_letter}\n"
@@ -95,10 +104,13 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
95
  with gr.Column():
96
  # Input fields
97
  question = gr.Textbox(label="Question", lines=3, interactive=True)
98
- option_a = gr.Textbox(label="Option A", interactive=True)
99
- option_b = gr.Textbox(label="Option B", interactive=True)
100
- option_c = gr.Textbox(label="Option C", interactive=True)
101
- option_d = gr.Textbox(label="Option D", interactive=True)
 
 
 
102
 
103
  # Generation parameters
104
  with gr.Accordion("Generation Parameters", open=False):
@@ -119,12 +131,12 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
119
  info="Higher values allow more diverse tokens, lower values more focused"
120
  )
121
  max_tokens = gr.Slider(
122
- minimum=10,
123
  maximum=512,
124
- value=20,
125
  step=32,
126
  label="Max Tokens",
127
- info="Maximum length of the generated response"
128
  )
129
 
130
  # Hidden fields for correct answer and explanation
 
8
  # Load model and tokenizer
9
  # model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
10
  model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
11
+
12
+ SYSTEM_PROMPT = """
13
+ You're a medical expert. Answer the question with careful analysis and explain why the selected option is correct in 150 words without reapeating.
14
+ Respond in the following format:
15
+ <answer>
16
+ [correct answer]
17
+ </answer>
18
+ <reasoning>
19
+ [explain why the selected option is correct]
20
+ </reasoning>
21
+ """
22
+
23
  model = AutoModelForCausalLM.from_pretrained(model_name)
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
 
 
33
 
34
  def get_random_question():
35
  """Get a random question from the dataset"""
36
+ index = random.randint(0, len(dataset['validation']) - 1)
37
+ question_data = dataset['validation'][index]
38
  return (
39
  question_data['question'],
40
  question_data['opa'],
 
45
  question_data.get('exp', None) # Explanation
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
49
  correct_option: int = None, explanation: str = None,
50
+ temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
51
+ # Format the question with options
52
+ formatted_question = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}"
53
+
54
+ # Create chat-style prompt
55
+ prompt = [
56
+ {'role': 'system', 'content': SYSTEM_PROMPT},
57
+ {'role': 'user', 'content': formatted_question}
58
+ ]
59
+
60
+ # Use apply_chat_template for better formatting
61
+ text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
62
 
63
  # Tokenize and generate
64
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
 
65
 
66
+ with torch.inference_mode():
67
+ generated_ids = model.generate(
68
+ **model_inputs,
69
  max_new_tokens=max_tokens,
70
  temperature=temperature,
71
  top_p=top_p,
72
+ # repetition_penalty=1.1,
 
73
  )
74
 
75
+ # Get only the generated response (excluding the prompt)
76
+ generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
77
+ model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
78
 
79
  # Format output with evaluation if available
80
+ output = model_response
81
+
82
  if correct_option is not None:
83
  correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
84
+ # Extract answer from model response for evaluation
85
+ answer_match = re.search(r"<answer>\s*([A-D])\s*</answer>", model_response, re.IGNORECASE)
86
+ model_answer = answer_match.group(1).upper() if answer_match else "Not found"
87
+
88
  is_correct = model_answer == correct_letter
89
  output += f"\n\n---\nEvaluation:\n"
90
  output += f"Correct Answer: {correct_letter}\n"
 
104
  with gr.Column():
105
  # Input fields
106
  question = gr.Textbox(label="Question", lines=3, interactive=True)
107
+
108
+ # Options in an expandable accordion
109
+ with gr.Accordion("Options", open=False):
110
+ option_a = gr.Textbox(label="Option A", interactive=True)
111
+ option_b = gr.Textbox(label="Option B", interactive=True)
112
+ option_c = gr.Textbox(label="Option C", interactive=True)
113
+ option_d = gr.Textbox(label="Option D", interactive=True)
114
 
115
  # Generation parameters
116
  with gr.Accordion("Generation Parameters", open=False):
 
131
  info="Higher values allow more diverse tokens, lower values more focused"
132
  )
133
  max_tokens = gr.Slider(
134
+ minimum=50,
135
  maximum=512,
136
+ value=256,
137
  step=32,
138
  label="Max Tokens",
139
+ info="Maximum length of the generated response (recommended: 256)"
140
  )
141
 
142
  # Hidden fields for correct answer and explanation