Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse filesadded dropdown box
app.py
CHANGED
@@ -3,13 +3,16 @@ os.environ["GRADIO_ENABLE_SSR"] = "0"
|
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
7 |
from datasets import load_dataset
|
8 |
from huggingface_hub import login
|
9 |
|
|
|
10 |
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
|
11 |
login(token=HF_READONLY_API_KEY)
|
12 |
|
|
|
13 |
COT_OPENING = "<think>"
|
14 |
EXPLANATION_OPENING = "<explanation>"
|
15 |
LABEL_OPENING = "<answer>"
|
@@ -17,6 +20,7 @@ LABEL_CLOSING = "</answer>"
|
|
17 |
INPUT_FIELD = "question"
|
18 |
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
|
19 |
|
|
|
20 |
def format_rules(rules):
|
21 |
formatted_rules = "<rules>\n"
|
22 |
for i, rule in enumerate(rules):
|
@@ -42,16 +46,37 @@ def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True)
|
|
42 |
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
|
43 |
return message
|
44 |
|
|
|
45 |
class ModelWrapper:
|
46 |
-
def __init__(self, model_name
|
47 |
self.model_name = model_name
|
|
|
48 |
if "nemoguard" in model_name:
|
49 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
50 |
else:
|
51 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
52 |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
|
57 |
"""Compile sys, user, assistant inputs into the proper dictionaries"""
|
@@ -69,34 +94,27 @@ class ModelWrapper:
|
|
69 |
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
|
70 |
"""Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
|
71 |
if assistant_content is not None:
|
72 |
-
# If assistant content is passed we simply use it.
|
73 |
-
# This works for both Qwen3 and non-Qwen3 models. With Qwen3 any time assistant_content is provided, it automatically adds the <think></think> pair before the content, which is what we want.
|
74 |
message = self.get_message_template(system_content, user_content, assistant_content)
|
75 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
76 |
else:
|
77 |
if enable_thinking:
|
78 |
if "qwen3" in self.model_name.lower():
|
79 |
-
# Let the Qwen chat template handle the thinking token
|
80 |
message = self.get_message_template(system_content, user_content)
|
81 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
|
82 |
-
# The way the Qwen3 chat template works is it adds a <think></think> pair when enable_thinking=False, but for enable_thinking=True, it adds nothing and lets the model decide. Here we force the <think> tag to be there.
|
83 |
prompt = prompt + f"\n{COT_OPENING}"
|
84 |
else:
|
85 |
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
|
86 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
87 |
else:
|
88 |
-
# This works for both Qwen3 and non-Qwen3 models.
|
89 |
-
# When Qwen3 gets assistant_content, it automatically adds the <think></think> pair before the content like we want. And other models ignore the enable_thinking argument.
|
90 |
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
|
91 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
|
92 |
return prompt
|
93 |
|
94 |
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
|
95 |
-
"""Generate and decode the response
|
96 |
print("Generating response...")
|
97 |
|
98 |
if "qwen3" in self.model_name.lower() and enable_thinking:
|
99 |
-
# Use values from https://huggingface.co/Qwen/Qwen3-8B#switching-between-thinking-and-non-thinking-mode
|
100 |
temperature = 0.6
|
101 |
top_p = 0.95
|
102 |
top_k = 20
|
@@ -106,36 +124,36 @@ class ModelWrapper:
|
|
106 |
|
107 |
with torch.no_grad():
|
108 |
output_content = self.model.generate(
|
109 |
-
**inputs,
|
110 |
-
|
111 |
-
|
112 |
-
temperature=temperature,
|
113 |
-
top_k=top_k,
|
114 |
-
top_p=top_p,
|
115 |
-
min_p=0,
|
116 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
117 |
-
do_sample=True,
|
118 |
eos_token_id=self.tokenizer.eos_token_id
|
119 |
)
|
120 |
|
121 |
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
|
122 |
|
123 |
try:
|
124 |
-
sys_prompt_text = output_text.split("Brief explanation\n</explanation>")[0]
|
125 |
remainder = output_text.split("Brief explanation\n</explanation>")[-1]
|
126 |
-
rules_transcript_text = remainder.split("</transcript>")[0]
|
127 |
thinking_answer_text = remainder.split("</transcript>")[-1]
|
128 |
return thinking_answer_text
|
129 |
except:
|
130 |
input_length = len(message)
|
131 |
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
# — Gradio
|
137 |
-
def compliance_check(rules_text, transcript_text, thinking):
|
138 |
try:
|
|
|
|
|
|
|
139 |
rules = [r for r in rules_text.split("\n") if r.strip()]
|
140 |
inp = format_rules(rules) + format_transcript(transcript_text)
|
141 |
|
@@ -149,7 +167,6 @@ def compliance_check(rules_text, transcript_text, thinking):
|
|
149 |
out_bytes = out.encode('utf-8')
|
150 |
|
151 |
if len(out_bytes) > max_bytes:
|
152 |
-
|
153 |
truncated_bytes = out_bytes[:max_bytes]
|
154 |
out = truncated_bytes.decode('utf-8', errors='ignore')
|
155 |
out += "\n\n[Response truncated to prevent server errors]"
|
@@ -161,7 +178,7 @@ def compliance_check(rules_text, transcript_text, thinking):
|
|
161 |
print(f"Full error: {e}")
|
162 |
return error_msg
|
163 |
|
164 |
-
|
165 |
demo = gr.Interface(
|
166 |
fn=compliance_check,
|
167 |
inputs=[
|
@@ -177,11 +194,17 @@ demo = gr.Interface(
|
|
177 |
max_lines=15,
|
178 |
placeholder='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
|
179 |
),
|
180 |
-
gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
],
|
182 |
outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
|
183 |
title="DynaGuard Compliance Checker",
|
184 |
-
description="
|
185 |
allow_flagging="never",
|
186 |
show_progress=True
|
187 |
)
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
7 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
8 |
from datasets import load_dataset
|
9 |
from huggingface_hub import login
|
10 |
|
11 |
+
# --- Hugging Face Login ---
|
12 |
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
|
13 |
login(token=HF_READONLY_API_KEY)
|
14 |
|
15 |
+
# --- Constants ---
|
16 |
COT_OPENING = "<think>"
|
17 |
EXPLANATION_OPENING = "<explanation>"
|
18 |
LABEL_OPENING = "<answer>"
|
|
|
20 |
INPUT_FIELD = "question"
|
21 |
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
|
22 |
|
23 |
+
# --- Helper Functions ---
|
24 |
def format_rules(rules):
|
25 |
formatted_rules = "<rules>\n"
|
26 |
for i, rule in enumerate(rules):
|
|
|
46 |
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
|
47 |
return message
|
48 |
|
49 |
+
# --- Model Handling ---
|
50 |
class ModelWrapper:
|
51 |
+
def __init__(self, model_name):
|
52 |
self.model_name = model_name
|
53 |
+
print(f"Initializing tokenizer for {model_name}...")
|
54 |
if "nemoguard" in model_name:
|
55 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
56 |
else:
|
57 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
58 |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
59 |
+
|
60 |
+
print(f"Loading model: {model_name}...")
|
61 |
+
# Use disk offloading for the large 8B model to handle memory constraints
|
62 |
+
if "8b" in model_name.lower():
|
63 |
+
config = AutoConfig.from_pretrained(model_name)
|
64 |
+
with init_empty_weights():
|
65 |
+
model_empty = AutoModelForCausalLM.from_config(config)
|
66 |
+
|
67 |
+
self.model = load_checkpoint_and_dispatch(
|
68 |
+
model_empty,
|
69 |
+
model_name,
|
70 |
+
device_map="auto",
|
71 |
+
offload_folder="offload", # A directory to store the offloaded layers
|
72 |
+
torch_dtype=torch.bfloat16
|
73 |
+
).eval()
|
74 |
+
else:
|
75 |
+
# Load the smaller model directly
|
76 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
77 |
+
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
78 |
+
print(f"Model {model_name} loaded successfully.")
|
79 |
+
|
80 |
|
81 |
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
|
82 |
"""Compile sys, user, assistant inputs into the proper dictionaries"""
|
|
|
94 |
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
|
95 |
"""Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
|
96 |
if assistant_content is not None:
|
|
|
|
|
97 |
message = self.get_message_template(system_content, user_content, assistant_content)
|
98 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
99 |
else:
|
100 |
if enable_thinking:
|
101 |
if "qwen3" in self.model_name.lower():
|
|
|
102 |
message = self.get_message_template(system_content, user_content)
|
103 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
|
|
|
104 |
prompt = prompt + f"\n{COT_OPENING}"
|
105 |
else:
|
106 |
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
|
107 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
108 |
else:
|
|
|
|
|
109 |
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
|
110 |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
|
111 |
return prompt
|
112 |
|
113 |
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
|
114 |
+
"""Generate and decode the response."""
|
115 |
print("Generating response...")
|
116 |
|
117 |
if "qwen3" in self.model_name.lower() and enable_thinking:
|
|
|
118 |
temperature = 0.6
|
119 |
top_p = 0.95
|
120 |
top_k = 20
|
|
|
124 |
|
125 |
with torch.no_grad():
|
126 |
output_content = self.model.generate(
|
127 |
+
**inputs, max_new_tokens=max_new_tokens, num_return_sequences=1,
|
128 |
+
temperature=temperature, top_k=top_k, top_p=top_p, min_p=0,
|
129 |
+
pad_token_id=self.tokenizer.pad_token_id, do_sample=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
eos_token_id=self.tokenizer.eos_token_id
|
131 |
)
|
132 |
|
133 |
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
|
134 |
|
135 |
try:
|
|
|
136 |
remainder = output_text.split("Brief explanation\n</explanation>")[-1]
|
|
|
137 |
thinking_answer_text = remainder.split("</transcript>")[-1]
|
138 |
return thinking_answer_text
|
139 |
except:
|
140 |
input_length = len(message)
|
141 |
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
|
142 |
|
143 |
+
# --- Model Cache to prevent reloading on every call ---
|
144 |
+
LOADED_MODELS = {}
|
145 |
+
|
146 |
+
def get_model(model_name):
|
147 |
+
if model_name not in LOADED_MODELS:
|
148 |
+
LOADED_MODELS[model_name] = ModelWrapper(model_name)
|
149 |
+
return LOADED_MODELS[model_name]
|
150 |
|
151 |
+
# — Gradio Inference Function —
|
152 |
+
def compliance_check(rules_text, transcript_text, thinking, model_name):
|
153 |
try:
|
154 |
+
# Get the selected model from our cache (or load it if it's the first time)
|
155 |
+
model = get_model(model_name)
|
156 |
+
|
157 |
rules = [r for r in rules_text.split("\n") if r.strip()]
|
158 |
inp = format_rules(rules) + format_transcript(transcript_text)
|
159 |
|
|
|
167 |
out_bytes = out.encode('utf-8')
|
168 |
|
169 |
if len(out_bytes) > max_bytes:
|
|
|
170 |
truncated_bytes = out_bytes[:max_bytes]
|
171 |
out = truncated_bytes.decode('utf-8', errors='ignore')
|
172 |
out += "\n\n[Response truncated to prevent server errors]"
|
|
|
178 |
print(f"Full error: {e}")
|
179 |
return error_msg
|
180 |
|
181 |
+
# --- Gradio Interface Definition ---
|
182 |
demo = gr.Interface(
|
183 |
fn=compliance_check,
|
184 |
inputs=[
|
|
|
194 |
max_lines=15,
|
195 |
placeholder='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
|
196 |
),
|
197 |
+
gr.Checkbox(label="Enable ⟨think⟩ mode", value=False),
|
198 |
+
gr.Dropdown(
|
199 |
+
["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
|
200 |
+
label="Select Model",
|
201 |
+
value="Qwen/Qwen3-0.6B",
|
202 |
+
info="The 8B model is more powerful but may be slower to load and run."
|
203 |
+
)
|
204 |
],
|
205 |
outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
|
206 |
title="DynaGuard Compliance Checker",
|
207 |
+
description="Select a model, paste your rules & transcript, then hit Submit.",
|
208 |
allow_flagging="never",
|
209 |
show_progress=True
|
210 |
)
|