Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse filesAdded google form and made sure the grey text stays
app.py
CHANGED
@@ -58,7 +58,6 @@ class ModelWrapper:
|
|
58 |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
59 |
|
60 |
print(f"Loading model: {model_name}...")
|
61 |
-
# Use disk offloading for the large 8B model to handle memory constraints
|
62 |
if "8b" in model_name.lower():
|
63 |
config = AutoConfig.from_pretrained(model_name)
|
64 |
with init_empty_weights():
|
@@ -68,18 +67,15 @@ class ModelWrapper:
|
|
68 |
model_empty,
|
69 |
model_name,
|
70 |
device_map="auto",
|
71 |
-
offload_folder="offload",
|
72 |
torch_dtype=torch.bfloat16
|
73 |
).eval()
|
74 |
else:
|
75 |
-
# Load the smaller model directly
|
76 |
self.model = AutoModelForCausalLM.from_pretrained(
|
77 |
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
78 |
print(f"Model {model_name} loaded successfully.")
|
79 |
|
80 |
-
|
81 |
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
|
82 |
-
"""Compile sys, user, assistant inputs into the proper dictionaries"""
|
83 |
message = []
|
84 |
if system_content is not None:
|
85 |
message.append({'role': 'system', 'content': system_content})
|
@@ -92,26 +88,29 @@ class ModelWrapper:
|
|
92 |
return message
|
93 |
|
94 |
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
|
95 |
-
"""Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
|
96 |
if assistant_content is not None:
|
97 |
message = self.get_message_template(system_content, user_content, assistant_content)
|
98 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
99 |
else:
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
return prompt
|
112 |
|
113 |
-
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256,
|
114 |
-
|
115 |
print("Generating response...")
|
116 |
|
117 |
if "qwen3" in self.model_name.lower() and enable_thinking:
|
@@ -140,7 +139,7 @@ class ModelWrapper:
|
|
140 |
input_length = len(message)
|
141 |
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
|
142 |
|
143 |
-
# --- Model Cache
|
144 |
LOADED_MODELS = {}
|
145 |
|
146 |
def get_model(model_name):
|
@@ -148,75 +147,68 @@ def get_model(model_name):
|
|
148 |
LOADED_MODELS[model_name] = ModelWrapper(model_name)
|
149 |
return LOADED_MODELS[model_name]
|
150 |
|
151 |
-
#
|
152 |
def compliance_check(rules_text, transcript_text, thinking, model_name):
|
153 |
try:
|
154 |
-
# Get the selected model from our cache (or load it if it's the first time)
|
155 |
model = get_model(model_name)
|
156 |
-
|
157 |
rules = [r for r in rules_text.split("\n") if r.strip()]
|
158 |
inp = format_rules(rules) + format_transcript(transcript_text)
|
159 |
|
160 |
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
|
161 |
-
|
162 |
out = str(out).strip()
|
163 |
if not out:
|
164 |
out = "No response generated. Please try with different input."
|
165 |
|
166 |
-
max_bytes = 2500
|
167 |
out_bytes = out.encode('utf-8')
|
168 |
-
|
169 |
if len(out_bytes) > max_bytes:
|
170 |
truncated_bytes = out_bytes[:max_bytes]
|
171 |
out = truncated_bytes.decode('utf-8', errors='ignore')
|
172 |
out += "\n\n[Response truncated to prevent server errors]"
|
173 |
-
|
174 |
return out
|
175 |
-
|
176 |
except Exception as e:
|
177 |
-
error_msg = f"Error: {str(e)[:200]}"
|
178 |
print(f"Full error: {e}")
|
179 |
return error_msg
|
180 |
|
181 |
-
# --- Gradio
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
),
|
197 |
-
gr.Checkbox(label="Enable ⟨think⟩ mode", value=False),
|
198 |
-
gr.Dropdown(
|
199 |
["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
|
200 |
label="Select Model",
|
201 |
value="Qwen/Qwen3-0.6B",
|
202 |
info="The 8B model is more powerful but may be slower to load and run."
|
203 |
-
)
|
204 |
-
gr.
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
220 |
|
221 |
if __name__ == "__main__":
|
222 |
-
demo.launch()
|
|
|
58 |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
59 |
|
60 |
print(f"Loading model: {model_name}...")
|
|
|
61 |
if "8b" in model_name.lower():
|
62 |
config = AutoConfig.from_pretrained(model_name)
|
63 |
with init_empty_weights():
|
|
|
67 |
model_empty,
|
68 |
model_name,
|
69 |
device_map="auto",
|
70 |
+
offload_folder="offload",
|
71 |
torch_dtype=torch.bfloat16
|
72 |
).eval()
|
73 |
else:
|
|
|
74 |
self.model = AutoModelForCausalLM.from_pretrained(
|
75 |
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
76 |
print(f"Model {model_name} loaded successfully.")
|
77 |
|
|
|
78 |
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
|
|
|
79 |
message = []
|
80 |
if system_content is not None:
|
81 |
message.append({'role': 'system', 'content': system_content})
|
|
|
88 |
return message
|
89 |
|
90 |
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
|
|
|
91 |
if assistant_content is not None:
|
92 |
message = self.get_message_template(system_content, user_content, assistant_content)
|
93 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
94 |
else:
|
95 |
+
if enable_thinking:
|
96 |
+
if "qwen3" in self.model_name.lower():
|
97 |
+
message = self.get_message_template(system_content, user_content)
|
98 |
+
prompt = self.tokenizer.apply_chat_template(
|
99 |
+
message, tokenize=False, add_generation_prompt=True, enable_thinking=True
|
100 |
+
)
|
101 |
+
prompt = prompt + f"\n{COT_OPENING}"
|
102 |
+
else:
|
103 |
+
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
|
104 |
+
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
105 |
+
else:
|
106 |
+
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
|
107 |
+
prompt = self.tokenizer.apply_chat_template(
|
108 |
+
message, tokenize=False, continue_final_message=True, enable_thinking=False
|
109 |
+
)
|
110 |
return prompt
|
111 |
|
112 |
+
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256,
|
113 |
+
enable_thinking=True, system_prompt=SYSTEM_PROMPT):
|
114 |
print("Generating response...")
|
115 |
|
116 |
if "qwen3" in self.model_name.lower() and enable_thinking:
|
|
|
139 |
input_length = len(message)
|
140 |
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
|
141 |
|
142 |
+
# --- Model Cache ---
|
143 |
LOADED_MODELS = {}
|
144 |
|
145 |
def get_model(model_name):
|
|
|
147 |
LOADED_MODELS[model_name] = ModelWrapper(model_name)
|
148 |
return LOADED_MODELS[model_name]
|
149 |
|
150 |
+
# --- Inference Function ---
|
151 |
def compliance_check(rules_text, transcript_text, thinking, model_name):
|
152 |
try:
|
|
|
153 |
model = get_model(model_name)
|
|
|
154 |
rules = [r for r in rules_text.split("\n") if r.strip()]
|
155 |
inp = format_rules(rules) + format_transcript(transcript_text)
|
156 |
|
157 |
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
|
|
|
158 |
out = str(out).strip()
|
159 |
if not out:
|
160 |
out = "No response generated. Please try with different input."
|
161 |
|
162 |
+
max_bytes = 2500
|
163 |
out_bytes = out.encode('utf-8')
|
|
|
164 |
if len(out_bytes) > max_bytes:
|
165 |
truncated_bytes = out_bytes[:max_bytes]
|
166 |
out = truncated_bytes.decode('utf-8', errors='ignore')
|
167 |
out += "\n\n[Response truncated to prevent server errors]"
|
|
|
168 |
return out
|
|
|
169 |
except Exception as e:
|
170 |
+
error_msg = f"Error: {str(e)[:200]}"
|
171 |
print(f"Full error: {e}")
|
172 |
return error_msg
|
173 |
|
174 |
+
# --- Gradio UI with Tabs ---
|
175 |
+
with gr.Blocks(title="DynaGuard Compliance Checker") as demo:
|
176 |
+
with gr.Tab("Compliance Checker"):
|
177 |
+
rules_box = gr.Textbox(
|
178 |
+
lines=5,
|
179 |
+
label="Rules (one per line)",
|
180 |
+
value='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
|
181 |
+
)
|
182 |
+
transcript_box = gr.Textbox(
|
183 |
+
lines=10,
|
184 |
+
label="Transcript",
|
185 |
+
value='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
|
186 |
+
)
|
187 |
+
thinking_box = gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
|
188 |
+
model_dropdown = gr.Dropdown(
|
|
|
|
|
|
|
189 |
["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
|
190 |
label="Select Model",
|
191 |
value="Qwen/Qwen3-0.6B",
|
192 |
info="The 8B model is more powerful but may be slower to load and run."
|
193 |
+
)
|
194 |
+
submit_btn = gr.Button("Submit")
|
195 |
+
output_box = gr.Textbox(label="Compliance Output", lines=10, max_lines=15)
|
196 |
+
|
197 |
+
submit_btn.click(
|
198 |
+
compliance_check,
|
199 |
+
inputs=[rules_box, transcript_box, thinking_box, model_dropdown],
|
200 |
+
outputs=[output_box]
|
201 |
+
)
|
202 |
+
|
203 |
+
with gr.Tab("Feedback"):
|
204 |
+
gr.HTML(
|
205 |
+
"""
|
206 |
+
<iframe src="https://docs.google.com/forms/d/e/https://forms.gle/xoBTdFw4xFaWHeSG7/viewform?embedded=true"
|
207 |
+
width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0">
|
208 |
+
Loading…
|
209 |
+
</iframe>
|
210 |
+
"""
|
211 |
+
)
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
+
demo.launch()
|