Abaryan commited on
Commit
398a7eb
·
verified ·
1 Parent(s): 20e34ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -10
app.py CHANGED
@@ -3,9 +3,10 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datasets import load_dataset
5
  import random
 
6
 
7
  # Load model and tokenizer
8
- model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
@@ -26,10 +27,27 @@ def get_random_question():
26
  question_data['opa'],
27
  question_data['opb'],
28
  question_data['opc'],
29
- question_data['opd']
 
 
30
  )
31
 
32
- def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Format the prompt
34
  prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
35
 
@@ -40,15 +58,30 @@ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d
40
  with torch.no_grad():
41
  outputs = model.generate(
42
  **inputs,
43
- max_new_tokens=10,
44
- temperature=0.7,
45
- do_sample=False,
 
46
  pad_token_id=tokenizer.eos_token_id
47
  )
48
 
49
  # Get prediction
50
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Create Gradio interface with Blocks for more control
54
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
@@ -64,25 +97,29 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
64
  option_c = gr.Textbox(label="Option C", interactive=True)
65
  option_d = gr.Textbox(label="Option D", interactive=True)
66
 
 
 
 
 
67
  # Buttons
68
  with gr.Row():
69
  predict_btn = gr.Button("Predict", variant="primary")
70
  random_btn = gr.Button("Get Random Question", variant="secondary")
71
 
72
  # Output
73
- output = gr.Textbox(label="Model's Answer", lines=5)
74
 
75
  # Set up button actions
76
  predict_btn.click(
77
  fn=predict,
78
- inputs=[question, option_a, option_b, option_c, option_d],
79
  outputs=output
80
  )
81
 
82
  random_btn.click(
83
  fn=get_random_question,
84
  inputs=[],
85
- outputs=[question, option_a, option_b, option_c, option_d]
86
  )
87
 
88
  # Launch the app
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datasets import load_dataset
5
  import random
6
+ import re
7
 
8
  # Load model and tokenizer
9
+ model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
10
  model = AutoModelForCausalLM.from_pretrained(model_name)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
 
27
  question_data['opa'],
28
  question_data['opb'],
29
  question_data['opc'],
30
+ question_data['opd'],
31
+ question_data.get('cop', None), # Correct option (0-3)
32
+ question_data.get('exp', None) # Explanation
33
  )
34
 
35
+ def extract_answer(prediction: str) -> tuple:
36
+ """Extract answer and reasoning from model output"""
37
+ # Try to find the answer part
38
+ answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
39
+ answer = answer_match.group(1).upper() if answer_match else "Not found"
40
+
41
+ # Try to find reasoning part
42
+ reasoning = ""
43
+ if "Reasoning:" in prediction:
44
+ reasoning = prediction.split("Reasoning:")[-1].strip()
45
+ elif "Explanation:" in prediction:
46
+ reasoning = prediction.split("Explanation:")[-1].strip()
47
+
48
+ return answer, reasoning
49
+
50
+ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, correct_option: int = None, explanation: str = None):
51
  # Format the prompt
52
  prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
53
 
 
58
  with torch.no_grad():
59
  outputs = model.generate(
60
  **inputs,
61
+ max_new_tokens=256,
62
+ temperature=0.6,
63
+ top_p=0.9,
64
+ do_sample=True,
65
  pad_token_id=tokenizer.eos_token_id
66
  )
67
 
68
  # Get prediction
69
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+ model_answer, model_reasoning = extract_answer(prediction)
71
+
72
+ # Format output with evaluation if available
73
+ output = prediction
74
+ if correct_option is not None:
75
+ correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
76
+ is_correct = model_answer == correct_letter
77
+ output += f"\n\n---\nEvaluation:\n"
78
+ output += f"Correct Answer: {correct_letter}\n"
79
+ output += f"Model's Answer: {model_answer}\n"
80
+ output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
81
+ if explanation:
82
+ output += f"\nExpert Explanation:\n{explanation}"
83
+
84
+ return output
85
 
86
  # Create Gradio interface with Blocks for more control
87
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
 
97
  option_c = gr.Textbox(label="Option C", interactive=True)
98
  option_d = gr.Textbox(label="Option D", interactive=True)
99
 
100
+ # Hidden fields for correct answer and explanation
101
+ correct_option = gr.Number(visible=False)
102
+ expert_explanation = gr.Textbox(visible=False)
103
+
104
  # Buttons
105
  with gr.Row():
106
  predict_btn = gr.Button("Predict", variant="primary")
107
  random_btn = gr.Button("Get Random Question", variant="secondary")
108
 
109
  # Output
110
+ output = gr.Textbox(label="Model's Answer", lines=10)
111
 
112
  # Set up button actions
113
  predict_btn.click(
114
  fn=predict,
115
+ inputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation],
116
  outputs=output
117
  )
118
 
119
  random_btn.click(
120
  fn=get_random_question,
121
  inputs=[],
122
+ outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
123
  )
124
 
125
  # Launch the app