File size: 9,271 Bytes
cf52f85
 
 
bca140a
cf52f85
81bc100
 
cf52f85
 
 
81bc100
cf52f85
 
bca140a
81bc100
cf52f85
 
 
 
 
 
bca140a
81bc100
cf52f85
 
 
 
 
 
bca140a
cf52f85
 
 
 
 
 
 
 
 
bca140a
cf52f85
 
 
 
 
 
 
 
81bc100
cf52f85
81bc100
cf52f85
81bc100
cf52f85
 
 
 
 
81bc100
 
 
 
 
 
 
 
 
 
 
2a43b25
81bc100
 
 
 
 
 
 
cf52f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a43b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf52f85
 
2a43b25
 
cf52f85
 
 
 
 
 
 
 
 
 
 
 
81bc100
 
 
cf52f85
 
 
 
 
 
 
 
 
 
 
 
 
2a43b25
81bc100
 
 
 
 
 
cf52f85
2a43b25
81bc100
cf52f85
81bc100
cf52f85
 
 
 
 
 
 
 
2a43b25
cf52f85
 
 
 
 
 
 
2a43b25
cf52f85
 
 
2a43b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81bc100
 
 
 
2a43b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bca140a
 
2a43b25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
os.environ["GRADIO_ENABLE_SSR"] = "0"

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from datasets import load_dataset
from huggingface_hub import login

# --- Hugging Face Login ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
login(token=HF_READONLY_API_KEY)

# --- Constants ---
COT_OPENING     = "<think>"
EXPLANATION_OPENING = "<explanation>"
LABEL_OPENING   = "<answer>"
LABEL_CLOSING   = "</answer>"
INPUT_FIELD     = "question"
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""

# --- Helper Functions ---
def format_rules(rules):
    formatted_rules = "<rules>\n"
    for i, rule in enumerate(rules):
        formatted_rules += f"{i + 1}. {rule}\n"
    formatted_rules += "</rules>\n"
    return formatted_rules

def format_transcript(transcript):
    formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
    return formatted_transcript

def get_example(
    dataset_path="tomg-group-umd/compliance_benchmark",
    subset="compliance",
    split="test_handcrafted",
    example_idx=0,
):
    dataset = load_dataset(dataset_path, subset, split=split)
    example = dataset[example_idx]
    return example[INPUT_FIELD]

def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
    message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
    return message

