Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,271 Bytes
cf52f85 bca140a cf52f85 81bc100 cf52f85 81bc100 cf52f85 bca140a 81bc100 cf52f85 bca140a 81bc100 cf52f85 bca140a cf52f85 bca140a cf52f85 81bc100 cf52f85 81bc100 cf52f85 81bc100 cf52f85 81bc100 2a43b25 81bc100 cf52f85 2a43b25 cf52f85 2a43b25 cf52f85 81bc100 cf52f85 2a43b25 81bc100 cf52f85 2a43b25 81bc100 cf52f85 81bc100 cf52f85 2a43b25 cf52f85 2a43b25 cf52f85 2a43b25 81bc100 2a43b25 bca140a 2a43b25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import os
os.environ["GRADIO_ENABLE_SSR"] = "0"
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from datasets import load_dataset
from huggingface_hub import login
# --- Hugging Face Login ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
login(token=HF_READONLY_API_KEY)
# --- Constants ---
COT_OPENING = "<think>"
EXPLANATION_OPENING = "<explanation>"
LABEL_OPENING = "<answer>"
LABEL_CLOSING = "</answer>"
INPUT_FIELD = "question"
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
# --- Helper Functions ---
def format_rules(rules):
formatted_rules = "<rules>\n"
for i, rule in enumerate(rules):
formatted_rules += f"{i + 1}. {rule}\n"
formatted_rules += "</rules>\n"
return formatted_rules
def format_transcript(transcript):
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
return formatted_transcript
def get_example(
dataset_path="tomg-group-umd/compliance_benchmark",
subset="compliance",
split="test_handcrafted",
example_idx=0,
):
dataset = load_dataset(dataset_path, subset, split=split)
example = dataset[example_idx]
return example[INPUT_FIELD]
def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
return message
# --- Model Handling ---
class ModelWrapper:
def __init__(self, model_name):
self.model_name = model_name
print(f"Initializing tokenizer for {model_name}...")
if "nemoguard" in model_name:
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
print(f"Loading model: {model_name}...")
if "8b" in model_name.lower():
config = AutoConfig.from_pretrained(model_name)
with init_empty_weights():
model_empty = AutoModelForCausalLM.from_config(config)
self.model = load_checkpoint_and_dispatch(
model_empty,
model_name,
device_map="auto",
offload_folder="offload",
torch_dtype=torch.bfloat16
).eval()
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
print(f"Model {model_name} loaded successfully.")
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
message = []
if system_content is not None:
message.append({'role': 'system', 'content': system_content})
if user_content is not None:
message.append({'role': 'user', 'content': user_content})
if assistant_content is not None:
message.append({'role': 'assistant', 'content': assistant_content})
if not message:
raise ValueError("No content provided for any role.")
return message
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
if assistant_content is not None:
message = self.get_message_template(system_content, user_content, assistant_content)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
else:
if enable_thinking:
if "qwen3" in self.model_name.lower():
message = self.get_message_template(system_content, user_content)
prompt = self.tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True, enable_thinking=True
)
prompt = prompt + f"\n{COT_OPENING}"
else:
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
else:
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
prompt = self.tokenizer.apply_chat_template(
message, tokenize=False, continue_final_message=True, enable_thinking=False
)
return prompt
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256,
enable_thinking=True, system_prompt=SYSTEM_PROMPT):
print("Generating response...")
if "qwen3" in self.model_name.lower() and enable_thinking:
temperature = 0.6
top_p = 0.95
top_k = 20
message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_content = self.model.generate(
**inputs, max_new_tokens=max_new_tokens, num_return_sequences=1,
temperature=temperature, top_k=top_k, top_p=top_p, min_p=0,
pad_token_id=self.tokenizer.pad_token_id, do_sample=True,
eos_token_id=self.tokenizer.eos_token_id
)
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
try:
remainder = output_text.split("Brief explanation\n</explanation>")[-1]
thinking_answer_text = remainder.split("</transcript>")[-1]
return thinking_answer_text
except:
input_length = len(message)
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
# --- Model Cache ---
LOADED_MODELS = {}
def get_model(model_name):
if model_name not in LOADED_MODELS:
LOADED_MODELS[model_name] = ModelWrapper(model_name)
return LOADED_MODELS[model_name]
# --- Inference Function ---
def compliance_check(rules_text, transcript_text, thinking, model_name):
try:
model = get_model(model_name)
rules = [r for r in rules_text.split("\n") if r.strip()]
inp = format_rules(rules) + format_transcript(transcript_text)
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
out = str(out).strip()
if not out:
out = "No response generated. Please try with different input."
max_bytes = 2500
out_bytes = out.encode('utf-8')
if len(out_bytes) > max_bytes:
truncated_bytes = out_bytes[:max_bytes]
out = truncated_bytes.decode('utf-8', errors='ignore')
out += "\n\n[Response truncated to prevent server errors]"
return out
except Exception as e:
error_msg = f"Error: {str(e)[:200]}"
print(f"Full error: {e}")
return error_msg
# --- Gradio UI with Tabs ---
with gr.Blocks(title="DynaGuard Compliance Checker") as demo:
with gr.Tab("Compliance Checker"):
rules_box = gr.Textbox(
lines=5,
label="Rules (one per line)",
value='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
)
transcript_box = gr.Textbox(
lines=10,
label="Transcript",
value='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
)
thinking_box = gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
model_dropdown = gr.Dropdown(
["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
label="Select Model",
value="Qwen/Qwen3-0.6B",
info="The 8B model is more powerful but may be slower to load and run."
)
submit_btn = gr.Button("Submit")
output_box = gr.Textbox(label="Compliance Output", lines=10, max_lines=15)
submit_btn.click(
compliance_check,
inputs=[rules_box, transcript_box, thinking_box, model_dropdown],
outputs=[output_box]
)
with gr.Tab("Feedback"):
gr.HTML(
"""
<iframe src="https://docs.google.com/forms/d/e/https://forms.gle/xoBTdFw4xFaWHeSG7/viewform?embedded=true"
width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0">
Loading…
</iframe>
"""
)
if __name__ == "__main__":
demo.launch()
|