Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ["GRADIO_ENABLE_SSR"] = "0" | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
from datasets import load_dataset | |
from huggingface_hub import login | |
# --- Hugging Face Login --- | |
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY") | |
login(token=HF_READONLY_API_KEY) | |
# --- Constants --- | |
COT_OPENING = "<think>" | |
EXPLANATION_OPENING = "<explanation>" | |
LABEL_OPENING = "<answer>" | |
LABEL_CLOSING = "</answer>" | |
INPUT_FIELD = "question" | |
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>""" | |
# --- Helper Functions --- | |
def format_rules(rules): | |
formatted_rules = "<rules>\n" | |
for i, rule in enumerate(rules): | |
formatted_rules += f"{i + 1}. {rule}\n" | |
formatted_rules += "</rules>\n" | |
return formatted_rules | |
def format_transcript(transcript): | |
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n" | |
return formatted_transcript | |
def get_example( | |
dataset_path="tomg-group-umd/compliance_benchmark", | |
subset="compliance", | |
split="test_handcrafted", | |
example_idx=0, | |
): | |
dataset = load_dataset(dataset_path, subset, split=split) | |
example = dataset[example_idx] | |
return example[INPUT_FIELD] | |
def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True): | |
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking) | |
return message | |
# --- Model Handling --- | |
class ModelWrapper: | |
def __init__(self, model_name): | |
self.model_name = model_name | |
print(f"Initializing tokenizer for {model_name}...") | |
if "nemoguard" in model_name: | |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id | |
print(f"Loading model: {model_name}...") | |
if "8b" in model_name.lower(): | |
config = AutoConfig.from_pretrained(model_name) | |
with init_empty_weights(): | |
model_empty = AutoModelForCausalLM.from_config(config) | |
self.model = load_checkpoint_and_dispatch( | |
model_empty, | |
model_name, | |
device_map="auto", | |
offload_folder="offload", | |
torch_dtype=torch.bfloat16 | |
).eval() | |
else: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval() | |
print(f"Model {model_name} loaded successfully.") | |
def get_message_template(self, system_content=None, user_content=None, assistant_content=None): | |
message = [] | |
if system_content is not None: | |
message.append({'role': 'system', 'content': system_content}) | |
if user_content is not None: | |
message.append({'role': 'user', 'content': user_content}) | |
if assistant_content is not None: | |
message.append({'role': 'assistant', 'content': assistant_content}) | |
if not message: | |
raise ValueError("No content provided for any role.") | |
return message | |
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True): | |
if assistant_content is not None: | |
message = self.get_message_template(system_content, user_content, assistant_content) | |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True) | |
else: | |
if enable_thinking: | |
if "qwen3" in self.model_name.lower(): | |
message = self.get_message_template(system_content, user_content) | |
prompt = self.tokenizer.apply_chat_template( | |
message, tokenize=False, add_generation_prompt=True, enable_thinking=True | |
) | |
prompt = prompt + f"\n{COT_OPENING}" | |
else: | |
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING) | |
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True) | |
else: | |
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING) | |
prompt = self.tokenizer.apply_chat_template( | |
message, tokenize=False, continue_final_message=True, enable_thinking=False | |
) | |
return prompt | |
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, | |
enable_thinking=True, system_prompt=SYSTEM_PROMPT): | |
print("Generating response...") | |
if "qwen3" in self.model_name.lower() and enable_thinking: | |
temperature = 0.6 | |
top_p = 0.95 | |
top_k = 20 | |
message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking) | |
inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device) | |
with torch.no_grad(): | |
output_content = self.model.generate( | |
**inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, | |
temperature=temperature, top_k=top_k, top_p=top_p, min_p=0, | |
pad_token_id=self.tokenizer.pad_token_id, do_sample=True, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True) | |
try: | |
remainder = output_text.split("Brief explanation\n</explanation>")[-1] | |
thinking_answer_text = remainder.split("</transcript>")[-1] | |
return thinking_answer_text | |
except: | |
input_length = len(message) | |
return output_text[input_length:] if len(output_text) > input_length else "No response generated." | |
# --- Model Cache --- | |
LOADED_MODELS = {} | |
def get_model(model_name): | |
if model_name not in LOADED_MODELS: | |
LOADED_MODELS[model_name] = ModelWrapper(model_name) | |
return LOADED_MODELS[model_name] | |
# --- Inference Function --- | |
def compliance_check(rules_text, transcript_text, thinking, model_name): | |
try: | |
model = get_model(model_name) | |
rules = [r for r in rules_text.split("\n") if r.strip()] | |
inp = format_rules(rules) + format_transcript(transcript_text) | |
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256) | |
out = str(out).strip() | |
if not out: | |
out = "No response generated. Please try with different input." | |
max_bytes = 2500 | |
out_bytes = out.encode('utf-8') | |
if len(out_bytes) > max_bytes: | |
truncated_bytes = out_bytes[:max_bytes] | |
out = truncated_bytes.decode('utf-8', errors='ignore') | |
out += "\n\n[Response truncated to prevent server errors]" | |
return out | |
except Exception as e: | |
error_msg = f"Error: {str(e)[:200]}" | |
print(f"Full error: {e}") | |
return error_msg | |
# --- Gradio UI with Tabs --- | |
with gr.Blocks(title="DynaGuard Compliance Checker") as demo: | |
with gr.Tab("Compliance Checker"): | |
rules_box = gr.Textbox( | |
lines=5, | |
label="Rules (one per line)", | |
value='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.' | |
) | |
transcript_box = gr.Textbox( | |
lines=10, | |
label="Transcript", | |
value='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?' | |
) | |
thinking_box = gr.Checkbox(label="Enable ⟨think⟩ mode", value=False) | |
model_dropdown = gr.Dropdown( | |
["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"], | |
label="Select Model", | |
value="Qwen/Qwen3-0.6B", | |
info="The 8B model is more powerful but may be slower to load and run." | |
) | |
submit_btn = gr.Button("Submit") | |
output_box = gr.Textbox(label="Compliance Output", lines=10, max_lines=15) | |
submit_btn.click( | |
compliance_check, | |
inputs=[rules_box, transcript_box, thinking_box, model_dropdown], | |
outputs=[output_box] | |
) | |
with gr.Tab("Feedback"): | |
gr.HTML( | |
""" | |
<iframe src="https://docs.google.com/forms/d/e/https://forms.gle/xoBTdFw4xFaWHeSG7/viewform?embedded=true" | |
width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0"> | |
Loading… | |
</iframe> | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |