Abaryan commited on
Commit
ffe13aa
·
verified ·
1 Parent(s): 4a74617

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -5
app.py CHANGED
@@ -47,7 +47,9 @@ def extract_answer(prediction: str) -> tuple:
47
 
48
  return answer, reasoning
49
 
50
- def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, correct_option: int = None, explanation: str = None):
 
 
51
  # Format the prompt
52
  prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
53
 
@@ -58,9 +60,9 @@ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d
58
  with torch.no_grad():
59
  outputs = model.generate(
60
  **inputs,
61
- max_new_tokens=20,
62
- temperature=0.6,
63
- top_p=0.9,
64
  do_sample=True,
65
  pad_token_id=tokenizer.eos_token_id
66
  )
@@ -97,6 +99,33 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
97
  option_c = gr.Textbox(label="Option C", interactive=True)
98
  option_d = gr.Textbox(label="Option D", interactive=True)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # Hidden fields for correct answer and explanation
101
  correct_option = gr.Number(visible=False)
102
  expert_explanation = gr.Textbox(visible=False)
@@ -112,7 +141,11 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
112
  # Set up button actions
113
  predict_btn.click(
114
  fn=predict,
115
- inputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation],
 
 
 
 
116
  outputs=output
117
  )
118
 
 
47
 
48
  return answer, reasoning
49
 
50
+ def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
51
+ correct_option: int = None, explanation: str = None,
52
+ temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
53
  # Format the prompt
54
  prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
55
 
 
60
  with torch.no_grad():
61
  outputs = model.generate(
62
  **inputs,
63
+ max_new_tokens=max_tokens,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
  do_sample=True,
67
  pad_token_id=tokenizer.eos_token_id
68
  )
 
99
  option_c = gr.Textbox(label="Option C", interactive=True)
100
  option_d = gr.Textbox(label="Option D", interactive=True)
101
 
102
+ # Generation parameters
103
+ with gr.Accordion("Generation Parameters", open=False):
104
+ temperature = gr.Slider(
105
+ minimum=0.1,
106
+ maximum=1.0,
107
+ value=0.6,
108
+ step=0.1,
109
+ label="Temperature",
110
+ info="Higher values make output more random, lower values more focused"
111
+ )
112
+ top_p = gr.Slider(
113
+ minimum=0.1,
114
+ maximum=1.0,
115
+ value=0.9,
116
+ step=0.1,
117
+ label="Top P",
118
+ info="Higher values allow more diverse tokens, lower values more focused"
119
+ )
120
+ max_tokens = gr.Slider(
121
+ minimum=32,
122
+ maximum=512,
123
+ value=256,
124
+ step=32,
125
+ label="Max Tokens",
126
+ info="Maximum length of the generated response"
127
+ )
128
+
129
  # Hidden fields for correct answer and explanation
130
  correct_option = gr.Number(visible=False)
131
  expert_explanation = gr.Textbox(visible=False)
 
141
  # Set up button actions
142
  predict_btn.click(
143
  fn=predict,
144
+ inputs=[
145
+ question, option_a, option_b, option_c, option_d,
146
+ correct_option, expert_explanation,
147
+ temperature, top_p, max_tokens
148
+ ],
149
  outputs=output
150
  )
151