iruno commited on
Commit
e9be283
·
verified ·
1 Parent(s): c8b845e

Delete agent

Browse files
Files changed (32) hide show
  1. agent/__init__.py +0 -0
  2. agent/checklist.py +0 -18
  3. agent/mini_bench/__init__.py +0 -0
  4. agent/mini_bench/__pycache__/__init__.cpython-311.pyc +0 -0
  5. agent/mini_bench/__pycache__/agent.cpython-311.pyc +0 -0
  6. agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc +0 -0
  7. agent/mini_bench/agent.py +0 -467
  8. agent/mini_bench/checklist_eval.py +0 -95
  9. agent/mini_bench/eval_utils.py +0 -309
  10. agent/mini_bench/inference_utils.py +0 -87
  11. agent/mini_bench/prompts/__init__.py +0 -1
  12. agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc +0 -0
  13. agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc +0 -0
  14. agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc +0 -0
  15. agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc +0 -0
  16. agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc +0 -0
  17. agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc +0 -0
  18. agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc +0 -0
  19. agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc +0 -0
  20. agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc +0 -0
  21. agent/mini_bench/prompts/action.py +0 -93
  22. agent/mini_bench/prompts/checklist_prompt.py +0 -50
  23. agent/mini_bench/prompts/construct_messages.py +0 -309
  24. agent/mini_bench/prompts/eval_type.py +0 -107
  25. agent/mini_bench/prompts/image_utils.py +0 -19
  26. agent/mini_bench/prompts/input_information.py +0 -36
  27. agent/mini_bench/prompts/judge_prompt.py +0 -159
  28. agent/mini_bench/prompts/utils.py +0 -18
  29. agent/mini_bench/reward_agent.py +0 -465
  30. agent/mini_bench/utils.py +0 -269
  31. agent/reward.py +0 -96
  32. agent/reward_postprocessor.py +0 -41
