taruschirag commited on
Commit
2a43b25
·
verified ·
1 Parent(s): d6658d9

Update app.py

Browse files

Added google form and made sure the grey text stays

Files changed (1) hide show
  1. app.py +57 -65
app.py CHANGED
@@ -58,7 +58,6 @@ class ModelWrapper:
58
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
59
 
60
  print(f"Loading model: {model_name}...")
61
- # Use disk offloading for the large 8B model to handle memory constraints
62
  if "8b" in model_name.lower():
63
  config = AutoConfig.from_pretrained(model_name)
64
  with init_empty_weights():
@@ -68,18 +67,15 @@ class ModelWrapper:
68
  model_empty,
69
  model_name,
70
  device_map="auto",
71
- offload_folder="offload", # A directory to store the offloaded layers
72
  torch_dtype=torch.bfloat16
73
  ).eval()
74
  else:
75
- # Load the smaller model directly
76
  self.model = AutoModelForCausalLM.from_pretrained(
77
  model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
78
  print(f"Model {model_name} loaded successfully.")
79
 
80
-
81
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
82
- """Compile sys, user, assistant inputs into the proper dictionaries"""
83
  message = []
84
  if system_content is not None:
85
  message.append({'role': 'system', 'content': system_content})
@@ -92,26 +88,29 @@ class ModelWrapper:
92
  return message
93
 
94
  def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
95
- """Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
96
  if assistant_content is not None:
97
  message = self.get_message_template(system_content, user_content, assistant_content)
98
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
99
  else:
100
- if enable_thinking:
101
- if "qwen3" in self.model_name.lower():
102
- message = self.get_message_template(system_content, user_content)
103
- prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
104
- prompt = prompt + f"\n{COT_OPENING}"
105
- else:
106
- message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
107
- prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
108
- else:
109
- message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
110
- prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
 
 
 
 
111
  return prompt
112
 
113
- def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
114
- """Generate and decode the response."""
115
  print("Generating response...")
116
 
117
  if "qwen3" in self.model_name.lower() and enable_thinking:
@@ -140,7 +139,7 @@ class ModelWrapper:
140
  input_length = len(message)
141
  return output_text[input_length:] if len(output_text) > input_length else "No response generated."
142
 
143
- # --- Model Cache to prevent reloading on every call ---
144
  LOADED_MODELS = {}
145
 
146
  def get_model(model_name):
@@ -148,75 +147,68 @@ def get_model(model_name):
148
  LOADED_MODELS[model_name] = ModelWrapper(model_name)
149
  return LOADED_MODELS[model_name]
150
 
151
- # Gradio Inference Function
152
  def compliance_check(rules_text, transcript_text, thinking, model_name):
153
  try:
154
- # Get the selected model from our cache (or load it if it's the first time)
155
  model = get_model(model_name)
156
-
157
  rules = [r for r in rules_text.split("\n") if r.strip()]
158
  inp = format_rules(rules) + format_transcript(transcript_text)
159
 
160
  out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
161
-
162
  out = str(out).strip()
163
  if not out:
164
  out = "No response generated. Please try with different input."
165
 
166
- max_bytes = 2500
167
  out_bytes = out.encode('utf-8')
168
-
169
  if len(out_bytes) > max_bytes:
170
  truncated_bytes = out_bytes[:max_bytes]
171
  out = truncated_bytes.decode('utf-8', errors='ignore')
172
  out += "\n\n[Response truncated to prevent server errors]"
173
-
174
  return out
175
-
176
  except Exception as e:
177
- error_msg = f"Error: {str(e)[:200]}"
178
  print(f"Full error: {e}")
179
  return error_msg
180
 
181
- # --- Gradio Interface Definition ---
182
- demo = gr.Interface(
183
- fn=compliance_check,
184
- inputs=[
185
- gr.Textbox(
186
- lines=5,
187
- label="Rules (one per line)",
188
- max_lines=10,
189
- placeholder='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
190
- ),
191
- gr.Textbox(
192
- lines=10,
193
- label="Transcript",
194
- max_lines=15,
195
- placeholder='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
196
- ),
197
- gr.Checkbox(label="Enable ⟨think⟩ mode", value=False),
198
- gr.Dropdown(
199
  ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
200
  label="Select Model",
201
  value="Qwen/Qwen3-0.6B",
202
  info="The 8B model is more powerful but may be slower to load and run."
203
- ),
204
- gr.Tab("Feedback"):
205
- gr.HTML(
206
- """
207
- <iframe src="https://docs.google.com/forms/d/e/YOUR_FORM_ID/viewform?embedded=true"
208
- width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0">
209
- Loading…
210
- </iframe>
211
- """
212
- )
213
- ],
214
- outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
215
- title="DynaGuard Compliance Checker",
216
- description="Select a model, paste your rules & transcript, then hit Submit.",
217
- allow_flagging="never",
218
- show_progress=True
219
- )
 
 
220
 
221
  if __name__ == "__main__":
222
- demo.launch()
 
58
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
59
 
60
  print(f"Loading model: {model_name}...")
 
61
  if "8b" in model_name.lower():
62
  config = AutoConfig.from_pretrained(model_name)
63
  with init_empty_weights():
 
67
  model_empty,
68
  model_name,
69
  device_map="auto",
70
+ offload_folder="offload",
71
  torch_dtype=torch.bfloat16
72
  ).eval()
73
  else:
 
74
  self.model = AutoModelForCausalLM.from_pretrained(
75
  model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
76
  print(f"Model {model_name} loaded successfully.")
77
 
 
78
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
 
79
  message = []
80
  if system_content is not None:
81
  message.append({'role': 'system', 'content': system_content})
 
88
  return message
89
 
90
  def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
 
91
  if assistant_content is not None:
92
  message = self.get_message_template(system_content, user_content, assistant_content)
93
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
94
  else:
95
+ if enable_thinking:
96
+ if "qwen3" in self.model_name.lower():
97
+ message = self.get_message_template(system_content, user_content)
98
+ prompt = self.tokenizer.apply_chat_template(
99
+ message, tokenize=False, add_generation_prompt=True, enable_thinking=True
100
+ )
101
+ prompt = prompt + f"\n{COT_OPENING}"
102
+ else:
103
+ message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
104
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
105
+ else:
106
+ message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
107
+ prompt = self.tokenizer.apply_chat_template(
108
+ message, tokenize=False, continue_final_message=True, enable_thinking=False
109
+ )
110
  return prompt
111
 
112
+ def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256,
113
+ enable_thinking=True, system_prompt=SYSTEM_PROMPT):
114
  print("Generating response...")
115
 
116
  if "qwen3" in self.model_name.lower() and enable_thinking:
 
139
  input_length = len(message)
140
  return output_text[input_length:] if len(output_text) > input_length else "No response generated."
141
 
142
+ # --- Model Cache ---
143
  LOADED_MODELS = {}
144
 
145
  def get_model(model_name):
 
147
  LOADED_MODELS[model_name] = ModelWrapper(model_name)
148
  return LOADED_MODELS[model_name]
149
 
150
+ # --- Inference Function ---
151
  def compliance_check(rules_text, transcript_text, thinking, model_name):
152
  try:
 
153
  model = get_model(model_name)
 
154
  rules = [r for r in rules_text.split("\n") if r.strip()]
155
  inp = format_rules(rules) + format_transcript(transcript_text)
156
 
157
  out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
 
158
  out = str(out).strip()
159
  if not out:
160
  out = "No response generated. Please try with different input."
161
 
162
+ max_bytes = 2500
163
  out_bytes = out.encode('utf-8')
 
164
  if len(out_bytes) > max_bytes:
165
  truncated_bytes = out_bytes[:max_bytes]
166
  out = truncated_bytes.decode('utf-8', errors='ignore')
167
  out += "\n\n[Response truncated to prevent server errors]"
 
168
  return out
 
169
  except Exception as e:
170
+ error_msg = f"Error: {str(e)[:200]}"
171
  print(f"Full error: {e}")
172
  return error_msg
173
 
174
+ # --- Gradio UI with Tabs ---
175
+ with gr.Blocks(title="DynaGuard Compliance Checker") as demo:
176
+ with gr.Tab("Compliance Checker"):
177
+ rules_box = gr.Textbox(
178
+ lines=5,
179
+ label="Rules (one per line)",
180
+ value='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
181
+ )
182
+ transcript_box = gr.Textbox(
183
+ lines=10,
184
+ label="Transcript",
185
+ value='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
186
+ )
187
+ thinking_box = gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
188
+ model_dropdown = gr.Dropdown(
 
 
 
189
  ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
190
  label="Select Model",
191
  value="Qwen/Qwen3-0.6B",
192
  info="The 8B model is more powerful but may be slower to load and run."
193
+ )
194
+ submit_btn = gr.Button("Submit")
195
+ output_box = gr.Textbox(label="Compliance Output", lines=10, max_lines=15)
196
+
197
+ submit_btn.click(
198
+ compliance_check,
199
+ inputs=[rules_box, transcript_box, thinking_box, model_dropdown],
200
+ outputs=[output_box]
201
+ )
202
+
203
+ with gr.Tab("Feedback"):
204
+ gr.HTML(
205
+ """
206
+ <iframe src="https://docs.google.com/forms/d/e/https://forms.gle/xoBTdFw4xFaWHeSG7/viewform?embedded=true"
207
+ width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0">
208
+ Loading…
209
+ </iframe>
210
+ """
211
+ )
212
 
213
  if __name__ == "__main__":
214
+ demo.launch()