DynaGuard / app.py
taruschirag's picture
Update app.py
2a43b25 verified
raw
history blame
9.27 kB
import os
os.environ["GRADIO_ENABLE_SSR"] = "0"
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from datasets import load_dataset
from huggingface_hub import login
# --- Hugging Face Login ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
login(token=HF_READONLY_API_KEY)
# --- Constants ---
COT_OPENING = "<think>"
EXPLANATION_OPENING = "<explanation>"
LABEL_OPENING = "<answer>"
LABEL_CLOSING = "</answer>"
INPUT_FIELD = "question"
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
# --- Helper Functions ---
def format_rules(rules):
formatted_rules = "<rules>\n"
for i, rule in enumerate(rules):
formatted_rules += f"{i + 1}. {rule}\n"
formatted_rules += "</rules>\n"
return formatted_rules
def format_transcript(transcript):
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
return formatted_transcript
def get_example(
dataset_path="tomg-group-umd/compliance_benchmark",
subset="compliance",
split="test_handcrafted",
example_idx=0,
):
dataset = load_dataset(dataset_path, subset, split=split)
example = dataset[example_idx]
return example[INPUT_FIELD]
def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
return message
# --- Model Handling ---
class ModelWrapper:
def __init__(self, model_name):
self.model_name = model_name
print(f"Initializing tokenizer for {model_name}...")
if "nemoguard" in model_name:
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
print(f"Loading model: {model_name}...")
if "8b" in model_name.lower():
config = AutoConfig.from_pretrained(model_name)
with init_empty_weights():
model_empty = AutoModelForCausalLM.from_config(config)
self.model = load_checkpoint_and_dispatch(
model_empty,
model_name,
device_map="auto",
offload_folder="offload",
torch_dtype=torch.bfloat16
).eval()
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
print(f"Model {model_name} loaded successfully.")
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
message = []
if system_content is not None:
message.append({'role': 'system', 'content': system_content})
if user_content is not None:
message.append({'role': 'user', 'content': user_content})
if assistant_content is not None:
message.append({'role': 'assistant', 'content': assistant_content})
if not message:
raise ValueError("No content provided for any role.")
return message
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
if assistant_content is not None:
message = self.get_message_template(system_content, user_content, assistant_content)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
else:
if enable_thinking:
if "qwen3" in self.model_name.lower():
message = self.get_message_template(system_content, user_content)
prompt = self.tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True, enable_thinking=True
)
prompt = prompt + f"\n{COT_OPENING}"
else:
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
else:
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
prompt = self.tokenizer.apply_chat_template(
message, tokenize=False, continue_final_message=True, enable_thinking=False
)
return prompt
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256,
enable_thinking=True, system_prompt=SYSTEM_PROMPT):
print("Generating response...")
if "qwen3" in self.model_name.lower() and enable_thinking:
temperature = 0.6
top_p = 0.95
top_k = 20
message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_content = self.model.generate(
**inputs, max_new_tokens=max_new_tokens, num_return_sequences=1,
temperature=temperature, top_k=top_k, top_p=top_p, min_p=0,
pad_token_id=self.tokenizer.pad_token_id, do_sample=True,
eos_token_id=self.tokenizer.eos_token_id
)
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
try:
remainder = output_text.split("Brief explanation\n</explanation>")[-1]
thinking_answer_text = remainder.split("</transcript>")[-1]
return thinking_answer_text
except:
input_length = len(message)
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
# --- Model Cache ---
LOADED_MODELS = {}
def get_model(model_name):
if model_name not in LOADED_MODELS:
LOADED_MODELS[model_name] = ModelWrapper(model_name)
return LOADED_MODELS[model_name]
# --- Inference Function ---
def compliance_check(rules_text, transcript_text, thinking, model_name):
try:
model = get_model(model_name)
rules = [r for r in rules_text.split("\n") if r.strip()]
inp = format_rules(rules) + format_transcript(transcript_text)
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
out = str(out).strip()
if not out:
out = "No response generated. Please try with different input."
max_bytes = 2500
out_bytes = out.encode('utf-8')
if len(out_bytes) > max_bytes:
truncated_bytes = out_bytes[:max_bytes]
out = truncated_bytes.decode('utf-8', errors='ignore')
out += "\n\n[Response truncated to prevent server errors]"
return out
except Exception as e:
error_msg = f"Error: {str(e)[:200]}"
print(f"Full error: {e}")
return error_msg
# --- Gradio UI with Tabs ---
with gr.Blocks(title="DynaGuard Compliance Checker") as demo:
with gr.Tab("Compliance Checker"):
rules_box = gr.Textbox(
lines=5,
label="Rules (one per line)",
value='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
)
transcript_box = gr.Textbox(
lines=10,
label="Transcript",
value='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
)
thinking_box = gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
model_dropdown = gr.Dropdown(
["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
label="Select Model",
value="Qwen/Qwen3-0.6B",
info="The 8B model is more powerful but may be slower to load and run."
)
submit_btn = gr.Button("Submit")
output_box = gr.Textbox(label="Compliance Output", lines=10, max_lines=15)
submit_btn.click(
compliance_check,
inputs=[rules_box, transcript_box, thinking_box, model_dropdown],
outputs=[output_box]
)
with gr.Tab("Feedback"):
gr.HTML(
"""
<iframe src="https://docs.google.com/forms/d/e/https://forms.gle/xoBTdFw4xFaWHeSG7/viewform?embedded=true"
width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0">
Loading…
</iframe>
"""
)
if __name__ == "__main__":
demo.launch()