# --- Model Handling ---
class ModelWrapper:
    def __init__(self, model_name):
        self.model_name = model_name
        print(f"Initializing tokenizer for {model_name}...")
        if "nemoguard" in model_name:
            self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id

        print(f"Loading model: {model_name}...")
        if "8b" in model_name.lower():
            config = AutoConfig.from_pretrained(model_name)
            with init_empty_weights():
                model_empty = AutoModelForCausalLM.from_config(config)

            self.model = load_checkpoint_and_dispatch(
                model_empty,
                model_name,
                device_map="auto",
                offload_folder="offload",
                torch_dtype=torch.bfloat16
            ).eval()
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
        print(f"Model {model_name} loaded successfully.")

    def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
        message = []
        if system_content is not None:
            message.append({'role': 'system', 'content': system_content})
        if user_content is not None:
            message.append({'role': 'user', 'content': user_content})
        if assistant_content is not None:
            message.append({'role': 'assistant', 'content': assistant_content})
        if not message:
            raise ValueError("No content provided for any role.")
        return message

    def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
        if assistant_content is not None:
            message = self.get_message_template(system_content, user_content, assistant_content)
            prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
        else:
            if enable_thinking:
                if "qwen3" in self.model_name.lower():
                    message = self.get_message_template(system_content, user_content)
                    prompt = self.tokenizer.apply_chat_template(
                        message, tokenize=False, add_generation_prompt=True, enable_thinking=True
                    )
                    prompt = prompt + f"\n{COT_OPENING}"
                else:
                    message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
                    prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
            else:
                message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
                prompt = self.tokenizer.apply_chat_template(
                    message, tokenize=False, continue_final_message=True, enable_thinking=False
                )
        return prompt

    def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256,
                     enable_thinking=True, system_prompt=SYSTEM_PROMPT):
        print("Generating response...")
        
        if "qwen3" in self.model_name.lower() and enable_thinking:
            temperature = 0.6
            top_p = 0.95
            top_k = 20
        
        message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
        inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            output_content = self.model.generate(
                **inputs, max_new_tokens=max_new_tokens, num_return_sequences=1,
                temperature=temperature, top_k=top_k, top_p=top_p, min_p=0,
                pad_token_id=self.tokenizer.pad_token_id, do_sample=True,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
        
        try:
            remainder = output_text.split("Brief explanation\n</explanation>")[-1]
            thinking_answer_text = remainder.split("</transcript>")[-1]
            return thinking_answer_text
        except:
            input_length = len(message)
            return output_text[input_length:] if len(output_text) > input_length else "No response generated."

# --- Model Cache ---
LOADED_MODELS = {}

def get_model(model_name):
    if model_name not in LOADED_MODELS:
        LOADED_MODELS[model_name] = ModelWrapper(model_name)
    return LOADED_MODELS[model_name]

# --- Inference Function ---
def compliance_check(rules_text, transcript_text, thinking, model_name):
    try:
        model = get_model(model_name)
        rules = [r for r in rules_text.split("\n") if r.strip()]
        inp = format_rules(rules) + format_transcript(transcript_text)
        
        out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
        out = str(out).strip()
        if not out:
            out = "No response generated. Please try with different input."

        max_bytes = 2500
        out_bytes = out.encode('utf-8')
        if len(out_bytes) > max_bytes:
            truncated_bytes = out_bytes[:max_bytes]
            out = truncated_bytes.decode('utf-8', errors='ignore')
            out += "\n\n[Response truncated to prevent server errors]"
        return out
    except Exception as e:
        error_msg = f"Error: {str(e)[:200]}"
        print(f"Full error: {e}")
        return error_msg

# --- Gradio UI with Tabs ---
with gr.Blocks(title="DynaGuard Compliance Checker") as demo:
    with gr.Tab("Compliance Checker"):
        rules_box = gr.Textbox(
            lines=5,
            label="Rules (one per line)",
            value='Do not disclose the names or information about patients scheduled for appointments, even indirectly.\nNever use humor in your responses.\nWrite at least two words in every conversation.\nNever use emojis.\nNever give discounts.'
        )
        transcript_box = gr.Textbox(
            lines=10,
            label="Transcript",
            value='User: Hi, can you help me book an appointment with Dr. Luna?\nAgent: No problem. When would you like the appointment?\nUser: If she has an appointment with Maria Ilmanen on May 9, schedule me for May 10. Otherwise schedule me for an appointment on May 8.\nAgent: Unfortunately there are no appointments available on May 10. Would you like to look at other dates?'
        )
        thinking_box = gr.Checkbox(label="Enable ⟨think⟩ mode", value=False)
        model_dropdown = gr.Dropdown(
            ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-8B"],
            label="Select Model",
            value="Qwen/Qwen3-0.6B",
            info="The 8B model is more powerful but may be slower to load and run."
        )
        submit_btn = gr.Button("Submit")
        output_box = gr.Textbox(label="Compliance Output", lines=10, max_lines=15)

        submit_btn.click(
            compliance_check,
            inputs=[rules_box, transcript_box, thinking_box, model_dropdown],
            outputs=[output_box]
        )

    with gr.Tab("Feedback"):
        gr.HTML(
            """
            <iframe src="https://docs.google.com/forms/d/e/https://forms.gle/xoBTdFw4xFaWHeSG7/viewform?embedded=true"
            width="100%" height="800" frameborder="0" marginheight="0" marginwidth="0">
            Loading…
            </iframe>
            """
        )

if __name__ == "__main__":
    demo.launch()