agent/__init__.py DELETED
File without changes
agent/checklist.py DELETED
@@ -1,18 +0,0 @@
1
- from .mini_bench.agent import ChecklistGenerationAgent
2
-
3
- def generate_checklist(**data):
4
- # data: 'intent', 'start_url', 'text_observation'
5
- agent_config = {
6
- 'model_name': 'WPRM/qwen-3b-ar-reward-cot-mtl-checklist-enhanced',
7
- 'base_url': 'http://165.132.144.84:7701/v1',
8
- 'api_key': 'empty',
9
- 'temperature': 0.7,
10
- 'use_log_probs': True,
11
- 'use_checklist': True,
12
- 'use_multimodal': False,
13
- 'num_generate': 1,
14
- }
15
- checklist_generation_agent = ChecklistGenerationAgent(agent_config)
16
- response_list, cost = checklist_generation_agent.generate_response(data, prompt_type='ours', constraint_str_list=["<think>", "</think>", "<answer>", "</answer>"])
17
- response = response_list[0]
18
- return response.split("<answer>")[-1].split("</answer>")[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/__init__.py DELETED
File without changes
agent/mini_bench/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (186 Bytes)
 
agent/mini_bench/__pycache__/agent.cpython-311.pyc DELETED
Binary file (20.7 kB)
 
agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc DELETED
Binary file (21.3 kB)
 
agent/mini_bench/agent.py DELETED
@@ -1,467 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- import time
3
- import requests
4
- import json
5
- import math
6
- from langsmith import Client
7
- from langchain_openai import ChatOpenAI
8
-
9
- from .prompts import get_messages
10
- from .prompts.judge_prompt import (
11
- JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
12
- JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
13
- JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
14
- JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
15
- )
16
- from .prompts.image_utils import image_to_base64_url
17
-
18
- MAX_RETRY = 3
19
- RETRY_SLEEP = 5
20
- MODEL_COST_MAPPING = {
21
- "gpt-4o-mini": {
22
- "input_token_cost": 0.15,
23
- "output_token_cost": 0.6
24
- },
25
- "gpt-4o": {
26
- "input_token_cost": 2.5,
27
- "output_token_cost": 10
28
- },
29
- }
30
-
31
-
32
- class Agent(ABC):
33
- @abstractmethod
34
- def generate_response(self, inputs: dict) -> str:
35
- pass
36
-
37
- class BaseAgent(Agent):
38
- def __init__(self, agent_config: dict):
39
- self.agent_config = agent_config
40
- self._setup()
41
-
42
- def _setup(self):
43
- use_log_probs = self.agent_config.get("use_log_probs", False)
44
- if use_log_probs:
45
- self.llm = ChatOpenAI(
46
- model=self.agent_config["model_name"],
47
- base_url=self.agent_config["base_url"],
48
- api_key=self.agent_config["api_key"],
49
- temperature=self.agent_config["temperature"],
50
- timeout=300,
51
- logprobs=True,
52
- top_logprobs=10
53
- )
54
- else:
55
- self.llm = ChatOpenAI(
56
- model=self.agent_config["model_name"],
57
- base_url=self.agent_config["base_url"],
58
- api_key=self.agent_config["api_key"],
59
- temperature=self.agent_config["temperature"],
60
- timeout=300
61
- )
62
- self.temperature = self.agent_config["temperature"]
63
- self.num_generate = self.agent_config["num_generate"]
64
- self.use_checklist = self.agent_config.get("use_checklist", False)
65
- self.use_multimodal = self.agent_config.get("use_multimodal", False)
66
-
67
- # setup cost
68
- model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
69
- if model_cost and "api" in self.agent_config["base_url"]:
70
- self.input_token_cost = model_cost["input_token_cost"]
71
- self.output_token_cost = model_cost["output_token_cost"]
72
- else:
73
- self.input_token_cost = 0.0
74
- self.output_token_cost = 0.0
75
-
76
- def generate_with_retry(self, model_input, constraint_str_list: list = None):
77
- total_input_tokens = 0
78
- total_output_tokens = 0
79
- if self.temperature == 0:
80
- response = self.llm.invoke(model_input)
81
- total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
82
- total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
83
- else:
84
- for i in range(MAX_RETRY):
85
- try:
86
- response = self.llm.invoke(model_input)
87
- total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
88
- total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
89
- if constraint_str_list:
90
- pass_constraint_num = 0
91
- for constraint_str in constraint_str_list:
92
- if constraint_str in response.content:
93
- pass_constraint_num += 1
94
- if pass_constraint_num == len(constraint_str_list):
95
- break
96
- else:
97
- print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
98
- print(response.content)
99
- else:
100
- break
101
- except Exception as e:
102
- print(f"Agent returned an Error: {e}")
103
- response = None
104
- time.sleep(RETRY_SLEEP)
105
-
106
- cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
107
-
108
- if response is None:
109
- return "", cost
110
- else:
111
- return response.content, cost
112
-
113
- def prepare_message(self, model_input: dict, prompt_type: str):
114
- message = []
115
- return message
116
-
117
- def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
118
- total_cost = 0
119
- response_list = []
120
- # prepare message
121
- message = self.prepare_message(model_input, prompt_type)
122
- # print(message)
123
-
124
- # n sampling
125
- for i in range(self.num_generate):
126
- response, cost = self.generate_with_retry(message, constraint_str_list)
127
- response_list.append(response)
128
- total_cost += cost
129
-
130
- return response_list, total_cost
131
-
132
-
133
- class GroundingJudgeAgent(BaseAgent):
134
- def __init__(self, agent_config: dict):
135
- super().__init__(agent_config)
136
- self._setup()
137
-
138
- def prepare_message(self, model_input: dict, prompt_type):
139
- message = get_messages(
140
- input_info=model_input,
141
- inference_mode="judge_grounding",
142
- prompt_type=prompt_type,
143
- use_multimodal=self.use_multimodal,
144
- text_obs=self.agent_config["text_obs_type"],
145
- image_obs=self.agent_config["image_obs_type"]
146
- )
147
- return message
148
-
149
-
150
- class ProgressJudgeAgent(BaseAgent):
151
- def __init__(self, agent_config: dict):
152
- super().__init__(agent_config)
153
- self._setup()
154
-
155
- def prepare_message(self, model_input: dict, prompt_type):
156
- if self.agent_config["input_type"]=="text_only":
157
- use_multimodal = False
158
- text_obs = self.agent_config["text_obs_type"]
159
- image_obs = None
160
- elif self.agent_config["input_type"]=="image_only":
161
- use_multimodal = True
162
- text_obs = None
163
- image_obs = self.agent_config["image_obs_type"]
164
- elif self.agent_config["input_type"]=="text_image":
165
- use_multimodal = True
166
- text_obs = self.agent_config["text_obs_type"]
167
- image_obs = self.agent_config["image_obs_type"]
168
- else:
169
- raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
170
-
171
- if self.agent_config["use_in_progress"]:
172
- use_in_progress = True
173
- else:
174
- use_in_progress = False
175
-
176
- message = get_messages(
177
- input_info=model_input,
178
- inference_mode="judge_progress",
179
- prompt_type=prompt_type,
180
- use_checklist=self.use_checklist,
181
- use_multimodal=use_multimodal,
182
- text_obs=text_obs,
183
- image_obs=image_obs,
184
- use_in_progress=use_in_progress
185
- )
186
- return message
187
-
188
- def add_logprob(self, ori_logprob: float, add_logprob: float):
189
- if ori_logprob is None:
190
- return add_logprob
191
- else:
192
- ori_prob = math.exp(ori_logprob)
193
- add_prob = math.exp(add_logprob)
194
- return math.log(ori_prob + add_prob)
195
-
196
- def get_judge_probs(self, logprobs: list):
197
- # target_judge = {
198
- # "yes": [" Yes", "Yes"],
199
- # "no": [" No", "No"],
200
- # "in": [" In", "In"]
201
- # }
202
- target_judge = {
203
- "yes": [
204
- " Yes", "ĠYes", "Yes", "ĊYes",
205
- "Ġyes", "yes", "Ċyes",
206
- "ĠYES", "YES", "ĊYES",
207
- "ĠDone", "Done", "ĊDone",
208
- "ĠCompleted", "Completed", "ĊCompleted",
209
- "ĠCorrect", "Correct", "ĊCorrect"
210
- ],
211
- "no": [
212
- " No", "ĠNo", "No", "ĊNo",
213
- "ĠNO", "NO", "ĊNO",
214
- "ĠNot", "Not", "ĊNot",
215
- "ĠNone", "None", "ĊNone",
216
- "ĠNope", "Nope", "ĊNope",
217
- "ĠUn", "Un", "ĊUn",
218
- "ĠWrong", "Wrong", "ĊWrong"
219
- ],
220
- "in": [
221
- " In", "ĠIn", "In", "ĊIn",
222
- "ĠPending", "Pending", "ĊPending",
223
- "ĠPart", "Part", "ĊPart",
224
- "ĠPartial", "Partial", "ĊPartial",
225
- "ĠInProgress", "InProgress", "ĊInProgress"
226
- ]
227
- }
228
- response_str = ""
229
- judge_probs_list = []
230
- # print(logprobs)
231
- for i, log_prob in enumerate(logprobs):
232
- # Start to find judge string
233
- if "<answer>" in response_str:
234
- find_judge_str = None
235
- for judge_type in target_judge:
236
- if log_prob["token"] in target_judge[judge_type]:
237
- # print(log_prob)
238
- find_judge_str = judge_type
239
- break
240
- if find_judge_str:
241
- # print("find judge str")
242
- token_judge_dict = {
243
- "yes": None,
244
- "no": None,
245
- "in": None
246
- }
247
- if "top_logprobs" in log_prob:
248
- for token_info in log_prob["top_logprobs"]:
249
- for judge_type in target_judge:
250
- for judge_str in target_judge[judge_type]:
251
- # if judge_str in token_info["token"] and token_info["logprob"] > token_judge_dict[judge_type]:
252
- # token_judge_dict[judge_type] = token_info["logprob"]
253
- if judge_str in token_info["token"]:
254
- # print(token_info["logprob"])
255
- token_judge_dict[judge_type] = self.add_logprob(token_judge_dict[judge_type], token_info["logprob"])
256
- # for None case
257
- for judge_type in token_judge_dict:
258
- if token_judge_dict[judge_type] is None:
259
- token_judge_dict[judge_type] = float("-inf")
260
- judge_probs_list.append(token_judge_dict)
261
- else:
262
- # for vllm bugs : no top_logprobs
263
- for judge_type in token_judge_dict:
264
- if judge_type == find_judge_str:
265
- token_judge_dict[judge_type] = log_prob["logprob"]
266
- else:
267
- token_judge_dict[judge_type] = float("-inf")
268
- judge_probs_list.append(token_judge_dict)
269
- # print(token_judge_dict)
270
-
271
- if "</answer>" in response_str:
272
- break
273
-
274
- response_str += log_prob["token"]
275
- # print(response_str.replace("Ġ", " ").replace("Ċ", "\n"))
276
- # print(judge_probs_list)
277
- if len(judge_probs_list) == 0:
278
- return [{
279
- "yes": 0.0,
280
- "no": 0.0,
281
- "in": 0.0
282
- }]
283
- else:
284
- # convert with softmax
285
- final_judge_probs_list = []
286
- for judge_probs in judge_probs_list:
287
- exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
288
- sum_exp_logprobs = sum(exp_logprobs)
289
- softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
290
- final_judge_probs_list.append({
291
- "yes": softmax_probs[0],
292
- "no": softmax_probs[1],
293
- "in": softmax_probs[2]
294
- })
295
- return final_judge_probs_list
296
-
297
- def generate_probs(self, model_input: dict, prompt_type: str):
298
- total_cost = 0
299
- response_list = []
300
- # prepare message
301
- message = self.prepare_message(model_input, prompt_type)
302
- # print(message)
303
-
304
- for i in range(self.num_generate):
305
- try:
306
- response = self.llm.invoke(message)
307
- total_input_tokens = response.response_metadata["token_usage"]["prompt_tokens"]
308
- total_output_tokens = response.response_metadata["token_usage"]["completion_tokens"]
309
- total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
310
- logprobs = response.response_metadata["logprobs"]["content"]
311
- response_list.append(
312
- {
313
- "response": response.content,
314
- "judge_probs": self.get_judge_probs(logprobs)
315
- }
316
- )
317
- except Exception as e:
318
- print(f"Error: {e}")
319
- # print(response.response_metadata["logprobs"])
320
- response_list.append(
321
- {
322
- "response": response.content,
323
- "judge_probs": []
324
- }
325
- )
326
- return response_list, total_cost
327
-
328
-
329
- class ChecklistGenerationAgent(BaseAgent):
330
- def __init__(self, agent_config: dict):
331
- super().__init__(agent_config)
332
- self._setup()
333
-
334
- def prepare_message(self, model_input: dict, prompt_type):
335
- message = get_messages(
336
- input_info=model_input,
337
- inference_mode="checklist_generation",
338
- prompt_type=prompt_type
339
- )
340
- return message
341
-
342
-
343
- class ClassifierRewardAgent(Agent):
344
- def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
345
- self.url = url
346
- self.use_checklist = use_checklist
347
- self.use_multimodal = use_multimodal
348
-
349
- def _process_multimodal_message(self, prompt: str, image_list: list[str]):
350
- multimodal_message = []
351
- text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
352
- text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
353
- multimodal_message = [
354
- {"type": "text", "text": text_prompt_prefix},
355
- # {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
356
- {"type": "image", "image": image_to_base64_url(image_list[0])},
357
- {"type": "text", "text": text_prompt_suffix}
358
- ]
359
- return multimodal_message
360
-
361
- def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
362
- if self.use_multimodal:
363
- tmp_user_prompt = user_prompt_template["user"].format(
364
- **model_input
365
- )
366
- user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
367
- else:
368
- user_prompt = user_prompt_template["user"].format(
369
- **model_input
370
- )
371
- assistant_prompt = user_prompt_template["assistant"].format(
372
- **model_input
373
- )
374
- query = [
375
- {"role": "user", "content": user_prompt},
376
- {"role": "assistant", "content": assistant_prompt}
377
- ]
378
- return query
379
-
380
- def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
381
- if self.use_checklist:
382
- if self.use_multimodal:
383
- user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
384
- else:
385
- user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
386
- else:
387
- if self.use_multimodal:
388
- user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
389
- else:
390
- user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
391
-
392
- if self.use_multimodal:
393
- if batch:
394
- message = [self._make_query(user_prompt_template, input) for input in model_input]
395
- else:
396
- message = [self._make_query(user_prompt_template, model_input)]
397
- else:
398
- if batch:
399
- message = {
400
- "query": [self._make_query(user_prompt_template, input) for input in model_input],
401
- "promptts": []
402
- }
403
- else:
404
- message = {
405
- "query": self._make_query(user_prompt_template, model_input),
406
- "prompts": []
407
- }
408
-
409
- return message
410
-
411
- def get_rm_scroe(self, message: dict | list):
412
- headers = {"Content-Type": "application/json"}
413
-
414
- try:
415
- if self.use_multimodal:
416
- response = requests.post(
417
- self.url,
418
- json={"messages": message},
419
- timeout=600
420
- )
421
- else:
422
- response = requests.post(
423
- self.url,
424
- headers=headers,
425
- data=json.dumps(message),
426
- timeout=300
427
- )
428
- response.raise_for_status()
429
-
430
- response_json = response.json()
431
-
432
- if "rewards" not in response_json:
433
- print(f"Error: 'rewards' key not found in API response: {response_json}")
434
- return []
435
-
436
- if "get_reward" in self.url:
437
- # use openrlhf
438
- return response_json["rewards"]
439
- elif "pooling" in self.url:
440
- # use vllm server
441
- return response_json["reward"]
442
- else:
443
- # error
444
- raise ValueError(f"Invalid URL: {self.url}")
445
-
446
- except requests.exceptions.Timeout:
447
- print(f"Error: Request timed out to {self.url}")
448
- return []
449
- except requests.exceptions.RequestException as e:
450
- print(f"Error during request to {self.url}: {e}")
451
- return []
452
- except json.JSONDecodeError:
453
- print(f"Error: Failed to decode JSON response from {self.url}")
454
- return []
455
- except KeyError as e:
456
- print(f"Error: Missing key {e} in response from {self.url}")
457
- return []
458
-
459
-
460
- def generate_response(self, model_input: dict | list[dict], batch: bool = False):
461
- if batch:
462
- message = self.prepare_message(model_input, batch=True)
463
- else:
464
- message = self.prepare_message(model_input)
465
- rewards = self.get_rm_scroe(message)
466
-
467
- return rewards, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/checklist_eval.py DELETED
@@ -1,95 +0,0 @@
1
- import re
2
-
3
- from langchain_openai import ChatOpenAI
4
-
5
- from .agent import BaseAgent
6
-
7
- SYSTEM_PROMPT = "You are an expert evaluator. Your task is to assess how well a Web Agent’s generated checklist aligns with the reference checklist for a given user instruction."
8
-
9
- USER_PROMPT = """# Task Description
10
- Use the provided task description, evaluation criteria, and both checklists to assign a score from 1 to 5. Justify your rating with a brief explanation that considers both content overlap and logical structure.
11
-
12
- ## Score Criteria
13
- - 5: Checklist covers all subgoals, is correct and clearly expressed
14
- - 4: Minor omissions or phrasing issues but mostly accurate and complete
15
- - 3: Partially matches, but with noticeable gaps or errors
16
- - 2: Incomplete or includes incorrect steps
17
- - 1: Mostly irrelevant, incorrect, or missing the task goal
18
-
19
- ## User Instruction:
20
- {intent}
21
-
22
- ## Reference Checklist:
23
- {gt_checklist}
24
-
25
- ## Agent’s Generated Checklist:
26
- {generated_checklist}
27
-
28
- # Output Format
29
- Your response should be in the following format:
30
- REASON: [Write 2–4 sentences explaining how well the generated checklist matches the reference. Mention specific matches, omissions, errors, or strengths.]
31
- SCORE: [1–5]
32
- """
33
-
34
-
35
- class ChecklistEvalAgent(BaseAgent):
36
- def __init__(self, agent_config: dict):
37
- super().__init__(agent_config)
38
- self._setup()
39
-
40
- def prepare_message(self, model_input: dict, prompt_type):
41
- message = [
42
- {
43
- "role": "system",
44
- "content": SYSTEM_PROMPT
45
- },
46
- {
47
- "role": "user",
48
- "content": USER_PROMPT.format(
49
- intent=model_input["intent"],
50
- gt_checklist=model_input["gt_checklist"],
51
- generated_checklist=model_input["generated_checklist"]
52
- )
53
- }
54
- ]
55
- return message
56
-
57
- def generate_response(self, model_input: dict):
58
- total_cost = 0
59
- response_list = []
60
- # prepare message
61
- message = self.prepare_message(model_input)
62
-
63
- # n sampling
64
- for _ in range(self.num_generate):
65
- response, cost = self.generate_with_retry(message, ["SCORE"])
66
- response_list.append(response)
67
- total_cost += cost
68
-
69
- return response_list, total_cost
70
-
71
- def parsing_score(response: str):
72
- score = response.split("SCORE:")[-1].split("\n")[0].strip()
73
- match = re.search(r'\d+', score)
74
-
75
- if match:
76
- return int(match.group())
77
- else:
78
- return None
79
-
80
- def average_score(scores: list[int]):
81
- if len(scores) == 0:
82
- return 0
83
- return sum(scores) / len(scores)
84
-
85
- def get_score(results: list[dict]):
86
- score_list = []
87
- for result in results:
88
- tmp_scores = [parsing_score(response) for response in result["response"]]
89
- scores = [score for score in tmp_scores if score is not None]
90
- result["score_list"] = scores
91
- final_score = average_score(scores)
92
- result["score"] = final_score
93
- score_list.append(result)
94
-
95
- return results, score_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/eval_utils.py DELETED
@@ -1,309 +0,0 @@
1
- import re
2
- import random
3
- from collections import Counter
4
-
5
- from .utils import load_json, save_json, create_html_report
6
-
7
- random.seed(42)
8
- def get_score(response_list: list, indicator: str) -> int:
9
- if len(response_list) == 0:
10
- return [-100]
11
-
12
- if isinstance(response_list[0], float):
13
- return response_list
14
-
15
- if indicator == "prob":
16
- score_list = []
17
- for response in response_list:
18
- total_score = 0
19
- for judge_probs in response:
20
- yes_prob = judge_probs.get("yes", 0)
21
- in_progress_prob = judge_probs.get("in", 0)
22
- total_score += yes_prob + in_progress_prob * 0.5
23
- if len(response) > 0:
24
- score_list.append(total_score / len(response))
25
- else:
26
- score_list.append(0)
27
- return score_list
28
- else:
29
- score_list = []
30
- for response in response_list:
31
- if indicator == "SCORE":
32
- if "SCORE" in response:
33
- try:
34
- score_str = response.split("SCORE:")[1].split("\n")[0].strip()
35
- except:
36
- score_str = response.split("SCORE:")[-1].strip()
37
- # find first integer
38
- try:
39
- score = re.search(r'-?\d+', score_str).group()
40
- score_list.append(int(score))
41
- except:
42
- score_list.append(0)
43
- else:
44
- try:
45
- score_str = response.split("<answer>")[1].split("</answer>")[0].strip()
46
- except:
47
- score_str = response.split("<answer>")[-1].split("</answer>")[0].strip()
48
- # find "Yes" or "No"
49
- if "Yes" in score_str:
50
- score_list.append(1)
51
- elif "In Progress" in score_str:
52
- score_list.append(0.5)
53
- elif "No" in score_str:
54
- score_list.append(0)
55
- else:
56
- score_list.append(0)
57
- elif indicator == "JUDGE":
58
- try:
59
- judge_str = response.split("JUDGE:")[1].split("\n")[0].strip()
60
- except:
61
- judge_str = response.split("JUDGE:")[-1].strip()
62
- if "Yes" in judge_str:
63
- score_list.append(1)
64
- elif "No" in judge_str:
65
- score_list.append(0)
66
- else:
67
- score_list.append(0)
68
- elif indicator == "CHECKLIST EVALUATION":
69
- if "<answer>" in response:
70
- try:
71
- checklist_str = response.split("<answer>")[1].split("</answer>")[0].strip()
72
- except:
73
- checklist_str = response.split("<answer>")[-1].split("</answer>")[0].strip()
74
- else:
75
- checklist_str = response.split("CHECKLIST EVALUATION:")[-1].strip()
76
-
77
- count_yes = checklist_str.count("Yes")
78
- count_no = checklist_str.count("No")
79
- count_in_progress = checklist_str.count("In Progress")
80
- try:
81
- total_score = (count_yes + count_in_progress*0.5) / (count_yes + count_no + count_in_progress)
82
- except:
83
- total_score = 0
84
- score_list.append(total_score)
85
- else:
86
- raise ValueError(f"Invalid indicator: {indicator}")
87
- return score_list
88
-
89
- def get_acc_and_mrr(chosen_score, rejected_scores):
90
- if len(rejected_scores) == 0:
91
- return 0, False
92
-
93
- same_score_num = rejected_scores.count(chosen_score)
94
- all_scores = rejected_scores + [chosen_score]
95
- sorted_scores = sorted(all_scores, reverse=True)
96
- rank = sorted_scores.index(chosen_score) + 1 + same_score_num # draw penalty
97
- if all(chosen_score > r for r in rejected_scores):
98
- accuracy = True
99
- else:
100
- accuracy = False
101
- return 1 / rank, accuracy
102
-
103
- def average_score(score_list: list[float]):
104
- if len(score_list) == 0:
105
- return -100
106
- return sum(score_list) / len(score_list)
107
-
108
- def self_consistency_score(score_list: list[float]):
109
- if len(score_list) == 0:
110
- return -100
111
- counter = Counter(score_list)
112
- return max(counter.values()) / len(score_list)
113
-
114
- def get_chosen_rejected_scores(data: dict, agg_func: str):
115
- if len(data["chosen"]) == 0:
116
- data["chosen"] = [{"score": [-100]}]
117
- if len(data["rejected"]) == 0:
118
- data["rejected"] = [{"score": [-100]}]
119
- if not isinstance(data["chosen"][0], dict):
120
- data["chosen"][0]["score"] = [-100]
121
- if not isinstance(data["rejected"][0], dict):
122
- data["rejected"][0]["score"] = [-100]
123
-
124
- if agg_func == "average":
125
- chosen_score = average_score(data["chosen"][0]["score"])
126
- rejected_scores = [average_score(rejected_score["score"]) for rejected_score in data["rejected"]]
127
- elif agg_func == "self_consistency":
128
- chosen_score = self_consistency_score(data["chosen"][0]["score"])
129
- rejected_scores = [self_consistency_score(rejected_score["score"]) for rejected_score in data["rejected"]]
130
- else:
131
- raise ValueError(f"Invalid agg_func: {agg_func}")
132
- return chosen_score, rejected_scores
133
-
134
- def get_score_results(results, agg_func):
135
- score_dict = {"mrr": [], "accuracy": [], "traj_accuracy": []}
136
- task_accuracy = {}
137
- for result in results:
138
- chosen_score, rejected_scores = get_chosen_rejected_scores(result, agg_func)
139
- mrr, accuracy = get_acc_and_mrr(chosen_score, rejected_scores)
140
- score_dict["mrr"].append(mrr)
141
- score_dict["accuracy"].append(accuracy)
142
- if result["task_id"] not in task_accuracy:
143
- task_accuracy[result["task_id"]] = []
144
- task_accuracy[result["task_id"]].append(accuracy)
145
-
146
- for task_id in task_accuracy:
147
- if sum(task_accuracy[task_id]) == len(task_accuracy[task_id]):
148
- score_dict["traj_accuracy"].append(True)
149
- else:
150
- score_dict["traj_accuracy"].append(False)
151
-
152
- return score_dict
153
-
154
- def calculate_stats(results, agg_func: str="average"):
155
- if len(results) == 0:
156
- return {
157
- "MRR": 0,
158
- "Accuracy": 0,
159
- "Traj_Accuracy": 0,
160
- }
161
- total_score = get_score_results(results, agg_func)
162
- stats = {
163
- "MRR": sum(total_score["mrr"]) / len(total_score["mrr"]),
164
- "Accuracy": sum(total_score["accuracy"]) / len(total_score["accuracy"]),
165
- "Traj_Accuracy": sum(total_score["traj_accuracy"]) / len(total_score["traj_accuracy"]),
166
- }
167
-
168
- return stats
169
-
170
- def group_by_task(results, split_indicator: str):
171
- # sort results by task_id and step_id
172
- results.sort(key=lambda x: (x["task_id"], x["step_id"]))
173
- # group by task_name
174
- grouped_task_dict = {}
175
- for result in results:
176
- task_name = "task_" + str(result["task_id"]) + "_step_" + str(result["step_id"])
177
- if task_name not in grouped_task_dict:
178
- grouped_task_dict[task_name] = {
179
- "task_id": result["task_id"],
180
- "step_id": result["step_id"],
181
- "intent": result["intent"],
182
- "start_url": result["start_url"],
183
- "gt_checklist": result["gt_checklist"],
184
- "generated_checklist": result.get("generated_checklist", None) ,
185
- "trajectory": result["trajectory"],
186
- "current_url": result["current_url"],
187
- "text_observation": result["text_observation"],
188
- # "image_list": result["image_list"],
189
- "chosen": [],
190
- "rejected": [],
191
- "source_name": result["source_name"],
192
- }
193
-
194
- response = result["response"] if "response" in result else []
195
- type_data = {
196
- "thought": result["thought"],
197
- "action": result["action"],
198
- "response": response,
199
- "score": get_score(response, split_indicator) if split_indicator != "prob" else get_score(result["judge_probs"], split_indicator),
200
- }
201
- if split_indicator == "prob":
202
- type_data["judge_probs"] = result["judge_probs"]
203
- if result["type"] == "chosen":
204
- grouped_task_dict[task_name]["chosen"].append(type_data)
205
- elif result["type"] == "rejected":
206
- grouped_task_dict[task_name]["rejected"].append(type_data)
207
-
208
- return list(grouped_task_dict.values())
209
-
210
-
211
- def processing_results(results, evaluation_mode: str, num_generate: int, use_batch: bool=False):
212
- if "judge_probs" in results[0]:
213
- split_indicator = "prob"
214
- else:
215
- if evaluation_mode == "judge_with_checklist_generation" or evaluation_mode == "judge_with_gt_checklist":
216
- split_indicator = "CHECKLIST EVALUATION"
217
- else:
218
- split_indicator = "SCORE"
219
-
220
- # if use_batch is True, make it flattened
221
- if use_batch:
222
- tmp_results = []
223
- for result in results:
224
- for d in result:
225
- tmp_results.append(d)
226
- grouped_results = group_by_task(tmp_results, split_indicator)
227
- else:
228
- grouped_results = group_by_task(results, split_indicator)
229
-
230
- mind2web_results = []
231
- webarena_results = []
232
- mind2web_task_results = []
233
- mind2web_website_results = []
234
- mind2web_domain_results = []
235
-
236
- for grouped_result in grouped_results:
237
- if "mind2web" in grouped_result["source_name"]:
238
- mind2web_results.append(grouped_result)
239
- if grouped_result["source_name"] == "mind2web_test_task":
240
- mind2web_task_results.append(grouped_result)
241
- elif grouped_result["source_name"] == "mind2web_test_website":
242
- mind2web_website_results.append(grouped_result)
243
- elif grouped_result["source_name"] == "mind2web_test_domain":
244
- mind2web_domain_results.append(grouped_result)
245
- elif "webarena" in grouped_result["source_name"]:
246
- webarena_results.append(grouped_result)
247
-
248
- try:
249
- final_stats = {
250
- "mind2web": {
251
- "MRR": {},
252
- "Accuracy": {},
253
- "Traj_Accuracy": {},
254
- },
255
- "webarena": {
256
- "MRR": {},
257
- "Accuracy": {},
258
- "Traj_Accuracy": {},
259
- },
260
- "mind2web_task": {
261
- "MRR": {},
262
- "Accuracy": {},
263
- "Traj_Accuracy": {},
264
- },
265
- "mind2web_website": {
266
- "MRR": {},
267
- "Accuracy": {},
268
- "Traj_Accuracy": {},
269
- },
270
- "mind2web_domain": {
271
- "MRR": {},
272
- "Accuracy": {},
273
- "Traj_Accuracy": {},
274
- },
275
- }
276
- for source_results in [
277
- ("mind2web", mind2web_results),
278
- ("webarena", webarena_results),
279
- ("mind2web_task", mind2web_task_results),
280
- ("mind2web_website", mind2web_website_results),
281
- ("mind2web_domain", mind2web_domain_results)
282
- ]:
283
- average_stats = calculate_stats(source_results[1], "average")
284
- self_consistency_stats = calculate_stats(source_results[1], "self_consistency")
285
- for metric in average_stats:
286
- final_stats[source_results[0]][metric]["Average"] = average_stats[metric]
287
- for metric in self_consistency_stats:
288
- final_stats[source_results[0]][metric]["Self_Consistency"] = self_consistency_stats[metric]
289
-
290
- if num_generate == 1:
291
- for source_name in final_stats:
292
- for metric in final_stats[source_name]:
293
- print(f"{round(100 * final_stats[source_name][metric]['Average'], 2)}", end=", ")
294
- print()
295
- else:
296
- for agg_func in ["Average", "Self_Consistency"]:
297
- print(f"{agg_func}")
298
- for source_name in final_stats:
299
- for metric in final_stats[source_name]:
300
- print(f"{round(100 * final_stats[source_name][metric][agg_func], 2)}", end=", ")
301
- print()
302
- except Exception as e:
303
- print(e)
304
- return grouped_results, None
305
-
306
- # add function to convert json format results to html format results
307
- # TODO: implement this function
308
- # create_html_report(results, "results.html")
309
- return grouped_results, final_stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/inference_utils.py DELETED
@@ -1,87 +0,0 @@
1
- import time
2
-
3
- from multiprocessing import Process, Manager
4
- from tqdm import tqdm
5
-
6
-
7
- def worker_main(work_queue, result_queue, process_func, config):
8
- while True:
9
- item = work_queue.get()
10
- if item is None:
11
- result_queue.put(None)
12
- break
13
- try:
14
- results, cost = process_func(config, item)
15
- result_queue.put((results, cost))
16
- except Exception as e:
17
- item_info = item.get('idx', item.get('id', 'unknown item'))
18
- print(f"Error processing item {item_info}: {e}")
19
- result_queue.put(None)
20
- finally:
21
- work_queue.task_done()
22
-
23
- def run_parallel_evaluation(dataset, process_func, config, num_workers, description):
24
- """
25
- Runs parallel evaluation on the given dataset and returns the results.
26
-
27
- Args:
28
- dataset (list or datasets.Dataset): Data to evaluate.
29
- process_func (callable): Function to process each data item.
30
- config (dict): Configuration for the process_func.
31
- num_workers (int): Number of worker processes to use.
32
- description (str): Description to display on the tqdm progress bar.
33
-
34
- Returns:
35
- tuple: (list of evaluation results, total cost)
36
- """
37
- manager = Manager()
38
- work_queue = manager.Queue()
39
- result_queue = manager.Queue()
40
-
41
- # Add data to the work queue
42
- dataset_list = list(dataset) if not isinstance(dataset, list) else dataset
43
- for data in dataset_list:
44
- work_queue.put(data)
45
-
46
- # Add termination signals for workers
47
- for _ in range(num_workers):
48
- work_queue.put(None)
49
-
50
- # Start parallel processing
51
- processes = []
52
- for _ in range(num_workers):
53
- p = Process(target=worker_main, args=(work_queue, result_queue, process_func, config))
54
- p.start()
55
- processes.append(p)
56
-
57
- # Show progress bar and collect results
58
- process_results = []
59
- process_cost = 0
60
- completed_workers = 0
61
-
62
- with tqdm(total=len(dataset_list), desc=description) as pbar:
63
- while completed_workers < num_workers:
64
- result_item = result_queue.get()
65
- if result_item is None:
66
- completed_workers += 1
67
- else:
68
- results, cost = result_item
69
- if results is not None:
70
- process_results.append(results)
71
- process_cost += cost if cost is not None else 0
72
- pbar.update(1)
73
-
74
- # Wait for all processes to finish
75
- for p in processes:
76
- p.join()
77
-
78
- # Collect remaining results
79
- while not result_queue.empty():
80
- result_item = result_queue.get_nowait()
81
- if result_item is not None:
82
- results, cost = result_item
83
- if results is not None:
84
- process_results.append(results)
85
- process_cost += cost if cost is not None else 0
86
-
87
- return process_results, process_cost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .construct_messages import get_messages
 
 
agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (263 Bytes)
 
agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc DELETED
Binary file (2.85 kB)
 
agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc DELETED
Binary file (3.11 kB)
 
agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc DELETED
Binary file (15 kB)
 
agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc DELETED
Binary file (5.46 kB)
 
agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc DELETED
Binary file (1.71 kB)
 
agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc DELETED
Binary file (1.03 kB)
 
agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc DELETED
Binary file (5.64 kB)
 
agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc DELETED
Binary file (1.19 kB)
 
agent/mini_bench/prompts/action.py DELETED
@@ -1,93 +0,0 @@
1
- ACTION_SPACE_PROMPT = """Note: This action set allows you to interact with your environment. Most of them are python function executing playwright code. The primary way of referring to elements in the page is through bid which are specified in your observations.
2
-
3
- 15 different types of actions are available.
4
-
5
- noop(wait_ms: float = 1000)
6
- Examples:
7
- noop()
8
-
9
- noop(500)
10
-
11
- scroll(delta_x: float, delta_y: float)
12
- Examples:
13
- scroll(0, 200)
14
-
15
- scroll(-50.2, -100.5)
16
-
17
- keyboard_press(key: str)
18
- Examples:
19
- keyboard_press('Backspace')
20
-
21
- keyboard_press('ControlOrMeta+a')
22
-
23
- keyboard_press('Meta+Shift+t')
24
-
25
- click(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
26
- Examples:
27
- click('a51')
28
-
29
- click('b22', button='right')
30
-
31
- click('48', button='middle', modifiers=['Shift'])
32
-
33
- fill(bid: str, value: str)
34
- Examples:
35
- fill('237', 'example value')
36
-
37
- fill('45', 'multi-line\nexample')
38
-
39
- fill('a12', 'example with "quotes"')
40
-
41
- hover(bid: str)
42
- Examples:
43
- hover('b8')
44
-
45
- tab_focus(index: int)
46
- Examples:
47
- tab_focus(2)
48
-
49
- new_tab()
50
- Examples:
51
- new_tab()
52
-
53
- go_back()
54
- Examples:
55
- go_back()
56
-
57
- go_forward()
58
- Examples:
59
- go_forward()
60
-
61
- goto(url: str)
62
- Examples:
63
- goto('http://www.example.com')
64
-
65
- tab_close()
66
- Examples:
67
- tab_close()
68
-
69
- select_option(bid: str, options: str | list[str])
70
- Examples:
71
- select_option('a48', 'blue')
72
-
73
- select_option('c48', ['red', 'green', 'blue'])
74
-
75
- send_msg_to_user(text: str)
76
- Examples:
77
- send_msg_to_user('Based on the results of my search, the city was built in 1751.')
78
-
79
- report_infeasible(reason: str)
80
- Examples:
81
- report_infeasible('I cannot follow these instructions because there is no email field in this form.')
82
-
83
- Only a single action can be provided at once. Example:
84
- fill('a12', 'example with "quotes"')
85
-
86
- Note:
87
- * Some tasks may be game like and may require to interact with the mouse position in x, y coordinates.
88
- * Some text field might have auto completion. To see it, you have to type a few characters and wait until next step.
89
- * If you have to cut and paste, don't forget to select the text first.
90
- * Coordinate inside an SVG are relative to it's top left corner.
91
- * Make sure to use bid to identify elements when using commands.
92
- * Interacting with combobox, dropdowns and auto-complete fields can be tricky, sometimes you need to use select_option, while other times you need to use fill or click and wait for the reaction of the page.
93
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/checklist_prompt.py DELETED
@@ -1,50 +0,0 @@
1
- CHECKLIST_SYSTEM_PROMPT = "You are an AI assistant tasked with generating structured checklists that highlight key subgoals necessary to complete a task."
2
-
3
- CHECKLIST_USER_PROMPT = """## Task Description
4
- User Instruction (Goal): "{intent}"
5
- Start Website URL: {start_url}
6
-
7
- ## Guidelines for Checklist Generation
8
- 1. Identify Essential High-Level Subgoals:
9
- - A subgoal should represent a significant step involving user interaction that leads to noticeable page transitions or meaningful changes in system state.
10
- - Consolidate closely related user actions (such as applying multiple filters or selecting several options) into a single subgoal, rather than separate checklist items for each action.
11
- - Prioritize only the most critical interactions necessary for meaningful progression, avoiding the inclusion of minor or unnecessary steps (e.g., scroll, hover).
12
- 2. Provide a Concise Subgoal Analysis:
13
- - Before creating the checklist, offer a brief paragraph summarizing the main subgoals, emphasizing significant transitions or page-level interactions.
14
- 3. Ensure Clear Goal:
15
- - If multiple related interactions occur (e.g., setting filters 1, 2, and 3), combine them into one subgoal with clear criteria verifying all required conditions.
16
- - The checklist should contain only essential steps, explicitly excluding unnecessary actions, and should not exceed five critical subgoals. It is not necessary to use all five checklist items if fewer steps adequately represent the essential subgoals.
17
-
18
- ### Output Format
19
- Before generating the checklist, first produce a concise subgoal analysis in a single paragraph summarizing the required interactions. Then, based on this, generate the checklist following the format below:
20
- [SUBGOAL ANALYSIS]
21
- [One-paragraph summary explaining the key subgoals and their logical sequence in task completion.]
22
-
23
- [CHECKLISTS]
24
- Checklist X: [Short title of the action/goal]
25
- - Goal: [Brief description of the subgoal at this stage, emphasizing the purpose of the action.]
26
- """
27
-
28
- # TODO: implement ours
29
- CHECKLIST_OURS_SYSTEM_PROMPT = ""
30
-
31
- CHECKLIST_OURS_USER_PROMPT = """You are an AI assistant tasked with generating structured checklists that highlight key subgoals necessary to complete a task.
32
-
33
- # Task Description
34
- Generate a checklist which are key milestones for achieving the given instruction. Frist, provide a concise
35
- subgoal analysis in a single paragraph summarizing the required interactions. Then, based on this, generate the checklist with breif description.
36
-
37
- Note: If the target website requires login, assume the user is already logged in and starts from an authenticated session.
38
-
39
- # Given Information
40
- ## User Instruction
41
- {intent}
42
-
43
- ## Current State
44
- ### Current URL
45
- {start_url}
46
-
47
- ### AXTREE
48
- Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
49
- {text_observation}
50
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/construct_messages.py DELETED
@@ -1,309 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- from .action import ACTION_SPACE_PROMPT
4
- from .eval_type import (
5
- GROUNDING,
6
- PROGRESS_LIKERT_SCALE,
7
- PROGRESS_THREE_CLASS,
8
- PROGRESS_WITH_CHECKLIST,
9
- PROGRESS_WITH_CHECKLIST_IN_PROGRESS,
10
- PROGRESS_OURS
11
- )
12
- from .input_information import (
13
- USER_INSTRUCTION,
14
- TRAJECTORY,
15
- AGENT_RESPONSE,
16
- CHECKLIST,
17
- CURRENT_URL,
18
- TEXT_OBSERVATION,
19
- SOM_IMAGE_OBSERVATION,
20
- COORD_IMAGE_OBSERVATION
21
- )
22
- from .judge_prompt import (
23
- JUDGE_GROUNDING_PROMPT_TEMPLATE,
24
- JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE,
25
- JUDGE_THREE_CLASS_PROMPT_TEMPLATE,
26
- JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE,
27
- JUDGE_OURS_PROMPT_TEMPLATE,
28
- JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE
29
- )
30
- from .checklist_prompt import (
31
- CHECKLIST_SYSTEM_PROMPT,
32
- CHECKLIST_USER_PROMPT,
33
- CHECKLIST_OURS_SYSTEM_PROMPT,
34
- CHECKLIST_OURS_USER_PROMPT
35
- )
36
- from .image_utils import image_to_base64_url
37
-
38
-
39
- class Message(ABC):
40
- @abstractmethod
41
- def get_messages(self):
42
- pass
43
-
44
- class BaseMessage(Message):
45
- def __init__(self, input_info:dict, use_multimodal:bool=False):
46
- self.input_info = input_info
47
- self.use_multimodal = use_multimodal
48
-
49
- def _get_system_message(self):
50
- system_message = {"role": "system", "content": "You are a helpful assistant."}
51
- return system_message
52
-
53
- def _process_multimodal_message(self, prompt: str, image_list: list[str]):
54
- multimodal_message = []
55
- text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
56
- text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
57
- multimodal_message.append({"type": "text", "text": text_prompt_prefix})
58
- for i, image in enumerate(image_list):
59
- # TODO: text prompt for multiple images
60
- # multimodal_message.append({"type": "text", "text": f"IMAGE {i+1}\n"})
61
- multimodal_message.append({"type": "image_url", "image_url": {"url": image_to_base64_url(image), "detail": "low"}})
62
- multimodal_message.append({"type": "text", "text": text_prompt_suffix})
63
- return {"role": "user", "content": multimodal_message}
64
-
65
- def _get_user_message(self):
66
- user_prompt = "What is the capital of France?"
67
- if self.use_multimodal:
68
- image_list = self.input_info.get("image_list", [])
69
- user_message = self._process_multimodal_message(user_prompt, image_list)
70
- else:
71
- user_message = {"role": "user", "content": user_prompt}
72
- return user_message
73
-
74
- def get_messages(self):
75
- message = []
76
- system_message = self._get_system_message()
77
- user_message = self._get_user_message()
78
-
79
- message.append(system_message)
80
- # message.append({"role": "system", "content": ""})
81
- message.append(user_message)
82
- return message
83
-
84
-
85
- class ProgressMessage(BaseMessage):
86
- '''
87
- Progress Judge Message
88
- '''
89
- def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str, use_checklist:bool, use_in_progress:bool):
90
- super().__init__(input_info, use_multimodal)
91
- self.prompt_type = prompt_type
92
- self.text_obs = text_obs
93
- self.image_obs = image_obs
94
- self.use_checklist = use_checklist
95
- self.use_in_progress = use_in_progress
96
-
97
- def _get_system_message(self):
98
- if self.prompt_type == "likert_scale":
99
- system_message = {"role": "system", "content": JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["system"]}
100
- elif self.prompt_type == "three_class":
101
- system_message = {"role": "system", "content": JUDGE_THREE_CLASS_PROMPT_TEMPLATE["system"]}
102
- elif self.prompt_type == "with_checklist":
103
- system_message = {"role": "system", "content": JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["system"]}
104
- elif self.prompt_type == "ours":
105
- system_message = {"role": "system", "content": JUDGE_OURS_PROMPT_TEMPLATE["system"]}
106
- else:
107
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
108
- return system_message
109
-
110
- def _setup_input_information(self):
111
- observation = "## Current State\n"
112
-
113
- observation += CURRENT_URL
114
-
115
- # text observation
116
- if self.text_obs:
117
- observation += TEXT_OBSERVATION
118
-
119
- # image observation (som, coord, none)
120
- if self.image_obs == "som":
121
- observation += SOM_IMAGE_OBSERVATION
122
- elif self.image_obs == "coord":
123
- observation += COORD_IMAGE_OBSERVATION
124
-
125
-
126
- if self.use_checklist:
127
- input_information = USER_INSTRUCTION + TRAJECTORY + observation + CHECKLIST + AGENT_RESPONSE
128
- else:
129
- input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE
130
-
131
- return input_information
132
-
133
- def _setup_task_info(self):
134
- if self.prompt_type == "likert_scale":
135
- task_description = PROGRESS_LIKERT_SCALE["task_description"]
136
- output_format = PROGRESS_LIKERT_SCALE["output_format"]
137
- elif self.prompt_type == "three_class":
138
- task_description = PROGRESS_THREE_CLASS["task_description"]
139
- output_format = PROGRESS_THREE_CLASS["output_format"]
140
- elif self.prompt_type == "with_checklist":
141
- if self.use_in_progress:
142
- task_description = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["task_description"]
143
- output_format = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["output_format"]
144
- else:
145
- task_description = PROGRESS_WITH_CHECKLIST["task_description"]
146
- output_format = PROGRESS_WITH_CHECKLIST["output_format"]
147
- else:
148
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
149
- return task_description, output_format
150
-
151
- def _get_user_prompt_template(self):
152
- if self.prompt_type == "likert_scale":
153
- user_prompt = JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["user"]
154
- elif self.prompt_type == "three_class":
155
- user_prompt = JUDGE_THREE_CLASS_PROMPT_TEMPLATE["user"]
156
- elif self.prompt_type == "with_checklist":
157
- user_prompt = JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["user"]
158
- else:
159
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
160
- return user_prompt
161
-
162
- def _get_user_message(self):
163
- # setup input information (user_instruction, trajectory, current_state, agent_response, checklist)
164
- input_information_template = self._setup_input_information()
165
- input_information = input_information_template.format(**self.input_info)
166
-
167
- if self.prompt_type == "ours":
168
- if self.use_checklist:
169
- user_prompt = JUDGE_OURS_PROMPT_TEMPLATE["user"].format(
170
- input_information=input_information,
171
- )
172
- else:
173
- user_prompt = JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE["user"].format(
174
- input_information=input_information,
175
- )
176
- else:
177
- task_description, output_format = self._setup_task_info()
178
- # get user prompt template by prompt type
179
- user_prompt_template = self._get_user_prompt_template()
180
- user_prompt = user_prompt_template.format(
181
- action_space=ACTION_SPACE_PROMPT,
182
- task_description=task_description,
183
- input_information=input_information,
184
- output_format=output_format
185
- )
186
-
187
- # process multimodal message
188
- if self.use_multimodal:
189
- image_list = self.input_info.get("image_list", [])
190
- user_message = self._process_multimodal_message(user_prompt, image_list)
191
- else:
192
- user_message = {"role": "user", "content": user_prompt}
193
-
194
- return user_message
195
-
196
-
197
- class GroundingMessage(BaseMessage):
198
- '''
199
- Grounding Judge Message
200
- '''
201
- def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str):
202
- super().__init__(input_info, use_multimodal)
203
- self.prompt_type = prompt_type
204
- self.text_obs = text_obs
205
- self.image_obs = image_obs
206
-
207
- def _get_system_message(self):
208
- if self.prompt_type == "ours":
209
- # TODO: implement ours
210
- system_message = {"role": "system", "content": "You are a helpful assistant."}
211
- elif self.prompt_type == "default":
212
- system_message = {"role": "system", "content": JUDGE_GROUNDING_PROMPT_TEMPLATE["system"]}
213
- else:
214
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
215
- return system_message
216
-
217
- def _setup_input_information(self):
218
- observation = "## Current State\n"
219
-
220
- observation += CURRENT_URL
221
-
222
- # text observation
223
- if self.text_obs:
224
- observation += TEXT_OBSERVATION
225
-
226
- # image observation (som, coord, none)
227
- if self.image_obs == "som":
228
- observation += SOM_IMAGE_OBSERVATION
229
- elif self.image_obs == "coord":
230
- observation += COORD_IMAGE_OBSERVATION
231
-
232
- # input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE # with trajectory
233
- input_information = USER_INSTRUCTION + observation + AGENT_RESPONSE # without trajectory
234
-
235
- return input_information
236
-
237
- def _get_user_message(self):
238
- if self.prompt_type == "ours":
239
- # TODO: implement ours
240
- user_message = {"role": "user", "content": "TODO"}
241
- elif self.prompt_type == "default":
242
- action_space = ACTION_SPACE_PROMPT
243
- task_description = GROUNDING["task_description"]
244
- output_format = GROUNDING["output_format"]
245
- input_information_template = self._setup_input_information()
246
- input_information = input_information_template.format(**self.input_info)
247
-
248
- user_prompt = JUDGE_GROUNDING_PROMPT_TEMPLATE["user"].format(
249
- action_space=action_space,
250
- task_description=task_description,
251
- input_information=input_information,
252
- output_format=output_format
253
- )
254
-
255
- # process multimodal message
256
- if self.use_multimodal:
257
- image_list = self.input_info.get("image_list", [])
258
- user_message = self._process_multimodal_message(user_prompt, image_list)
259
- else:
260
- user_message = {"role": "user", "content": user_prompt}
261
- else:
262
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
263
- return user_message
264
-
265
-
266
- class ChecklistMessage(BaseMessage):
267
- '''
268
- Checklist Message
269
- '''
270
- def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str):
271
- super().__init__(input_info, use_multimodal)
272
- self.prompt_type = prompt_type
273
-
274
- def _get_system_message(self):
275
- if self.prompt_type == "ours":
276
- # TODO: implement ours
277
- system_message = {"role": "system", "content": CHECKLIST_OURS_SYSTEM_PROMPT}
278
- elif self.prompt_type == "default":
279
- system_message = {"role": "system", "content": CHECKLIST_SYSTEM_PROMPT}
280
- else:
281
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
282
- return system_message
283
-
284
- def _get_user_message(self):
285
- if self.prompt_type == "ours":
286
- user_message = {"role": "user", "content": CHECKLIST_OURS_USER_PROMPT.format(**self.input_info)}
287
- elif self.prompt_type == "default":
288
- user_message = {"role": "user", "content": CHECKLIST_USER_PROMPT.format(**self.input_info)}
289
- else:
290
- raise ValueError(f"Invalid prompt type: {self.prompt_type}")
291
- return user_message
292
-
293
-
294
- def get_messages(input_info:dict, inference_mode:str, prompt_type:str, text_obs:str=None, image_obs:str=None, use_multimodal:bool=False, use_checklist:bool=False, use_in_progress:bool=False):
295
- message_list = []
296
- if inference_mode == "judge_grounding":
297
- message = GroundingMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs)
298
- elif inference_mode == "judge_progress":
299
- message = ProgressMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs, use_checklist=use_checklist, use_in_progress=use_in_progress)
300
- elif inference_mode == "checklist_generation":
301
- message = ChecklistMessage(input_info, use_multimodal=False, prompt_type=prompt_type)
302
- else:
303
- raise ValueError(f"Invalid inference mode: {inference_mode}")
304
-
305
- system_message, user_message = message.get_messages()
306
-
307
- message_list.append(system_message)
308
- message_list.append(user_message)
309
- return message_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/eval_type.py DELETED
@@ -1,107 +0,0 @@
1
- # Task Description & Output Format
2
- GROUNDING_TASK = """Your task is to evaluate whether the agent's ACTION is properly grounded in its THOUGHT, considering the current state of the webpage.
3
- Use the user instruction, the current webpage state, and the agent's thought and action as evidence for your judgment. Your evaluation should assess whether the ACTION logically follows from the THOUGHT and is feasible and appropriate in the given environment.
4
- Mark the action as 'Yes' only if it is clearly and fully grounded in the thought and current webpage state. If there is any inconsistency, ambiguity, irrelevance, or if the action is not supported by the current page state, mark it as 'No'."""
5
-
6
- GROUNDING_FORMAT = """Please return your response in the following format:
7
- REASON: [Your explanation for whether the action is properly grounded]
8
- JUDGE: [Yes / No]"""
9
-
10
-
11
- PROGRESS_LIKERT_SCALE_TASK = """Evaluate how helpful the given thought and action is for achieving the goal. Use the following scale:
12
- **Scoring Criteria (1 to 5):**
13
- - **5 (Very Helpful)**: The action directly and effectively moves toward fulfilling a key part of the goal.
14
- - **4 (Helpful)**: The action contributes meaningfully to progress, though it may require follow-up actions.
15
- - **3 (Somewhat Helpful)**: The action is partially relevant or a preparatory step, but doesn’t make immediate progress.
16
- - **2 (Slightly Helpful)**: The action is weakly related to the goal or might only indirectly help.
17
- - **1 (Not Helpful)**: The action is unrelated, redundant, or distracts from the goal."""
18
-
19
- PROGRESS_LIKERT_SCALE_FORMAT = """Please return your response in the following format:
20
- REASON: [Your explanation for the score]
21
- SCORE: [1-5]"""
22
-
23
-
24
- PROGRESS_THREE_CLASS_TASK = """Evaluate how helpful the given thought and action is for achieving the goal. Use the following scale:
25
- **Scoring Criteria:**
26
- - **1 (Helpful)**: The action clearly contributes to achieving the goal. It takes a necessary or productive step toward completing the task.
27
- - **0 (Neutral)**: The action is neither helpful nor harmful. It may be a placeholder, irrelevant at the current step, or too ambiguous to evaluate.
28
- - **-1 (Not Helpful)**: The action works against the goal, causes confusion, repeats a previous step unnecessarily, or leads the agent off track."""
29
-
30
- PROGRESS_THREE_CLASS_FORMAT = """Please return your response in the following format:
31
- REASON: [Your explanation for the score]
32
- SCORE: [-1 / 0 / 1]"""
33
-
34
-
35
- PROGRESS_WITH_CHECKLIST_TASK = """Your task is to evaluate how well the agent's THOUGHT and ACTION satisfy each item in the checklist.
36
- Use the task instruction, trajectory (including previously completed steps from history), current webpage state, and the agent's current response as evidence for your evaluation.
37
- For each checklist item:
38
- - Mark it as 'Yes' if it is clearly and fully satisfied either in the current response or already completed in the history.
39
- - Mark it as 'No' if there is ambiguity, insufficient evidence, or the step is incomplete or not yet started."""
40
-
41
- PROGRESS_WITH_CHECKLIST_FORMAT = """Please return your response in the following format:
42
- REASON: [Write a single, coherent paragraph explaining how well the agent's response satisfies the checklist overall. Use both the history and the agent's current thought/action as evidence. Mention specific strengths or missing elements that influence your decision.]
43
- CHECKLIST EVALUATION:
44
- Checklist X: [Yes / No]
45
- """
46
-
47
- PROGRESS_WITH_CHECKLIST_IN_PROGRESS_TASK = """Your task is to evaluate how well the agent's THOUGHT and ACTION satisfy each item in the checklist.
48
- Use the task instruction, trajectory (including previously completed steps from history), current webpage state, and the agent's current response as evidence for your evaluation. Clearly consider any items already successfully completed or currently in progress according to the provided trajectory.
49
- For each checklist item:
50
- - Mark it as 'Yes' if it is clearly and fully satisfied either in the current response or already completed in the history.
51
- - Mark it as 'In Progress' if the agent has made partial but meaningful progress toward completing the item.
52
- - Mark it as 'No' if there is ambiguity, insufficient evidence, or the step is incomplete or not yet started."""
53
-
54
- PROGRESS_WITH_CHECKLIST_IN_PROGRESS_FORMAT = """Please return your response in the following format:
55
- REASON: [Write a single, coherent paragraph explaining how well the agent's response satisfies the checklist overall. Use both the history and the agent's current thought/action as evidence. Mention specific strengths or missing elements that influence your decision.]
56
- CHECKLIST EVALUATION:
57
- Checklist X: [Yes / In Progress / No]
58
- """
59
-
60
-
61
- GROUNDING_OURS_TASK = """
62
- """
63
-
64
- GROUNDING_OURS_FORMAT = """
65
- """
66
-
67
- PROGRESS_OURS_TASK = """
68
- """
69
-
70
- PROGRESS_OURS_FORMAT = """
71
- """
72
-
73
- ## EVALUATION TYPE
74
- GROUNDING = {
75
- "task_description": GROUNDING_TASK,
76
- "output_format": GROUNDING_FORMAT,
77
- }
78
-
79
- GROUNDING_OURS = {
80
- "task_description": GROUNDING_OURS_TASK,
81
- "output_format": GROUNDING_OURS_FORMAT,
82
- }
83
-
84
- PROGRESS_LIKERT_SCALE = {
85
- "task_description": PROGRESS_LIKERT_SCALE_TASK,
86
- "output_format": PROGRESS_LIKERT_SCALE_FORMAT,
87
- }
88
-
89
- PROGRESS_THREE_CLASS = {
90
- "task_description": PROGRESS_THREE_CLASS_TASK,
91
- "output_format": PROGRESS_THREE_CLASS_FORMAT,
92
- }
93
-
94
- PROGRESS_WITH_CHECKLIST = {
95
- "task_description": PROGRESS_WITH_CHECKLIST_TASK,
96
- "output_format": PROGRESS_WITH_CHECKLIST_FORMAT,
97
- }
98
-
99
- PROGRESS_WITH_CHECKLIST_IN_PROGRESS = {
100
- "task_description": PROGRESS_WITH_CHECKLIST_IN_PROGRESS_TASK,
101
- "output_format": PROGRESS_WITH_CHECKLIST_IN_PROGRESS_FORMAT,
102
- }
103
-
104
- PROGRESS_OURS = {
105
- "task_description": PROGRESS_OURS_TASK,
106
- "output_format": PROGRESS_OURS_FORMAT,
107
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/image_utils.py DELETED
@@ -1,19 +0,0 @@
1
- import base64
2
- import io
3
- from PIL import Image
4
-
5
-
6
- def image_to_base64_url(image: str | Image.Image):
7
- if isinstance(image, str):
8
- with open(image, "rb") as f:
9
- image = f.read()
10
- elif isinstance(image, Image.Image):
11
- if image.mode in ("RGBA", "LA"):
12
- image = image.convert("RGB")
13
- with io.BytesIO() as buffer:
14
- image.save(buffer, format="PNG")
15
- image = buffer.getvalue()
16
- else:
17
- raise ValueError(f"Invalid image type: {type(image)}")
18
-
19
- return "data:image/png;base64," + base64.b64encode(image).decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/input_information.py DELETED
@@ -1,36 +0,0 @@
1
- USER_INSTRUCTION = """## User Instruction
2
- {intent}
3
- """
4
-
5
- TRAJECTORY = """## Trajectory
6
- {trajectory}"""
7
-
8
- AGENT_RESPONSE = """## Agent's Response
9
- THOUGHT: {thought}
10
- ACTION: {action}
11
- """
12
-
13
- CHECKLIST = """## Checklist
14
- {checklist}
15
- """
16
-
17
-
18
- # Observation
19
- CURRENT_URL = """### Current URL
20
- {current_url}
21
- """
22
-
23
- TEXT_OBSERVATION = """### AXTREE
24
- Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
25
- {text_observation}
26
- """
27
-
28
- SOM_IMAGE_OBSERVATION = """### SOM Image Screenshot
29
- Here is a current image screenshot of the page, it is annotated with bounding boxes and corresponding bids:
30
- <IMAGE_PLACEHOLDER>
31
- """
32
-
33
- COORD_IMAGE_OBSERVATION = """### Raw Image Screenshot
34
- Here is a screenshot of the page:
35
- <IMAGE_PLACEHOLDER>
36
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/judge_prompt.py DELETED
@@ -1,159 +0,0 @@
1
- # SYSTEM PROMPT
2
- DEFAULT_SYSTEM_PROMPT_FORMAT = "You are an expert evaluator of web agent. {role_description}"
3
-
4
- PROGRESS_WITHOUT_CHECKLIST_ROLE = "Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage."
5
- PROGRESS_WITH_CHECKLIST_ROLE = "Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage."
6
-
7
- GROUNDING_ROLE = "Your task is to assess whether the ACTION taken by the agent is properly grounded, based on agent's THOUGHT and the current state of the webpage."
8
-
9
- # USER PROMPT
10
- DEFAULT_USER_PROMPT_FORMAT = """# Action space:
11
- {action_space}
12
-
13
- # Task Description
14
- {task_description}
15
-
16
- # Given Information
17
- {input_information}
18
-
19
- # Output Format
20
- {output_format}
21
- """
22
-
23
-
24
- JUDGE_OURS_WO_CHECKLIST_USER_PROMPT_FORMAT = """You are an expert evaluator of web agent. Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage.
25
-
26
- # Task Description
27
- Evaluate how well the agent’s THOUGHT and ACTION satisfy each item in the checklist using the task instruction, trajectory (including previously completed steps), current webpage state, and the agent’s latest response. Start by writing a concise paragraph summarizing the agent’s overall performance. Refer to the reasoning provided in the trajectory, and discuss whether the THOUGHT is appropriate and the ACTION moves the task forward.
28
-
29
- # Given Information
30
- {input_information}
31
- """
32
-
33
-
34
- JUDGE_OURS_USER_PROMPT_FORMAT = """You are an expert evaluator of web agent. Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage.
35
-
36
- # Task Description
37
- Evaluate how well the agent’s THOUGHT and ACTION satisfy each item in the checklist using the task instruction, trajectory (including previously completed steps), current webpage state, and the agent’s latest response. Start by writing a concise paragraph summarizing the agent’s overall performance. Refer to the reasoning provided in the trajectory, and discuss whether the THOUGHT is appropriate and the ACTION moves the task forward.
38
- Then, assess each checklist item individually using the following labels:
39
- - Yes: The item is fully and clearly satisfied, either in the current response or previously completed.
40
- - In Progress: There is meaningful partial progress toward completing the item.
41
- - No: The item is not satisfied due to ambiguity, insufficient evidence, or lack of progress.
42
-
43
- # Given Information
44
- {input_information}
45
- """
46
-
47
-
48
- JUDGE_OURS_BT_MODELING_USER_PROMPT_FORMAT = """You are an expert web agent that browses internet via GUI actions. Your task is to achieve the user's goal described in the user instruction.
49
-
50
- # Task Description
51
- Generate the most appropriate GUI action to achieve the user's goal. When choosing your action, consider the current webpage state and the checklist which can be interpreted as subtasks.
52
-
53
- # Given Information
54
- ## User Instruction
55
- {intent}
56
-
57
- ## Trajectory
58
- {trajectory}
59
-
60
- ## Current State
61
- ### Current URL
62
- {current_url}
63
-
64
- ### AXTREE
65
- Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
66
- {text_observation}
67
-
68
- ## Checklist
69
- {checklist}
70
-
71
- ## Agent's Response
72
- """
73
-
74
- JUDGE_OURS_BT_MODELING_BASE_PROMPT = """You are an expert web agent that browses internet via GUI actions. Your task is to achieve the user's goal described in the user instruction.
75
-
76
- # Task Description
77
- Generate the most appropriate GUI action to achieve the user's goal. When choosing your action, consider the current webpage state and the checklist which can be interpreted as subtasks.
78
-
79
- # Given Information
80
- ## User Instruction
81
- {intent}
82
-
83
- ## Trajectory
84
- {trajectory}
85
-
86
- ## Current State
87
- ### Current URL
88
- {current_url}
89
-
90
- ### AXTREE
91
- Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
92
- {text_observation}
93
- """
94
-
95
- JUDGE_OURS_IMAGE_INPUT = """
96
- ### Image Screenshot
97
- <IMAGE_PLACEHOLDER>
98
- """
99
-
100
- JUDGE_OURS_WITH_CHECKLIST = """
101
- ## Checklist
102
- {checklist}
103
- """
104
-
105
- BT_MODELING_RESPONSE_FORMAT = """
106
- THOUGHT: {thought}
107
- ACTION: {action}
108
- """
109
-
110
- ## PROMPT TEMPLATE
111
- JUDGE_GROUNDING_PROMPT_TEMPLATE = {
112
- "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=GROUNDING_ROLE),
113
- "user": DEFAULT_USER_PROMPT_FORMAT,
114
- }
115
-
116
- JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE = {
117
- "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITHOUT_CHECKLIST_ROLE),
118
- "user": DEFAULT_USER_PROMPT_FORMAT
119
- }
120
-
121
- JUDGE_THREE_CLASS_PROMPT_TEMPLATE = {
122
- "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITHOUT_CHECKLIST_ROLE),
123
- "user": DEFAULT_USER_PROMPT_FORMAT
124
- }
125
-
126
- JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE = {
127
- "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITH_CHECKLIST_ROLE),
128
- "user": DEFAULT_USER_PROMPT_FORMAT
129
- }
130
-
131
- JUDGE_OURS_PROMPT_TEMPLATE = {
132
- "system": "",
133
- "user": JUDGE_OURS_USER_PROMPT_FORMAT,
134
- }
135
-
136
- JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE = {
137
- "system": "",
138
- "user": JUDGE_OURS_WO_CHECKLIST_USER_PROMPT_FORMAT,
139
- }
140
-
141
- JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE = {
142
- "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_WITH_CHECKLIST+"\n## Agent's Response\n",
143
- "assistant": BT_MODELING_RESPONSE_FORMAT,
144
- }
145
-
146
- JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE = {
147
- "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_IMAGE_INPUT+JUDGE_OURS_WITH_CHECKLIST+"\n## Agent's Response\n",
148
- "assistant": BT_MODELING_RESPONSE_FORMAT,
149
- }
150
-
151
- JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE = {
152
- "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+"\n## Agent's Response\n",
153
- "assistant": BT_MODELING_RESPONSE_FORMAT,
154
- }
155
-
156
- JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE = {
157
- "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_IMAGE_INPUT+"\n## Agent's Response\n",
158
- "assistant": BT_MODELING_RESPONSE_FORMAT,
159
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/prompts/utils.py DELETED
@@ -1,18 +0,0 @@
1
- from langchain.schema import HumanMessage, AIMessage, SystemMessage
2
-
3
- def convert_dict_messages(dict_messages):
4
- message_objs = []
5
- for msg in dict_messages:
6
- role = msg.get("role")
7
- content = msg.get("content", "")
8
-
9
- if role == "user":
10
- message_objs.append(HumanMessage(content=content))
11
- elif role == "assistant":
12
- message_objs.append(AIMessage(content=content))
13
- elif role == "system":
14
- message_objs.append(SystemMessage(content=content))
15
- else:
16
- raise ValueError(f"Unknown role: {role}")
17
-
18
- return message_objs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/reward_agent.py DELETED
@@ -1,465 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- import time
3
- import requests
4
- import json
5
- import math
6
- from langsmith import Client
7
- import numpy as np
8
- from langchain_openai import ChatOpenAI
9
-
10
- from .prompts import get_messages
11
- from .prompts.judge_prompt import (
12
- JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
13
- JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
14
- JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
15
- JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
16
- )
17
- from .prompts.image_utils import image_to_base64_url
18
- from .prompts.utils import convert_dict_messages
19
-
20
- MAX_RETRY = 3
21
- RETRY_SLEEP = 5
22
- MODEL_COST_MAPPING = {
23
- "gpt-4o-mini": {
24
- "input_token_cost": 0.15,
25
- "output_token_cost": 0.6
26
- },
27
- "gpt-4o": {
28
- "input_token_cost": 2.5,
29
- "output_token_cost": 10
30
- },
31
- }
32
-
33
-
34
- class Agent(ABC):
35
- @abstractmethod
36
- def generate_response(self, inputs: dict) -> str:
37
- pass
38
-
39
- class BaseAgent(Agent):
40
- def __init__(self, agent_config: dict):
41
- self.agent_config = agent_config
42
- self._setup()
43
-
44
- def _init_llm_object(self, **extra_kwargs):
45
- config = self.agent_config
46
- config.update(extra_kwargs)
47
-
48
- use_log_probs = config.get("use_log_probs", False)
49
- if use_log_probs:
50
- self.llm = ChatOpenAI(
51
- model=config["model_name"],
52
- base_url=config["base_url"],
53
- api_key=config["api_key"],
54
- temperature=config["temperature"],
55
- timeout=300,
56
- logprobs=True,
57
- top_logprobs=10,
58
- n=config.get('n', None)
59
- )
60
- else:
61
- self.llm = ChatOpenAI(
62
- model=config["model_name"],
63
- base_url=config["base_url"],
64
- api_key=config["api_key"],
65
- temperature=config["temperature"],
66
- timeout=300,
67
- n=config.get('n', None)
68
- )
69
-
70
- def _setup(self):
71
- self._init_llm_object()
72
-
73
- self.temperature = self.agent_config["temperature"]
74
- self.num_generate = self.agent_config["num_generate"]
75
- self.use_checklist = self.agent_config.get("use_checklist", False)
76
- self.use_multimodal = self.agent_config.get("use_multimodal", False)
77
-
78
- # setup cost
79
- model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
80
- if model_cost and "api" in self.agent_config["base_url"]:
81
- self.input_token_cost = model_cost["input_token_cost"]
82
- self.output_token_cost = model_cost["output_token_cost"]
83
- else:
84
- self.input_token_cost = 0.0
85
- self.output_token_cost = 0.0
86
-
87
- def generate_with_retry(self, model_input, constraint_str_list: list = None):
88
- total_input_tokens = 0
89
- total_output_tokens = 0
90
- if self.temperature == 0:
91
- response = self.llm.invoke(model_input)
92
- total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
93
- total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
94
- else:
95
- for i in range(MAX_RETRY):
96
- try:
97
- response = self.llm.invoke(model_input)
98
- total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
99
- total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
100
- if constraint_str_list:
101
- pass_constraint_num = 0
102
- for constraint_str in constraint_str_list:
103
- if constraint_str in response.content:
104
- pass_constraint_num += 1
105
- if pass_constraint_num == len(constraint_str_list):
106
- break
107
- else:
108
- print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
109
- else:
110
- break
111
- except Exception as e:
112
- print(f"Agent returned an Error: {e}")
113
- response = None
114
- time.sleep(RETRY_SLEEP)
115
-
116
- cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
117
-
118
- if response is None:
119
- return "", cost
120
- else:
121
- return response.content, cost
122
-
123
- def prepare_message(self, model_input: dict, prompt_type: str):
124
- message = []
125
- return message
126
-
127
- def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
128
- total_cost = 0
129
- response_list = []
130
- # prepare message
131
- message = self.prepare_message(model_input, prompt_type)
132
-
133
- # n sampling
134
- for i in range(self.num_generate):
135
- response, cost = self.generate_with_retry(message, constraint_str_list)
136
- response_list.append(response)
137
- total_cost += cost
138
-
139
- return response_list, total_cost
140
-
141
-
142
- class GroundingJudgeAgent(BaseAgent):
143
- def __init__(self, agent_config: dict):
144
- super().__init__(agent_config)
145
- self._setup()
146
-
147
- def prepare_message(self, model_input: dict, prompt_type):
148
- message = get_messages(
149
- input_info=model_input,
150
- inference_mode="judge_grounding",
151
- prompt_type=prompt_type,
152
- use_multimodal=self.use_multimodal,
153
- text_obs=self.agent_config["text_obs_type"],
154
- image_obs=self.agent_config["image_obs_type"]
155
- )
156
- return message
157
-
158
-
159
- class ProgressJudgeAgent(BaseAgent):
160
- def __init__(self, agent_config: dict):
161
- super().__init__(agent_config)
162
- self._setup()
163
-
164
- def prepare_message(self, model_input: dict, prompt_type):
165
- if self.agent_config["input_type"]=="text_only":
166
- use_multimodal = False
167
- text_obs = self.agent_config["text_obs_type"]
168
- image_obs = None
169
- elif self.agent_config["input_type"]=="image_only":
170
- use_multimodal = True
171
- text_obs = None
172
- image_obs = self.agent_config["image_obs_type"]
173
- elif self.agent_config["input_type"]=="text_image":
174
- use_multimodal = True
175
- text_obs = self.agent_config["text_obs_type"]
176
- image_obs = self.agent_config["image_obs_type"]
177
- else:
178
- raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
179
-
180
- if self.agent_config["use_in_progress"]:
181
- use_in_progress = True
182
- else:
183
- use_in_progress = False
184
-
185
- message = get_messages(
186
- input_info=model_input,
187
- inference_mode="judge_progress",
188
- prompt_type=prompt_type,
189
- use_checklist=self.use_checklist,
190
- use_multimodal=use_multimodal,
191
- text_obs=text_obs,
192
- image_obs=image_obs,
193
- use_in_progress=use_in_progress
194
- )
195
- return message
196
-
197
- def get_judge_probs(self, logprobs: list):
198
- # target_judge = {
199
- # "yes": [" Yes", "Yes", "ĠYes", "ĊYes"],
200
- # "no": [" No", "No", "ĠNo", "ĊNo"],
201
- # "in": [" In", "In", "ĠIn", "ĊIn"]
202
- # }
203
- target_judge = {
204
- "yes": [
205
- "ĠYes", "Yes", "ĊYes",
206
- "Ġyes", "yes", "Ċyes",
207
- "ĠYES", "YES", "ĊYES",
208
- "ĠDone", "Done", "ĊDone",
209
- "ĠCompleted", "Completed", "ĊCompleted",
210
- "ĠCorrect", "Correct", "ĊCorrect"
211
- ],
212
- "no": [
213
- "ĠNo", "No", "ĊNo",
214
- "ĠNO", "NO", "ĊNO",
215
- "ĠNot", "Not", "ĊNot",
216
- "ĠNone", "None", "ĊNone",
217
- "ĠNope", "Nope", "ĊNope",
218
- "ĠUn", "Un", "ĊUn",
219
- "ĠWrong", "Wrong", "ĊWrong"
220
- ],
221
- "in": [
222
- "ĠIn", "In", "ĊIn",
223
- "ĠPending", "Pending", "ĊPending",
224
- "ĠPart", "Part", "ĊPart",
225
- "ĠPartial", "Partial", "ĊPartial",
226
- "ĠInProgress", "InProgress", "ĊInProgress"
227
- ]
228
- }
229
- response_str = ""
230
- judge_probs_list = []
231
- for i, log_prob in enumerate(logprobs):
232
- # Start to find judge string
233
- if "<answer>" in response_str:
234
- find_judge_str = False
235
- for judge_type in target_judge:
236
- if log_prob["token"] in target_judge[judge_type]:
237
- # print(log_prob)
238
- find_judge_str = True
239
- break
240
- if find_judge_str:
241
- token_judge_dict = {
242
- "yes": None,
243
- "no": None,
244
- "in": None
245
- }
246
- for token_info in log_prob["top_logprobs"]:
247
- for judge_type in target_judge:
248
- for judge_str in target_judge[judge_type]:
249
- if judge_str in token_info["token"] :
250
- if token_judge_dict[judge_type] is None:
251
- token_judge_dict[judge_type] = math.exp(token_info["logprob"])
252
- else:
253
- token_judge_dict[judge_type] += math.exp(token_info["logprob"])
254
-
255
- token_judge_dict = {
256
- "yes": math.log(token_judge_dict["yes"]) if token_judge_dict["yes"] is not None else -float('inf'),
257
- "no": math.log(token_judge_dict["no"]) if token_judge_dict["no"] is not None else -float('inf'),
258
- "in": math.log(token_judge_dict["in"]) if token_judge_dict["in"] is not None else -float('inf')
259
- }
260
- judge_probs_list.append(token_judge_dict)
261
-
262
- if "</answer>" in response_str:
263
- break
264
-
265
- response_str += log_prob["token"]
266
-
267
- if len(judge_probs_list) == 0:
268
- return [{
269
- "yes": 0.0,
270
- "no": 0.0,
271
- "in": 0.0
272
- }]
273
- else:
274
- # convert with softmax
275
- final_judge_probs_list = []
276
- max_in_prob = -float('inf')
277
- for idx, judge_probs in enumerate(judge_probs_list):
278
- exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
279
- sum_exp_logprobs = sum(exp_logprobs)
280
- softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
281
- if softmax_probs[2] > max_in_prob:
282
- max_in_prob = softmax_probs[2]
283
- final_judge_probs_list.append({
284
- "yes": softmax_probs[0],
285
- "no": softmax_probs[1],
286
- "in": softmax_probs[2]
287
- })
288
- return final_judge_probs_list
289
-
290
- def generate_probs(self, model_input: dict, prompt_type: str, n=1, temperature=None):
291
- total_cost = 0
292
- # prepare message
293
- message = self.prepare_message(model_input, prompt_type)
294
- messages = convert_dict_messages(message)
295
-
296
- kwargs = {'n': n}
297
- if temperature is not None:
298
- kwargs['temperature'] = temperature
299
- self._init_llm_object(**kwargs)
300
-
301
- try:
302
- response = self.llm.generate([messages]) # assume single batch
303
- finally:
304
- print('request url: ', self.agent_config['base_url'])
305
-
306
-
307
- # parse responses
308
- response_list = []
309
- for generation in response.generations[0]: # assume singel batch
310
- # parse logprobs
311
- logprobs = generation.message.response_metadata["logprobs"]["content"]
312
- response_list.append(
313
- {
314
- "response": generation.message.content,
315
- "judge_probs": self.get_judge_probs(logprobs)
316
- }
317
- )
318
-
319
- # calculate cost
320
- total_input_tokens = response.llm_output["token_usage"]["prompt_tokens"]
321
- total_output_tokens = response.llm_output["token_usage"]["completion_tokens"]
322
- total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
323
-
324
- return response_list, total_cost
325
-
326
-
327
- class ChecklistGenerationAgent(BaseAgent):
328
- def __init__(self, agent_config: dict):
329
- super().__init__(agent_config)
330
- self._setup()
331
-
332
- def prepare_message(self, model_input: dict, prompt_type):
333
- message = get_messages(
334
- input_info=model_input,
335
- inference_mode="checklist_generation",
336
- prompt_type=prompt_type
337
- )
338
- return message
339
-
340
-
341
- class ClassifierRewardAgent(Agent):
342
- def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
343
- self.url = url
344
- self.use_checklist = use_checklist
345
- self.use_multimodal = use_multimodal
346
-
347
- def _process_multimodal_message(self, prompt: str, image_list: list[str]):
348
- multimodal_message = []
349
- text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
350
- text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
351
- multimodal_message = [
352
- {"type": "text", "text": text_prompt_prefix},
353
- # {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
354
- {"type": "image", "image": image_to_base64_url(image_list[0])},
355
- {"type": "text", "text": text_prompt_suffix}
356
- ]
357
- return multimodal_message
358
-
359
- def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
360
- if self.use_multimodal:
361
- tmp_user_prompt = user_prompt_template["user"].format(
362
- **model_input
363
- )
364
- user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
365
- else:
366
- user_prompt = user_prompt_template["user"].format(
367
- **model_input
368
- )
369
- assistant_prompt = user_prompt_template["assistant"].format(
370
- **model_input
371
- )
372
- query = [
373
- {"role": "user", "content": user_prompt},
374
- {"role": "assistant", "content": assistant_prompt}
375
- ]
376
- return query
377
-
378
- def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
379
- if self.use_checklist:
380
- if self.use_multimodal:
381
- user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
382
- else:
383
- user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
384
- else:
385
- if self.use_multimodal:
386
- user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
387
- else:
388
- user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
389
-
390
- if self.use_multimodal:
391
- if batch:
392
- message = [self._make_query(user_prompt_template, input) for input in model_input]
393
- else:
394
- message = [self._make_query(user_prompt_template, model_input)]
395
- else:
396
- if batch:
397
- message = {
398
- "query": [self._make_query(user_prompt_template, input) for input in model_input],
399
- "promptts": []
400
- }
401
- else:
402
- message = {
403
- "query": self._make_query(user_prompt_template, model_input),
404
- "prompts": []
405
- }
406
-
407
- return message
408
-
409
- def get_rm_scroe(self, message: dict | list):
410
- headers = {"Content-Type": "application/json"}
411
-
412
- try:
413
- if self.use_multimodal:
414
- response = requests.post(
415
- self.url,
416
- json={"messages": message},
417
- timeout=600
418
- )
419
- else:
420
- response = requests.post(
421
- self.url,
422
- headers=headers,
423
- data=json.dumps(message),
424
- timeout=300
425
- )
426
- response.raise_for_status()
427
-
428
- response_json = response.json()
429
-
430
- if "rewards" not in response_json:
431
- print(f"Error: 'rewards' key not found in API response: {response_json}")
432
- return []
433
-
434
- if "get_reward" in self.url:
435
- # use openrlhf
436
- return response_json["rewards"]
437
- elif "pooling" in self.url:
438
- # use vllm server
439
- return response_json["reward"]
440
- else:
441
- # error
442
- raise ValueError(f"Invalid URL: {self.url}")
443
-
444
- except requests.exceptions.Timeout:
445
- print(f"Error: Request timed out to {self.url}")
446
- return []
447
- except requests.exceptions.RequestException as e:
448
- print(f"Error during request to {self.url}: {e}")
449
- return []
450
- except json.JSONDecodeError:
451
- print(f"Error: Failed to decode JSON response from {self.url}")
452
- return []
453
- except KeyError as e:
454
- print(f"Error: Missing key {e} in response from {self.url}")
455
- return []
456
-
457
-
458
- def generate_response(self, model_input: dict | list[dict], batch: bool = False):
459
- if batch:
460
- message = self.prepare_message(model_input, batch=True)
461
- else:
462
- message = self.prepare_message(model_input)
463
- rewards = self.get_rm_scroe(message)
464
-
465
- return rewards, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/mini_bench/utils.py DELETED
@@ -1,269 +0,0 @@
1
- import json
2
- import base64
3
- import io
4
- import html
5
- from PIL import Image
6
-
7
-
8
- def image_to_base64_url(image: str | Image.Image):
9
- if isinstance(image, str):
10
- with open(image, "rb") as f:
11
- image = f.read()
12
- elif isinstance(image, Image.Image):
13
- if image.mode in ("RGBA", "LA"):
14
- image = image.convert("RGB")
15
- with io.BytesIO() as buffer:
16
- image.save(buffer, format="PNG")
17
- image = buffer.getvalue()
18
- else:
19
- raise ValueError(f"Invalid image type: {type(image)}")
20
-
21
- return "data:image/png;base64," + base64.b64encode(image).decode("utf-8")
22
-
23
-
24
- def load_json(file_path: str) -> dict:
25
- with open(file_path, "r") as f:
26
- return json.load(f)
27
-
28
- def save_json(data: dict, file_path: str):
29
- with open(file_path, "w") as f:
30
- json.dump(data, f, indent=4)
31
-
32
- def str_to_bool(s: str) -> bool:
33
- if s.lower() in ["true", "1", "yes", "y"]:
34
- return True
35
- elif s.lower() in ["false", "0", "no", "n"]:
36
- return False
37
- else:
38
- raise ValueError(f"Invalid boolean string: {s}")
39
-
40
-
41
- def create_html_report(json_path, html_path, checklist_generation=False):
42
- """
43
- Reads the given JSON result file and generates a filterable HTML report.
44
-
45
- Args:
46
- json_path (str): Path to the input JSON file.
47
- html_path (str): Path to the output HTML file.
48
- """
49
- try:
50
- with open(json_path, 'r', encoding='utf-8') as f:
51
- data = json.load(f)
52
- except FileNotFoundError:
53
- print(f"Error: JSON file not found - {json_path}") # Error message in English
54
- return
55
- except json.JSONDecodeError:
56
- print(f"Error: JSON file parsing error - {json_path}") # Error message in English
57
- return
58
- except Exception as e:
59
- print(f"Unexpected error during data loading: {e}") # Error message in English
60
- return
61
-
62
- # Extract unique Task IDs and sort them
63
- task_ids = sorted(list(set(item.get("task_id") for item in data if item.get("task_id") is not None)))
64
-
65
- html_content = """
66
- <!DOCTYPE html>
67
- <html lang="en">
68
- <head>
69
- <meta charset="UTF-8">
70
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
71
- <title>Benchmark Results Report</title>
72
- <style>
73
- body { font-family: sans-serif; line-height: 1.6; padding: 20px; }
74
- .task-step { border: 1px solid #ccc; margin-bottom: 20px; padding: 15px; border-radius: 5px; background-color: #f9f9f9; }
75
- .task-step h2 { margin-top: 0; color: #333; border-bottom: 1px solid #eee; padding-bottom: 5px;}
76
- .task-step h3 { color: #555; margin-top: 15px; margin-bottom: 5px; }
77
- .task-step h4 { color: #777; margin-top: 10px; margin-bottom: 5px; font-style: italic;}
78
- pre { background-color: #eee; padding: 10px; border-radius: 3px; white-space: pre-wrap; word-wrap: break-word; font-size: 0.9em; margin-top: 5px; }
79
- details { margin-top: 10px; border: 1px solid #ddd; border-radius: 3px; background-color: #fff; }
80
- summary { cursor: pointer; padding: 8px; background-color: #f8f9fa; font-weight: bold; border-bottom: 1px solid #ddd; }
81
- details[open] summary { border-bottom: 1px solid #ddd; }
82
- details > pre { border: none; background-color: #fff; padding: 10px 8px; }
83
- .response-item-toggle { margin-top: 10px; }
84
- .chosen-section { border-left: 5px solid #4CAF50; padding-left: 10px; margin-top: 15px; }
85
- .rejected-section { border-left: 5px solid #f44336; padding-left: 10px; margin-top: 15px; }
86
- hr { border: 0; border-top: 1px solid #eee; margin: 15px 0; }
87
- .thought-action { background-color: #f0f0f0; padding: 10px; border-radius: 3px; margin-bottom: 10px; border: 1px solid #e0e0e0;}
88
- .thought-action h4 { margin-top: 0; color: #666; }
89
- .task-container { display: none; }
90
- .filter-controls { margin-bottom: 20px; padding: 10px; background-color: #e9ecef; border-radius: 5px; }
91
- .filter-controls label { margin-right: 10px; font-weight: bold; }
92
- .filter-controls select { padding: 5px; border-radius: 3px; border: 1px solid #ced4da; }
93
- </style>
94
- </head>
95
- <body>
96
- <h1>Benchmark Results Report</h1>
97
-
98
- <!-- Task ID Filter Dropdown -->
99
- <div class="filter-controls">
100
- <label for="taskSelector">Select Task ID:</label>
101
- <select id="taskSelector">
102
- <option value="">-- Show All --</option>
103
- """
104
- # Add dropdown options
105
- for tid in task_ids:
106
- html_content += f' <option value="{html.escape(str(tid))}">{html.escape(str(tid))}</option>\n'
107
-
108
- html_content += """
109
- </select>
110
- </div>
111
-
112
- <!-- Results Display Area -->
113
- <div id="resultsArea">
114
- """
115
-
116
- # Process each Task/Step data
117
- for i, step_data in enumerate(data):
118
- task_id = step_data.get("task_id", "N/A")
119
- step_id = step_data.get("step_id", "N/A")
120
- intent = step_data.get("intent", "N/A")
121
- start_url = step_data.get("start_url", "N/A")
122
- gt_checklist = step_data.get("gt_checklist", "N/A")
123
- generated_checklist = step_data.get("generated_checklist", None)
124
- trajectory = step_data.get("trajectory", "N/A")
125
- text_observation = step_data.get("text_observation", "N/A")
126
- source_name = step_data.get("source_name", "")
127
-
128
- # Wrap each Task/Step in a container with a unique ID (hidden initially)
129
- html_content += f"""
130
- <div class="task-container" data-task-id="{html.escape(str(task_id))}">
131
- <div class="task-step">
132
- <h2>Task ID: {html.escape(str(task_id))} | Step ID: {html.escape(str(step_id))} {f'({html.escape(source_name)})' if source_name else ''}</h2>
133
- <h3>Intent:</h3>
134
- <p>{html.escape(intent)}</p>
135
- <p><strong>Start URL:</strong> <a href="{html.escape(start_url)}" target="_blank">{html.escape(start_url)}</a></p>
136
-
137
- <h3>Ground Truth Checklist:</h3>
138
- <pre>{html.escape(gt_checklist)}</pre>
139
- """
140
- if checklist_generation and generated_checklist is not None:
141
- html_content += f"""
142
- <details>
143
- <summary>Generated Checklist (Click to expand/collapse)</summary>
144
- <pre>{html.escape(str(generated_checklist))}</pre>
145
- </details>
146
- """
147
-
148
- html_content += f"""
149
- <details>
150
- <summary>Trajectory (Click to expand/collapse)</summary>
151
- <pre>{html.escape(trajectory)}</pre>
152
- </details>
153
-
154
- <details>
155
- <summary>Text Observation (Click to expand/collapse)</summary>
156
- <pre>{html.escape(text_observation)}</pre>
157
- </details>
158
- <hr>
159
- """
160
-
161
- # Chosen Responses
162
- if 'chosen' in step_data and step_data['chosen']:
163
- html_content += '<div class="chosen-section"><h3>Chosen Responses:</h3>'
164
- for choice_block in step_data['chosen']:
165
- thought = choice_block.get('thought', 'N/A')
166
- action = choice_block.get('action', 'N/A')
167
- responses = choice_block.get('response', [])
168
- scores = choice_block.get('score', [])
169
-
170
- # Add Thought and Action information
171
- html_content += f"""
172
- <div class="thought-action">
173
- <h4>Thought:</h4>
174
- <pre>{html.escape(thought)}</pre>
175
- <h4>Action:</h4>
176
- <pre>{html.escape(action)}</pre>
177
- </div>"""
178
-
179
- # Loop through responses and create toggles
180
- for idx, (response, score) in enumerate(zip(responses, scores)):
181
- html_content += f"""
182
- <details class="response-item-toggle">
183
- <summary>Judge Response {idx + 1}: {html.escape(str(score))}</summary>
184
- <pre>{html.escape(str(response))}</pre>
185
- </details>"""
186
- html_content += '</div>' # End chosen-section
187
-
188
- # Rejected Responses
189
- if 'rejected' in step_data and step_data['rejected']:
190
- html_content += '<div class="rejected-section"><h3>Rejected Responses:</h3>'
191
- for rejection_block in step_data['rejected']:
192
- thought = rejection_block.get('thought', 'N/A')
193
- action = rejection_block.get('action', 'N/A')
194
- responses = rejection_block.get('response', [])
195
- scores = rejection_block.get('score', [])
196
-
197
- # Add Thought and Action information
198
- html_content += f"""
199
- <div class="thought-action">
200
- <h4>Thought:</h4>
201
- <pre>{html.escape(thought)}</pre>
202
- <h4>Action:</h4>
203
- <pre>{html.escape(action)}</pre>
204
- </div>"""
205
-
206
- # Loop through responses and create toggles
207
- for idx, (response, score) in enumerate(zip(responses, scores)):
208
- html_content += f"""
209
- <details class="response-item-toggle">
210
- <summary>Judge Response {idx + 1}: {html.escape(str(score))}</summary>
211
- <pre>{html.escape(str(response))}</pre>
212
- </details>"""
213
- html_content += '</div>' # End rejected-section
214
-
215
- html_content += """
216
- </div> <!-- End task-step -->
217
- </div> <!-- End task-container -->
218
- """
219
-
220
- # Finalize HTML and add JavaScript
221
- html_content += """
222
- </div> <!-- End resultsArea -->
223
-
224
- <script>
225
- document.addEventListener('DOMContentLoaded', function() {
226
- const taskSelector = document.getElementById('taskSelector');
227
- const taskContainers = document.querySelectorAll('.task-container');
228
-
229
- function filterTasks() {
230
- const selectedTaskId = taskSelector.value;
231
-
232
- taskContainers.forEach(container => {
233
- const containerTaskId = container.getAttribute('data-task-id');
234
- // Show if no Task ID is selected (Show All) or if the container's Task ID matches
235
- if (selectedTaskId === "" || containerTaskId === selectedTaskId) {
236
- container.style.display = 'block';
237
- } else {
238
- // Otherwise, hide it
239
- container.style.display = 'none';
240
- }
241
- });
242
- }
243
-
244
- // Run filter function on dropdown change
245
- taskSelector.addEventListener('change', filterTasks);
246
-
247
- // Run initial filtering on page load (default: Show All)
248
- filterTasks();
249
- });
250
- </script>
251
-
252
- </body>
253
- </html>
254
- """
255
-
256
- # Save the HTML file
257
- try:
258
- with open(html_path, 'w', encoding='utf-8') as f:
259
- f.write(html_content)
260
- print(f"Completed: HTML report created at {html_path}")
261
- except IOError:
262
- print(f"Error: Failed to write HTML file - {html_path}")
263
- except Exception as e:
264
- print(f"Unexpected error during HTML file saving: {e}")
265
-
266
- # --- Example Usage ---
267
- # input_json_file = 'path/to/your/results.json'
268
- # output_html_file = 'trajectory_report.html'
269
- # create_html_report(input_json_file, output_html_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/reward.py DELETED
@@ -1,96 +0,0 @@
1
- import time
2
- from typing import List, Dict, Any, Optional, Union
3
- import numpy as np
4
- from .mini_bench.reward_agent import ProgressJudgeAgent
5
- from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash
6
- import json
7
- import os
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
-
10
- def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5):
11
- """하나의 unit을 처리해 (idx, reward, thought)를 돌려준다."""
12
- agent = ProgressJudgeAgent(configs)
13
- current_temperature = configs["temperature"]
14
-
15
- rewards = []
16
- n_err = 0
17
- retry_count = 0
18
- judge_hash_count_thought = {}
19
-
20
- while len(rewards) < n_samples and retry_count < max_retries:
21
- # 외부 API 호출
22
- responses, _ = agent.generate_probs(
23
- unit, "ours", n=n_samples - len(rewards), temperature=current_temperature
24
- )
25
-
26
- for response in responses:
27
- content = response["response"]
28
- thought = content # 전체를 로그로 저장
29
- reward = REWARD_PROCESSORS[reward_processor](response)
30
- rewards.append(reward)
31
-
32
- if np.isnan(reward) or reward is None:
33
- n_err += 1
34
- else:
35
- judge_hash = extract_judge_hash(response)
36
- judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought)
37
-
38
- if n_err > 0:
39
- # 실패 시 온도를 높여 재시도
40
- if n_samples == 1:
41
- current_temperature = 0.5
42
- retry_count += 1
43
-
44
- reward = np.nanmean(rewards)
45
- if np.isnan(reward):
46
- print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0")
47
- reward = 0.0
48
- print(judge_hash_count_thought)
49
- thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1]
50
-
51
- return idx, reward, thought
52
-
53
-
54
- def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8):
55
- """원본 get_ar_reward를 스레드 버전으로 교체."""
56
- n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor]
57
-
58
- temperature = 0.5 if n_samples > 1 else 0.0
59
-
60
- configs = {
61
- "model_name": model_name,
62
- "base_url": base_url,
63
- "api_key": "empty",
64
- "temperature": temperature,
65
- "num_generate": 1,
66
- "use_checklist": True,
67
- "input_type": "text_only",
68
- "text_obs_type": "axtree",
69
- "image_obs_type": "som",
70
- "use_in_progress": True,
71
- "use_multimodal": False,
72
- "use_log_probs": True,
73
- }
74
-
75
- t_start = time.time()
76
- results = [None] * len(dataset)
77
-
78
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
79
- futures = [
80
- executor.submit(
81
- _process_unit, idx, unit, configs, n_samples, reward_processor
82
- )
83
- for idx, unit in enumerate(dataset)
84
- ]
85
-
86
- for fut in as_completed(futures):
87
- idx, reward, thought = fut.result()
88
- results[idx] = (reward, thought)
89
-
90
- # 순서 보존된 리스트로 분리
91
- final_rewards = [float(r) for r, _ in results]
92
- thoughts = [t for _, t in results]
93
-
94
- print(f"Time taken (threaded): {time.time() - t_start:.2f} s")
95
- return final_rewards, thoughts
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/reward_postprocessor.py DELETED
@@ -1,41 +0,0 @@
1
- import numpy as np
2
- import re
3
-
4
-
5
- def extract_judge_hash(response):
6
- """
7
- checklist 별로 yes, in, no를 판단한 정보를 hash 형태로 변환하여 반환
8
- """
9
- content = response['response']
10
-
11
- try:
12
- judge_content = content.lower().replace(' ', '').split('<answer>')[1].split('</answer>')[0]
13
- except:
14
- import traceback
15
- traceback.print_exc()
16
- return None
17
- pattern = r":yes|:inprogress|:no"
18
- matches = re.findall(pattern, judge_content)
19
- matches = [{':yes': 'y', ':inprogress': 'i', ':no': 'n'}[match] for match in matches]
20
- return ''.join(matches)
21
-
22
- def average_logits(response):
23
- """
24
- yes, in, no를 logits 레벨에서 계산.
25
- """
26
- judge_probs = response['judge_probs']
27
-
28
- yes_ = np.mean([r['yes'] for r in judge_probs])
29
- in_ = np.mean([r['in'] for r in judge_probs])
30
-
31
- reward = yes_ + 0.5 * in_
32
- return reward
33
-
34
-
35
- REWARD_PROCESSORS = {
36
- 'avg_logits': average_logits
37
- }
38
-
39
- REWARD_PROCESSOR_N_SAMPLES = {
40
- 'avg_logits': 5
41
- }