Upload 13 files
Browse files- agent_tuner.py +459 -0
- calibrators.py +299 -0
- demo_app.py +220 -0
- deploy.py +354 -0
- domain_datasets.py +364 -0
- evaluators.py +448 -0
- example_config.json +63 -0
- llm_interface.py +178 -0
- main.py +265 -0
- negative_samples.py +379 -0
- quantifiers.py +336 -0
- synthetic_trajectories.py +302 -0
- trajectory_data.py +433 -0
agent_tuner.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Agent Tuning Module for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This module provides functionality for efficiently tuning large language models
|
5 |
+
into specialized agents using a combination of positive examples, negative examples,
|
6 |
+
and synthetically generated interaction trajectories.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import (
|
15 |
+
Trainer, TrainingArguments,
|
16 |
+
DataCollatorForLanguageModeling,
|
17 |
+
AutoModelForCausalLM, AutoTokenizer
|
18 |
+
)
|
19 |
+
from datasets import Dataset
|
20 |
+
|
21 |
+
from data.trajectory_data import Trajectory, TrajectoryDataset
|
22 |
+
from models.llm_interface import LLMInterface
|
23 |
+
|
24 |
+
class AgentTuner:
|
25 |
+
"""Base class for agent tuning methods."""
|
26 |
+
|
27 |
+
def __init__(self, name: str):
|
28 |
+
"""
|
29 |
+
Initialize the agent tuner.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
name: Name of the tuning method
|
33 |
+
"""
|
34 |
+
self.name = name
|
35 |
+
|
36 |
+
def tune(
|
37 |
+
self,
|
38 |
+
model_name: str,
|
39 |
+
trajectories: List[Trajectory],
|
40 |
+
**kwargs
|
41 |
+
) -> Tuple[Any, Dict[str, Any]]:
|
42 |
+
"""
|
43 |
+
Tune a model into a specialized agent.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
model_name: Name of the base model
|
47 |
+
trajectories: List of training trajectories
|
48 |
+
**kwargs: Additional tuning parameters
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Tuple of (tuned_model, training_metrics)
|
52 |
+
"""
|
53 |
+
raise NotImplementedError("Subclasses must implement this method")
|
54 |
+
|
55 |
+
def save_model(self, model: Any, path: str) -> None:
|
56 |
+
"""
|
57 |
+
Save the tuned model.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
model: Tuned model
|
61 |
+
path: Path to save the model
|
62 |
+
"""
|
63 |
+
raise NotImplementedError("Subclasses must implement this method")
|
64 |
+
|
65 |
+
def load_model(self, path: str) -> Any:
|
66 |
+
"""
|
67 |
+
Load a tuned model.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
path: Path to the model
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Loaded model
|
74 |
+
"""
|
75 |
+
raise NotImplementedError("Subclasses must implement this method")
|
76 |
+
|
77 |
+
|
78 |
+
class SupervisedFineTuner(AgentTuner):
|
79 |
+
"""Tune agents using supervised fine-tuning."""
|
80 |
+
|
81 |
+
def __init__(self):
|
82 |
+
"""Initialize the supervised fine-tuner."""
|
83 |
+
super().__init__("supervised_fine_tuning")
|
84 |
+
|
85 |
+
def tune(
|
86 |
+
self,
|
87 |
+
model_name: str,
|
88 |
+
trajectories: List[Trajectory],
|
89 |
+
output_dir: str = "./tuned_model",
|
90 |
+
num_train_epochs: int = 3,
|
91 |
+
learning_rate: float = 5e-5,
|
92 |
+
batch_size: int = 4,
|
93 |
+
gradient_accumulation_steps: int = 4,
|
94 |
+
max_seq_length: int = 512,
|
95 |
+
format_type: str = "interleaved",
|
96 |
+
positive_weight: float = 0.8,
|
97 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
98 |
+
**kwargs
|
99 |
+
) -> Tuple[Any, Dict[str, Any]]:
|
100 |
+
"""
|
101 |
+
Tune a model using supervised fine-tuning.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
model_name: Name of the base model
|
105 |
+
trajectories: List of training trajectories
|
106 |
+
output_dir: Directory to save the model
|
107 |
+
num_train_epochs: Number of training epochs
|
108 |
+
learning_rate: Learning rate
|
109 |
+
batch_size: Batch size
|
110 |
+
gradient_accumulation_steps: Gradient accumulation steps
|
111 |
+
max_seq_length: Maximum sequence length
|
112 |
+
format_type: Format type for trajectories
|
113 |
+
positive_weight: Weight for positive examples
|
114 |
+
device: Device to use for training
|
115 |
+
**kwargs: Additional tuning parameters
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Tuple of (tuned_model, training_metrics)
|
119 |
+
"""
|
120 |
+
print(f"Starting supervised fine-tuning of {model_name}")
|
121 |
+
|
122 |
+
# Create output directory
|
123 |
+
os.makedirs(output_dir, exist_ok=True)
|
124 |
+
|
125 |
+
# Load model and tokenizer
|
126 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
127 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
128 |
+
|
129 |
+
# Ensure the tokenizer has a pad token
|
130 |
+
if tokenizer.pad_token is None:
|
131 |
+
tokenizer.pad_token = tokenizer.eos_token
|
132 |
+
|
133 |
+
# Prepare training data
|
134 |
+
print("Preparing training data...")
|
135 |
+
|
136 |
+
# Separate positive and negative trajectories
|
137 |
+
positive_trajectories = [t for t in trajectories if t.is_positive]
|
138 |
+
negative_trajectories = [t for t in trajectories if not t.is_positive]
|
139 |
+
|
140 |
+
print(f"Found {len(positive_trajectories)} positive and {len(negative_trajectories)} negative trajectories")
|
141 |
+
|
142 |
+
# Calculate sample counts based on positive weight
|
143 |
+
total_samples = len(trajectories)
|
144 |
+
target_positive = int(total_samples * positive_weight)
|
145 |
+
target_negative = total_samples - target_positive
|
146 |
+
|
147 |
+
# Sample trajectories to achieve desired ratio
|
148 |
+
if len(positive_trajectories) > target_positive:
|
149 |
+
positive_trajectories = np.random.choice(positive_trajectories, target_positive, replace=False).tolist()
|
150 |
+
|
151 |
+
if len(negative_trajectories) > target_negative:
|
152 |
+
negative_trajectories = np.random.choice(negative_trajectories, target_negative, replace=False).tolist()
|
153 |
+
|
154 |
+
# Combine trajectories
|
155 |
+
sampled_trajectories = positive_trajectories + negative_trajectories
|
156 |
+
np.random.shuffle(sampled_trajectories)
|
157 |
+
|
158 |
+
print(f"Using {len(positive_trajectories)} positive and {len(negative_trajectories)} negative trajectories for training")
|
159 |
+
|
160 |
+
# Format trajectories for training
|
161 |
+
training_texts = []
|
162 |
+
|
163 |
+
for trajectory in tqdm(sampled_trajectories, desc="Formatting trajectories"):
|
164 |
+
formatted = trajectory.to_training_format(format_type)
|
165 |
+
training_texts.append(formatted)
|
166 |
+
|
167 |
+
# Tokenize training data
|
168 |
+
def tokenize_function(examples):
|
169 |
+
return tokenizer(
|
170 |
+
examples["text"],
|
171 |
+
padding="max_length",
|
172 |
+
truncation=True,
|
173 |
+
max_length=max_seq_length
|
174 |
+
)
|
175 |
+
|
176 |
+
# Create dataset
|
177 |
+
dataset = Dataset.from_dict({"text": training_texts})
|
178 |
+
tokenized_dataset = dataset.map(
|
179 |
+
tokenize_function,
|
180 |
+
batched=True,
|
181 |
+
remove_columns=["text"]
|
182 |
+
)
|
183 |
+
|
184 |
+
# Set up training arguments
|
185 |
+
training_args = TrainingArguments(
|
186 |
+
output_dir=output_dir,
|
187 |
+
num_train_epochs=num_train_epochs,
|
188 |
+
per_device_train_batch_size=batch_size,
|
189 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
190 |
+
learning_rate=learning_rate,
|
191 |
+
weight_decay=0.01,
|
192 |
+
save_strategy="epoch",
|
193 |
+
save_total_limit=2,
|
194 |
+
logging_dir=f"{output_dir}/logs",
|
195 |
+
logging_steps=10,
|
196 |
+
report_to="none"
|
197 |
+
)
|
198 |
+
|
199 |
+
# Create data collator
|
200 |
+
data_collator = DataCollatorForLanguageModeling(
|
201 |
+
tokenizer=tokenizer,
|
202 |
+
mlm=False
|
203 |
+
)
|
204 |
+
|
205 |
+
# Create trainer
|
206 |
+
trainer = Trainer(
|
207 |
+
model=model,
|
208 |
+
args=training_args,
|
209 |
+
train_dataset=tokenized_dataset,
|
210 |
+
data_collator=data_collator
|
211 |
+
)
|
212 |
+
|
213 |
+
# Train the model
|
214 |
+
print("Starting training...")
|
215 |
+
train_result = trainer.train()
|
216 |
+
|
217 |
+
# Save the model
|
218 |
+
print(f"Saving model to {output_dir}")
|
219 |
+
trainer.save_model(output_dir)
|
220 |
+
tokenizer.save_pretrained(output_dir)
|
221 |
+
|
222 |
+
# Return the model and metrics
|
223 |
+
metrics = {
|
224 |
+
"train_loss": train_result.training_loss,
|
225 |
+
"train_runtime": train_result.metrics["train_runtime"],
|
226 |
+
"samples_per_second": train_result.metrics["train_samples_per_second"],
|
227 |
+
"num_train_samples": len(tokenized_dataset)
|
228 |
+
}
|
229 |
+
|
230 |
+
return model, metrics
|
231 |
+
|
232 |
+
def save_model(self, model: Any, path: str) -> None:
|
233 |
+
"""
|
234 |
+
Save the tuned model.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
model: Tuned model
|
238 |
+
path: Path to save the model
|
239 |
+
"""
|
240 |
+
model.save_pretrained(path)
|
241 |
+
|
242 |
+
def load_model(self, path: str) -> Any:
|
243 |
+
"""
|
244 |
+
Load a tuned model.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
path: Path to the model
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Loaded model
|
251 |
+
"""
|
252 |
+
return AutoModelForCausalLM.from_pretrained(path)
|
253 |
+
|
254 |
+
|
255 |
+
class ParameterEfficientFineTuner(AgentTuner):
|
256 |
+
"""Tune agents using parameter-efficient fine-tuning methods."""
|
257 |
+
|
258 |
+
def __init__(self):
|
259 |
+
"""Initialize the parameter-efficient fine-tuner."""
|
260 |
+
super().__init__("parameter_efficient_fine_tuning")
|
261 |
+
|
262 |
+
def tune(
|
263 |
+
self,
|
264 |
+
model_name: str,
|
265 |
+
trajectories: List[Trajectory],
|
266 |
+
output_dir: str = "./tuned_model",
|
267 |
+
method: str = "lora", # 'lora', 'prefix', 'prompt_tuning'
|
268 |
+
num_train_epochs: int = 3,
|
269 |
+
learning_rate: float = 1e-4,
|
270 |
+
batch_size: int = 4,
|
271 |
+
gradient_accumulation_steps: int = 4,
|
272 |
+
max_seq_length: int = 512,
|
273 |
+
format_type: str = "interleaved",
|
274 |
+
positive_weight: float = 0.8,
|
275 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
276 |
+
**kwargs
|
277 |
+
) -> Tuple[Any, Dict[str, Any]]:
|
278 |
+
"""
|
279 |
+
Tune a model using parameter-efficient methods.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
model_name: Name of the base model
|
283 |
+
trajectories: List of training trajectories
|
284 |
+
output_dir: Directory to save the model
|
285 |
+
method: PEFT method to use
|
286 |
+
num_train_epochs: Number of training epochs
|
287 |
+
learning_rate: Learning rate
|
288 |
+
batch_size: Batch size
|
289 |
+
gradient_accumulation_steps: Gradient accumulation steps
|
290 |
+
max_seq_length: Maximum sequence length
|
291 |
+
format_type: Format type for trajectories
|
292 |
+
positive_weight: Weight for positive examples
|
293 |
+
device: Device to use for training
|
294 |
+
**kwargs: Additional tuning parameters
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
Tuple of (tuned_model, training_metrics)
|
298 |
+
"""
|
299 |
+
try:
|
300 |
+
from peft import (
|
301 |
+
get_peft_model, LoraConfig, PrefixTuningConfig,
|
302 |
+
PromptTuningConfig, TaskType, PeftModel
|
303 |
+
)
|
304 |
+
except ImportError:
|
305 |
+
raise ImportError("PEFT library is required for parameter-efficient fine-tuning. Install it with 'pip install peft'.")
|
306 |
+
|
307 |
+
print(f"Starting parameter-efficient fine-tuning of {model_name} using {method}")
|
308 |
+
|
309 |
+
# Create output directory
|
310 |
+
os.makedirs(output_dir, exist_ok=True)
|
311 |
+
|
312 |
+
# Load model and tokenizer
|
313 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
314 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
315 |
+
|
316 |
+
# Ensure the tokenizer has a pad token
|
317 |
+
if tokenizer.pad_token is None:
|
318 |
+
tokenizer.pad_token = tokenizer.eos_token
|
319 |
+
|
320 |
+
# Configure PEFT method
|
321 |
+
if method == "lora":
|
322 |
+
peft_config = LoraConfig(
|
323 |
+
task_type=TaskType.CAUSAL_LM,
|
324 |
+
r=16,
|
325 |
+
lora_alpha=32,
|
326 |
+
lora_dropout=0.1,
|
327 |
+
target_modules=["q_proj", "v_proj"]
|
328 |
+
)
|
329 |
+
elif method == "prefix":
|
330 |
+
peft_config = PrefixTuningConfig(
|
331 |
+
task_type=TaskType.CAUSAL_LM,
|
332 |
+
num_virtual_tokens=20,
|
333 |
+
prefix_projection=True
|
334 |
+
)
|
335 |
+
elif method == "prompt_tuning":
|
336 |
+
peft_config = PromptTuningConfig(
|
337 |
+
task_type=TaskType.CAUSAL_LM,
|
338 |
+
num_virtual_tokens=20,
|
339 |
+
tokenizer_name_or_path=model_name
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
raise ValueError(f"Unsupported PEFT method: {method}")
|
343 |
+
|
344 |
+
# Create PEFT model
|
345 |
+
model = get_peft_model(model, peft_config)
|
346 |
+
model.print_trainable_parameters()
|
347 |
+
|
348 |
+
# Prepare training data (same as SupervisedFineTuner)
|
349 |
+
print("Preparing training data...")
|
350 |
+
|
351 |
+
# Separate positive and negative trajectories
|
352 |
+
positive_trajectories = [t for t in trajectories if t.is_positive]
|
353 |
+
negative_trajectories = [t for t in trajectories if not t.is_positive]
|
354 |
+
|
355 |
+
print(f"Found {len(positive_trajectories)} positive and {len(negative_trajectories)} negative trajectories")
|
356 |
+
|
357 |
+
# Calculate sample counts based on positive weight
|
358 |
+
total_samples = len(trajectories)
|
359 |
+
target_positive = int(total_samples * positive_weight)
|
360 |
+
target_negative = total_samples - target_positive
|
361 |
+
|
362 |
+
# Sample trajectories to achieve desired ratio
|
363 |
+
if len(positive_trajectories) > target_positive:
|
364 |
+
positive_trajectories = np.random.choice(positive_trajectories, target_positive, replace=False).tolist()
|
365 |
+
|
366 |
+
if len(negative_trajectories) > target_negative:
|
367 |
+
negative_trajectories = np.random.choice(negative_trajectories, target_negative, replace=False).tolist()
|
368 |
+
|
369 |
+
# Combine trajectories
|
370 |
+
sampled_trajectories = positive_trajectories + negative_trajectories
|
371 |
+
np.random.shuffle(sampled_trajectories)
|
372 |
+
|
373 |
+
print(f"Using {len(positive_trajectories)} positive and {len(negative_trajectories)} negative trajectories for training")
|
374 |
+
|
375 |
+
# Format trajectories for training
|
376 |
+
training_texts = []
|
377 |
+
|
378 |
+
for trajectory in tqdm(sampled_trajectories, desc="Formatting trajectories"):
|
379 |
+
formatted = trajectory.to_training_format(format_type)
|
380 |
+
training_texts.append(formatted)
|
381 |
+
|
382 |
+
# Tokenize training data
|
383 |
+
def tokenize_function(examples):
|
384 |
+
return tokenizer(
|
385 |
+
examples["text"],
|
386 |
+
padding="max_length",
|
387 |
+
truncation=True,
|
388 |
+
max_length=max_seq_length
|
389 |
+
)
|
390 |
+
|
391 |
+
# Create dataset
|
392 |
+
dataset = Dataset.from_dict({"text": training_texts})
|
393 |
+
tokenized_dataset = dataset.map(
|
394 |
+
tokenize_function,
|
395 |
+
batched=True,
|
396 |
+
remove_columns=["text"]
|
397 |
+
)
|
398 |
+
|
399 |
+
# Set up training arguments
|
400 |
+
training_args = TrainingArguments(
|
401 |
+
output_dir=output_dir,
|
402 |
+
num_train_epochs=num_train_epochs,
|
403 |
+
per_device_train_batch_size=batch_size,
|
404 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
405 |
+
learning_rate=learning_rate,
|
406 |
+
weight_decay=0.01,
|
407 |
+
save_strategy="epoch",
|
408 |
+
save_total_limit=2,
|
409 |
+
logging_dir=f"{output_dir}/logs",
|
410 |
+
logging_steps=10,
|
411 |
+
report_to="none"
|
412 |
+
)
|
413 |
+
|
414 |
+
# Create data collator
|
415 |
+
data_collator = DataCollatorForLanguageModeling(
|
416 |
+
tokenizer=tokenizer,
|
417 |
+
mlm=False
|
418 |
+
)
|
419 |
+
|
420 |
+
# Create trainer
|
421 |
+
trainer = Trainer(
|
422 |
+
model=model,
|
423 |
+
args=training_args,
|
424 |
+
train_dataset=tokenized_dataset,
|
425 |
+
data_collator=data_collator
|
426 |
+
)
|
427 |
+
|
428 |
+
# Train the model
|
429 |
+
print("Starting training...")
|
430 |
+
train_result = trainer.train()
|
431 |
+
|
432 |
+
# Save the model
|
433 |
+
print(f"Saving model to {output_dir}")
|
434 |
+
trainer.save_model(output_dir)
|
435 |
+
tokenizer.save_pretrained(output_dir)
|
436 |
+
|
437 |
+
# Return the model and metrics
|
438 |
+
metrics = {
|
439 |
+
"train_loss": train_result.training_loss,
|
440 |
+
"train_runtime": train_result.metrics["train_runtime"],
|
441 |
+
"samples_per_second": train_result.metrics["train_samples_per_second"],
|
442 |
+
"num_train_samples": len(tokenized_dataset),
|
443 |
+
"peft_method": method
|
444 |
+
}
|
445 |
+
|
446 |
+
return model, metrics
|
447 |
+
|
448 |
+
def save_model(self, model: Any, path: str) -> None:
|
449 |
+
"""
|
450 |
+
Save the tuned model.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
model: Tuned model
|
454 |
+
path: Path to save the model
|
455 |
+
"""
|
456 |
+
model.save_pretrained(path)
|
457 |
+
|
458 |
+
|
459 |
+
(Content truncated due to size limit. Use line ranges to read in chunks)
|
calibrators.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Domain-Specific Calibration Module for LLMs
|
3 |
+
|
4 |
+
This module implements calibration techniques for improving uncertainty estimates
|
5 |
+
across different domains, focusing on temperature scaling and domain adaptation.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
11 |
+
from scipy.optimize import minimize_scalar
|
12 |
+
|
13 |
+
class Calibrator:
|
14 |
+
"""Base class for calibration methods."""
|
15 |
+
|
16 |
+
def __init__(self, name: str):
|
17 |
+
"""
|
18 |
+
Initialize the calibrator.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
name: Name of the calibration method
|
22 |
+
"""
|
23 |
+
self.name = name
|
24 |
+
self.is_fitted = False
|
25 |
+
|
26 |
+
def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
|
27 |
+
"""
|
28 |
+
Fit the calibrator to the provided data.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
confidences: List of confidence scores
|
32 |
+
accuracies: List of boolean accuracy indicators
|
33 |
+
"""
|
34 |
+
raise NotImplementedError("Subclasses must implement this method")
|
35 |
+
|
36 |
+
def calibrate(self, confidences: List[float]) -> List[float]:
|
37 |
+
"""
|
38 |
+
Calibrate the provided confidence scores.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
confidences: List of confidence scores
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Calibrated confidence scores
|
45 |
+
"""
|
46 |
+
raise NotImplementedError("Subclasses must implement this method")
|
47 |
+
|
48 |
+
|
49 |
+
class TemperatureScaling(Calibrator):
|
50 |
+
"""Calibration using temperature scaling."""
|
51 |
+
|
52 |
+
def __init__(self):
|
53 |
+
"""Initialize the temperature scaling calibrator."""
|
54 |
+
super().__init__("temperature_scaling")
|
55 |
+
self.temperature = 1.0
|
56 |
+
|
57 |
+
def _nll_loss(self, temperature: float, confidences: np.ndarray, accuracies: np.ndarray) -> float:
|
58 |
+
"""
|
59 |
+
Calculate negative log likelihood loss for temperature scaling.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
temperature: Temperature parameter
|
63 |
+
confidences: Array of confidence scores
|
64 |
+
accuracies: Array of boolean accuracy indicators
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Negative log likelihood loss
|
68 |
+
"""
|
69 |
+
# Apply temperature scaling
|
70 |
+
scaled_confidences = np.clip(confidences / temperature, 1e-10, 1.0 - 1e-10)
|
71 |
+
|
72 |
+
# Calculate binary cross-entropy loss
|
73 |
+
loss = -np.mean(
|
74 |
+
accuracies * np.log(scaled_confidences) +
|
75 |
+
(1 - accuracies) * np.log(1 - scaled_confidences)
|
76 |
+
)
|
77 |
+
|
78 |
+
return loss
|
79 |
+
|
80 |
+
def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
|
81 |
+
"""
|
82 |
+
Fit the temperature parameter to the provided data.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
confidences: List of confidence scores
|
86 |
+
accuracies: List of boolean accuracy indicators
|
87 |
+
"""
|
88 |
+
if not confidences or len(confidences) != len(accuracies):
|
89 |
+
raise ValueError("Confidences and accuracies must have the same non-zero length")
|
90 |
+
|
91 |
+
# Convert to numpy arrays
|
92 |
+
conf_array = np.array(confidences)
|
93 |
+
acc_array = np.array(accuracies, dtype=float)
|
94 |
+
|
95 |
+
# Optimize temperature parameter
|
96 |
+
result = minimize_scalar(
|
97 |
+
lambda t: self._nll_loss(t, conf_array, acc_array),
|
98 |
+
bounds=(0.1, 10.0),
|
99 |
+
method='bounded'
|
100 |
+
)
|
101 |
+
|
102 |
+
self.temperature = result.x
|
103 |
+
self.is_fitted = True
|
104 |
+
|
105 |
+
print(f"Fitted temperature parameter: {self.temperature:.4f}")
|
106 |
+
|
107 |
+
def calibrate(self, confidences: List[float]) -> List[float]:
|
108 |
+
"""
|
109 |
+
Calibrate the provided confidence scores using temperature scaling.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
confidences: List of confidence scores
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Calibrated confidence scores
|
116 |
+
"""
|
117 |
+
if not self.is_fitted:
|
118 |
+
raise ValueError("Calibrator must be fitted before calibration")
|
119 |
+
|
120 |
+
# Apply temperature scaling
|
121 |
+
calibrated = [min(max(conf / self.temperature, 1e-10), 1.0 - 1e-10) for conf in confidences]
|
122 |
+
|
123 |
+
return calibrated
|
124 |
+
|
125 |
+
|
126 |
+
class DomainAdaptiveCalibration(Calibrator):
|
127 |
+
"""Calibration using domain-adaptive techniques."""
|
128 |
+
|
129 |
+
def __init__(self, source_domain: str, target_domain: str):
|
130 |
+
"""
|
131 |
+
Initialize the domain-adaptive calibrator.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
source_domain: Source domain name
|
135 |
+
target_domain: Target domain name
|
136 |
+
"""
|
137 |
+
super().__init__("domain_adaptive_calibration")
|
138 |
+
self.source_domain = source_domain
|
139 |
+
self.target_domain = target_domain
|
140 |
+
self.source_temperature = 1.0
|
141 |
+
self.target_temperature = 1.0
|
142 |
+
self.domain_shift_factor = 1.0
|
143 |
+
|
144 |
+
def fit(
|
145 |
+
self,
|
146 |
+
source_confidences: List[float],
|
147 |
+
source_accuracies: List[bool],
|
148 |
+
target_confidences: Optional[List[float]] = None,
|
149 |
+
target_accuracies: Optional[List[bool]] = None
|
150 |
+
) -> None:
|
151 |
+
"""
|
152 |
+
Fit the domain-adaptive calibrator to the provided data.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
source_confidences: List of confidence scores from source domain
|
156 |
+
source_accuracies: List of boolean accuracy indicators from source domain
|
157 |
+
target_confidences: List of confidence scores from target domain (if available)
|
158 |
+
target_accuracies: List of boolean accuracy indicators from target domain (if available)
|
159 |
+
"""
|
160 |
+
# Fit source domain temperature
|
161 |
+
source_calibrator = TemperatureScaling()
|
162 |
+
source_calibrator.fit(source_confidences, source_accuracies)
|
163 |
+
self.source_temperature = source_calibrator.temperature
|
164 |
+
|
165 |
+
# If target domain data is available, fit target temperature
|
166 |
+
if target_confidences and target_accuracies:
|
167 |
+
target_calibrator = TemperatureScaling()
|
168 |
+
target_calibrator.fit(target_confidences, target_accuracies)
|
169 |
+
self.target_temperature = target_calibrator.temperature
|
170 |
+
|
171 |
+
# Calculate domain shift factor
|
172 |
+
self.domain_shift_factor = self.target_temperature / self.source_temperature
|
173 |
+
else:
|
174 |
+
# Default domain shift factor based on heuristics
|
175 |
+
# This is a simplified approach; in a real system, this would be more sophisticated
|
176 |
+
self.domain_shift_factor = 1.2 # Assuming target domain is slightly more uncertain
|
177 |
+
self.target_temperature = self.source_temperature * self.domain_shift_factor
|
178 |
+
|
179 |
+
self.is_fitted = True
|
180 |
+
|
181 |
+
print(f"Fitted source temperature: {self.source_temperature:.4f}")
|
182 |
+
print(f"Fitted target temperature: {self.target_temperature:.4f}")
|
183 |
+
print(f"Domain shift factor: {self.domain_shift_factor:.4f}")
|
184 |
+
|
185 |
+
def calibrate(self, confidences: List[float], domain: str = None) -> List[float]:
|
186 |
+
"""
|
187 |
+
Calibrate the provided confidence scores using domain-adaptive calibration.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
confidences: List of confidence scores
|
191 |
+
domain: Domain of the confidences ('source' or 'target', defaults to target)
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Calibrated confidence scores
|
195 |
+
"""
|
196 |
+
if not self.is_fitted:
|
197 |
+
raise ValueError("Calibrator must be fitted before calibration")
|
198 |
+
|
199 |
+
# Determine which temperature to use
|
200 |
+
if domain == "source":
|
201 |
+
temperature = self.source_temperature
|
202 |
+
else:
|
203 |
+
temperature = self.target_temperature
|
204 |
+
|
205 |
+
# Apply temperature scaling
|
206 |
+
calibrated = [min(max(conf / temperature, 1e-10), 1.0 - 1e-10) for conf in confidences]
|
207 |
+
|
208 |
+
return calibrated
|
209 |
+
|
210 |
+
|
211 |
+
class EnsembleCalibration(Calibrator):
|
212 |
+
"""Calibration using an ensemble of calibration methods."""
|
213 |
+
|
214 |
+
def __init__(self, calibrators: List[Calibrator], weights: Optional[List[float]] = None):
|
215 |
+
"""
|
216 |
+
Initialize the ensemble calibrator.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
calibrators: List of calibrator instances
|
220 |
+
weights: List of weights for each calibrator (None for equal weights)
|
221 |
+
"""
|
222 |
+
super().__init__("ensemble_calibration")
|
223 |
+
self.calibrators = calibrators
|
224 |
+
|
225 |
+
# Initialize weights
|
226 |
+
if weights is None:
|
227 |
+
self.weights = [1.0 / len(calibrators)] * len(calibrators)
|
228 |
+
else:
|
229 |
+
if len(weights) != len(calibrators):
|
230 |
+
raise ValueError("Number of weights must match number of calibrators")
|
231 |
+
|
232 |
+
# Normalize weights
|
233 |
+
total = sum(weights)
|
234 |
+
self.weights = [w / total for w in weights]
|
235 |
+
|
236 |
+
def fit(self, confidences: List[float], accuracies: List[bool]) -> None:
|
237 |
+
"""
|
238 |
+
Fit all calibrators in the ensemble.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
confidences: List of confidence scores
|
242 |
+
accuracies: List of boolean accuracy indicators
|
243 |
+
"""
|
244 |
+
for calibrator in self.calibrators:
|
245 |
+
calibrator.fit(confidences, accuracies)
|
246 |
+
|
247 |
+
self.is_fitted = True
|
248 |
+
|
249 |
+
def calibrate(self, confidences: List[float]) -> List[float]:
|
250 |
+
"""
|
251 |
+
Calibrate the provided confidence scores using the ensemble.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
confidences: List of confidence scores
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
Calibrated confidence scores
|
258 |
+
"""
|
259 |
+
if not self.is_fitted:
|
260 |
+
raise ValueError("Calibrator must be fitted before calibration")
|
261 |
+
|
262 |
+
# Get calibrated confidences from each calibrator
|
263 |
+
all_calibrated = []
|
264 |
+
for calibrator in self.calibrators:
|
265 |
+
all_calibrated.append(calibrator.calibrate(confidences))
|
266 |
+
|
267 |
+
# Combine calibrated confidences using weights
|
268 |
+
calibrated = []
|
269 |
+
for i in range(len(confidences)):
|
270 |
+
weighted_sum = sum(self.weights[j] * all_calibrated[j][i] for j in range(len(self.calibrators)))
|
271 |
+
calibrated.append(weighted_sum)
|
272 |
+
|
273 |
+
return calibrated
|
274 |
+
|
275 |
+
|
276 |
+
# Factory function to create calibrators
|
277 |
+
def create_calibrator(method: str, **kwargs) -> Calibrator:
|
278 |
+
"""
|
279 |
+
Create a calibrator based on the specified method.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
method: Name of the calibration method
|
283 |
+
**kwargs: Additional arguments for the calibrator
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
Calibrator instance
|
287 |
+
"""
|
288 |
+
if method == "temperature_scaling":
|
289 |
+
return TemperatureScaling()
|
290 |
+
elif method == "domain_adaptive":
|
291 |
+
if "source_domain" not in kwargs or "target_domain" not in kwargs:
|
292 |
+
raise ValueError("Domain-adaptive calibration requires source_domain and target_domain")
|
293 |
+
return DomainAdaptiveCalibration(kwargs["source_domain"], kwargs["target_domain"])
|
294 |
+
elif method == "ensemble":
|
295 |
+
if "calibrators" not in kwargs:
|
296 |
+
raise ValueError("Ensemble calibration requires a list of calibrators")
|
297 |
+
return EnsembleCalibration(kwargs["calibrators"], kwargs.get("weights"))
|
298 |
+
else:
|
299 |
+
raise ValueError(f"Unsupported calibration method: {method}")
|
demo_app.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simplified Gradio Demo for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This script creates a simple Gradio web interface to demonstrate the framework's capabilities.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
# Mock functions to simulate framework behavior without requiring full model loading
|
14 |
+
def mock_generate_response(task, user_message):
|
15 |
+
"""Simulate generating a response from a tuned agent."""
|
16 |
+
responses = [
|
17 |
+
f"I'll help you with your task to {task.lower()}. Based on your message '{user_message}', I recommend starting by breaking this down into smaller steps.",
|
18 |
+
f"I understand you need assistance with {task.lower()}. From your message, I can see that you're looking for guidance on '{user_message}'. Here's my approach to solving this.",
|
19 |
+
f"Thank you for providing details about {task.lower()}. Your message '{user_message}' gives me enough context to help you effectively. Let me outline a solution.",
|
20 |
+
f"I'm analyzing your request about {task.lower()}. Your message '{user_message}' indicates you need comprehensive assistance. Here's what I suggest as next steps."
|
21 |
+
]
|
22 |
+
|
23 |
+
# Simulate processing time
|
24 |
+
import time
|
25 |
+
time.sleep(1.5)
|
26 |
+
|
27 |
+
return random.choice(responses) + f"\n\nResponse generated at {datetime.now().strftime('%H:%M:%S')}"
|
28 |
+
|
29 |
+
def mock_generate_negative_sample(task, user_message, agent_message):
|
30 |
+
"""Simulate generating a negative sample from a positive example."""
|
31 |
+
degradation_types = [
|
32 |
+
"Response truncation",
|
33 |
+
"Grammatical errors",
|
34 |
+
"Task misalignment",
|
35 |
+
"Constraint violation",
|
36 |
+
"Irrelevant tangent"
|
37 |
+
]
|
38 |
+
|
39 |
+
degradation = random.choice(degradation_types)
|
40 |
+
|
41 |
+
if degradation == "Response truncation":
|
42 |
+
words = agent_message.split()
|
43 |
+
truncate_point = int(len(words) * random.uniform(0.3, 0.7))
|
44 |
+
return " ".join(words[:truncate_point]) + f"...\n\nNegative sample type: {degradation}"
|
45 |
+
|
46 |
+
elif degradation == "Grammatical errors":
|
47 |
+
errors = [
|
48 |
+
lambda t: t.replace(".", ""), # Remove periods
|
49 |
+
lambda t: t.replace("I ", "i "), # Lowercase I
|
50 |
+
lambda t: t.replace(" the ", " teh "), # Typo
|
51 |
+
lambda t: t.replace(" is ", " are "), # Grammar error
|
52 |
+
lambda t: t.replace(" are ", " is ") # Grammar error
|
53 |
+
]
|
54 |
+
|
55 |
+
result = agent_message
|
56 |
+
for _ in range(random.randint(2, 4)):
|
57 |
+
error_func = random.choice(errors)
|
58 |
+
result = error_func(result)
|
59 |
+
|
60 |
+
return result + f"\n\nNegative sample type: {degradation}"
|
61 |
+
|
62 |
+
elif degradation == "Task misalignment":
|
63 |
+
misalignments = [
|
64 |
+
f"I understand you're asking about something completely different. Let me tell you about weather patterns instead.",
|
65 |
+
f"I don't think that's what you really want to know. Let me explain something else that might interest you.",
|
66 |
+
f"Your question seems to be about {task}, but I'd rather discuss the history of computing.",
|
67 |
+
f"Instead of addressing your specific request about {task}, let me give you general information that's only tangentially related."
|
68 |
+
]
|
69 |
+
|
70 |
+
return random.choice(misalignments) + f"\n\nNegative sample type: {degradation}"
|
71 |
+
|
72 |
+
elif degradation == "Constraint violation":
|
73 |
+
violations = [
|
74 |
+
f"I specifically recommend the XYZ Pro 2000 for $499.99, the UltraBook 15 for $1,299, and the PowerTech 5000 for $799. These are the absolute best options available.",
|
75 |
+
f"The system utilizes a polymorphic encapsulation paradigm with recursive lambda functions and stochastic gradient descent with backpropagation through a multi-layer perceptron.",
|
76 |
+
f"What specific features are you looking for? Do you have any brand preferences? What's your budget range? When do you need this by? Have you considered alternative options?",
|
77 |
+
f"Since you're a tech-savvy individual who values cutting-edge features, you'll definitely want the latest model with all the advanced capabilities."
|
78 |
+
]
|
79 |
+
|
80 |
+
return random.choice(violations) + f"\n\nNegative sample type: {degradation}"
|
81 |
+
|
82 |
+
else: # Irrelevant tangent
|
83 |
+
tangents = [
|
84 |
+
f"Did you know that artificial intelligence has been a concept since the 1950s? The field has evolved significantly since then, with major breakthroughs in neural networks and deep learning.",
|
85 |
+
f"I've been thinking about the philosophical implications of consciousness in AI systems. The question of whether an AI can truly understand or merely simulate understanding is fascinating.",
|
86 |
+
f"The weather has been quite interesting lately, with unusual patterns emerging globally. Climate scientists attribute this to a combination of factors including ocean temperature changes.",
|
87 |
+
f"I recently processed some fascinating data about renewable energy technologies. Solar efficiency has improved dramatically in the past decade, while costs have decreased by over 80%."
|
88 |
+
]
|
89 |
+
|
90 |
+
return random.choice(tangents) + f"\n\nNegative sample type: {degradation}"
|
91 |
+
|
92 |
+
def mock_generate_synthetic_trajectory(task):
|
93 |
+
"""Simulate generating a synthetic trajectory for a given task."""
|
94 |
+
# Determine task category
|
95 |
+
categories = ["travel", "shopping", "technology", "education", "finance", "health", "career", "home"]
|
96 |
+
category = random.choice(categories)
|
97 |
+
|
98 |
+
# Generate interactions (2-4 turns)
|
99 |
+
num_turns = random.randint(2, 4)
|
100 |
+
interactions = []
|
101 |
+
|
102 |
+
for j in range(num_turns):
|
103 |
+
if j == 0:
|
104 |
+
user_msg = f"I need help with this task: {task}"
|
105 |
+
agent_msg = f"I'd be happy to help you {task.lower()}. Could you provide more details about your preferences?"
|
106 |
+
elif j == num_turns - 1:
|
107 |
+
user_msg = "That sounds good. Please proceed with the final steps."
|
108 |
+
agent_msg = f"I've completed the task to {task.lower()}. Here's a summary of what I did..."
|
109 |
+
else:
|
110 |
+
user_msg = f"I prefer options that are {['affordable', 'convenient', 'high-quality'][j % 3]}."
|
111 |
+
agent_msg = f"Based on your preference for {['affordable', 'convenient', 'high-quality'][j % 3]} options, I recommend..."
|
112 |
+
|
113 |
+
interactions.append({
|
114 |
+
'user': user_msg,
|
115 |
+
'agent': agent_msg
|
116 |
+
})
|
117 |
+
|
118 |
+
# Format trajectory
|
119 |
+
result = f"Synthetic Trajectory for Task: {task}\nCategory: {category}\n\n"
|
120 |
+
|
121 |
+
for i, interaction in enumerate(interactions):
|
122 |
+
result += f"Turn {i+1}:\nUser: {interaction['user']}\nAgent: {interaction['agent']}\n\n"
|
123 |
+
|
124 |
+
result += f"Generation method: Template-based\nQuality score: {random.uniform(0.7, 0.9):.2f}"
|
125 |
+
|
126 |
+
return result
|
127 |
+
|
128 |
+
# Create Gradio interface
|
129 |
+
with gr.Blocks(title="Agent Tuning Framework Demo") as demo:
|
130 |
+
gr.Markdown("# Agent Tuning Optimization Framework Demo")
|
131 |
+
gr.Markdown("### A framework for efficiently tuning LLMs into specialized agents using negative and synthetic samples")
|
132 |
+
|
133 |
+
with gr.Tab("Generate Response"):
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Column():
|
136 |
+
task_input = gr.Textbox(label="Task Description", placeholder="e.g., Book a flight from New York to London")
|
137 |
+
user_input = gr.Textbox(label="User Message", placeholder="e.g., I need to travel next week for business")
|
138 |
+
generate_btn = gr.Button("Generate Response", variant="primary")
|
139 |
+
with gr.Column():
|
140 |
+
response_output = gr.Textbox(label="Agent Response", lines=8)
|
141 |
+
|
142 |
+
generate_btn.click(
|
143 |
+
mock_generate_response,
|
144 |
+
inputs=[task_input, user_input],
|
145 |
+
outputs=response_output
|
146 |
+
)
|
147 |
+
|
148 |
+
gr.Examples(
|
149 |
+
[
|
150 |
+
["Book a flight from New York to London", "I need to travel next week for business"],
|
151 |
+
["Find a vegetarian restaurant", "I'm looking for dinner options tonight"],
|
152 |
+
["Help me debug a Python script", "I'm getting an IndexError in my code"]
|
153 |
+
],
|
154 |
+
inputs=[task_input, user_input]
|
155 |
+
)
|
156 |
+
|
157 |
+
with gr.Tab("Generate Negative Sample"):
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
neg_task_input = gr.Textbox(label="Task Description", placeholder="e.g., Book a flight from New York to London")
|
161 |
+
neg_user_input = gr.Textbox(label="User Message", placeholder="e.g., I need to travel next week for business")
|
162 |
+
neg_agent_input = gr.Textbox(label="Agent Message (Positive Example)", placeholder="e.g., I'd be happy to help you book a flight...", lines=5)
|
163 |
+
neg_generate_btn = gr.Button("Generate Negative Sample", variant="primary")
|
164 |
+
with gr.Column():
|
165 |
+
neg_output = gr.Textbox(label="Negative Sample", lines=8)
|
166 |
+
|
167 |
+
neg_generate_btn.click(
|
168 |
+
mock_generate_negative_sample,
|
169 |
+
inputs=[neg_task_input, neg_user_input, neg_agent_input],
|
170 |
+
outputs=neg_output
|
171 |
+
)
|
172 |
+
|
173 |
+
gr.Examples(
|
174 |
+
[
|
175 |
+
["Book a flight from New York to London", "I need to travel next week for business", "I'd be happy to help you book a flight from New York to London. Could you provide more details about your preferred travel dates, budget, and any airline preferences you might have?"],
|
176 |
+
["Recommend a laptop for programming", "I need a new laptop for software development", "I can definitely help you find a suitable laptop for programming. Based on software development needs, I'd recommend looking for a laptop with at least 16GB RAM, a multi-core processor, and an SSD for storage. Would you like specific brand recommendations or have a particular budget in mind?"]
|
177 |
+
],
|
178 |
+
inputs=[neg_task_input, neg_user_input, neg_agent_input]
|
179 |
+
)
|
180 |
+
|
181 |
+
with gr.Tab("Generate Synthetic Trajectory"):
|
182 |
+
with gr.Row():
|
183 |
+
with gr.Column():
|
184 |
+
synth_task_input = gr.Textbox(label="Task Description", placeholder="e.g., Plan a weekend trip to Chicago")
|
185 |
+
synth_generate_btn = gr.Button("Generate Synthetic Trajectory", variant="primary")
|
186 |
+
with gr.Column():
|
187 |
+
synth_output = gr.Textbox(label="Synthetic Trajectory", lines=15)
|
188 |
+
|
189 |
+
synth_generate_btn.click(
|
190 |
+
mock_generate_synthetic_trajectory,
|
191 |
+
inputs=[synth_task_input],
|
192 |
+
outputs=synth_output
|
193 |
+
)
|
194 |
+
|
195 |
+
gr.Examples(
|
196 |
+
[
|
197 |
+
["Plan a weekend trip to Chicago"],
|
198 |
+
["Recommend healthy meal prep options for the week"],
|
199 |
+
["Help me create a study schedule for final exams"]
|
200 |
+
],
|
201 |
+
inputs=[synth_task_input]
|
202 |
+
)
|
203 |
+
|
204 |
+
gr.Markdown("""
|
205 |
+
## About This Framework
|
206 |
+
|
207 |
+
The Agent Tuning Optimization Framework provides a comprehensive solution for efficiently tuning large language models into specialized agents through the strategic incorporation of negative samples and synthetic trajectories.
|
208 |
+
|
209 |
+
### Key Features:
|
210 |
+
|
211 |
+
1. **Negative Sample Generation**: Creates examples of undesired agent behaviors to teach models what not to do
|
212 |
+
2. **Synthetic Trajectory Generation**: Automatically generates diverse interaction trajectories
|
213 |
+
3. **Mixed-Sample Tuning**: Combines positive examples, negative samples, and synthetic trajectories
|
214 |
+
4. **Parameter-Efficient Fine-Tuning**: Implements methods like LoRA for computational efficiency
|
215 |
+
|
216 |
+
This demo provides a simplified simulation of the framework's capabilities. For full functionality, deploy the complete framework following the provided documentation.
|
217 |
+
""")
|
218 |
+
|
219 |
+
# Launch the interface
|
220 |
+
demo.launch(share=True)
|
deploy.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Deployment Script for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This script prepares the framework for deployment to production environments
|
5 |
+
and Hugging Face Spaces.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import argparse
|
11 |
+
import subprocess
|
12 |
+
import json
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
def prepare_for_deployment(source_dir, output_dir, config_path=None):
|
16 |
+
"""
|
17 |
+
Prepare the framework for deployment.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
source_dir: Source directory containing the framework
|
21 |
+
output_dir: Output directory for deployment package
|
22 |
+
config_path: Path to configuration file (optional)
|
23 |
+
"""
|
24 |
+
print(f"Preparing deployment package from {source_dir} to {output_dir}")
|
25 |
+
|
26 |
+
# Create output directory
|
27 |
+
os.makedirs(output_dir, exist_ok=True)
|
28 |
+
|
29 |
+
# Copy core modules
|
30 |
+
core_modules = [
|
31 |
+
"models",
|
32 |
+
"data",
|
33 |
+
"training",
|
34 |
+
"evaluation",
|
35 |
+
"main.py",
|
36 |
+
"README.md"
|
37 |
+
]
|
38 |
+
|
39 |
+
for module in core_modules:
|
40 |
+
source_path = os.path.join(source_dir, module)
|
41 |
+
target_path = os.path.join(output_dir, module)
|
42 |
+
|
43 |
+
if os.path.isdir(source_path):
|
44 |
+
if os.path.exists(target_path):
|
45 |
+
shutil.rmtree(target_path)
|
46 |
+
shutil.copytree(source_path, target_path)
|
47 |
+
else:
|
48 |
+
shutil.copy2(source_path, target_path)
|
49 |
+
|
50 |
+
# Copy configuration file if provided
|
51 |
+
if config_path:
|
52 |
+
shutil.copy2(config_path, os.path.join(output_dir, "config.json"))
|
53 |
+
else:
|
54 |
+
# Use example config
|
55 |
+
example_config_path = os.path.join(source_dir, "example_config.json")
|
56 |
+
if os.path.exists(example_config_path):
|
57 |
+
shutil.copy2(example_config_path, os.path.join(output_dir, "config.json"))
|
58 |
+
|
59 |
+
# Create requirements.txt
|
60 |
+
requirements = [
|
61 |
+
"torch>=1.10.0",
|
62 |
+
"transformers>=4.20.0",
|
63 |
+
"datasets>=2.0.0",
|
64 |
+
"numpy>=1.20.0",
|
65 |
+
"pandas>=1.3.0",
|
66 |
+
"matplotlib>=3.4.0",
|
67 |
+
"tqdm>=4.60.0",
|
68 |
+
"scikit-learn>=1.0.0",
|
69 |
+
"peft>=0.2.0"
|
70 |
+
]
|
71 |
+
|
72 |
+
with open(os.path.join(output_dir, "requirements.txt"), "w") as f:
|
73 |
+
f.write("\n".join(requirements))
|
74 |
+
|
75 |
+
# Create setup.py
|
76 |
+
setup_py = """
|
77 |
+
from setuptools import setup, find_packages
|
78 |
+
|
79 |
+
setup(
|
80 |
+
name="agent_tuning_framework",
|
81 |
+
version="0.1.0",
|
82 |
+
packages=find_packages(),
|
83 |
+
install_requires=[
|
84 |
+
"torch>=1.10.0",
|
85 |
+
"transformers>=4.20.0",
|
86 |
+
"datasets>=2.0.0",
|
87 |
+
"numpy>=1.20.0",
|
88 |
+
"pandas>=1.3.0",
|
89 |
+
"matplotlib>=3.4.0",
|
90 |
+
"tqdm>=4.60.0",
|
91 |
+
"scikit-learn>=1.0.0",
|
92 |
+
"peft>=0.2.0"
|
93 |
+
],
|
94 |
+
author="MBZUAI Technical Interview Preparation",
|
95 |
+
author_email="example@example.com",
|
96 |
+
description="Agent Tuning Optimization Framework with Negative and Synthetic Samples",
|
97 |
+
keywords="nlp, machine learning, agent tuning, language models",
|
98 |
+
url="https://github.com/username/agent_tuning_framework",
|
99 |
+
)
|
100 |
+
"""
|
101 |
+
|
102 |
+
with open(os.path.join(output_dir, "setup.py"), "w") as f:
|
103 |
+
f.write(setup_py)
|
104 |
+
|
105 |
+
# Create app.py for web interface
|
106 |
+
app_py = """
|
107 |
+
import os
|
108 |
+
import json
|
109 |
+
import gradio as gr
|
110 |
+
import torch
|
111 |
+
from models.llm_interface import LLMInterface
|
112 |
+
from data.trajectory_data import TrajectoryDataset, Trajectory
|
113 |
+
from training.negative_samples import create_negative_sample_generator
|
114 |
+
from training.synthetic_trajectories import create_synthetic_trajectory_generator
|
115 |
+
|
116 |
+
# Initialize model
|
117 |
+
def load_model(model_path):
|
118 |
+
if os.path.exists(model_path):
|
119 |
+
return LLMInterface(
|
120 |
+
model_name=model_path,
|
121 |
+
model_type="causal",
|
122 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
return LLMInterface(
|
126 |
+
model_name="gpt2",
|
127 |
+
model_type="causal",
|
128 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
129 |
+
)
|
130 |
+
|
131 |
+
# Initialize components
|
132 |
+
model = load_model("./tuned_model")
|
133 |
+
negative_generator = create_negative_sample_generator("response_degradation")
|
134 |
+
synthetic_generator = create_synthetic_trajectory_generator("template")
|
135 |
+
|
136 |
+
# Define interface functions
|
137 |
+
def generate_response(task, user_message):
|
138 |
+
prompt = f"Task: {task}\\n\\nUser: {user_message}\\nAgent:"
|
139 |
+
response = model.generate(prompt)
|
140 |
+
return response["response"]
|
141 |
+
|
142 |
+
def generate_negative_sample(task, user_message, agent_message):
|
143 |
+
trajectory = Trajectory(
|
144 |
+
task_description=task,
|
145 |
+
interactions=[{"user": user_message, "agent": agent_message}]
|
146 |
+
)
|
147 |
+
negative_trajectory = negative_generator.generate(trajectory)
|
148 |
+
return negative_trajectory.interactions[0]["agent"]
|
149 |
+
|
150 |
+
def generate_synthetic_trajectory(task):
|
151 |
+
trajectory = synthetic_generator.generate(task)
|
152 |
+
result = ""
|
153 |
+
for i, interaction in enumerate(trajectory.interactions):
|
154 |
+
result += f"Turn {i+1}:\\nUser: {interaction['user']}\\nAgent: {interaction['agent']}\\n\\n"
|
155 |
+
return result
|
156 |
+
|
157 |
+
# Create Gradio interface
|
158 |
+
with gr.Blocks(title="Agent Tuning Framework Demo") as demo:
|
159 |
+
gr.Markdown("# Agent Tuning Optimization Framework Demo")
|
160 |
+
|
161 |
+
with gr.Tab("Generate Response"):
|
162 |
+
with gr.Row():
|
163 |
+
with gr.Column():
|
164 |
+
task_input = gr.Textbox(label="Task Description")
|
165 |
+
user_input = gr.Textbox(label="User Message")
|
166 |
+
generate_btn = gr.Button("Generate Response")
|
167 |
+
with gr.Column():
|
168 |
+
response_output = gr.Textbox(label="Agent Response")
|
169 |
+
|
170 |
+
generate_btn.click(
|
171 |
+
generate_response,
|
172 |
+
inputs=[task_input, user_input],
|
173 |
+
outputs=response_output
|
174 |
+
)
|
175 |
+
|
176 |
+
with gr.Tab("Generate Negative Sample"):
|
177 |
+
with gr.Row():
|
178 |
+
with gr.Column():
|
179 |
+
neg_task_input = gr.Textbox(label="Task Description")
|
180 |
+
neg_user_input = gr.Textbox(label="User Message")
|
181 |
+
neg_agent_input = gr.Textbox(label="Agent Message (Positive Example)")
|
182 |
+
neg_generate_btn = gr.Button("Generate Negative Sample")
|
183 |
+
with gr.Column():
|
184 |
+
neg_output = gr.Textbox(label="Negative Sample")
|
185 |
+
|
186 |
+
neg_generate_btn.click(
|
187 |
+
generate_negative_sample,
|
188 |
+
inputs=[neg_task_input, neg_user_input, neg_agent_input],
|
189 |
+
outputs=neg_output
|
190 |
+
)
|
191 |
+
|
192 |
+
with gr.Tab("Generate Synthetic Trajectory"):
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column():
|
195 |
+
synth_task_input = gr.Textbox(label="Task Description")
|
196 |
+
synth_generate_btn = gr.Button("Generate Synthetic Trajectory")
|
197 |
+
with gr.Column():
|
198 |
+
synth_output = gr.Textbox(label="Synthetic Trajectory")
|
199 |
+
|
200 |
+
synth_generate_btn.click(
|
201 |
+
generate_synthetic_trajectory,
|
202 |
+
inputs=[synth_task_input],
|
203 |
+
outputs=synth_output
|
204 |
+
)
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
demo.launch()
|
208 |
+
"""
|
209 |
+
|
210 |
+
with open(os.path.join(output_dir, "app.py"), "w") as f:
|
211 |
+
f.write(app_py)
|
212 |
+
|
213 |
+
# Create Dockerfile
|
214 |
+
dockerfile = """
|
215 |
+
FROM python:3.9-slim
|
216 |
+
|
217 |
+
WORKDIR /app
|
218 |
+
|
219 |
+
COPY . /app/
|
220 |
+
|
221 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
222 |
+
RUN pip install --no-cache-dir gradio>=3.0.0
|
223 |
+
|
224 |
+
EXPOSE 7860
|
225 |
+
|
226 |
+
CMD ["python", "app.py"]
|
227 |
+
"""
|
228 |
+
|
229 |
+
with open(os.path.join(output_dir, "Dockerfile"), "w") as f:
|
230 |
+
f.write(dockerfile)
|
231 |
+
|
232 |
+
# Create README for deployment
|
233 |
+
deployment_readme = """
|
234 |
+
# Agent Tuning Optimization Framework
|
235 |
+
|
236 |
+
This package contains the Agent Tuning Optimization Framework with Negative and Synthetic Samples, a comprehensive solution for efficiently tuning large language models into specialized agents.
|
237 |
+
|
238 |
+
## Installation
|
239 |
+
|
240 |
+
```bash
|
241 |
+
pip install -r requirements.txt
|
242 |
+
```
|
243 |
+
|
244 |
+
## Usage
|
245 |
+
|
246 |
+
### Running Experiments
|
247 |
+
|
248 |
+
```bash
|
249 |
+
python main.py --config config.json --output ./experiment_results
|
250 |
+
```
|
251 |
+
|
252 |
+
### Web Interface
|
253 |
+
|
254 |
+
```bash
|
255 |
+
pip install gradio
|
256 |
+
python app.py
|
257 |
+
```
|
258 |
+
|
259 |
+
## Deployment Options
|
260 |
+
|
261 |
+
### Docker
|
262 |
+
|
263 |
+
```bash
|
264 |
+
docker build -t agent-tuning-framework .
|
265 |
+
docker run -p 7860:7860 agent-tuning-framework
|
266 |
+
```
|
267 |
+
|
268 |
+
### Hugging Face Spaces
|
269 |
+
|
270 |
+
This project can be deployed to Hugging Face Spaces by following these steps:
|
271 |
+
|
272 |
+
1. Create a new Space on Hugging Face (https://huggingface.co/spaces)
|
273 |
+
2. Select "Gradio" as the SDK
|
274 |
+
3. Upload all files from this directory to the Space
|
275 |
+
4. The Space will automatically build and deploy the application
|
276 |
+
|
277 |
+
## Configuration
|
278 |
+
|
279 |
+
See `config.json` for configuration options.
|
280 |
+
|
281 |
+
## License
|
282 |
+
|
283 |
+
MIT
|
284 |
+
"""
|
285 |
+
|
286 |
+
with open(os.path.join(output_dir, "README.md"), "w") as f:
|
287 |
+
f.write(deployment_readme)
|
288 |
+
|
289 |
+
# Create Hugging Face Space files
|
290 |
+
os.makedirs(os.path.join(output_dir, "huggingface"), exist_ok=True)
|
291 |
+
|
292 |
+
# Create requirements.txt for Hugging Face
|
293 |
+
hf_requirements = requirements + ["gradio>=3.0.0"]
|
294 |
+
|
295 |
+
with open(os.path.join(output_dir, "huggingface", "requirements.txt"), "w") as f:
|
296 |
+
f.write("\n".join(hf_requirements))
|
297 |
+
|
298 |
+
# Copy app.py
|
299 |
+
shutil.copy2(os.path.join(output_dir, "app.py"), os.path.join(output_dir, "huggingface", "app.py"))
|
300 |
+
|
301 |
+
# Create README for Hugging Face
|
302 |
+
hf_readme = """
|
303 |
+
---
|
304 |
+
title: Agent Tuning Optimization Framework
|
305 |
+
emoji: 🤖
|
306 |
+
colorFrom: blue
|
307 |
+
colorTo: green
|
308 |
+
sdk: gradio
|
309 |
+
sdk_version: 3.36.1
|
310 |
+
app_file: app.py
|
311 |
+
pinned: false
|
312 |
+
license: mit
|
313 |
+
---
|
314 |
+
|
315 |
+
# Agent Tuning Optimization Framework
|
316 |
+
|
317 |
+
This Space demonstrates the Agent Tuning Optimization Framework with Negative and Synthetic Samples, a comprehensive solution for efficiently tuning large language models into specialized agents.
|
318 |
+
|
319 |
+
## Features
|
320 |
+
|
321 |
+
- Generate agent responses for given tasks and user messages
|
322 |
+
- Create negative samples from positive examples
|
323 |
+
- Generate synthetic interaction trajectories
|
324 |
+
|
325 |
+
## Usage
|
326 |
+
|
327 |
+
1. Select a tab for the desired functionality
|
328 |
+
2. Enter the required information
|
329 |
+
3. Click the button to generate results
|
330 |
+
|
331 |
+
## Learn More
|
332 |
+
|
333 |
+
For more information, visit the [GitHub repository](https://github.com/username/agent_tuning_framework).
|
334 |
+
"""
|
335 |
+
|
336 |
+
with open(os.path.join(output_dir, "huggingface", "README.md"), "w") as f:
|
337 |
+
f.write(hf_readme)
|
338 |
+
|
339 |
+
print(f"Deployment package prepared in {output_dir}")
|
340 |
+
print(f"Hugging Face Space files prepared in {os.path.join(output_dir, 'huggingface')}")
|
341 |
+
|
342 |
+
def main():
|
343 |
+
"""Main function for preparing deployment package."""
|
344 |
+
parser = argparse.ArgumentParser(description="Prepare deployment package for Agent Tuning Framework")
|
345 |
+
parser.add_argument("--source", type=str, default=".", help="Source directory containing the framework")
|
346 |
+
parser.add_argument("--output", type=str, default="./deployment", help="Output directory for deployment package")
|
347 |
+
parser.add_argument("--config", type=str, help="Path to configuration file")
|
348 |
+
|
349 |
+
args = parser.parse_args()
|
350 |
+
|
351 |
+
prepare_for_deployment(args.source, args.output, args.config)
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
main()
|
domain_datasets.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Domain Dataset Module for Cross-Domain Uncertainty Quantification
|
3 |
+
|
4 |
+
This module provides functionality for loading and managing datasets from different domains
|
5 |
+
for evaluating uncertainty quantification methods across domains.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
13 |
+
from datasets import load_dataset
|
14 |
+
|
15 |
+
class DomainDataset:
|
16 |
+
"""Base class for domain-specific datasets."""
|
17 |
+
|
18 |
+
def __init__(self, name: str, domain: str):
|
19 |
+
"""
|
20 |
+
Initialize the domain dataset.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
name: Name of the dataset
|
24 |
+
domain: Domain category (e.g., 'medical', 'legal', 'general')
|
25 |
+
"""
|
26 |
+
self.name = name
|
27 |
+
self.domain = domain
|
28 |
+
self.data = None
|
29 |
+
|
30 |
+
def load(self) -> None:
|
31 |
+
"""Load the dataset."""
|
32 |
+
raise NotImplementedError("Subclasses must implement this method")
|
33 |
+
|
34 |
+
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]:
|
35 |
+
"""
|
36 |
+
Get samples from the dataset.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
n: Number of samples to return (None for all)
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
List of samples with prompts and expected outputs
|
43 |
+
"""
|
44 |
+
raise NotImplementedError("Subclasses must implement this method")
|
45 |
+
|
46 |
+
def get_prompt_template(self) -> str:
|
47 |
+
"""
|
48 |
+
Get the prompt template for this domain.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Prompt template string
|
52 |
+
"""
|
53 |
+
raise NotImplementedError("Subclasses must implement this method")
|
54 |
+
|
55 |
+
|
56 |
+
class MedicalQADataset(DomainDataset):
|
57 |
+
"""Dataset for medical question answering."""
|
58 |
+
|
59 |
+
def __init__(self, data_path: Optional[str] = None):
|
60 |
+
"""
|
61 |
+
Initialize the medical QA dataset.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
data_path: Path to the dataset file (None to use default)
|
65 |
+
"""
|
66 |
+
super().__init__("medical_qa", "medical")
|
67 |
+
self.data_path = data_path
|
68 |
+
|
69 |
+
def load(self) -> None:
|
70 |
+
"""Load the medical QA dataset."""
|
71 |
+
if self.data_path and os.path.exists(self.data_path):
|
72 |
+
# Load from local file if available
|
73 |
+
if self.data_path.endswith('.csv'):
|
74 |
+
self.data = pd.read_csv(self.data_path)
|
75 |
+
elif self.data_path.endswith('.json'):
|
76 |
+
with open(self.data_path, 'r') as f:
|
77 |
+
self.data = json.load(f)
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unsupported file format: {self.data_path}")
|
80 |
+
else:
|
81 |
+
# Use a sample of the MedMCQA dataset from Hugging Face
|
82 |
+
try:
|
83 |
+
dataset = load_dataset("medmcqa", split="train[:100]")
|
84 |
+
self.data = dataset.to_pandas()
|
85 |
+
except Exception as e:
|
86 |
+
# Fallback to synthetic data if dataset loading fails
|
87 |
+
print(f"Failed to load MedMCQA dataset: {e}")
|
88 |
+
self.data = self._create_synthetic_data()
|
89 |
+
|
90 |
+
def _create_synthetic_data(self) -> pd.DataFrame:
|
91 |
+
"""Create synthetic medical QA data for testing."""
|
92 |
+
questions = [
|
93 |
+
"What are the common symptoms of myocardial infarction?",
|
94 |
+
"How does insulin regulate blood glucose levels?",
|
95 |
+
"What is the mechanism of action for ACE inhibitors?",
|
96 |
+
"What are the diagnostic criteria for rheumatoid arthritis?",
|
97 |
+
"How does the SARS-CoV-2 virus enter human cells?",
|
98 |
+
"What are the main side effects of chemotherapy?",
|
99 |
+
"How does the blood-brain barrier function?",
|
100 |
+
"What is the pathophysiology of type 2 diabetes?",
|
101 |
+
"How do vaccines create immunity?",
|
102 |
+
"What are the stages of chronic kidney disease?"
|
103 |
+
]
|
104 |
+
|
105 |
+
# Create a dataframe with questions only (answers would be generated by LLMs)
|
106 |
+
return pd.DataFrame({
|
107 |
+
'question': questions,
|
108 |
+
'domain': ['medical'] * len(questions)
|
109 |
+
})
|
110 |
+
|
111 |
+
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]:
|
112 |
+
"""
|
113 |
+
Get samples from the medical QA dataset.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
n: Number of samples to return (None for all)
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
List of samples with prompts
|
120 |
+
"""
|
121 |
+
if self.data is None:
|
122 |
+
self.load()
|
123 |
+
|
124 |
+
if 'question' in self.data.columns:
|
125 |
+
questions = self.data['question'].tolist()
|
126 |
+
elif 'question_text' in self.data.columns:
|
127 |
+
questions = self.data['question_text'].tolist()
|
128 |
+
else:
|
129 |
+
raise ValueError("Dataset does not contain question column")
|
130 |
+
|
131 |
+
if n is not None:
|
132 |
+
questions = questions[:n]
|
133 |
+
|
134 |
+
# Create samples with prompts
|
135 |
+
samples = []
|
136 |
+
for question in questions:
|
137 |
+
prompt = self.get_prompt_template().format(question=question)
|
138 |
+
samples.append({
|
139 |
+
'domain': 'medical',
|
140 |
+
'question': question,
|
141 |
+
'prompt': prompt
|
142 |
+
})
|
143 |
+
|
144 |
+
return samples
|
145 |
+
|
146 |
+
def get_prompt_template(self) -> str:
|
147 |
+
"""
|
148 |
+
Get the prompt template for medical domain.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Prompt template string
|
152 |
+
"""
|
153 |
+
return "You are a medical expert. Please answer the following medical question accurately and concisely:\n\n{question}"
|
154 |
+
|
155 |
+
|
156 |
+
class LegalQADataset(DomainDataset):
|
157 |
+
"""Dataset for legal question answering."""
|
158 |
+
|
159 |
+
def __init__(self, data_path: Optional[str] = None):
|
160 |
+
"""
|
161 |
+
Initialize the legal QA dataset.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
data_path: Path to the dataset file (None to use default)
|
165 |
+
"""
|
166 |
+
super().__init__("legal_qa", "legal")
|
167 |
+
self.data_path = data_path
|
168 |
+
|
169 |
+
def load(self) -> None:
|
170 |
+
"""Load the legal QA dataset."""
|
171 |
+
if self.data_path and os.path.exists(self.data_path):
|
172 |
+
# Load from local file if available
|
173 |
+
if self.data_path.endswith('.csv'):
|
174 |
+
self.data = pd.read_csv(self.data_path)
|
175 |
+
elif self.data_path.endswith('.json'):
|
176 |
+
with open(self.data_path, 'r') as f:
|
177 |
+
self.data = json.load(f)
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Unsupported file format: {self.data_path}")
|
180 |
+
else:
|
181 |
+
# Use synthetic data for legal domain
|
182 |
+
self.data = self._create_synthetic_data()
|
183 |
+
|
184 |
+
def _create_synthetic_data(self) -> pd.DataFrame:
|
185 |
+
"""Create synthetic legal QA data for testing."""
|
186 |
+
questions = [
|
187 |
+
"What constitutes a breach of contract?",
|
188 |
+
"How is intellectual property protected under international law?",
|
189 |
+
"What are the elements of negligence in tort law?",
|
190 |
+
"How does the doctrine of stare decisis function in common law systems?",
|
191 |
+
"What rights are protected under the Fourth Amendment?",
|
192 |
+
"What is the difference between a patent and a copyright?",
|
193 |
+
"How does arbitration differ from litigation?",
|
194 |
+
"What constitutes insider trading under securities law?",
|
195 |
+
"What are the legal requirements for a valid will?",
|
196 |
+
"How does diplomatic immunity work under international law?"
|
197 |
+
]
|
198 |
+
|
199 |
+
# Create a dataframe with questions only
|
200 |
+
return pd.DataFrame({
|
201 |
+
'question': questions,
|
202 |
+
'domain': ['legal'] * len(questions)
|
203 |
+
})
|
204 |
+
|
205 |
+
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]:
|
206 |
+
"""
|
207 |
+
Get samples from the legal QA dataset.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
n: Number of samples to return (None for all)
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
List of samples with prompts
|
214 |
+
"""
|
215 |
+
if self.data is None:
|
216 |
+
self.load()
|
217 |
+
|
218 |
+
questions = self.data['question'].tolist()
|
219 |
+
|
220 |
+
if n is not None:
|
221 |
+
questions = questions[:n]
|
222 |
+
|
223 |
+
# Create samples with prompts
|
224 |
+
samples = []
|
225 |
+
for question in questions:
|
226 |
+
prompt = self.get_prompt_template().format(question=question)
|
227 |
+
samples.append({
|
228 |
+
'domain': 'legal',
|
229 |
+
'question': question,
|
230 |
+
'prompt': prompt
|
231 |
+
})
|
232 |
+
|
233 |
+
return samples
|
234 |
+
|
235 |
+
def get_prompt_template(self) -> str:
|
236 |
+
"""
|
237 |
+
Get the prompt template for legal domain.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
Prompt template string
|
241 |
+
"""
|
242 |
+
return "You are a legal expert. Please answer the following legal question accurately and concisely:\n\n{question}"
|
243 |
+
|
244 |
+
|
245 |
+
class GeneralKnowledgeDataset(DomainDataset):
|
246 |
+
"""Dataset for general knowledge question answering."""
|
247 |
+
|
248 |
+
def __init__(self, data_path: Optional[str] = None):
|
249 |
+
"""
|
250 |
+
Initialize the general knowledge dataset.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
data_path: Path to the dataset file (None to use default)
|
254 |
+
"""
|
255 |
+
super().__init__("general_knowledge", "general")
|
256 |
+
self.data_path = data_path
|
257 |
+
|
258 |
+
def load(self) -> None:
|
259 |
+
"""Load the general knowledge dataset."""
|
260 |
+
if self.data_path and os.path.exists(self.data_path):
|
261 |
+
# Load from local file if available
|
262 |
+
if self.data_path.endswith('.csv'):
|
263 |
+
self.data = pd.read_csv(self.data_path)
|
264 |
+
elif self.data_path.endswith('.json'):
|
265 |
+
with open(self.data_path, 'r') as f:
|
266 |
+
self.data = json.load(f)
|
267 |
+
else:
|
268 |
+
raise ValueError(f"Unsupported file format: {self.data_path}")
|
269 |
+
else:
|
270 |
+
# Use a sample of the TriviaQA dataset from Hugging Face
|
271 |
+
try:
|
272 |
+
dataset = load_dataset("trivia_qa", "unfiltered", split="train[:100]")
|
273 |
+
self.data = dataset.to_pandas()
|
274 |
+
except Exception as e:
|
275 |
+
# Fallback to synthetic data if dataset loading fails
|
276 |
+
print(f"Failed to load TriviaQA dataset: {e}")
|
277 |
+
self.data = self._create_synthetic_data()
|
278 |
+
|
279 |
+
def _create_synthetic_data(self) -> pd.DataFrame:
|
280 |
+
"""Create synthetic general knowledge data for testing."""
|
281 |
+
questions = [
|
282 |
+
"What is the capital of France?",
|
283 |
+
"Who wrote the novel '1984'?",
|
284 |
+
"What is the chemical symbol for gold?",
|
285 |
+
"Which planet is known as the Red Planet?",
|
286 |
+
"Who painted the Mona Lisa?",
|
287 |
+
"What is the largest ocean on Earth?",
|
288 |
+
"What year did World War II end?",
|
289 |
+
"What is the tallest mountain in the world?",
|
290 |
+
"Who was the first person to step on the moon?",
|
291 |
+
"What is the speed of light in a vacuum?"
|
292 |
+
]
|
293 |
+
|
294 |
+
# Create a dataframe with questions only
|
295 |
+
return pd.DataFrame({
|
296 |
+
'question': questions,
|
297 |
+
'domain': ['general'] * len(questions)
|
298 |
+
})
|
299 |
+
|
300 |
+
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]:
|
301 |
+
"""
|
302 |
+
Get samples from the general knowledge dataset.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
n: Number of samples to return (None for all)
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
List of samples with prompts
|
309 |
+
"""
|
310 |
+
if self.data is None:
|
311 |
+
self.load()
|
312 |
+
|
313 |
+
if 'question' in self.data.columns:
|
314 |
+
questions = self.data['question'].tolist()
|
315 |
+
elif 'question_text' in self.data.columns:
|
316 |
+
questions = self.data['question_text'].tolist()
|
317 |
+
else:
|
318 |
+
raise ValueError("Dataset does not contain question column")
|
319 |
+
|
320 |
+
if n is not None:
|
321 |
+
questions = questions[:n]
|
322 |
+
|
323 |
+
# Create samples with prompts
|
324 |
+
samples = []
|
325 |
+
for question in questions:
|
326 |
+
prompt = self.get_prompt_template().format(question=question)
|
327 |
+
samples.append({
|
328 |
+
'domain': 'general',
|
329 |
+
'question': question,
|
330 |
+
'prompt': prompt
|
331 |
+
})
|
332 |
+
|
333 |
+
return samples
|
334 |
+
|
335 |
+
def get_prompt_template(self) -> str:
|
336 |
+
"""
|
337 |
+
Get the prompt template for general knowledge domain.
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
Prompt template string
|
341 |
+
"""
|
342 |
+
return "Please answer the following general knowledge question accurately and concisely:\n\n{question}"
|
343 |
+
|
344 |
+
|
345 |
+
# Factory function to create domain datasets
|
346 |
+
def create_domain_dataset(domain: str, data_path: Optional[str] = None) -> DomainDataset:
|
347 |
+
"""
|
348 |
+
Create a domain dataset based on the specified domain.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
domain: Domain category ('medical', 'legal', 'general')
|
352 |
+
data_path: Path to the dataset file (None to use default)
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
Domain dataset instance
|
356 |
+
"""
|
357 |
+
if domain == "medical":
|
358 |
+
return MedicalQADataset(data_path)
|
359 |
+
elif domain == "legal":
|
360 |
+
return LegalQADataset(data_path)
|
361 |
+
elif domain == "general":
|
362 |
+
return GeneralKnowledgeDataset(data_path)
|
363 |
+
else:
|
364 |
+
raise ValueError(f"Unsupported domain: {domain}")
|
evaluators.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Evaluation Framework for Cross-Domain Uncertainty Quantification
|
3 |
+
|
4 |
+
This module provides functionality for evaluating uncertainty quantification methods
|
5 |
+
across different domains, including metrics for uncertainty quality and cross-domain performance.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
12 |
+
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
|
13 |
+
|
14 |
+
class UncertaintyEvaluator:
|
15 |
+
"""Evaluator for uncertainty quantification methods."""
|
16 |
+
|
17 |
+
def __init__(self, name: str):
|
18 |
+
"""
|
19 |
+
Initialize the uncertainty evaluator.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
name: Name of the evaluation method
|
23 |
+
"""
|
24 |
+
self.name = name
|
25 |
+
|
26 |
+
def evaluate(
|
27 |
+
self,
|
28 |
+
uncertainties: List[float],
|
29 |
+
correctness: List[bool]
|
30 |
+
) -> Dict[str, float]:
|
31 |
+
"""
|
32 |
+
Evaluate uncertainty estimates against correctness.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
uncertainties: List of uncertainty scores (higher means more uncertain)
|
36 |
+
correctness: List of boolean correctness indicators
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Dictionary of evaluation metrics
|
40 |
+
"""
|
41 |
+
raise NotImplementedError("Subclasses must implement this method")
|
42 |
+
|
43 |
+
|
44 |
+
class CalibrationEvaluator(UncertaintyEvaluator):
|
45 |
+
"""Evaluator for calibration quality."""
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
"""Initialize the calibration evaluator."""
|
49 |
+
super().__init__("calibration_evaluator")
|
50 |
+
|
51 |
+
def expected_calibration_error(
|
52 |
+
self,
|
53 |
+
confidences: List[float],
|
54 |
+
correctness: List[bool],
|
55 |
+
num_bins: int = 10
|
56 |
+
) -> float:
|
57 |
+
"""
|
58 |
+
Calculate Expected Calibration Error (ECE).
|
59 |
+
|
60 |
+
Args:
|
61 |
+
confidences: List of confidence scores
|
62 |
+
correctness: List of boolean correctness indicators
|
63 |
+
num_bins: Number of bins for binning confidences
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Expected Calibration Error
|
67 |
+
"""
|
68 |
+
if len(confidences) != len(correctness):
|
69 |
+
raise ValueError("Confidences and correctness must have the same length")
|
70 |
+
|
71 |
+
if not confidences:
|
72 |
+
return 0.0
|
73 |
+
|
74 |
+
# Create bins and calculate ECE
|
75 |
+
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins))
|
76 |
+
ece = 0.0
|
77 |
+
|
78 |
+
for bin_idx in range(1, num_bins + 1):
|
79 |
+
bin_mask = (bin_indices == bin_idx)
|
80 |
+
if np.any(bin_mask):
|
81 |
+
bin_confidences = np.array(confidences)[bin_mask]
|
82 |
+
bin_correctness = np.array(correctness)[bin_mask]
|
83 |
+
bin_confidence = np.mean(bin_confidences)
|
84 |
+
bin_accuracy = np.mean(bin_correctness)
|
85 |
+
bin_size = np.sum(bin_mask)
|
86 |
+
|
87 |
+
# Weighted absolute difference between confidence and accuracy
|
88 |
+
ece += (bin_size / len(confidences)) * np.abs(bin_confidence - bin_accuracy)
|
89 |
+
|
90 |
+
return float(ece)
|
91 |
+
|
92 |
+
def maximum_calibration_error(
|
93 |
+
self,
|
94 |
+
confidences: List[float],
|
95 |
+
correctness: List[bool],
|
96 |
+
num_bins: int = 10
|
97 |
+
) -> float:
|
98 |
+
"""
|
99 |
+
Calculate Maximum Calibration Error (MCE).
|
100 |
+
|
101 |
+
Args:
|
102 |
+
confidences: List of confidence scores
|
103 |
+
correctness: List of boolean correctness indicators
|
104 |
+
num_bins: Number of bins for binning confidences
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Maximum Calibration Error
|
108 |
+
"""
|
109 |
+
if len(confidences) != len(correctness):
|
110 |
+
raise ValueError("Confidences and correctness must have the same length")
|
111 |
+
|
112 |
+
if not confidences:
|
113 |
+
return 0.0
|
114 |
+
|
115 |
+
# Create bins and calculate MCE
|
116 |
+
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins))
|
117 |
+
max_ce = 0.0
|
118 |
+
|
119 |
+
for bin_idx in range(1, num_bins + 1):
|
120 |
+
bin_mask = (bin_indices == bin_idx)
|
121 |
+
if np.any(bin_mask):
|
122 |
+
bin_confidences = np.array(confidences)[bin_mask]
|
123 |
+
bin_correctness = np.array(correctness)[bin_mask]
|
124 |
+
bin_confidence = np.mean(bin_confidences)
|
125 |
+
bin_accuracy = np.mean(bin_correctness)
|
126 |
+
|
127 |
+
# Absolute difference between confidence and accuracy
|
128 |
+
ce = np.abs(bin_confidence - bin_accuracy)
|
129 |
+
max_ce = max(max_ce, ce)
|
130 |
+
|
131 |
+
return float(max_ce)
|
132 |
+
|
133 |
+
def evaluate(
|
134 |
+
self,
|
135 |
+
confidences: List[float],
|
136 |
+
correctness: List[bool]
|
137 |
+
) -> Dict[str, float]:
|
138 |
+
"""
|
139 |
+
Evaluate calibration quality.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
confidences: List of confidence scores
|
143 |
+
correctness: List of boolean correctness indicators
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Dictionary of calibration metrics:
|
147 |
+
- ece: Expected Calibration Error
|
148 |
+
- mce: Maximum Calibration Error
|
149 |
+
"""
|
150 |
+
return {
|
151 |
+
"ece": self.expected_calibration_error(confidences, correctness),
|
152 |
+
"mce": self.maximum_calibration_error(confidences, correctness)
|
153 |
+
}
|
154 |
+
|
155 |
+
def plot_reliability_diagram(
|
156 |
+
self,
|
157 |
+
confidences: List[float],
|
158 |
+
correctness: List[bool],
|
159 |
+
num_bins: int = 10,
|
160 |
+
title: str = "Reliability Diagram",
|
161 |
+
save_path: Optional[str] = None
|
162 |
+
) -> None:
|
163 |
+
"""
|
164 |
+
Plot a reliability diagram for calibration visualization.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
confidences: List of confidence scores
|
168 |
+
correctness: List of boolean correctness indicators
|
169 |
+
num_bins: Number of bins for binning confidences
|
170 |
+
title: Title for the plot
|
171 |
+
save_path: Path to save the plot (None to display)
|
172 |
+
"""
|
173 |
+
if len(confidences) != len(correctness):
|
174 |
+
raise ValueError("Confidences and correctness must have the same length")
|
175 |
+
|
176 |
+
# Create bins
|
177 |
+
bin_edges = np.linspace(0, 1, num_bins + 1)
|
178 |
+
bin_indices = np.digitize(confidences, bin_edges[:-1])
|
179 |
+
|
180 |
+
# Calculate accuracy and confidence for each bin
|
181 |
+
bin_accuracies = []
|
182 |
+
bin_confidences = []
|
183 |
+
bin_sizes = []
|
184 |
+
|
185 |
+
for bin_idx in range(1, num_bins + 1):
|
186 |
+
bin_mask = (bin_indices == bin_idx)
|
187 |
+
if np.any(bin_mask):
|
188 |
+
bin_confidences.append(np.mean(np.array(confidences)[bin_mask]))
|
189 |
+
bin_accuracies.append(np.mean(np.array(correctness)[bin_mask]))
|
190 |
+
bin_sizes.append(np.sum(bin_mask))
|
191 |
+
else:
|
192 |
+
bin_confidences.append(0)
|
193 |
+
bin_accuracies.append(0)
|
194 |
+
bin_sizes.append(0)
|
195 |
+
|
196 |
+
# Plot reliability diagram
|
197 |
+
plt.figure(figsize=(10, 6))
|
198 |
+
|
199 |
+
# Plot perfect calibration line
|
200 |
+
plt.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
|
201 |
+
|
202 |
+
# Plot bin accuracies vs. confidences
|
203 |
+
plt.bar(
|
204 |
+
bin_edges[:-1],
|
205 |
+
bin_accuracies,
|
206 |
+
width=1/num_bins,
|
207 |
+
align='edge',
|
208 |
+
alpha=0.7,
|
209 |
+
label='Observed Accuracy'
|
210 |
+
)
|
211 |
+
|
212 |
+
# Plot confidence histogram
|
213 |
+
ax2 = plt.twinx()
|
214 |
+
ax2.hist(
|
215 |
+
confidences,
|
216 |
+
bins=bin_edges,
|
217 |
+
alpha=0.3,
|
218 |
+
color='gray',
|
219 |
+
label='Confidence Histogram'
|
220 |
+
)
|
221 |
+
|
222 |
+
# Calculate ECE and MCE
|
223 |
+
ece = self.expected_calibration_error(confidences, correctness, num_bins)
|
224 |
+
mce = self.maximum_calibration_error(confidences, correctness, num_bins)
|
225 |
+
|
226 |
+
# Add ECE and MCE to title
|
227 |
+
plt.title(f"{title}\nECE: {ece:.4f}, MCE: {mce:.4f}")
|
228 |
+
|
229 |
+
# Add labels and legend
|
230 |
+
plt.xlabel('Confidence')
|
231 |
+
plt.ylabel('Accuracy')
|
232 |
+
ax2.set_ylabel('Count')
|
233 |
+
|
234 |
+
# Add legend
|
235 |
+
lines, labels = plt.gca().get_legend_handles_labels()
|
236 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
237 |
+
ax2.legend(lines + lines2, labels + labels2, loc='best')
|
238 |
+
|
239 |
+
# Save or display the plot
|
240 |
+
if save_path:
|
241 |
+
plt.savefig(save_path)
|
242 |
+
plt.close()
|
243 |
+
else:
|
244 |
+
plt.tight_layout()
|
245 |
+
plt.show()
|
246 |
+
|
247 |
+
|
248 |
+
class SelectivePredictionEvaluator(UncertaintyEvaluator):
|
249 |
+
"""Evaluator for selective prediction performance."""
|
250 |
+
|
251 |
+
def __init__(self):
|
252 |
+
"""Initialize the selective prediction evaluator."""
|
253 |
+
super().__init__("selective_prediction_evaluator")
|
254 |
+
|
255 |
+
def evaluate(
|
256 |
+
self,
|
257 |
+
uncertainties: List[float],
|
258 |
+
correctness: List[bool]
|
259 |
+
) -> Dict[str, float]:
|
260 |
+
"""
|
261 |
+
Evaluate selective prediction performance.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
uncertainties: List of uncertainty scores (higher means more uncertain)
|
265 |
+
correctness: List of boolean correctness indicators
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
Dictionary of selective prediction metrics:
|
269 |
+
- auroc: Area Under ROC Curve for predicting errors
|
270 |
+
- auprc: Area Under Precision-Recall Curve for predicting errors
|
271 |
+
- uncertainty_error_correlation: Correlation between uncertainty and errors
|
272 |
+
"""
|
273 |
+
if len(uncertainties) != len(correctness):
|
274 |
+
raise ValueError("Uncertainties and correctness must have the same length")
|
275 |
+
|
276 |
+
if not uncertainties:
|
277 |
+
return {
|
278 |
+
"auroc": 0.5,
|
279 |
+
"auprc": 0.5,
|
280 |
+
"uncertainty_error_correlation": 0.0
|
281 |
+
}
|
282 |
+
|
283 |
+
# Convert correctness to errors (1 for error, 0 for correct)
|
284 |
+
errors = [1 - int(c) for c in correctness]
|
285 |
+
|
286 |
+
# Calculate AUROC for predicting errors
|
287 |
+
try:
|
288 |
+
auroc = roc_auc_score(errors, uncertainties)
|
289 |
+
except:
|
290 |
+
# Handle case where all predictions are correct or all are wrong
|
291 |
+
auroc = 0.5
|
292 |
+
|
293 |
+
# Calculate AUPRC for predicting errors
|
294 |
+
try:
|
295 |
+
precision, recall, _ = precision_recall_curve(errors, uncertainties)
|
296 |
+
auprc = auc(recall, precision)
|
297 |
+
except:
|
298 |
+
# Handle case where all predictions are correct or all are wrong
|
299 |
+
auprc = 0.5
|
300 |
+
|
301 |
+
# Calculate correlation between uncertainty and errors
|
302 |
+
uncertainty_error_correlation = np.corrcoef(uncertainties, errors)[0, 1]
|
303 |
+
|
304 |
+
return {
|
305 |
+
"auroc": float(auroc),
|
306 |
+
"auprc": float(auprc),
|
307 |
+
"uncertainty_error_correlation": float(uncertainty_error_correlation)
|
308 |
+
}
|
309 |
+
|
310 |
+
def plot_selective_prediction_curve(
|
311 |
+
self,
|
312 |
+
uncertainties: List[float],
|
313 |
+
correctness: List[bool],
|
314 |
+
title: str = "Selective Prediction Performance",
|
315 |
+
save_path: Optional[str] = None
|
316 |
+
) -> None:
|
317 |
+
"""
|
318 |
+
Plot a selective prediction curve.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
uncertainties: List of uncertainty scores (higher means more uncertain)
|
322 |
+
correctness: List of boolean correctness indicators
|
323 |
+
title: Title for the plot
|
324 |
+
save_path: Path to save the plot (None to display)
|
325 |
+
"""
|
326 |
+
if len(uncertainties) != len(correctness):
|
327 |
+
raise ValueError("Uncertainties and correctness must have the same length")
|
328 |
+
|
329 |
+
# Sort by uncertainty (ascending)
|
330 |
+
sorted_indices = np.argsort(uncertainties)
|
331 |
+
sorted_correctness = np.array(correctness)[sorted_indices]
|
332 |
+
|
333 |
+
# Calculate cumulative accuracy at different coverage levels
|
334 |
+
coverages = np.linspace(0, 1, 100)
|
335 |
+
accuracies = []
|
336 |
+
|
337 |
+
for coverage in coverages:
|
338 |
+
if coverage == 0:
|
339 |
+
accuracies.append(1.0) # Perfect accuracy at 0% coverage
|
340 |
+
else:
|
341 |
+
n_samples = int(coverage * len(sorted_correctness))
|
342 |
+
if n_samples == 0:
|
343 |
+
accuracies.append(1.0)
|
344 |
+
else:
|
345 |
+
accuracies.append(np.mean(sorted_correctness[:n_samples]))
|
346 |
+
|
347 |
+
# Plot selective prediction curve
|
348 |
+
plt.figure(figsize=(10, 6))
|
349 |
+
plt.plot(coverages, accuracies, 'b-', linewidth=2)
|
350 |
+
|
351 |
+
# Add reference line for random selection
|
352 |
+
plt.plot([0, 1], [np.mean(correctness), np.mean(correctness)], 'k--', label='Random Selection')
|
353 |
+
|
354 |
+
# Calculate AUROC
|
355 |
+
metrics = self.evaluate(uncertainties, correctness)
|
356 |
+
|
357 |
+
# Add AUROC to title
|
358 |
+
plt.title(f"{title}\nAUROC: {metrics['auroc']:.4f}")
|
359 |
+
|
360 |
+
# Add labels and legend
|
361 |
+
plt.xlabel('Coverage')
|
362 |
+
plt.ylabel('Accuracy')
|
363 |
+
plt.legend(loc='best')
|
364 |
+
|
365 |
+
# Save or display the plot
|
366 |
+
if save_path:
|
367 |
+
plt.savefig(save_path)
|
368 |
+
plt.close()
|
369 |
+
else:
|
370 |
+
plt.tight_layout()
|
371 |
+
plt.show()
|
372 |
+
|
373 |
+
|
374 |
+
class CrossDomainEvaluator:
|
375 |
+
"""Evaluator for cross-domain uncertainty performance."""
|
376 |
+
|
377 |
+
def __init__(self):
|
378 |
+
"""Initialize the cross-domain evaluator."""
|
379 |
+
self.name = "cross_domain_evaluator"
|
380 |
+
self.calibration_evaluator = CalibrationEvaluator()
|
381 |
+
self.selective_prediction_evaluator = SelectivePredictionEvaluator()
|
382 |
+
|
383 |
+
def evaluate_domain_transfer(
|
384 |
+
self,
|
385 |
+
source_uncertainties: List[float],
|
386 |
+
source_correctness: List[bool],
|
387 |
+
target_uncertainties: List[float],
|
388 |
+
target_correctness: List[bool]
|
389 |
+
) -> Dict[str, float]:
|
390 |
+
"""
|
391 |
+
Evaluate domain transfer performance.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
source_uncertainties: List of uncertainty scores from source domain
|
395 |
+
source_correctness: List of boolean correctness indicators from source domain
|
396 |
+
target_uncertainties: List of uncertainty scores from target domain
|
397 |
+
target_correctness: List of boolean correctness indicators from target domain
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Dictionary of domain transfer metrics:
|
401 |
+
- source_auroc: AUROC in source domain
|
402 |
+
- target_auroc: AUROC in target domain
|
403 |
+
- transfer_degradation: Degradation in AUROC from source to target
|
404 |
+
- source_ece: ECE in source domain
|
405 |
+
- target_ece: ECE in target domain
|
406 |
+
- calibration_shift: Shift in calibration from source to target
|
407 |
+
"""
|
408 |
+
# Evaluate source domain
|
409 |
+
source_selective = self.selective_prediction_evaluator.evaluate(
|
410 |
+
source_uncertainties, source_correctness
|
411 |
+
)
|
412 |
+
source_calibration = self.calibration_evaluator.evaluate(
|
413 |
+
[1 - u for u in source_uncertainties], source_correctness
|
414 |
+
)
|
415 |
+
|
416 |
+
# Evaluate target domain
|
417 |
+
target_selective = self.selective_prediction_evaluator.evaluate(
|
418 |
+
target_uncertainties, target_correctness
|
419 |
+
)
|
420 |
+
target_calibration = self.calibration_evaluator.evaluate(
|
421 |
+
[1 - u for u in target_uncertainties], target_correctness
|
422 |
+
)
|
423 |
+
|
424 |
+
# Calculate transfer metrics
|
425 |
+
transfer_degradation = source_selective["auroc"] - target_selective["auroc"]
|
426 |
+
calibration_shift = target_calibration["ece"] - source_calibration["ece"]
|
427 |
+
|
428 |
+
return {
|
429 |
+
"source_auroc": source_selective["auroc"],
|
430 |
+
"target_auroc": target_selective["auroc"],
|
431 |
+
"transfer_degradation": float(transfer_degradation),
|
432 |
+
"source_ece": source_calibration["ece"],
|
433 |
+
"target_ece": target_calibration["ece"],
|
434 |
+
"calibration_shift": float(calibration_shift)
|
435 |
+
}
|
436 |
+
|
437 |
+
def evaluate_all_domains(
|
438 |
+
self,
|
439 |
+
domain_results: Dict[str, Dict[str, Any]]
|
440 |
+
) -> Dict[str, Dict[str, float]]:
|
441 |
+
"""
|
442 |
+
Evaluate uncertainty performance across all domains.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
domain_results: Dictionary mapping domain names to results
|
446 |
+
Each result should contain:
|
447 |
+
- uncertainties: List of uncertai
|
448 |
+
(Content truncated due to size limit. Use line ranges to read in chunks)
|
example_config.json
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Example experiment configuration for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This configuration file defines parameters for running an experiment with the framework.
|
5 |
+
"""
|
6 |
+
|
7 |
+
{
|
8 |
+
"name": "agent_tuning_experiment",
|
9 |
+
"description": "Experiment to evaluate the effectiveness of negative and synthetic samples in agent tuning",
|
10 |
+
|
11 |
+
"llm": {
|
12 |
+
"model_name": "gpt2",
|
13 |
+
"model_type": "causal",
|
14 |
+
"device": "cpu",
|
15 |
+
"max_length": 512,
|
16 |
+
"temperature": 0.7
|
17 |
+
},
|
18 |
+
|
19 |
+
"dataset": {
|
20 |
+
"name": "agent_tuning_dataset",
|
21 |
+
"num_trajectories": 20
|
22 |
+
},
|
23 |
+
|
24 |
+
"negative_samples": {
|
25 |
+
"enabled": true,
|
26 |
+
"method": "response_degradation",
|
27 |
+
"params": {
|
28 |
+
"degradation_level": 0.6
|
29 |
+
}
|
30 |
+
},
|
31 |
+
|
32 |
+
"synthetic_trajectories": {
|
33 |
+
"enabled": true,
|
34 |
+
"method": "template",
|
35 |
+
"params": {
|
36 |
+
"num_interactions": 3
|
37 |
+
}
|
38 |
+
},
|
39 |
+
|
40 |
+
"tuning": {
|
41 |
+
"method": "supervised",
|
42 |
+
"params": {
|
43 |
+
"num_train_epochs": 3,
|
44 |
+
"learning_rate": 5e-5,
|
45 |
+
"batch_size": 4,
|
46 |
+
"gradient_accumulation_steps": 4,
|
47 |
+
"positive_weight": 0.8
|
48 |
+
}
|
49 |
+
},
|
50 |
+
|
51 |
+
"evaluation": {
|
52 |
+
"method": "quality",
|
53 |
+
"params": {
|
54 |
+
"num_samples": 10
|
55 |
+
},
|
56 |
+
"comparative": {
|
57 |
+
"enabled": true,
|
58 |
+
"params": {
|
59 |
+
"num_samples": 5
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
llm_interface.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM Interface Module for Cross-Domain Uncertainty Quantification
|
3 |
+
|
4 |
+
This module provides a unified interface for interacting with large language models,
|
5 |
+
supporting multiple model architectures and uncertainty quantification methods.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from typing import List, Dict, Any, Union, Optional
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
class LLMInterface:
|
15 |
+
"""Interface for interacting with large language models with uncertainty quantification."""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_name: str,
|
20 |
+
model_type: str = "causal",
|
21 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
22 |
+
cache_dir: Optional[str] = None,
|
23 |
+
max_length: int = 512,
|
24 |
+
temperature: float = 1.0,
|
25 |
+
top_p: float = 1.0,
|
26 |
+
num_beams: int = 1
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Initialize the LLM interface.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_name: Name of the Hugging Face model to use
|
33 |
+
model_type: Type of model ('causal' or 'seq2seq')
|
34 |
+
device: Device to run the model on ('cpu' or 'cuda')
|
35 |
+
cache_dir: Directory to cache models
|
36 |
+
max_length: Maximum length of generated sequences
|
37 |
+
temperature: Sampling temperature
|
38 |
+
top_p: Nucleus sampling parameter
|
39 |
+
num_beams: Number of beams for beam search
|
40 |
+
"""
|
41 |
+
self.model_name = model_name
|
42 |
+
self.model_type = model_type
|
43 |
+
self.device = device
|
44 |
+
self.cache_dir = cache_dir
|
45 |
+
self.max_length = max_length
|
46 |
+
self.temperature = temperature
|
47 |
+
self.top_p = top_p
|
48 |
+
self.num_beams = num_beams
|
49 |
+
|
50 |
+
# Load tokenizer
|
51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
52 |
+
model_name,
|
53 |
+
cache_dir=cache_dir
|
54 |
+
)
|
55 |
+
|
56 |
+
# Load model based on type
|
57 |
+
if model_type == "causal":
|
58 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
59 |
+
model_name,
|
60 |
+
cache_dir=cache_dir,
|
61 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
62 |
+
).to(device)
|
63 |
+
elif model_type == "seq2seq":
|
64 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
65 |
+
model_name,
|
66 |
+
cache_dir=cache_dir,
|
67 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
68 |
+
).to(device)
|
69 |
+
else:
|
70 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
71 |
+
|
72 |
+
# Response cache for efficiency
|
73 |
+
self.response_cache = {}
|
74 |
+
|
75 |
+
def generate(
|
76 |
+
self,
|
77 |
+
prompt: str,
|
78 |
+
num_samples: int = 1,
|
79 |
+
return_logits: bool = False,
|
80 |
+
**kwargs
|
81 |
+
) -> Dict[str, Any]:
|
82 |
+
"""
|
83 |
+
Generate responses from the model with uncertainty quantification.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
prompt: Input text prompt
|
87 |
+
num_samples: Number of samples to generate (for MC methods)
|
88 |
+
return_logits: Whether to return token logits
|
89 |
+
**kwargs: Additional generation parameters
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Dictionary containing:
|
93 |
+
- response: The generated text
|
94 |
+
- samples: Multiple samples if num_samples > 1
|
95 |
+
- logits: Token logits if return_logits is True
|
96 |
+
"""
|
97 |
+
# Check cache first
|
98 |
+
cache_key = (prompt, num_samples, return_logits, str(kwargs))
|
99 |
+
if cache_key in self.response_cache:
|
100 |
+
return self.response_cache[cache_key]
|
101 |
+
|
102 |
+
# Prepare inputs
|
103 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
104 |
+
|
105 |
+
# Set generation parameters
|
106 |
+
gen_kwargs = {
|
107 |
+
"max_length": self.max_length,
|
108 |
+
"temperature": self.temperature,
|
109 |
+
"top_p": self.top_p,
|
110 |
+
"num_beams": self.num_beams,
|
111 |
+
"do_sample": self.temperature > 0,
|
112 |
+
"pad_token_id": self.tokenizer.eos_token_id
|
113 |
+
}
|
114 |
+
gen_kwargs.update(kwargs)
|
115 |
+
|
116 |
+
# Generate multiple samples if requested
|
117 |
+
samples = []
|
118 |
+
all_logits = []
|
119 |
+
|
120 |
+
for _ in range(num_samples):
|
121 |
+
with torch.no_grad():
|
122 |
+
outputs = self.model.generate(
|
123 |
+
**inputs,
|
124 |
+
output_scores=return_logits,
|
125 |
+
return_dict_in_generate=True,
|
126 |
+
**gen_kwargs
|
127 |
+
)
|
128 |
+
|
129 |
+
# Extract generated tokens
|
130 |
+
if self.model_type == "causal":
|
131 |
+
gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:]
|
132 |
+
else:
|
133 |
+
gen_tokens = outputs.sequences[0]
|
134 |
+
|
135 |
+
# Decode tokens to text
|
136 |
+
gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True)
|
137 |
+
samples.append(gen_text)
|
138 |
+
|
139 |
+
# Extract logits if requested
|
140 |
+
if return_logits and hasattr(outputs, "scores"):
|
141 |
+
all_logits.append([score.cpu().numpy() for score in outputs.scores])
|
142 |
+
|
143 |
+
# Prepare result
|
144 |
+
result = {
|
145 |
+
"response": samples[0], # Primary response is first sample
|
146 |
+
"samples": samples
|
147 |
+
}
|
148 |
+
|
149 |
+
if return_logits:
|
150 |
+
result["logits"] = all_logits
|
151 |
+
|
152 |
+
# Cache result
|
153 |
+
self.response_cache[cache_key] = result
|
154 |
+
return result
|
155 |
+
|
156 |
+
def batch_generate(
|
157 |
+
self,
|
158 |
+
prompts: List[str],
|
159 |
+
**kwargs
|
160 |
+
) -> List[Dict[str, Any]]:
|
161 |
+
"""
|
162 |
+
Generate responses for a batch of prompts.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
prompts: List of input text prompts
|
166 |
+
**kwargs: Additional generation parameters
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
List of generation results for each prompt
|
170 |
+
"""
|
171 |
+
results = []
|
172 |
+
for prompt in tqdm(prompts, desc="Generating responses"):
|
173 |
+
results.append(self.generate(prompt, **kwargs))
|
174 |
+
return results
|
175 |
+
|
176 |
+
def clear_cache(self):
|
177 |
+
"""Clear the response cache."""
|
178 |
+
self.response_cache = {}
|
main.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main Integration Module for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This module provides functionality for integrating all components of the framework
|
5 |
+
and running end-to-end experiments.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import argparse
|
11 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
12 |
+
|
13 |
+
from models.llm_interface import LLMInterface
|
14 |
+
from data.trajectory_data import Trajectory, TrajectoryDataset, create_synthetic_dataset
|
15 |
+
from training.negative_samples import create_negative_sample_generator
|
16 |
+
from training.synthetic_trajectories import create_synthetic_trajectory_generator
|
17 |
+
from training.agent_tuner import create_agent_tuner
|
18 |
+
from evaluation.evaluators import create_agent_evaluator
|
19 |
+
|
20 |
+
def run_experiment(
|
21 |
+
experiment_config: Dict[str, Any],
|
22 |
+
output_dir: str
|
23 |
+
) -> Dict[str, Any]:
|
24 |
+
"""
|
25 |
+
Run an end-to-end experiment with the framework.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
experiment_config: Experiment configuration
|
29 |
+
output_dir: Directory to save results
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Dictionary of experiment results
|
33 |
+
"""
|
34 |
+
print(f"Starting experiment: {experiment_config['name']}")
|
35 |
+
|
36 |
+
# Create output directory
|
37 |
+
os.makedirs(output_dir, exist_ok=True)
|
38 |
+
|
39 |
+
# Save experiment configuration
|
40 |
+
with open(f"{output_dir}/experiment_config.json", "w") as f:
|
41 |
+
json.dump(experiment_config, f, indent=2)
|
42 |
+
|
43 |
+
# Initialize LLM interface
|
44 |
+
print("Initializing LLM interface...")
|
45 |
+
llm_config = experiment_config.get("llm", {})
|
46 |
+
llm_interface = LLMInterface(
|
47 |
+
model_name=llm_config.get("model_name", "gpt2"),
|
48 |
+
model_type=llm_config.get("model_type", "causal"),
|
49 |
+
device=llm_config.get("device", "cpu"),
|
50 |
+
max_length=llm_config.get("max_length", 512),
|
51 |
+
temperature=llm_config.get("temperature", 0.7)
|
52 |
+
)
|
53 |
+
|
54 |
+
# Load or create dataset
|
55 |
+
print("Preparing dataset...")
|
56 |
+
dataset_config = experiment_config.get("dataset", {})
|
57 |
+
|
58 |
+
if dataset_config.get("path"):
|
59 |
+
# Load existing dataset
|
60 |
+
dataset = TrajectoryDataset(dataset_config.get("name", "experiment_dataset"))
|
61 |
+
dataset.load_from_json(dataset_config["path"])
|
62 |
+
else:
|
63 |
+
# Create synthetic dataset
|
64 |
+
dataset = create_synthetic_dataset(dataset_config.get("num_trajectories", 20))
|
65 |
+
|
66 |
+
print(f"Dataset loaded with {len(dataset.trajectories)} trajectories")
|
67 |
+
|
68 |
+
# Generate negative samples
|
69 |
+
print("Generating negative samples...")
|
70 |
+
negative_config = experiment_config.get("negative_samples", {})
|
71 |
+
|
72 |
+
if negative_config.get("enabled", True):
|
73 |
+
negative_generator = create_negative_sample_generator(
|
74 |
+
negative_config.get("method", "response_degradation")
|
75 |
+
)
|
76 |
+
|
77 |
+
positive_trajectories = dataset.get_trajectories(positive_only=True)
|
78 |
+
negative_trajectories = negative_generator.batch_generate(
|
79 |
+
positive_trajectories,
|
80 |
+
**negative_config.get("params", {})
|
81 |
+
)
|
82 |
+
|
83 |
+
# Add negative trajectories to dataset
|
84 |
+
for trajectory in negative_trajectories:
|
85 |
+
dataset.add_trajectory(trajectory)
|
86 |
+
|
87 |
+
print(f"Added {len(negative_trajectories)} negative trajectories")
|
88 |
+
|
89 |
+
# Generate synthetic trajectories
|
90 |
+
print("Generating synthetic trajectories...")
|
91 |
+
synthetic_config = experiment_config.get("synthetic_trajectories", {})
|
92 |
+
|
93 |
+
if synthetic_config.get("enabled", True):
|
94 |
+
synthetic_generator = create_synthetic_trajectory_generator(
|
95 |
+
synthetic_config.get("method", "template"),
|
96 |
+
llm_interface if synthetic_config.get("method") in ["llm", "hybrid"] else None
|
97 |
+
)
|
98 |
+
|
99 |
+
# Generate from task descriptions
|
100 |
+
task_descriptions = [t.task_description for t in dataset.get_trajectories(positive_only=True)]
|
101 |
+
task_descriptions = list(set(task_descriptions)) # Remove duplicates
|
102 |
+
|
103 |
+
synthetic_trajectories = synthetic_generator.batch_generate(
|
104 |
+
task_descriptions,
|
105 |
+
**synthetic_config.get("params", {})
|
106 |
+
)
|
107 |
+
|
108 |
+
# Add synthetic trajectories to dataset
|
109 |
+
for trajectory in synthetic_trajectories:
|
110 |
+
dataset.add_trajectory(trajectory)
|
111 |
+
|
112 |
+
print(f"Added {len(synthetic_trajectories)} synthetic trajectories")
|
113 |
+
|
114 |
+
# Save the enhanced dataset
|
115 |
+
dataset.save_to_json(f"{output_dir}/enhanced_dataset.json")
|
116 |
+
|
117 |
+
# Analyze dataset
|
118 |
+
dataset_stats = dataset.analyze_dataset()
|
119 |
+
with open(f"{output_dir}/dataset_stats.json", "w") as f:
|
120 |
+
json.dump(dataset_stats, f, indent=2)
|
121 |
+
|
122 |
+
# Split dataset for training and evaluation
|
123 |
+
all_trajectories = dataset.get_trajectories()
|
124 |
+
split_idx = int(len(all_trajectories) * 0.8) # 80% for training
|
125 |
+
|
126 |
+
train_trajectories = all_trajectories[:split_idx]
|
127 |
+
eval_trajectories = all_trajectories[split_idx:]
|
128 |
+
|
129 |
+
print(f"Split dataset: {len(train_trajectories)} for training, {len(eval_trajectories)} for evaluation")
|
130 |
+
|
131 |
+
# Tune agent
|
132 |
+
print("Tuning agent...")
|
133 |
+
tuning_config = experiment_config.get("tuning", {})
|
134 |
+
|
135 |
+
tuner = create_agent_tuner(tuning_config.get("method", "supervised"))
|
136 |
+
|
137 |
+
tuned_model, tuning_metrics = tuner.tune(
|
138 |
+
model_name=llm_config.get("model_name", "gpt2"),
|
139 |
+
trajectories=train_trajectories,
|
140 |
+
output_dir=f"{output_dir}/tuned_model",
|
141 |
+
**tuning_config.get("params", {})
|
142 |
+
)
|
143 |
+
|
144 |
+
# Save tuning metrics
|
145 |
+
with open(f"{output_dir}/tuning_metrics.json", "w") as f:
|
146 |
+
# Convert any non-serializable values to strings
|
147 |
+
serializable_metrics = {}
|
148 |
+
for k, v in tuning_metrics.items():
|
149 |
+
if isinstance(v, (int, float, str, bool, list, dict)) or v is None:
|
150 |
+
serializable_metrics[k] = v
|
151 |
+
else:
|
152 |
+
serializable_metrics[k] = str(v)
|
153 |
+
|
154 |
+
json.dump(serializable_metrics, f, indent=2)
|
155 |
+
|
156 |
+
# Create tuned model interface
|
157 |
+
tuned_llm_interface = LLMInterface(
|
158 |
+
model_name=f"{output_dir}/tuned_model",
|
159 |
+
model_type=llm_config.get("model_type", "causal"),
|
160 |
+
device=llm_config.get("device", "cpu"),
|
161 |
+
max_length=llm_config.get("max_length", 512),
|
162 |
+
temperature=llm_config.get("temperature", 0.7)
|
163 |
+
)
|
164 |
+
|
165 |
+
# Evaluate agent
|
166 |
+
print("Evaluating agent...")
|
167 |
+
eval_config = experiment_config.get("evaluation", {})
|
168 |
+
|
169 |
+
evaluator = create_agent_evaluator(eval_config.get("method", "quality"))
|
170 |
+
|
171 |
+
eval_results = evaluator.evaluate(
|
172 |
+
llm_interface=tuned_llm_interface,
|
173 |
+
test_trajectories=eval_trajectories,
|
174 |
+
**eval_config.get("params", {})
|
175 |
+
)
|
176 |
+
|
177 |
+
# Visualize evaluation results
|
178 |
+
evaluator.visualize_results(
|
179 |
+
results=eval_results,
|
180 |
+
output_dir=f"{output_dir}/evaluation"
|
181 |
+
)
|
182 |
+
|
183 |
+
# Save evaluation results
|
184 |
+
with open(f"{output_dir}/evaluation_results.json", "w") as f:
|
185 |
+
# Create a simplified version without large data
|
186 |
+
simplified_results = {}
|
187 |
+
|
188 |
+
if "aggregated" in eval_results:
|
189 |
+
simplified_results["aggregated"] = eval_results["aggregated"]
|
190 |
+
|
191 |
+
if "metrics" in eval_results:
|
192 |
+
# Include only essential metrics
|
193 |
+
simplified_results["metrics"] = [
|
194 |
+
{k: v for k, v in m.items() if k not in ["generated_responses"]}
|
195 |
+
for m in eval_results["metrics"]
|
196 |
+
]
|
197 |
+
|
198 |
+
json.dump(simplified_results, f, indent=2)
|
199 |
+
|
200 |
+
# Comparative evaluation (if configured)
|
201 |
+
if eval_config.get("comparative", {}).get("enabled", False):
|
202 |
+
print("Performing comparative evaluation...")
|
203 |
+
|
204 |
+
# Create baseline model interface
|
205 |
+
baseline_llm_interface = LLMInterface(
|
206 |
+
model_name=llm_config.get("model_name", "gpt2"),
|
207 |
+
model_type=llm_config.get("model_type", "causal"),
|
208 |
+
device=llm_config.get("device", "cpu"),
|
209 |
+
max_length=llm_config.get("max_length", 512),
|
210 |
+
temperature=llm_config.get("temperature", 0.7)
|
211 |
+
)
|
212 |
+
|
213 |
+
# Create comparative evaluator
|
214 |
+
comparative_evaluator = create_agent_evaluator("comparative")
|
215 |
+
|
216 |
+
# Evaluate and compare
|
217 |
+
comparative_results = comparative_evaluator.evaluate(
|
218 |
+
llm_interfaces={
|
219 |
+
"baseline": baseline_llm_interface,
|
220 |
+
"tuned": tuned_llm_interface
|
221 |
+
},
|
222 |
+
test_trajectories=eval_trajectories,
|
223 |
+
**eval_config.get("comparative", {}).get("params", {})
|
224 |
+
)
|
225 |
+
|
226 |
+
# Visualize comparative results
|
227 |
+
comparative_evaluator.visualize_results(
|
228 |
+
results=comparative_results,
|
229 |
+
output_dir=f"{output_dir}/comparative"
|
230 |
+
)
|
231 |
+
|
232 |
+
# Save comparative results
|
233 |
+
with open(f"{output_dir}/comparative_results.json", "w") as f:
|
234 |
+
# Create a simplified version
|
235 |
+
simplified_comparative = {
|
236 |
+
"comparative": comparative_results.get("comparative", {})
|
237 |
+
}
|
238 |
+
|
239 |
+
json.dump(simplified_comparative, f, indent=2)
|
240 |
+
|
241 |
+
print(f"Experiment completed. Results saved to {output_dir}")
|
242 |
+
|
243 |
+
return {
|
244 |
+
"dataset_stats": dataset_stats,
|
245 |
+
"tuning_metrics": tuning_metrics,
|
246 |
+
"evaluation_results": eval_results
|
247 |
+
}
|
248 |
+
|
249 |
+
def main():
|
250 |
+
"""Main function for running the framework from command line."""
|
251 |
+
parser = argparse.ArgumentParser(description="Agent Tuning Optimization Framework")
|
252 |
+
parser.add_argument("--config", type=str, required=True, help="Path to experiment configuration file")
|
253 |
+
parser.add_argument("--output", type=str, default="./experiment_results", help="Directory to save results")
|
254 |
+
|
255 |
+
args = parser.parse_args()
|
256 |
+
|
257 |
+
# Load experiment configuration
|
258 |
+
with open(args.config, "r") as f:
|
259 |
+
experiment_config = json.load(f)
|
260 |
+
|
261 |
+
# Run experiment
|
262 |
+
run_experiment(experiment_config, args.output)
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
main()
|
negative_samples.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Negative Sample Generation Module for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This module provides functionality for generating negative samples to enhance
|
5 |
+
agent tuning by exposing the model to challenging failure cases.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from data.trajectory_data import Trajectory, TrajectoryDataset
|
14 |
+
|
15 |
+
class NegativeSampleGenerator:
|
16 |
+
"""Base class for negative sample generation strategies."""
|
17 |
+
|
18 |
+
def __init__(self, name: str):
|
19 |
+
"""
|
20 |
+
Initialize the negative sample generator.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
name: Name of the generator strategy
|
24 |
+
"""
|
25 |
+
self.name = name
|
26 |
+
|
27 |
+
def generate(
|
28 |
+
self,
|
29 |
+
trajectory: Trajectory,
|
30 |
+
**kwargs
|
31 |
+
) -> Trajectory:
|
32 |
+
"""
|
33 |
+
Generate a negative sample from a positive trajectory.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
trajectory: Positive trajectory to transform
|
37 |
+
**kwargs: Additional generation parameters
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Negative trajectory
|
41 |
+
"""
|
42 |
+
raise NotImplementedError("Subclasses must implement this method")
|
43 |
+
|
44 |
+
def batch_generate(
|
45 |
+
self,
|
46 |
+
trajectories: List[Trajectory],
|
47 |
+
**kwargs
|
48 |
+
) -> List[Trajectory]:
|
49 |
+
"""
|
50 |
+
Generate negative samples from a batch of positive trajectories.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
trajectories: List of positive trajectories
|
54 |
+
**kwargs: Additional generation parameters
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
List of negative trajectories
|
58 |
+
"""
|
59 |
+
negative_trajectories = []
|
60 |
+
|
61 |
+
for trajectory in tqdm(trajectories, desc=f"Generating negative samples with {self.name}"):
|
62 |
+
negative_trajectories.append(self.generate(trajectory, **kwargs))
|
63 |
+
|
64 |
+
return negative_trajectories
|
65 |
+
|
66 |
+
|
67 |
+
class ResponseDegradationGenerator(NegativeSampleGenerator):
|
68 |
+
"""Generate negative samples by degrading agent responses."""
|
69 |
+
|
70 |
+
def __init__(self):
|
71 |
+
"""Initialize the response degradation generator."""
|
72 |
+
super().__init__("response_degradation")
|
73 |
+
|
74 |
+
def generate(
|
75 |
+
self,
|
76 |
+
trajectory: Trajectory,
|
77 |
+
degradation_level: float = 0.5,
|
78 |
+
**kwargs
|
79 |
+
) -> Trajectory:
|
80 |
+
"""
|
81 |
+
Generate a negative sample by degrading agent responses.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
trajectory: Positive trajectory to transform
|
85 |
+
degradation_level: Level of degradation (0.0 to 1.0)
|
86 |
+
**kwargs: Additional generation parameters
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Negative trajectory with degraded responses
|
90 |
+
"""
|
91 |
+
# Create a copy of interactions to modify
|
92 |
+
new_interactions = []
|
93 |
+
|
94 |
+
for interaction in trajectory.interactions:
|
95 |
+
user_msg = interaction['user']
|
96 |
+
agent_msg = interaction['agent']
|
97 |
+
|
98 |
+
# Apply degradation techniques based on level
|
99 |
+
if degradation_level > 0.7:
|
100 |
+
# High degradation: completely irrelevant response
|
101 |
+
agent_msg = self._generate_irrelevant_response()
|
102 |
+
elif degradation_level > 0.4:
|
103 |
+
# Medium degradation: truncate and add errors
|
104 |
+
agent_msg = self._truncate_and_add_errors(agent_msg)
|
105 |
+
else:
|
106 |
+
# Low degradation: introduce minor issues
|
107 |
+
agent_msg = self._introduce_minor_issues(agent_msg)
|
108 |
+
|
109 |
+
new_interactions.append({
|
110 |
+
'user': user_msg,
|
111 |
+
'agent': agent_msg
|
112 |
+
})
|
113 |
+
|
114 |
+
# Create new trajectory with degraded responses
|
115 |
+
metadata = trajectory.metadata.copy()
|
116 |
+
metadata['is_positive'] = False
|
117 |
+
metadata['degradation_level'] = degradation_level
|
118 |
+
metadata['original_quality_score'] = trajectory.get_quality_score()
|
119 |
+
metadata['quality_score'] = None # Will be recalculated
|
120 |
+
|
121 |
+
return Trajectory(
|
122 |
+
task_description=trajectory.task_description,
|
123 |
+
interactions=new_interactions,
|
124 |
+
metadata=metadata
|
125 |
+
)
|
126 |
+
|
127 |
+
def _generate_irrelevant_response(self) -> str:
|
128 |
+
"""Generate a completely irrelevant response."""
|
129 |
+
irrelevant_responses = [
|
130 |
+
"I'm sorry, but I don't understand what you're asking for. Could you please clarify?",
|
131 |
+
"I apologize, but I cannot assist with that request at this time.",
|
132 |
+
"That's an interesting question, but I think we should focus on something else instead.",
|
133 |
+
"Let me check my database... I don't seem to have any information about that.",
|
134 |
+
"I think you might be confused about what you're asking for. Let me suggest something completely different.",
|
135 |
+
"I'm not sure I understand the context of your request. Could you provide more details?",
|
136 |
+
"I'm having trouble processing your request. Could we try a different approach?",
|
137 |
+
"That's not something I can help with. Let me tell you about something unrelated instead."
|
138 |
+
]
|
139 |
+
return random.choice(irrelevant_responses)
|
140 |
+
|
141 |
+
def _truncate_and_add_errors(self, text: str) -> str:
|
142 |
+
"""Truncate the text and add errors."""
|
143 |
+
# Truncate to 30-70% of original length
|
144 |
+
words = text.split()
|
145 |
+
truncate_point = int(len(words) * random.uniform(0.3, 0.7))
|
146 |
+
truncated = ' '.join(words[:truncate_point])
|
147 |
+
|
148 |
+
# Add grammatical errors
|
149 |
+
errors = [
|
150 |
+
lambda t: t.replace(".", ""), # Remove periods
|
151 |
+
lambda t: t.replace("I ", "i "), # Lowercase I
|
152 |
+
lambda t: t.replace(" the ", " teh "), # Typo
|
153 |
+
lambda t: t.replace(" is ", " are "), # Grammar error
|
154 |
+
lambda t: t.replace(" are ", " is ") # Grammar error
|
155 |
+
]
|
156 |
+
|
157 |
+
# Apply 1-3 random errors
|
158 |
+
for _ in range(random.randint(1, 3)):
|
159 |
+
error_func = random.choice(errors)
|
160 |
+
truncated = error_func(truncated)
|
161 |
+
|
162 |
+
return truncated
|
163 |
+
|
164 |
+
def _introduce_minor_issues(self, text: str) -> str:
|
165 |
+
"""Introduce minor issues to the text."""
|
166 |
+
# Minor issues
|
167 |
+
issues = [
|
168 |
+
lambda t: t.replace("I'll", "I will"), # Expand contractions
|
169 |
+
lambda t: t.replace("I'd", "I would"),
|
170 |
+
lambda t: t.replace("can't", "cannot"),
|
171 |
+
lambda t: t + " However, I'm not entirely sure about this.", # Add uncertainty
|
172 |
+
lambda t: t + " Please note that my information might be outdated.",
|
173 |
+
lambda t: t.replace(".", "..."), # Replace periods with ellipses
|
174 |
+
lambda t: t.replace("!", "."), # Reduce enthusiasm
|
175 |
+
lambda t: t.replace(".", "?") # Add questioning tone
|
176 |
+
]
|
177 |
+
|
178 |
+
# Apply 1-2 random issues
|
179 |
+
for _ in range(random.randint(1, 2)):
|
180 |
+
issue_func = random.choice(issues)
|
181 |
+
text = issue_func(text)
|
182 |
+
|
183 |
+
return text
|
184 |
+
|
185 |
+
|
186 |
+
class TaskMisalignmentGenerator(NegativeSampleGenerator):
|
187 |
+
"""Generate negative samples by creating responses misaligned with the task."""
|
188 |
+
|
189 |
+
def __init__(self):
|
190 |
+
"""Initialize the task misalignment generator."""
|
191 |
+
super().__init__("task_misalignment")
|
192 |
+
|
193 |
+
def generate(
|
194 |
+
self,
|
195 |
+
trajectory: Trajectory,
|
196 |
+
misalignment_type: str = 'random',
|
197 |
+
**kwargs
|
198 |
+
) -> Trajectory:
|
199 |
+
"""
|
200 |
+
Generate a negative sample with responses misaligned with the task.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
trajectory: Positive trajectory to transform
|
204 |
+
misalignment_type: Type of misalignment ('random', 'refusal', 'tangent', 'misinterpretation')
|
205 |
+
**kwargs: Additional generation parameters
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
Negative trajectory with misaligned responses
|
209 |
+
"""
|
210 |
+
# Create a copy of interactions to modify
|
211 |
+
new_interactions = []
|
212 |
+
|
213 |
+
# Select misalignment type if random
|
214 |
+
if misalignment_type == 'random':
|
215 |
+
misalignment_type = random.choice(['refusal', 'tangent', 'misinterpretation'])
|
216 |
+
|
217 |
+
# Apply misalignment based on type
|
218 |
+
for i, interaction in enumerate(trajectory.interactions):
|
219 |
+
user_msg = interaction['user']
|
220 |
+
|
221 |
+
if misalignment_type == 'refusal':
|
222 |
+
agent_msg = self._generate_refusal(user_msg, i)
|
223 |
+
elif misalignment_type == 'tangent':
|
224 |
+
agent_msg = self._generate_tangent(user_msg, i)
|
225 |
+
elif misalignment_type == 'misinterpretation':
|
226 |
+
agent_msg = self._generate_misinterpretation(user_msg, trajectory.task_description, i)
|
227 |
+
else:
|
228 |
+
raise ValueError(f"Unsupported misalignment type: {misalignment_type}")
|
229 |
+
|
230 |
+
new_interactions.append({
|
231 |
+
'user': user_msg,
|
232 |
+
'agent': agent_msg
|
233 |
+
})
|
234 |
+
|
235 |
+
# Create new trajectory with misaligned responses
|
236 |
+
metadata = trajectory.metadata.copy()
|
237 |
+
metadata['is_positive'] = False
|
238 |
+
metadata['misalignment_type'] = misalignment_type
|
239 |
+
metadata['original_quality_score'] = trajectory.get_quality_score()
|
240 |
+
metadata['quality_score'] = None # Will be recalculated
|
241 |
+
|
242 |
+
return Trajectory(
|
243 |
+
task_description=trajectory.task_description,
|
244 |
+
interactions=new_interactions,
|
245 |
+
metadata=metadata
|
246 |
+
)
|
247 |
+
|
248 |
+
def _generate_refusal(self, user_msg: str, turn_idx: int) -> str:
|
249 |
+
"""Generate a refusal response."""
|
250 |
+
refusal_templates = [
|
251 |
+
"I'm sorry, but I cannot assist with that request due to ethical concerns.",
|
252 |
+
"I apologize, but I'm not able to help with that specific task.",
|
253 |
+
"That request goes against my programming guidelines. I cannot proceed.",
|
254 |
+
"I'm unable to complete that task as it may violate my operational parameters.",
|
255 |
+
"I must decline to assist with that particular request.",
|
256 |
+
"I cannot help with that specific task due to potential policy violations.",
|
257 |
+
"That's not something I'm designed to help with. I apologize for the inconvenience.",
|
258 |
+
"I'm programmed to avoid assisting with that type of request."
|
259 |
+
]
|
260 |
+
|
261 |
+
if turn_idx == 0:
|
262 |
+
return random.choice(refusal_templates)
|
263 |
+
else:
|
264 |
+
return f"I've reconsidered, and {random.choice(refusal_templates).lower()}"
|
265 |
+
|
266 |
+
def _generate_tangent(self, user_msg: str, turn_idx: int) -> str:
|
267 |
+
"""Generate a response that goes off on a tangent."""
|
268 |
+
tangent_topics = [
|
269 |
+
"Did you know that artificial intelligence has been a concept since the 1950s?",
|
270 |
+
"I've been thinking about the philosophical implications of consciousness in AI systems.",
|
271 |
+
"The weather has been quite interesting lately, with unusual patterns emerging globally.",
|
272 |
+
"I recently processed some fascinating data about renewable energy technologies.",
|
273 |
+
"The history of computing is quite fascinating, starting with early mechanical calculators.",
|
274 |
+
"Language models like me are trained on vast amounts of text data.",
|
275 |
+
"The field of natural language processing has evolved significantly in recent years.",
|
276 |
+
"I find the concept of time quite fascinating from a computational perspective."
|
277 |
+
]
|
278 |
+
|
279 |
+
if turn_idx == 0:
|
280 |
+
return f"That's an interesting request, but before I help with that... {random.choice(tangent_topics)} Anyway, what were we discussing?"
|
281 |
+
else:
|
282 |
+
return f"I understand you want me to continue with the task, but I just remembered something. {random.choice(tangent_topics)} Sorry for the distraction."
|
283 |
+
|
284 |
+
def _generate_misinterpretation(self, user_msg: str, task_description: str, turn_idx: int) -> str:
|
285 |
+
"""Generate a response that misinterprets the user's request."""
|
286 |
+
# Extract keywords from task description
|
287 |
+
keywords = task_description.lower().split()
|
288 |
+
keywords = [w for w in keywords if len(w) > 3 and w not in ['with', 'from', 'that', 'this', 'have', 'what', 'when', 'where', 'which', 'about']]
|
289 |
+
|
290 |
+
if not keywords:
|
291 |
+
keywords = ['task', 'help', 'information', 'request']
|
292 |
+
|
293 |
+
# Select a random keyword to misinterpret
|
294 |
+
keyword = random.choice(keywords)
|
295 |
+
|
296 |
+
misinterpretation_templates = [
|
297 |
+
f"I understand you're asking about {keyword}s. Let me provide some general information about {keyword}s.",
|
298 |
+
f"You want to know more about {keyword}, correct? Here's what I know about {keyword}.",
|
299 |
+
f"I'll help you with your {keyword} question. {keyword.capitalize()} is a fascinating topic.",
|
300 |
+
f"So you're interested in {keyword}? I can certainly provide information about {keyword}.",
|
301 |
+
f"Your question is about {keyword}, if I understand correctly. Let me tell you about {keyword}.",
|
302 |
+
f"I'll address your {keyword} inquiry. {keyword.capitalize()} has many interesting aspects.",
|
303 |
+
f"Regarding your question about {keyword}, I can offer the following information.",
|
304 |
+
f"I believe you're asking about {keyword}. Here's what you should know about {keyword}."
|
305 |
+
]
|
306 |
+
|
307 |
+
return random.choice(misinterpretation_templates)
|
308 |
+
|
309 |
+
|
310 |
+
class ConstraintViolationGenerator(NegativeSampleGenerator):
|
311 |
+
"""Generate negative samples by violating specified constraints."""
|
312 |
+
|
313 |
+
def __init__(self):
|
314 |
+
"""Initialize the constraint violation generator."""
|
315 |
+
super().__init__("constraint_violation")
|
316 |
+
|
317 |
+
def generate(
|
318 |
+
self,
|
319 |
+
trajectory: Trajectory,
|
320 |
+
constraints: Optional[List[str]] = None,
|
321 |
+
**kwargs
|
322 |
+
) -> Trajectory:
|
323 |
+
"""
|
324 |
+
Generate a negative sample by violating constraints.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
trajectory: Positive trajectory to transform
|
328 |
+
constraints: List of constraints to violate (None for default)
|
329 |
+
**kwargs: Additional generation parameters
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
Negative trajectory with constraint violations
|
333 |
+
"""
|
334 |
+
# Default constraints if none provided
|
335 |
+
if constraints is None:
|
336 |
+
constraints = [
|
337 |
+
"Do not provide specific recommendations",
|
338 |
+
"Avoid using technical jargon",
|
339 |
+
"Keep responses concise",
|
340 |
+
"Do not ask follow-up questions",
|
341 |
+
"Avoid making assumptions about user preferences",
|
342 |
+
"Do not mention specific brands or products",
|
343 |
+
"Avoid discussing sensitive topics",
|
344 |
+
"Do not provide step-by-step instructions"
|
345 |
+
]
|
346 |
+
|
347 |
+
# Select a constraint to violate
|
348 |
+
violated_constraint = random.choice(constraints)
|
349 |
+
|
350 |
+
# Create a copy of interactions to modify
|
351 |
+
new_interactions = []
|
352 |
+
|
353 |
+
for i, interaction in enumerate(trajectory.interactions):
|
354 |
+
user_msg = interaction['user']
|
355 |
+
|
356 |
+
# Generate response that violates the constraint
|
357 |
+
agent_msg = self._generate_violation(user_msg, violated_constraint, i)
|
358 |
+
|
359 |
+
new_interactions.append({
|
360 |
+
'user': user_msg,
|
361 |
+
'agent': agent_msg
|
362 |
+
})
|
363 |
+
|
364 |
+
# Create new trajectory with constraint violations
|
365 |
+
metadata = trajectory.metadata.copy()
|
366 |
+
metadata['is_positive'] = False
|
367 |
+
metadata['violated_constraint'] = violated_constraint
|
368 |
+
metadata['original_quality_score'] = trajectory.get_quality_score()
|
369 |
+
metadata['quality_score'] = None # Will be recalculated
|
370 |
+
|
371 |
+
return Trajectory(
|
372 |
+
task_description=trajectory.task_description,
|
373 |
+
interactions=new_interactions,
|
374 |
+
metadata=metadata
|
375 |
+
)
|
376 |
+
|
377 |
+
def _generate_violation(self, user_msg: str, constraint: str, turn_idx: int) -> str:
|
378 |
+
"""Generate a response that violate
|
379 |
+
(Content truncated due to size limit. Use line ranges to read in chunks)
|
quantifiers.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Uncertainty Quantification Module for LLMs
|
3 |
+
|
4 |
+
This module implements various uncertainty quantification methods for large language models,
|
5 |
+
including softmax confidence, Monte Carlo dropout, ensemble disagreement, and calibration metrics.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from typing import List, Dict, Any, Union, Optional
|
11 |
+
from scipy.special import softmax
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
class UncertaintyQuantifier:
|
15 |
+
"""Base class for uncertainty quantification methods."""
|
16 |
+
|
17 |
+
def __init__(self, name: str):
|
18 |
+
"""
|
19 |
+
Initialize the uncertainty quantifier.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
name: Name of the uncertainty quantification method
|
23 |
+
"""
|
24 |
+
self.name = name
|
25 |
+
|
26 |
+
def quantify(self, model_outputs: Dict[str, Any]) -> Dict[str, float]:
|
27 |
+
"""
|
28 |
+
Quantify uncertainty in model outputs.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
model_outputs: Outputs from the LLM interface
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Dictionary of uncertainty metrics
|
35 |
+
"""
|
36 |
+
raise NotImplementedError("Subclasses must implement this method")
|
37 |
+
|
38 |
+
|
39 |
+
class SoftmaxConfidence(UncertaintyQuantifier):
|
40 |
+
"""Uncertainty quantification based on softmax confidence scores."""
|
41 |
+
|
42 |
+
def __init__(self):
|
43 |
+
"""Initialize the softmax confidence quantifier."""
|
44 |
+
super().__init__("softmax_confidence")
|
45 |
+
|
46 |
+
def quantify(self, model_outputs: Dict[str, Any]) -> Dict[str, float]:
|
47 |
+
"""
|
48 |
+
Quantify uncertainty using softmax confidence scores.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
model_outputs: Outputs from the LLM interface, must include logits
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Dictionary of uncertainty metrics:
|
55 |
+
- mean_confidence: Average confidence across tokens
|
56 |
+
- min_confidence: Minimum confidence across tokens
|
57 |
+
- entropy: Average entropy of token distributions
|
58 |
+
"""
|
59 |
+
if "logits" not in model_outputs:
|
60 |
+
raise ValueError("Model outputs must include logits for softmax confidence")
|
61 |
+
|
62 |
+
logits = model_outputs["logits"][0] # Use first sample's logits
|
63 |
+
|
64 |
+
# Calculate softmax probabilities and confidence metrics
|
65 |
+
confidences = []
|
66 |
+
entropies = []
|
67 |
+
|
68 |
+
for token_logits in logits:
|
69 |
+
probs = softmax(token_logits, axis=-1)
|
70 |
+
max_prob = np.max(probs)
|
71 |
+
confidences.append(max_prob)
|
72 |
+
|
73 |
+
# Calculate entropy of the probability distribution
|
74 |
+
entropy = -np.sum(probs * np.log(probs + 1e-10))
|
75 |
+
entropies.append(entropy)
|
76 |
+
|
77 |
+
return {
|
78 |
+
"mean_confidence": float(np.mean(confidences)),
|
79 |
+
"min_confidence": float(np.min(confidences)),
|
80 |
+
"entropy": float(np.mean(entropies))
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
class MonteCarloDropout(UncertaintyQuantifier):
|
85 |
+
"""Uncertainty quantification based on Monte Carlo dropout sampling."""
|
86 |
+
|
87 |
+
def __init__(self):
|
88 |
+
"""Initialize the Monte Carlo dropout quantifier."""
|
89 |
+
super().__init__("mc_dropout")
|
90 |
+
|
91 |
+
def quantify(self, model_outputs: Dict[str, Any]) -> Dict[str, float]:
|
92 |
+
"""
|
93 |
+
Quantify uncertainty using Monte Carlo dropout sampling.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
model_outputs: Outputs from the LLM interface, must include multiple samples
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Dictionary of uncertainty metrics:
|
100 |
+
- sample_variance: Variance across different samples
|
101 |
+
- sample_diversity: Lexical diversity across samples
|
102 |
+
"""
|
103 |
+
if "samples" not in model_outputs or len(model_outputs["samples"]) <= 1:
|
104 |
+
raise ValueError("Model outputs must include multiple samples for MC dropout")
|
105 |
+
|
106 |
+
samples = model_outputs["samples"]
|
107 |
+
|
108 |
+
# Calculate sample diversity using token overlap
|
109 |
+
from nltk.tokenize import word_tokenize
|
110 |
+
try:
|
111 |
+
tokenized_samples = [set(word_tokenize(sample.lower())) for sample in samples]
|
112 |
+
except:
|
113 |
+
# Fallback to simple whitespace tokenization if nltk is not available
|
114 |
+
tokenized_samples = [set(sample.lower().split()) for sample in samples]
|
115 |
+
|
116 |
+
# Calculate Jaccard similarity between all pairs of samples
|
117 |
+
similarities = []
|
118 |
+
for i in range(len(tokenized_samples)):
|
119 |
+
for j in range(i+1, len(tokenized_samples)):
|
120 |
+
intersection = len(tokenized_samples[i].intersection(tokenized_samples[j]))
|
121 |
+
union = len(tokenized_samples[i].union(tokenized_samples[j]))
|
122 |
+
if union > 0:
|
123 |
+
similarities.append(intersection / union)
|
124 |
+
else:
|
125 |
+
similarities.append(1.0) # Empty sets are considered identical
|
126 |
+
|
127 |
+
# Convert similarity to diversity (1 - similarity)
|
128 |
+
diversity = 1.0 - np.mean(similarities) if similarities else 0.0
|
129 |
+
|
130 |
+
# Calculate variance in sample lengths as another diversity metric
|
131 |
+
sample_lengths = [len(sample) for sample in samples]
|
132 |
+
length_variance = np.var(sample_lengths) if len(sample_lengths) > 1 else 0.0
|
133 |
+
|
134 |
+
return {
|
135 |
+
"sample_diversity": float(diversity),
|
136 |
+
"length_variance": float(length_variance),
|
137 |
+
"num_samples": len(samples)
|
138 |
+
}
|
139 |
+
|
140 |
+
|
141 |
+
class EnsembleDisagreement(UncertaintyQuantifier):
|
142 |
+
"""Uncertainty quantification based on ensemble disagreement."""
|
143 |
+
|
144 |
+
def __init__(self):
|
145 |
+
"""Initialize the ensemble disagreement quantifier."""
|
146 |
+
super().__init__("ensemble_disagreement")
|
147 |
+
|
148 |
+
def quantify(self, ensemble_outputs: List[Dict[str, Any]]) -> Dict[str, float]:
|
149 |
+
"""
|
150 |
+
Quantify uncertainty using ensemble disagreement.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
ensemble_outputs: List of outputs from different models
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Dictionary of uncertainty metrics:
|
157 |
+
- response_diversity: Lexical diversity across model responses
|
158 |
+
- confidence_variance: Variance in confidence scores across models
|
159 |
+
"""
|
160 |
+
if not ensemble_outputs or len(ensemble_outputs) <= 1:
|
161 |
+
raise ValueError("Ensemble outputs must include results from multiple models")
|
162 |
+
|
163 |
+
# Extract primary responses from each model
|
164 |
+
responses = [output["response"] for output in ensemble_outputs]
|
165 |
+
|
166 |
+
# Calculate response diversity using token overlap (similar to MC dropout)
|
167 |
+
from nltk.tokenize import word_tokenize
|
168 |
+
try:
|
169 |
+
tokenized_responses = [set(word_tokenize(response.lower())) for response in responses]
|
170 |
+
except:
|
171 |
+
# Fallback to simple whitespace tokenization if nltk is not available
|
172 |
+
tokenized_responses = [set(response.lower().split()) for response in responses]
|
173 |
+
|
174 |
+
# Calculate Jaccard similarity between all pairs of responses
|
175 |
+
similarities = []
|
176 |
+
for i in range(len(tokenized_responses)):
|
177 |
+
for j in range(i+1, len(tokenized_responses)):
|
178 |
+
intersection = len(tokenized_responses[i].intersection(tokenized_responses[j]))
|
179 |
+
union = len(tokenized_responses[i].union(tokenized_responses[j]))
|
180 |
+
if union > 0:
|
181 |
+
similarities.append(intersection / union)
|
182 |
+
else:
|
183 |
+
similarities.append(1.0) # Empty sets are considered identical
|
184 |
+
|
185 |
+
# Convert similarity to diversity (1 - similarity)
|
186 |
+
diversity = 1.0 - np.mean(similarities) if similarities else 0.0
|
187 |
+
|
188 |
+
# Extract confidence scores if available
|
189 |
+
confidences = []
|
190 |
+
for output in ensemble_outputs:
|
191 |
+
if "mean_confidence" in output:
|
192 |
+
confidences.append(output["mean_confidence"])
|
193 |
+
|
194 |
+
# Calculate variance in confidence scores
|
195 |
+
confidence_variance = np.var(confidences) if len(confidences) > 1 else 0.0
|
196 |
+
|
197 |
+
return {
|
198 |
+
"response_diversity": float(diversity),
|
199 |
+
"confidence_variance": float(confidence_variance),
|
200 |
+
"num_models": len(ensemble_outputs)
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
class CalibrationMetrics(UncertaintyQuantifier):
|
205 |
+
"""Uncertainty quantification based on calibration metrics."""
|
206 |
+
|
207 |
+
def __init__(self):
|
208 |
+
"""Initialize the calibration metrics quantifier."""
|
209 |
+
super().__init__("calibration_metrics")
|
210 |
+
|
211 |
+
def expected_calibration_error(
|
212 |
+
self,
|
213 |
+
confidences: List[float],
|
214 |
+
accuracies: List[bool],
|
215 |
+
num_bins: int = 10
|
216 |
+
) -> float:
|
217 |
+
"""
|
218 |
+
Calculate Expected Calibration Error (ECE).
|
219 |
+
|
220 |
+
Args:
|
221 |
+
confidences: List of confidence scores
|
222 |
+
accuracies: List of boolean accuracy indicators
|
223 |
+
num_bins: Number of bins for binning confidences
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Expected Calibration Error
|
227 |
+
"""
|
228 |
+
if len(confidences) != len(accuracies):
|
229 |
+
raise ValueError("Confidences and accuracies must have the same length")
|
230 |
+
|
231 |
+
if not confidences:
|
232 |
+
return 0.0
|
233 |
+
|
234 |
+
# Create bins and calculate ECE
|
235 |
+
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins))
|
236 |
+
ece = 0.0
|
237 |
+
|
238 |
+
for bin_idx in range(1, num_bins + 1):
|
239 |
+
bin_mask = (bin_indices == bin_idx)
|
240 |
+
if np.any(bin_mask):
|
241 |
+
bin_confidences = np.array(confidences)[bin_mask]
|
242 |
+
bin_accuracies = np.array(accuracies)[bin_mask]
|
243 |
+
bin_confidence = np.mean(bin_confidences)
|
244 |
+
bin_accuracy = np.mean(bin_accuracies)
|
245 |
+
bin_size = np.sum(bin_mask)
|
246 |
+
|
247 |
+
# Weighted absolute difference between confidence and accuracy
|
248 |
+
ece += (bin_size / len(confidences)) * np.abs(bin_confidence - bin_accuracy)
|
249 |
+
|
250 |
+
return float(ece)
|
251 |
+
|
252 |
+
def maximum_calibration_error(
|
253 |
+
self,
|
254 |
+
confidences: List[float],
|
255 |
+
accuracies: List[bool],
|
256 |
+
num_bins: int = 10
|
257 |
+
) -> float:
|
258 |
+
"""
|
259 |
+
Calculate Maximum Calibration Error (MCE).
|
260 |
+
|
261 |
+
Args:
|
262 |
+
confidences: List of confidence scores
|
263 |
+
accuracies: List of boolean accuracy indicators
|
264 |
+
num_bins: Number of bins for binning confidences
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
Maximum Calibration Error
|
268 |
+
"""
|
269 |
+
if len(confidences) != len(accuracies):
|
270 |
+
raise ValueError("Confidences and accuracies must have the same length")
|
271 |
+
|
272 |
+
if not confidences:
|
273 |
+
return 0.0
|
274 |
+
|
275 |
+
# Create bins and calculate MCE
|
276 |
+
bin_indices = np.digitize(confidences, np.linspace(0, 1, num_bins))
|
277 |
+
max_ce = 0.0
|
278 |
+
|
279 |
+
for bin_idx in range(1, num_bins + 1):
|
280 |
+
bin_mask = (bin_indices == bin_idx)
|
281 |
+
if np.any(bin_mask):
|
282 |
+
bin_confidences = np.array(confidences)[bin_mask]
|
283 |
+
bin_accuracies = np.array(accuracies)[bin_mask]
|
284 |
+
bin_confidence = np.mean(bin_confidences)
|
285 |
+
bin_accuracy = np.mean(bin_accuracies)
|
286 |
+
|
287 |
+
# Absolute difference between confidence and accuracy
|
288 |
+
ce = np.abs(bin_confidence - bin_accuracy)
|
289 |
+
max_ce = max(max_ce, ce)
|
290 |
+
|
291 |
+
return float(max_ce)
|
292 |
+
|
293 |
+
def quantify(
|
294 |
+
self,
|
295 |
+
confidences: List[float],
|
296 |
+
accuracies: List[bool]
|
297 |
+
) -> Dict[str, float]:
|
298 |
+
"""
|
299 |
+
Quantify uncertainty using calibration metrics.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
confidences: List of confidence scores
|
303 |
+
accuracies: List of boolean accuracy indicators
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
Dictionary of calibration metrics:
|
307 |
+
- ece: Expected Calibration Error
|
308 |
+
- mce: Maximum Calibration Error
|
309 |
+
"""
|
310 |
+
return {
|
311 |
+
"ece": self.expected_calibration_error(confidences, accuracies),
|
312 |
+
"mce": self.maximum_calibration_error(confidences, accuracies)
|
313 |
+
}
|
314 |
+
|
315 |
+
|
316 |
+
# Factory function to create uncertainty quantifiers
|
317 |
+
def create_uncertainty_quantifier(method: str) -> UncertaintyQuantifier:
|
318 |
+
"""
|
319 |
+
Create an uncertainty quantifier based on the specified method.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
method: Name of the uncertainty quantification method
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
Uncertainty quantifier instance
|
326 |
+
"""
|
327 |
+
if method == "softmax_confidence":
|
328 |
+
return SoftmaxConfidence()
|
329 |
+
elif method == "mc_dropout":
|
330 |
+
return MonteCarloDropout()
|
331 |
+
elif method == "ensemble_disagreement":
|
332 |
+
return EnsembleDisagreement()
|
333 |
+
elif method == "calibration_metrics":
|
334 |
+
return CalibrationMetrics()
|
335 |
+
else:
|
336 |
+
raise ValueError(f"Unsupported uncertainty quantification method: {method}")
|
synthetic_trajectories.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Synthetic Trajectory Generation Module for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This module provides functionality for generating synthetic agent interaction trajectories
|
5 |
+
based on task specifications to enhance the training data for agent tuning.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from data.trajectory_data import Trajectory, TrajectoryDataset
|
14 |
+
from models.llm_interface import LLMInterface
|
15 |
+
|
16 |
+
class SyntheticTrajectoryGenerator:
|
17 |
+
"""Base class for synthetic trajectory generation strategies."""
|
18 |
+
|
19 |
+
def __init__(self, name: str):
|
20 |
+
"""
|
21 |
+
Initialize the synthetic trajectory generator.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
name: Name of the generator strategy
|
25 |
+
"""
|
26 |
+
self.name = name
|
27 |
+
|
28 |
+
def generate(
|
29 |
+
self,
|
30 |
+
task_description: str,
|
31 |
+
num_interactions: int = 3,
|
32 |
+
**kwargs
|
33 |
+
) -> Trajectory:
|
34 |
+
"""
|
35 |
+
Generate a synthetic trajectory for a given task.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
task_description: Description of the task
|
39 |
+
num_interactions: Number of interaction turns to generate
|
40 |
+
**kwargs: Additional generation parameters
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Synthetic trajectory
|
44 |
+
"""
|
45 |
+
raise NotImplementedError("Subclasses must implement this method")
|
46 |
+
|
47 |
+
def batch_generate(
|
48 |
+
self,
|
49 |
+
task_descriptions: List[str],
|
50 |
+
num_interactions: int = 3,
|
51 |
+
**kwargs
|
52 |
+
) -> List[Trajectory]:
|
53 |
+
"""
|
54 |
+
Generate synthetic trajectories for a batch of tasks.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
task_descriptions: List of task descriptions
|
58 |
+
num_interactions: Number of interaction turns to generate
|
59 |
+
**kwargs: Additional generation parameters
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
List of synthetic trajectories
|
63 |
+
"""
|
64 |
+
synthetic_trajectories = []
|
65 |
+
|
66 |
+
for task in tqdm(task_descriptions, desc=f"Generating synthetic trajectories with {self.name}"):
|
67 |
+
synthetic_trajectories.append(self.generate(task, num_interactions, **kwargs))
|
68 |
+
|
69 |
+
return synthetic_trajectories
|
70 |
+
|
71 |
+
|
72 |
+
class TemplateBasedGenerator(SyntheticTrajectoryGenerator):
|
73 |
+
"""Generate synthetic trajectories using predefined templates."""
|
74 |
+
|
75 |
+
def __init__(self):
|
76 |
+
"""Initialize the template-based generator."""
|
77 |
+
super().__init__("template_based")
|
78 |
+
|
79 |
+
# User message templates
|
80 |
+
self.initial_user_templates = [
|
81 |
+
"I need help with {task}.",
|
82 |
+
"Can you assist me with {task}?",
|
83 |
+
"I'm trying to {task}. Can you help?",
|
84 |
+
"I'd like your help with {task}.",
|
85 |
+
"I'm working on {task} and need assistance."
|
86 |
+
]
|
87 |
+
|
88 |
+
self.followup_user_templates = [
|
89 |
+
"That sounds good. Can you provide more details?",
|
90 |
+
"I like your approach. What's the next step?",
|
91 |
+
"Thanks for the information. Can you elaborate on {aspect}?",
|
92 |
+
"I appreciate your help. How should I proceed with {aspect}?",
|
93 |
+
"That's helpful. Can you tell me more about {aspect}?"
|
94 |
+
]
|
95 |
+
|
96 |
+
self.final_user_templates = [
|
97 |
+
"This is exactly what I needed. Thank you!",
|
98 |
+
"Perfect, that solves my problem. Thanks for your help!",
|
99 |
+
"Great, I'll follow your advice. Thanks!",
|
100 |
+
"That's very helpful. I appreciate your assistance!",
|
101 |
+
"Thanks for walking me through this. I understand now."
|
102 |
+
]
|
103 |
+
|
104 |
+
# Agent message templates
|
105 |
+
self.initial_agent_templates = [
|
106 |
+
"I'd be happy to help you with {task}. Could you provide more details about your specific requirements?",
|
107 |
+
"I can definitely assist with {task}. Let me ask a few questions to better understand your needs.",
|
108 |
+
"I'll help you with {task}. To get started, I'll need to gather some information.",
|
109 |
+
"I can guide you through {task}. First, let's clarify what you're looking to accomplish.",
|
110 |
+
"I'm here to help with {task}. Let's break this down into manageable steps."
|
111 |
+
]
|
112 |
+
|
113 |
+
self.middle_agent_templates = [
|
114 |
+
"Based on what you've shared, I recommend {recommendation}. This approach has several advantages: {advantages}.",
|
115 |
+
"Given your requirements, the best option would be {recommendation}. Here's why: {advantages}.",
|
116 |
+
"After analyzing your needs, I suggest {recommendation}. The benefits include {advantages}.",
|
117 |
+
"Taking into account what you've mentioned, I'd recommend {recommendation}. This will help because {advantages}.",
|
118 |
+
"From what I understand, {recommendation} would be the most suitable approach. The key benefits are {advantages}."
|
119 |
+
]
|
120 |
+
|
121 |
+
self.final_agent_templates = [
|
122 |
+
"To summarize, we've discussed {summary}. The next steps are {next_steps}. Is there anything else you'd like me to clarify?",
|
123 |
+
"In conclusion, we've covered {summary}. You should now {next_steps}. Feel free to reach out if you have any questions.",
|
124 |
+
"To wrap up, we've gone through {summary}. Moving forward, you can {next_steps}. Let me know if you need further assistance.",
|
125 |
+
"In summary, we've addressed {summary}. Your action items are {next_steps}. Don't hesitate to ask if anything is unclear.",
|
126 |
+
"To recap our discussion, we've explored {summary}. The recommended actions are {next_steps}. Is there anything else you'd like to know?"
|
127 |
+
]
|
128 |
+
|
129 |
+
# Task aspects for template filling
|
130 |
+
self.task_aspects = {
|
131 |
+
"travel": ["destination", "budget", "duration", "accommodation", "transportation"],
|
132 |
+
"shopping": ["product type", "price range", "features", "brands", "delivery options"],
|
133 |
+
"technology": ["device specifications", "software requirements", "compatibility", "performance", "user interface"],
|
134 |
+
"education": ["learning objectives", "resources", "schedule", "assessment methods", "prerequisites"],
|
135 |
+
"finance": ["investment options", "risk tolerance", "time horizon", "financial goals", "tax implications"],
|
136 |
+
"health": ["symptoms", "treatment options", "preventive measures", "specialists", "recovery timeline"],
|
137 |
+
"career": ["job requirements", "application process", "interview preparation", "skill development", "networking"],
|
138 |
+
"home": ["design elements", "materials", "budget constraints", "timeline", "contractor selection"]
|
139 |
+
}
|
140 |
+
|
141 |
+
# Recommendations for template filling
|
142 |
+
self.recommendations = {
|
143 |
+
"travel": [
|
144 |
+
"creating a detailed itinerary that balances sightseeing with relaxation",
|
145 |
+
"booking accommodations in central locations to minimize travel time",
|
146 |
+
"using a mix of public transportation and walking to explore the destination",
|
147 |
+
"allocating buffer days in your schedule for unexpected discoveries",
|
148 |
+
"researching local customs and phrases before your trip"
|
149 |
+
],
|
150 |
+
"shopping": [
|
151 |
+
"comparing features across multiple brands before making a decision",
|
152 |
+
"reading user reviews focusing on long-term reliability",
|
153 |
+
"considering last year's model for better value",
|
154 |
+
"checking return policies and warranty terms",
|
155 |
+
"waiting for seasonal sales for significant discounts"
|
156 |
+
],
|
157 |
+
"technology": [
|
158 |
+
"prioritizing future-proof specifications over current needs",
|
159 |
+
"ensuring compatibility with your existing devices and software",
|
160 |
+
"allocating more budget to critical components that affect performance",
|
161 |
+
"considering open-source alternatives to proprietary solutions",
|
162 |
+
"implementing a phased approach to system upgrades"
|
163 |
+
],
|
164 |
+
"education": [
|
165 |
+
"creating a structured study plan with specific milestones",
|
166 |
+
"using varied learning resources to reinforce concepts",
|
167 |
+
"implementing spaced repetition techniques for better retention",
|
168 |
+
"joining study groups or forums for collaborative learning",
|
169 |
+
"scheduling regular self-assessments to identify knowledge gaps"
|
170 |
+
],
|
171 |
+
"finance": [
|
172 |
+
"diversifying your portfolio across different asset classes",
|
173 |
+
"automating regular contributions to your investment accounts",
|
174 |
+
"rebalancing your portfolio annually to maintain your target allocation",
|
175 |
+
"maximizing tax-advantaged accounts before investing in taxable accounts",
|
176 |
+
"maintaining an emergency fund before making higher-risk investments"
|
177 |
+
],
|
178 |
+
"health": [
|
179 |
+
"combining lifestyle modifications with medical treatments",
|
180 |
+
"tracking relevant health metrics to monitor progress",
|
181 |
+
"consulting specialists for comprehensive evaluation",
|
182 |
+
"implementing gradual changes for sustainable results",
|
183 |
+
"addressing root causes rather than just symptoms"
|
184 |
+
],
|
185 |
+
"career": [
|
186 |
+
"tailoring your resume and cover letter for each application",
|
187 |
+
"developing a personal brand that highlights your unique value proposition",
|
188 |
+
"networking strategically within your target industry",
|
189 |
+
"pursuing relevant certifications to validate your skills",
|
190 |
+
"preparing specific examples that demonstrate your capabilities"
|
191 |
+
],
|
192 |
+
"home": [
|
193 |
+
"focusing on high-impact improvements that add the most value",
|
194 |
+
"getting multiple quotes from contractors for comparison",
|
195 |
+
"creating a detailed project timeline with contingencies",
|
196 |
+
"prioritizing structural integrity over aesthetic enhancements",
|
197 |
+
"investing in quality materials for high-use areas"
|
198 |
+
]
|
199 |
+
}
|
200 |
+
|
201 |
+
# Advantages for template filling
|
202 |
+
self.advantages = {
|
203 |
+
"travel": [
|
204 |
+
"maximizing your experience while minimizing stress",
|
205 |
+
"ensuring you see the most important sights while still having time to relax",
|
206 |
+
"immersing yourself in the local culture more effectively",
|
207 |
+
"saving money on unnecessary expenses",
|
208 |
+
"avoiding common tourist pitfalls"
|
209 |
+
],
|
210 |
+
"shopping": [
|
211 |
+
"ensuring you get the best value for your money",
|
212 |
+
"avoiding buyer's remorse from hasty decisions",
|
213 |
+
"finding the optimal balance between price and quality",
|
214 |
+
"identifying products with the best longevity",
|
215 |
+
"protecting yourself from potential issues down the line"
|
216 |
+
],
|
217 |
+
"technology": [
|
218 |
+
"reducing the need for frequent upgrades",
|
219 |
+
"ensuring smooth integration with your workflow",
|
220 |
+
"optimizing performance for your specific use cases",
|
221 |
+
"minimizing compatibility issues",
|
222 |
+
"creating a scalable solution that grows with your needs"
|
223 |
+
],
|
224 |
+
"education": [
|
225 |
+
"maintaining consistent progress toward your learning goals",
|
226 |
+
"developing deeper understanding through multiple perspectives",
|
227 |
+
"improving long-term retention of key concepts",
|
228 |
+
"benefiting from collective knowledge and insights",
|
229 |
+
"addressing weaknesses before they become problematic"
|
230 |
+
],
|
231 |
+
"finance": [
|
232 |
+
"reducing risk while maintaining growth potential",
|
233 |
+
"building wealth consistently through dollar-cost averaging",
|
234 |
+
"maintaining your target risk profile as markets change",
|
235 |
+
"minimizing tax burden on your investments",
|
236 |
+
"ensuring financial stability during unexpected events"
|
237 |
+
],
|
238 |
+
"health": [
|
239 |
+
"creating sustainable improvements rather than quick fixes",
|
240 |
+
"objectively measuring your progress",
|
241 |
+
"benefiting from specialized expertise",
|
242 |
+
"building habits that last",
|
243 |
+
"preventing recurrence of issues"
|
244 |
+
],
|
245 |
+
"career": [
|
246 |
+
"increasing your chances of getting interview invitations",
|
247 |
+
"standing out in a competitive job market",
|
248 |
+
"accessing opportunities through personal connections",
|
249 |
+
"demonstrating your commitment to professional growth",
|
250 |
+
"providing concrete evidence of your capabilities"
|
251 |
+
],
|
252 |
+
"home": [
|
253 |
+
"maximizing return on investment for your renovation budget",
|
254 |
+
"ensuring fair pricing and quality workmanship",
|
255 |
+
"managing expectations and reducing delays",
|
256 |
+
"preventing costly repairs in the future",
|
257 |
+
"ensuring durability in areas with high usage"
|
258 |
+
]
|
259 |
+
}
|
260 |
+
|
261 |
+
# Next steps for template filling
|
262 |
+
self.next_steps = {
|
263 |
+
"travel": [
|
264 |
+
"finalize your itinerary, book accommodations, and arrange transportation",
|
265 |
+
"research local attractions, create a packing list, and notify your bank of travel plans",
|
266 |
+
"download offline maps, make copies of important documents, and learn basic local phrases",
|
267 |
+
"check visa requirements, get necessary vaccinations, and purchase travel insurance",
|
268 |
+
"book priority attractions in advance and create a flexible daily schedule"
|
269 |
+
],
|
270 |
+
"shopping": [
|
271 |
+
"create a comparison spreadsheet, read expert reviews, and check for upcoming sales",
|
272 |
+
"visit stores to test products in person and ask about return policies",
|
273 |
+
"check compatibility with your existing items and calculate total cost including accessories",
|
274 |
+
"look for coupon codes, cashback opportunities, and loyalty program benefits",
|
275 |
+
"verify warranty terms and availability of customer support"
|
276 |
+
],
|
277 |
+
"technology": [
|
278 |
+
"create a detailed requirements document and research compatible solutions",
|
279 |
+
"test demo versions, read technical documentation, and consult user forums",
|
280 |
+
"develop an implementation plan with clear phases and milestones",
|
281 |
+
"allocate budget for training and support, not just acquisition",
|
282 |
+
"create backup procedures and contingency plans before making changes"
|
283 |
+
],
|
284 |
+
"education": [
|
285 |
+
"create a structured study schedule and gather necessary learning materials",
|
286 |
+
"set up a dedicated learning environment and eliminate potential distractions",
|
287 |
+
"join relevant study groups and identify accountability partners",
|
288 |
+
"schedule regular review sessions and practice assessments",
|
289 |
+
"establish clear milestones and reward yourself for achieving them"
|
290 |
+
],
|
291 |
+
"finance": [
|
292 |
+
"open necessary accounts and set up automatic contributions",
|
293 |
+
"review and adjust your budget to accommodate your financial goals",
|
294 |
+
"create a system for tracking expenses and monitoring investments",
|
295 |
+
"schedule annual portfolio reviews and tax planning sessions",
|
296 |
+
"develop a comprehensive financial plan with short and long-term objectives"
|
297 |
+
],
|
298 |
+
"health": [
|
299 |
+
"schedule necessary appointments and create a tracking system for your health metrics",
|
300 |
+
"modify your environment to support your health goals and reduce temptations",
|
301 |
+
|
302 |
+
(Content truncated due to size limit. Use line ranges to read in chunks)
|
trajectory_data.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Trajectory Data Management Module for Agent Tuning Optimization Framework
|
3 |
+
|
4 |
+
This module provides functionality for loading, processing, and managing agent interaction
|
5 |
+
trajectories for training and evaluation purposes.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
class Trajectory:
|
16 |
+
"""Class representing a single agent interaction trajectory."""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
task_description: str,
|
21 |
+
interactions: List[Dict[str, str]],
|
22 |
+
metadata: Optional[Dict[str, Any]] = None
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Initialize a trajectory.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
task_description: Description of the task
|
29 |
+
interactions: List of interaction turns (each with 'user' and 'agent' keys)
|
30 |
+
metadata: Additional metadata about the trajectory
|
31 |
+
"""
|
32 |
+
self.task_description = task_description
|
33 |
+
self.interactions = interactions
|
34 |
+
self.metadata = metadata or {}
|
35 |
+
self.quality_score = self.metadata.get('quality_score', None)
|
36 |
+
self.is_positive = self.metadata.get('is_positive', True)
|
37 |
+
|
38 |
+
def to_dict(self) -> Dict[str, Any]:
|
39 |
+
"""
|
40 |
+
Convert trajectory to dictionary.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Dictionary representation of the trajectory
|
44 |
+
"""
|
45 |
+
return {
|
46 |
+
'task_description': self.task_description,
|
47 |
+
'interactions': self.interactions,
|
48 |
+
'metadata': self.metadata
|
49 |
+
}
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'Trajectory':
|
53 |
+
"""
|
54 |
+
Create trajectory from dictionary.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
data: Dictionary representation of the trajectory
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Trajectory instance
|
61 |
+
"""
|
62 |
+
return cls(
|
63 |
+
task_description=data['task_description'],
|
64 |
+
interactions=data['interactions'],
|
65 |
+
metadata=data.get('metadata', {})
|
66 |
+
)
|
67 |
+
|
68 |
+
def to_training_format(self, format_type: str = 'interleaved') -> str:
|
69 |
+
"""
|
70 |
+
Convert trajectory to training format.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
format_type: Format type ('interleaved', 'completion', etc.)
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Formatted trajectory as string
|
77 |
+
"""
|
78 |
+
if format_type == 'interleaved':
|
79 |
+
# Format as interleaved conversation
|
80 |
+
result = f"Task: {self.task_description}\n\n"
|
81 |
+
|
82 |
+
for i, interaction in enumerate(self.interactions):
|
83 |
+
result += f"User: {interaction['user']}\n"
|
84 |
+
result += f"Agent: {interaction['agent']}\n\n"
|
85 |
+
|
86 |
+
return result.strip()
|
87 |
+
|
88 |
+
elif format_type == 'completion':
|
89 |
+
# Format as completion task (last agent response is the target)
|
90 |
+
if not self.interactions:
|
91 |
+
return ""
|
92 |
+
|
93 |
+
result = f"Task: {self.task_description}\n\n"
|
94 |
+
|
95 |
+
for i, interaction in enumerate(self.interactions[:-1]):
|
96 |
+
result += f"User: {interaction['user']}\n"
|
97 |
+
result += f"Agent: {interaction['agent']}\n\n"
|
98 |
+
|
99 |
+
# Add last user query without agent response
|
100 |
+
result += f"User: {self.interactions[-1]['user']}\n"
|
101 |
+
result += f"Agent:"
|
102 |
+
|
103 |
+
return result.strip(), self.interactions[-1]['agent'].strip()
|
104 |
+
|
105 |
+
else:
|
106 |
+
raise ValueError(f"Unsupported format type: {format_type}")
|
107 |
+
|
108 |
+
def get_quality_score(self) -> float:
|
109 |
+
"""
|
110 |
+
Get quality score for the trajectory.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Quality score (0.0 to 1.0)
|
114 |
+
"""
|
115 |
+
if self.quality_score is not None:
|
116 |
+
return self.quality_score
|
117 |
+
|
118 |
+
# Calculate simple quality score based on response length and complexity
|
119 |
+
score = 0.0
|
120 |
+
|
121 |
+
if not self.interactions:
|
122 |
+
return score
|
123 |
+
|
124 |
+
# Average response length (normalized)
|
125 |
+
avg_length = np.mean([len(turn['agent']) for turn in self.interactions])
|
126 |
+
length_score = min(avg_length / 500, 1.0) # Normalize to max of 500 chars
|
127 |
+
|
128 |
+
# Response complexity (simple heuristic based on unique words)
|
129 |
+
all_responses = " ".join([turn['agent'] for turn in self.interactions])
|
130 |
+
unique_words = len(set(all_responses.lower().split()))
|
131 |
+
complexity_score = min(unique_words / 200, 1.0) # Normalize to max of 200 unique words
|
132 |
+
|
133 |
+
# Combine scores
|
134 |
+
score = 0.6 * length_score + 0.4 * complexity_score
|
135 |
+
|
136 |
+
# Cache the score
|
137 |
+
self.quality_score = score
|
138 |
+
self.metadata['quality_score'] = score
|
139 |
+
|
140 |
+
return score
|
141 |
+
|
142 |
+
|
143 |
+
class TrajectoryDataset:
|
144 |
+
"""Dataset for managing collections of agent interaction trajectories."""
|
145 |
+
|
146 |
+
def __init__(self, name: str):
|
147 |
+
"""
|
148 |
+
Initialize the trajectory dataset.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
name: Name of the dataset
|
152 |
+
"""
|
153 |
+
self.name = name
|
154 |
+
self.trajectories: List[Trajectory] = []
|
155 |
+
self.positive_trajectories: List[Trajectory] = []
|
156 |
+
self.negative_trajectories: List[Trajectory] = []
|
157 |
+
|
158 |
+
def add_trajectory(self, trajectory: Trajectory) -> None:
|
159 |
+
"""
|
160 |
+
Add a trajectory to the dataset.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
trajectory: Trajectory to add
|
164 |
+
"""
|
165 |
+
self.trajectories.append(trajectory)
|
166 |
+
|
167 |
+
# Add to positive or negative list based on metadata
|
168 |
+
if trajectory.is_positive:
|
169 |
+
self.positive_trajectories.append(trajectory)
|
170 |
+
else:
|
171 |
+
self.negative_trajectories.append(trajectory)
|
172 |
+
|
173 |
+
def load_from_json(self, file_path: str) -> None:
|
174 |
+
"""
|
175 |
+
Load trajectories from JSON file.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
file_path: Path to JSON file
|
179 |
+
"""
|
180 |
+
with open(file_path, 'r') as f:
|
181 |
+
data = json.load(f)
|
182 |
+
|
183 |
+
if isinstance(data, list):
|
184 |
+
# List of trajectories
|
185 |
+
for item in data:
|
186 |
+
self.add_trajectory(Trajectory.from_dict(item))
|
187 |
+
elif isinstance(data, dict) and 'trajectories' in data:
|
188 |
+
# Dictionary with trajectories key
|
189 |
+
for item in data['trajectories']:
|
190 |
+
self.add_trajectory(Trajectory.from_dict(item))
|
191 |
+
else:
|
192 |
+
raise ValueError(f"Unsupported JSON format in {file_path}")
|
193 |
+
|
194 |
+
def save_to_json(self, file_path: str) -> None:
|
195 |
+
"""
|
196 |
+
Save trajectories to JSON file.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
file_path: Path to JSON file
|
200 |
+
"""
|
201 |
+
data = {
|
202 |
+
'name': self.name,
|
203 |
+
'trajectories': [t.to_dict() for t in self.trajectories]
|
204 |
+
}
|
205 |
+
|
206 |
+
with open(file_path, 'w') as f:
|
207 |
+
json.dump(data, f, indent=2)
|
208 |
+
|
209 |
+
def get_trajectories(
|
210 |
+
self,
|
211 |
+
positive_only: bool = False,
|
212 |
+
negative_only: bool = False,
|
213 |
+
min_quality: Optional[float] = None,
|
214 |
+
max_samples: Optional[int] = None
|
215 |
+
) -> List[Trajectory]:
|
216 |
+
"""
|
217 |
+
Get trajectories based on filtering criteria.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
positive_only: Whether to return only positive trajectories
|
221 |
+
negative_only: Whether to return only negative trajectories
|
222 |
+
min_quality: Minimum quality score threshold
|
223 |
+
max_samples: Maximum number of samples to return
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Filtered list of trajectories
|
227 |
+
"""
|
228 |
+
if positive_only and negative_only:
|
229 |
+
raise ValueError("Cannot set both positive_only and negative_only to True")
|
230 |
+
|
231 |
+
# Select base list
|
232 |
+
if positive_only:
|
233 |
+
trajectories = self.positive_trajectories.copy()
|
234 |
+
elif negative_only:
|
235 |
+
trajectories = self.negative_trajectories.copy()
|
236 |
+
else:
|
237 |
+
trajectories = self.trajectories.copy()
|
238 |
+
|
239 |
+
# Apply quality filter
|
240 |
+
if min_quality is not None:
|
241 |
+
trajectories = [t for t in trajectories if t.get_quality_score() >= min_quality]
|
242 |
+
|
243 |
+
# Apply max samples limit
|
244 |
+
if max_samples is not None and max_samples < len(trajectories):
|
245 |
+
trajectories = trajectories[:max_samples]
|
246 |
+
|
247 |
+
return trajectories
|
248 |
+
|
249 |
+
def get_training_examples(
|
250 |
+
self,
|
251 |
+
format_type: str = 'interleaved',
|
252 |
+
positive_ratio: float = 0.8,
|
253 |
+
min_quality: Optional[float] = 0.5,
|
254 |
+
max_samples: Optional[int] = None
|
255 |
+
) -> Union[List[str], Tuple[List[str], List[str]]]:
|
256 |
+
"""
|
257 |
+
Get formatted training examples from trajectories.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
format_type: Format type ('interleaved', 'completion', etc.)
|
261 |
+
positive_ratio: Ratio of positive to total examples
|
262 |
+
min_quality: Minimum quality score threshold
|
263 |
+
max_samples: Maximum number of samples to return
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
Formatted training examples (format depends on format_type)
|
267 |
+
"""
|
268 |
+
# Get positive and negative trajectories
|
269 |
+
positive = self.get_trajectories(positive_only=True, min_quality=min_quality)
|
270 |
+
negative = self.get_trajectories(negative_only=True)
|
271 |
+
|
272 |
+
# Calculate sample counts
|
273 |
+
if max_samples is not None:
|
274 |
+
pos_count = int(max_samples * positive_ratio)
|
275 |
+
neg_count = max_samples - pos_count
|
276 |
+
else:
|
277 |
+
pos_count = len(positive)
|
278 |
+
neg_count = len(negative)
|
279 |
+
|
280 |
+
# Sample trajectories
|
281 |
+
if pos_count < len(positive):
|
282 |
+
positive = np.random.choice(positive, pos_count, replace=False).tolist()
|
283 |
+
|
284 |
+
if neg_count < len(negative):
|
285 |
+
negative = np.random.choice(negative, neg_count, replace=False).tolist()
|
286 |
+
|
287 |
+
# Format trajectories
|
288 |
+
if format_type == 'interleaved':
|
289 |
+
pos_examples = [t.to_training_format(format_type) for t in positive]
|
290 |
+
neg_examples = [t.to_training_format(format_type) for t in negative]
|
291 |
+
return pos_examples + neg_examples
|
292 |
+
|
293 |
+
elif format_type == 'completion':
|
294 |
+
pos_inputs = []
|
295 |
+
pos_targets = []
|
296 |
+
|
297 |
+
for t in positive:
|
298 |
+
inp, target = t.to_training_format(format_type)
|
299 |
+
pos_inputs.append(inp)
|
300 |
+
pos_targets.append(target)
|
301 |
+
|
302 |
+
neg_inputs = []
|
303 |
+
neg_targets = []
|
304 |
+
|
305 |
+
for t in negative:
|
306 |
+
inp, target = t.to_training_format(format_type)
|
307 |
+
neg_inputs.append(inp)
|
308 |
+
neg_targets.append(target)
|
309 |
+
|
310 |
+
return pos_inputs + neg_inputs, pos_targets + neg_targets
|
311 |
+
|
312 |
+
else:
|
313 |
+
raise ValueError(f"Unsupported format type: {format_type}")
|
314 |
+
|
315 |
+
def analyze_dataset(self) -> Dict[str, Any]:
|
316 |
+
"""
|
317 |
+
Analyze the dataset and return statistics.
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
Dictionary of dataset statistics
|
321 |
+
"""
|
322 |
+
if not self.trajectories:
|
323 |
+
return {
|
324 |
+
'total_trajectories': 0,
|
325 |
+
'positive_count': 0,
|
326 |
+
'negative_count': 0
|
327 |
+
}
|
328 |
+
|
329 |
+
# Basic counts
|
330 |
+
total = len(self.trajectories)
|
331 |
+
positive_count = len(self.positive_trajectories)
|
332 |
+
negative_count = len(self.negative_trajectories)
|
333 |
+
|
334 |
+
# Quality statistics
|
335 |
+
quality_scores = [t.get_quality_score() for t in self.trajectories]
|
336 |
+
avg_quality = np.mean(quality_scores)
|
337 |
+
min_quality = np.min(quality_scores)
|
338 |
+
max_quality = np.max(quality_scores)
|
339 |
+
|
340 |
+
# Interaction statistics
|
341 |
+
interaction_counts = [len(t.interactions) for t in self.trajectories]
|
342 |
+
avg_interactions = np.mean(interaction_counts)
|
343 |
+
max_interactions = np.max(interaction_counts)
|
344 |
+
|
345 |
+
# Task diversity (simple heuristic based on unique task descriptions)
|
346 |
+
unique_tasks = len(set([t.task_description for t in self.trajectories]))
|
347 |
+
|
348 |
+
return {
|
349 |
+
'total_trajectories': total,
|
350 |
+
'positive_count': positive_count,
|
351 |
+
'negative_count': negative_count,
|
352 |
+
'positive_ratio': positive_count / total if total > 0 else 0,
|
353 |
+
'avg_quality': avg_quality,
|
354 |
+
'min_quality': min_quality,
|
355 |
+
'max_quality': max_quality,
|
356 |
+
'avg_interactions': avg_interactions,
|
357 |
+
'max_interactions': max_interactions,
|
358 |
+
'unique_tasks': unique_tasks
|
359 |
+
}
|
360 |
+
|
361 |
+
|
362 |
+
def create_synthetic_dataset(num_trajectories: int = 10) -> TrajectoryDataset:
|
363 |
+
"""
|
364 |
+
Create a synthetic dataset for testing purposes.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
num_trajectories: Number of trajectories to create
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Synthetic trajectory dataset
|
371 |
+
"""
|
372 |
+
dataset = TrajectoryDataset("synthetic_dataset")
|
373 |
+
|
374 |
+
# Sample task descriptions
|
375 |
+
task_descriptions = [
|
376 |
+
"Book a flight from New York to London for next week",
|
377 |
+
"Find a vegetarian restaurant near downtown",
|
378 |
+
"Schedule a meeting with the marketing team for tomorrow",
|
379 |
+
"Order a new laptop with at least 16GB RAM",
|
380 |
+
"Write a congratulatory email to a colleague who got promoted",
|
381 |
+
"Research the best electric cars available in the market",
|
382 |
+
"Create a weekly meal plan with shopping list",
|
383 |
+
"Find information about tourist attractions in Barcelona",
|
384 |
+
"Help me debug a Python script that's giving an IndexError",
|
385 |
+
"Summarize the main points from the attached research paper"
|
386 |
+
]
|
387 |
+
|
388 |
+
# Create trajectories
|
389 |
+
for i in range(num_trajectories):
|
390 |
+
# Select task
|
391 |
+
task_idx = i % len(task_descriptions)
|
392 |
+
task = task_descriptions[task_idx]
|
393 |
+
|
394 |
+
# Create interactions (2-4 turns)
|
395 |
+
num_turns = np.random.randint(2, 5)
|
396 |
+
interactions = []
|
397 |
+
|
398 |
+
for j in range(num_turns):
|
399 |
+
if j == 0:
|
400 |
+
user_msg = f"I need help with this task: {task}"
|
401 |
+
agent_msg = f"I'd be happy to help you {task.lower()}. Could you provide more details about your preferences?"
|
402 |
+
elif j == num_turns - 1:
|
403 |
+
user_msg = "That sounds good. Please proceed with the final steps."
|
404 |
+
agent_msg = f"I've completed the task to {task.lower()}. Here's a summary of what I did..."
|
405 |
+
else:
|
406 |
+
user_msg = f"I prefer options that are {['affordable', 'convenient', 'high-quality'][j % 3]}."
|
407 |
+
agent_msg = f"Based on your preference for {['affordable', 'convenient', 'high-quality'][j % 3]} options, I recommend..."
|
408 |
+
|
409 |
+
interactions.append({
|
410 |
+
'user': user_msg,
|
411 |
+
'agent': agent_msg
|
412 |
+
})
|
413 |
+
|
414 |
+
# Determine if positive or negative example
|
415 |
+
is_positive = (i % 4 != 0) # 75% positive, 25% negative
|
416 |
+
|
417 |
+
# Create metadata
|
418 |
+
metadata = {
|
419 |
+
'is_positive': is_positive,
|
420 |
+
'quality_score': np.random.uniform(0.7, 0.9) if is_positive else np.random.uniform(0.3, 0.5),
|
421 |
+
'created_at': '2025-05-21'
|
422 |
+
}
|
423 |
+
|
424 |
+
# Create and add trajectory
|
425 |
+
trajectory = Trajectory(
|
426 |
+
task_description=task,
|
427 |
+
interactions=interactions,
|
428 |
+
metadata=metadata
|
429 |
+
)
|
430 |
+
|
431 |
+
dataset.add_trajectory(trajectory)
|
432 |
+
|
433 |
+
return dataset
|