Spaces:
Running
Running
Abaryan
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,18 @@ import re
|
|
8 |
# Load model and tokenizer
|
9 |
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
|
10 |
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
|
@@ -21,8 +33,8 @@ model.eval()
|
|
21 |
|
22 |
def get_random_question():
|
23 |
"""Get a random question from the dataset"""
|
24 |
-
index = random.randint(0, len(dataset['
|
25 |
-
question_data = dataset['
|
26 |
return (
|
27 |
question_data['question'],
|
28 |
question_data['opa'],
|
@@ -33,49 +45,46 @@ def get_random_question():
|
|
33 |
question_data.get('exp', None) # Explanation
|
34 |
)
|
35 |
|
36 |
-
def extract_answer(prediction: str) -> tuple:
|
37 |
-
"""Extract answer and reasoning from model output"""
|
38 |
-
# Try to find the answer part
|
39 |
-
answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
|
40 |
-
answer = answer_match.group(1).upper() if answer_match else "Not found"
|
41 |
-
|
42 |
-
# Try to find reasoning part
|
43 |
-
reasoning = ""
|
44 |
-
if "Reasoning:" in prediction:
|
45 |
-
reasoning = prediction.split("Reasoning:")[-1].strip()
|
46 |
-
elif "Explanation:" in prediction:
|
47 |
-
reasoning = prediction.split("Explanation:")[-1].strip()
|
48 |
-
|
49 |
-
return answer, reasoning
|
50 |
-
|
51 |
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
|
52 |
correct_option: int = None, explanation: str = None,
|
53 |
-
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int =
|
54 |
-
# Format the
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# Tokenize and generate
|
58 |
-
|
59 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
60 |
|
61 |
-
with torch.
|
62 |
-
|
63 |
-
**
|
64 |
max_new_tokens=max_tokens,
|
65 |
temperature=temperature,
|
66 |
top_p=top_p,
|
67 |
-
|
68 |
-
# pad_token_id=tokenizer.eos_token_id
|
69 |
)
|
70 |
|
71 |
-
# Get
|
72 |
-
|
73 |
-
|
74 |
|
75 |
# Format output with evaluation if available
|
76 |
-
output =
|
|
|
77 |
if correct_option is not None:
|
78 |
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
|
|
|
|
|
|
|
|
|
79 |
is_correct = model_answer == correct_letter
|
80 |
output += f"\n\n---\nEvaluation:\n"
|
81 |
output += f"Correct Answer: {correct_letter}\n"
|
@@ -95,10 +104,13 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
|
|
95 |
with gr.Column():
|
96 |
# Input fields
|
97 |
question = gr.Textbox(label="Question", lines=3, interactive=True)
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
|
103 |
# Generation parameters
|
104 |
with gr.Accordion("Generation Parameters", open=False):
|
@@ -119,12 +131,12 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
|
|
119 |
info="Higher values allow more diverse tokens, lower values more focused"
|
120 |
)
|
121 |
max_tokens = gr.Slider(
|
122 |
-
minimum=
|
123 |
maximum=512,
|
124 |
-
value=
|
125 |
step=32,
|
126 |
label="Max Tokens",
|
127 |
-
info="Maximum length of the generated response"
|
128 |
)
|
129 |
|
130 |
# Hidden fields for correct answer and explanation
|
|
|
8 |
# Load model and tokenizer
|
9 |
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
|
10 |
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
|
11 |
+
|
12 |
+
SYSTEM_PROMPT = """
|
13 |
+
You're a medical expert. Answer the question with careful analysis and explain why the selected option is correct in 150 words without reapeating.
|
14 |
+
Respond in the following format:
|
15 |
+
<answer>
|
16 |
+
[correct answer]
|
17 |
+
</answer>
|
18 |
+
<reasoning>
|
19 |
+
[explain why the selected option is correct]
|
20 |
+
</reasoning>
|
21 |
+
"""
|
22 |
+
|
23 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
|
|
|
33 |
|
34 |
def get_random_question():
|
35 |
"""Get a random question from the dataset"""
|
36 |
+
index = random.randint(0, len(dataset['validation']) - 1)
|
37 |
+
question_data = dataset['validation'][index]
|
38 |
return (
|
39 |
question_data['question'],
|
40 |
question_data['opa'],
|
|
|
45 |
question_data.get('exp', None) # Explanation
|
46 |
)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
|
49 |
correct_option: int = None, explanation: str = None,
|
50 |
+
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
|
51 |
+
# Format the question with options
|
52 |
+
formatted_question = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}"
|
53 |
+
|
54 |
+
# Create chat-style prompt
|
55 |
+
prompt = [
|
56 |
+
{'role': 'system', 'content': SYSTEM_PROMPT},
|
57 |
+
{'role': 'user', 'content': formatted_question}
|
58 |
+
]
|
59 |
+
|
60 |
+
# Use apply_chat_template for better formatting
|
61 |
+
text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
|
62 |
|
63 |
# Tokenize and generate
|
64 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
|
|
65 |
|
66 |
+
with torch.inference_mode():
|
67 |
+
generated_ids = model.generate(
|
68 |
+
**model_inputs,
|
69 |
max_new_tokens=max_tokens,
|
70 |
temperature=temperature,
|
71 |
top_p=top_p,
|
72 |
+
# repetition_penalty=1.1,
|
|
|
73 |
)
|
74 |
|
75 |
+
# Get only the generated response (excluding the prompt)
|
76 |
+
generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
|
77 |
+
model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
78 |
|
79 |
# Format output with evaluation if available
|
80 |
+
output = model_response
|
81 |
+
|
82 |
if correct_option is not None:
|
83 |
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
|
84 |
+
# Extract answer from model response for evaluation
|
85 |
+
answer_match = re.search(r"<answer>\s*([A-D])\s*</answer>", model_response, re.IGNORECASE)
|
86 |
+
model_answer = answer_match.group(1).upper() if answer_match else "Not found"
|
87 |
+
|
88 |
is_correct = model_answer == correct_letter
|
89 |
output += f"\n\n---\nEvaluation:\n"
|
90 |
output += f"Correct Answer: {correct_letter}\n"
|
|
|
104 |
with gr.Column():
|
105 |
# Input fields
|
106 |
question = gr.Textbox(label="Question", lines=3, interactive=True)
|
107 |
+
|
108 |
+
# Options in an expandable accordion
|
109 |
+
with gr.Accordion("Options", open=False):
|
110 |
+
option_a = gr.Textbox(label="Option A", interactive=True)
|
111 |
+
option_b = gr.Textbox(label="Option B", interactive=True)
|
112 |
+
option_c = gr.Textbox(label="Option C", interactive=True)
|
113 |
+
option_d = gr.Textbox(label="Option D", interactive=True)
|
114 |
|
115 |
# Generation parameters
|
116 |
with gr.Accordion("Generation Parameters", open=False):
|
|
|
131 |
info="Higher values allow more diverse tokens, lower values more focused"
|
132 |
)
|
133 |
max_tokens = gr.Slider(
|
134 |
+
minimum=50,
|
135 |
maximum=512,
|
136 |
+
value=256,
|
137 |
step=32,
|
138 |
label="Max Tokens",
|
139 |
+
info="Maximum length of the generated response (recommended: 256)"
|
140 |
)
|
141 |
|
142 |
# Hidden fields for correct answer and explanation
|