Abaryan commited on
Commit
20e34ca
·
verified ·
1 Parent(s): b89d5ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -75
app.py CHANGED
@@ -3,10 +3,9 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datasets import load_dataset
5
  import random
6
- import re
7
 
8
  # Load model and tokenizer
9
- model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
10
  model = AutoModelForCausalLM.from_pretrained(model_name)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
@@ -27,38 +26,12 @@ def get_random_question():
27
  question_data['opa'],
28
  question_data['opb'],
29
  question_data['opc'],
30
- question_data['opd'],
31
- question_data.get('cop', None), # Correct option (0-3)
32
- question_data.get('exp', None) # Explanation
33
  )
34
 
35
- def extract_answer(prediction: str) -> tuple:
36
- """Extract answer and reasoning from model output"""
37
- # Try to find the answer part
38
- answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
39
- answer = answer_match.group(1).upper() if answer_match else "Not found"
40
-
41
- # Try to find reasoning part
42
- reasoning = ""
43
- if "Reasoning:" in prediction:
44
- reasoning = prediction.split("Reasoning:")[-1].strip()
45
- elif "Explanation:" in prediction:
46
- reasoning = prediction.split("Explanation:")[-1].strip()
47
-
48
- return answer, reasoning
49
-
50
- def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, correct_option: int = None, explanation: str = None):
51
  # Format the prompt
52
- prompt = f"""Question: {question}
53
-
54
- Options:
55
- A. {option_a}
56
- B. {option_b}
57
- C. {option_c}
58
- D. {option_d}
59
-
60
- Please provide your answer and reasoning.
61
- Answer:"""
62
 
63
  # Tokenize and generate
64
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
@@ -67,47 +40,15 @@ Answer:"""
67
  with torch.no_grad():
68
  outputs = model.generate(
69
  **inputs,
70
- max_new_tokens=256,
71
- temperature=0.6,
72
- top_p=0.9,
73
- do_sample=True,
74
  pad_token_id=tokenizer.eos_token_id
75
  )
76
 
77
  # Get prediction
78
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
- model_answer, model_reasoning = extract_answer(prediction)
80
-
81
- # Format the output
82
- output = f"""## Raw Model Output
83
- ```
84
- {prediction}
85
- ```
86
-
87
- ## Evaluation
88
-
89
- ### Answer
90
- {model_answer}
91
-
92
- ### Reasoning
93
- {model_reasoning if model_reasoning else "No reasoning provided"}
94
- """
95
-
96
- # Add evaluation if correct answer is available
97
- if correct_option is not None:
98
- correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
99
- is_correct = model_answer == correct_letter
100
- output += f"""
101
- ### Results
102
- - Correct Answer: {correct_letter}
103
- - Model's Answer: {model_answer}
104
- - Result: {'✅ Correct' if is_correct else '❌ Incorrect'}
105
-
106
- ### Expert Explanation
107
- {explanation if explanation else "No expert explanation available"}
108
- """
109
-
110
- return output
111
 
112
  # Create Gradio interface with Blocks for more control
113
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
@@ -123,29 +64,25 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
123
  option_c = gr.Textbox(label="Option C", interactive=True)
124
  option_d = gr.Textbox(label="Option D", interactive=True)
125
 
126
- # Hidden fields for correct answer and explanation
127
- correct_option = gr.Number(visible=False)
128
- expert_explanation = gr.Textbox(visible=False)
129
-
130
  # Buttons
131
  with gr.Row():
132
  predict_btn = gr.Button("Predict", variant="primary")
133
  random_btn = gr.Button("Get Random Question", variant="secondary")
134
 
135
- # Output with markdown support
136
- output = gr.Markdown(label="Model's Answer")
137
 
138
  # Set up button actions
139
  predict_btn.click(
140
  fn=predict,
141
- inputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation],
142
  outputs=output
143
  )
144
 
145
  random_btn.click(
146
  fn=get_random_question,
147
  inputs=[],
148
- outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
149
  )
150
 
151
  # Launch the app
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from datasets import load_dataset
5
  import random
 
6
 
7
  # Load model and tokenizer
8
+ model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
 
26
  question_data['opa'],
27
  question_data['opb'],
28
  question_data['opc'],
29
+ question_data['opd']
 
 
30
  )
31
 
32
+ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Format the prompt
34
+ prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
 
 
 
 
 
 
 
 
 
35
 
36
  # Tokenize and generate
37
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
40
  with torch.no_grad():
41
  outputs = model.generate(
42
  **inputs,
43
+ max_new_tokens=10,
44
+ temperature=0.7,
45
+ do_sample=False,
 
46
  pad_token_id=tokenizer.eos_token_id
47
  )
48
 
49
  # Get prediction
50
  prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Create Gradio interface with Blocks for more control
54
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
 
64
  option_c = gr.Textbox(label="Option C", interactive=True)
65
  option_d = gr.Textbox(label="Option D", interactive=True)
66
 
 
 
 
 
67
  # Buttons
68
  with gr.Row():
69
  predict_btn = gr.Button("Predict", variant="primary")
70
  random_btn = gr.Button("Get Random Question", variant="secondary")
71
 
72
+ # Output
73
+ output = gr.Textbox(label="Model's Answer", lines=5)
74
 
75
  # Set up button actions
76
  predict_btn.click(
77
  fn=predict,
78
+ inputs=[question, option_a, option_b, option_c, option_d],
79
  outputs=output
80
  )
81
 
82
  random_btn.click(
83
  fn=get_random_question,
84
  inputs=[],
85
+ outputs=[question, option_a, option_b, option_c, option_d]
86
  )
87
 
88
  # Launch the app