taruschirag commited on
Commit
81bc100
·
verified ·
1 Parent(s): cc0c804

Update app.py

Browse files

added dropdown box

Files changed (1) hide show
  1. app.py +54 -31
app.py CHANGED
@@ -3,13 +3,16 @@ os.environ["GRADIO_ENABLE_SSR"] = "0"
3
 
4
  import gradio as gr
5
  import torch
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
7
  from datasets import load_dataset
8
  from huggingface_hub import login
9
 
 
10
  HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
11
  login(token=HF_READONLY_API_KEY)
12
 
 
13
  COT_OPENING = "<think>"
14
  EXPLANATION_OPENING = "<explanation>"
15
  LABEL_OPENING = "<answer>"
@@ -17,6 +20,7 @@ LABEL_CLOSING = "</answer>"
17
  INPUT_FIELD = "question"
18
  SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
19
 
 
20
  def format_rules(rules):
21
  formatted_rules = "<rules>\n"
22
  for i, rule in enumerate(rules):
@@ -42,16 +46,37 @@ def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True)
42
  message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
43
  return message
44
 
 
45
  class ModelWrapper:
46
- def __init__(self, model_name="Qwen/Qwen3-0.6B"):
47
  self.model_name = model_name
 
48
  if "nemoguard" in model_name:
49
  self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
50
  else:
51
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
53
- self.model = AutoModelForCausalLM.from_pretrained(
54
- model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
57
  """Compile sys, user, assistant inputs into the proper dictionaries"""
@@ -69,34 +94,27 @@ class ModelWrapper:
69
  def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
70
  """Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
71
  if assistant_content is not None:
72
- # If assistant content is passed we simply use it.
73
- # This works for both Qwen3 and non-Qwen3 models. With Qwen3 any time assistant_content is provided, it automatically adds the <think></think> pair before the content, which is what we want.
74
  message = self.get_message_template(system_content, user_content, assistant_content)
75
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
76
  else:
77
  if enable_thinking:
78
  if "qwen3" in self.model_name.lower():
79
- # Let the Qwen chat template handle the thinking token
80
  message = self.get_message_template(system_content, user_content)
81
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
82
- # The way the Qwen3 chat template works is it adds a <think></think> pair when enable_thinking=False, but for enable_thinking=True, it adds nothing and lets the model decide. Here we force the <think> tag to be there.
83
  prompt = prompt + f"\n{COT_OPENING}"
84
  else:
85
  message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
86
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
87
  else:
88
- # This works for both Qwen3 and non-Qwen3 models.
89
- # When Qwen3 gets assistant_content, it automatically adds the <think></think> pair before the content like we want. And other models ignore the enable_thinking argument.
90
  message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
91
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
92
  return prompt
93
 
94
  def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
95
- """Generate and decode the response with the recommended temperature settings for thinking and non-thinking."""
96
  print("Generating response...")
97
 
98
  if "qwen3" in self.model_name.lower() and enable_thinking:
99
- # Use values from https://huggingface.co/Qwen/Qwen3-8B#switching-between-thinking-and-non-thinking-mode
100
  temperature = 0.6
101
  top_p = 0.95
102
  top_k = 20
@@ -106,36 +124,36 @@ class ModelWrapper:
106
 
107
  with torch.no_grad():
108
  output_content = self.model.generate(
109
- **inputs,
110
- max_new_tokens=max_new_tokens,
111
- num_return_sequences=1,
112
- temperature=temperature,
113
- top_k=top_k,
114
- top_p=top_p,
115
- min_p=0,
116
- pad_token_id=self.tokenizer.pad_token_id,
117
- do_sample=True,
118
  eos_token_id=self.tokenizer.eos_token_id
119
  )
120
 
121
  output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
122
 
123
  try:
124
- sys_prompt_text = output_text.split("Brief explanation\n</explanation>")[0]
125
  remainder = output_text.split("Brief explanation\n</explanation>")[-1]
126
- rules_transcript_text = remainder.split("</transcript>")[0]
127
  thinking_answer_text = remainder.split("</transcript>")[-1]
128
  return thinking_answer_text
129
  except:
130
  input_length = len(message)
131
  return output_text[input_length:] if len(output_text) > input_length else "No response generated."
132
 
133
- MODEL_NAME = "Qwen/Qwen3-0.6B"
134
- model = ModelWrapper(MODEL_NAME)
 
 
 
 
 
135
 
136
- # — Gradio inference function
137
- def compliance_check(rules_text, transcript_text, thinking):
138
  try:
 
 
 
139
  rules = [r for r in rules_text.split("\n") if r.strip()]
140
  inp = format_rules(rules) + format_transcript(transcript_text)
141
 
@@ -149,7 +167,6 @@ def compliance_check(rules_text, transcript_text, thinking):
149
  out_bytes = out.encode('utf-8')
150
 
151
  if len(out_bytes) > max_bytes:
152
-
153
  truncated_bytes = out_bytes[:max_bytes]
154
  out = truncated_bytes.decode('utf-8', errors='ignore')
155
  out += "\n\n[Response truncated to prevent server errors]"
@@ -161,7 +178,7 @@ def compliance_check(rules_text, transcript_text, thinking):
161
  print(f"Full error: {e}")
162
  return error_msg
163
 
164
-
165
  demo = gr.Interface(
166
  fn=compliance_check,
167
  inputs=[
@@ -177,11 +194,17 @@ demo = gr.Interface(
177
  max_lines=15,
178
  placeholder='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
179
  ),
180
- gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
 
 
 
 
 
 
181
  ],
182
  outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
183
  title="DynaGuard Compliance Checker",
184
- description="Paste your rules & transcript, then hit Submit.",
185
  allow_flagging="never",
186
  show_progress=True
187
  )
 
3
 
4
  import gradio as gr
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
7
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
8
  from datasets import load_dataset
9
  from huggingface_hub import login
10
 
11
+ # --- Hugging Face Login ---
12
  HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
13
  login(token=HF_READONLY_API_KEY)
14
 
15
+ # --- Constants ---
16
  COT_OPENING = "<think>"
17
  EXPLANATION_OPENING = "<explanation>"
18
  LABEL_OPENING = "<answer>"
 
20
  INPUT_FIELD = "question"
21
  SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
22
 
23
+ # --- Helper Functions ---
24
  def format_rules(rules):
25
  formatted_rules = "<rules>\n"
26
  for i, rule in enumerate(rules):
 
46
  message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
47
  return message
48
 
49
+ # --- Model Handling ---
50
  class ModelWrapper:
51
+ def __init__(self, model_name):
52
  self.model_name = model_name
53
+ print(f"Initializing tokenizer for {model_name}...")
54
  if "nemoguard" in model_name:
55
  self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
56
  else:
57
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
59
+
60
+ print(f"Loading model: {model_name}...")
61
+ # Use disk offloading for the large 8B model to handle memory constraints
62
+ if "8b" in model_name.lower():
63
+ config = AutoConfig.from_pretrained(model_name)
64
+ with init_empty_weights():
65
+ model_empty = AutoModelForCausalLM.from_config(config)
66
+
67
+ self.model = load_checkpoint_and_dispatch(
68
+ model_empty,
69
+ model_name,
70
+ device_map="auto",
71
+ offload_folder="offload", # A directory to store the offloaded layers
72
+ torch_dtype=torch.bfloat16
73
+ ).eval()
74
+ else:
75
+ # Load the smaller model directly
76
+ self.model = AutoModelForCausalLM.from_pretrained(
77
+ model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
78
+ print(f"Model {model_name} loaded successfully.")
79
+
80
 
81
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
82
  """Compile sys, user, assistant inputs into the proper dictionaries"""
 
94
  def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
95
  """Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
96
  if assistant_content is not None:
 
 
97
  message = self.get_message_template(system_content, user_content, assistant_content)
98
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
99
  else:
100
  if enable_thinking:
101
  if "qwen3" in self.model_name.lower():
 
102
  message = self.get_message_template(system_content, user_content)
103
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
 
104
  prompt = prompt + f"\n{COT_OPENING}"
105
  else:
106
  message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
107
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
108
  else:
 
 
109
  message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
110
  prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
111
  return prompt
112
 
113
  def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
114
+ """Generate and decode the response."""
115
  print("Generating response...")
116
 
117
  if "qwen3" in self.model_name.lower() and enable_thinking:
 
118
  temperature = 0.6
119
  top_p = 0.95
120
  top_k = 20
 
124
 
125
  with torch.no_grad():
126
  output_content = self.model.generate(
127
+ **inputs, max_new_tokens=max_new_tokens, num_return_sequences=1,
128
+ temperature=temperature, top_k=top_k, top_p=top_p, min_p=0,
129
+ pad_token_id=self.tokenizer.pad_token_id, do_sample=True,
 
 
 
 
 
 
130
  eos_token_id=self.tokenizer.eos_token_id
131
  )
132
 
133
  output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
134
 
135
  try:
 
136
  remainder = output_text.split("Brief explanation\n</explanation>")[-1]
 
137
  thinking_answer_text = remainder.split("</transcript>")[-1]
138
  return thinking_answer_text
139
  except:
140
  input_length = len(message)
141
  return output_text[input_length:] if len(output_text) > input_length else "No response generated."
142
 
143
+ # --- Model Cache to prevent reloading on every call ---
144
+ LOADED_MODELS = {}
145
+
146
+ def get_model(model_name):
147
+ if model_name not in LOADED_MODELS:
148
+ LOADED_MODELS[model_name] = ModelWrapper(model_name)
149
+ return LOADED_MODELS[model_name]
150
 
151
+ # — Gradio Inference Function
152
+ def compliance_check(rules_text, transcript_text, thinking, model_name):
153
  try:
154
+ # Get the selected model from our cache (or load it if it's the first time)
155
+ model = get_model(model_name)
156
+
157
  rules = [r for r in rules_text.split("\n") if r.strip()]
158
  inp = format_rules(rules) + format_transcript(transcript_text)
159
 
 
167
  out_bytes = out.encode('utf-8')
168
 
169
  if len(out_bytes) > max_bytes:
 
170
  truncated_bytes = out_bytes[:max_bytes]
171
  out = truncated_bytes.decode('utf-8', errors='ignore')
172
  out += "\n\n[Response truncated to prevent server errors]"
 
178
  print(f"Full error: {e}")
179
  return error_msg
180
 
181
+ # --- Gradio Interface Definition ---
182
  demo = gr.Interface(
183
  fn=compliance_check,
184
  inputs=[
 
194
  max_lines=15,
195
  placeholder='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
196
  ),
197
+ gr.Checkbox(label="Enable ⟨think⟩ mode", value=False),
198
+ gr.Dropdown(
199
+ ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
200
+ label="Select Model",
201
+ value="Qwen/Qwen3-0.6B",
202
+ info="The 8B model is more powerful but may be slower to load and run."
203
+ )
204
  ],
205
  outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
206
  title="DynaGuard Compliance Checker",
207
+ description="Select a model, paste your rules & transcript, then hit Submit.",
208
  allow_flagging="never",
209
  show_progress=True
210
  )