Spaces:
Sleeping
Sleeping
Delete agent
Browse files- agent/__init__.py +0 -0
- agent/checklist.py +0 -18
- agent/mini_bench/__init__.py +0 -0
- agent/mini_bench/__pycache__/__init__.cpython-311.pyc +0 -0
- agent/mini_bench/__pycache__/agent.cpython-311.pyc +0 -0
- agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc +0 -0
- agent/mini_bench/agent.py +0 -467
- agent/mini_bench/checklist_eval.py +0 -95
- agent/mini_bench/eval_utils.py +0 -309
- agent/mini_bench/inference_utils.py +0 -87
- agent/mini_bench/prompts/__init__.py +0 -1
- agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc +0 -0
- agent/mini_bench/prompts/action.py +0 -93
- agent/mini_bench/prompts/checklist_prompt.py +0 -50
- agent/mini_bench/prompts/construct_messages.py +0 -309
- agent/mini_bench/prompts/eval_type.py +0 -107
- agent/mini_bench/prompts/image_utils.py +0 -19
- agent/mini_bench/prompts/input_information.py +0 -36
- agent/mini_bench/prompts/judge_prompt.py +0 -159
- agent/mini_bench/prompts/utils.py +0 -18
- agent/mini_bench/reward_agent.py +0 -465
- agent/mini_bench/utils.py +0 -269
- agent/reward.py +0 -96
- agent/reward_postprocessor.py +0 -41
agent/__init__.py
DELETED
File without changes
|
agent/checklist.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
from .mini_bench.agent import ChecklistGenerationAgent
|
2 |
-
|
3 |
-
def generate_checklist(**data):
|
4 |
-
# data: 'intent', 'start_url', 'text_observation'
|
5 |
-
agent_config = {
|
6 |
-
'model_name': 'WPRM/qwen-3b-ar-reward-cot-mtl-checklist-enhanced',
|
7 |
-
'base_url': 'http://165.132.144.84:7701/v1',
|
8 |
-
'api_key': 'empty',
|
9 |
-
'temperature': 0.7,
|
10 |
-
'use_log_probs': True,
|
11 |
-
'use_checklist': True,
|
12 |
-
'use_multimodal': False,
|
13 |
-
'num_generate': 1,
|
14 |
-
}
|
15 |
-
checklist_generation_agent = ChecklistGenerationAgent(agent_config)
|
16 |
-
response_list, cost = checklist_generation_agent.generate_response(data, prompt_type='ours', constraint_str_list=["<think>", "</think>", "<answer>", "</answer>"])
|
17 |
-
response = response_list[0]
|
18 |
-
return response.split("<answer>")[-1].split("</answer>")[0].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/__init__.py
DELETED
File without changes
|
agent/mini_bench/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (186 Bytes)
|
|
agent/mini_bench/__pycache__/agent.cpython-311.pyc
DELETED
Binary file (20.7 kB)
|
|
agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc
DELETED
Binary file (21.3 kB)
|
|
agent/mini_bench/agent.py
DELETED
@@ -1,467 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
import time
|
3 |
-
import requests
|
4 |
-
import json
|
5 |
-
import math
|
6 |
-
from langsmith import Client
|
7 |
-
from langchain_openai import ChatOpenAI
|
8 |
-
|
9 |
-
from .prompts import get_messages
|
10 |
-
from .prompts.judge_prompt import (
|
11 |
-
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
|
12 |
-
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
|
13 |
-
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
|
14 |
-
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
|
15 |
-
)
|
16 |
-
from .prompts.image_utils import image_to_base64_url
|
17 |
-
|
18 |
-
MAX_RETRY = 3
|
19 |
-
RETRY_SLEEP = 5
|
20 |
-
MODEL_COST_MAPPING = {
|
21 |
-
"gpt-4o-mini": {
|
22 |
-
"input_token_cost": 0.15,
|
23 |
-
"output_token_cost": 0.6
|
24 |
-
},
|
25 |
-
"gpt-4o": {
|
26 |
-
"input_token_cost": 2.5,
|
27 |
-
"output_token_cost": 10
|
28 |
-
},
|
29 |
-
}
|
30 |
-
|
31 |
-
|
32 |
-
class Agent(ABC):
|
33 |
-
@abstractmethod
|
34 |
-
def generate_response(self, inputs: dict) -> str:
|
35 |
-
pass
|
36 |
-
|
37 |
-
class BaseAgent(Agent):
|
38 |
-
def __init__(self, agent_config: dict):
|
39 |
-
self.agent_config = agent_config
|
40 |
-
self._setup()
|
41 |
-
|
42 |
-
def _setup(self):
|
43 |
-
use_log_probs = self.agent_config.get("use_log_probs", False)
|
44 |
-
if use_log_probs:
|
45 |
-
self.llm = ChatOpenAI(
|
46 |
-
model=self.agent_config["model_name"],
|
47 |
-
base_url=self.agent_config["base_url"],
|
48 |
-
api_key=self.agent_config["api_key"],
|
49 |
-
temperature=self.agent_config["temperature"],
|
50 |
-
timeout=300,
|
51 |
-
logprobs=True,
|
52 |
-
top_logprobs=10
|
53 |
-
)
|
54 |
-
else:
|
55 |
-
self.llm = ChatOpenAI(
|
56 |
-
model=self.agent_config["model_name"],
|
57 |
-
base_url=self.agent_config["base_url"],
|
58 |
-
api_key=self.agent_config["api_key"],
|
59 |
-
temperature=self.agent_config["temperature"],
|
60 |
-
timeout=300
|
61 |
-
)
|
62 |
-
self.temperature = self.agent_config["temperature"]
|
63 |
-
self.num_generate = self.agent_config["num_generate"]
|
64 |
-
self.use_checklist = self.agent_config.get("use_checklist", False)
|
65 |
-
self.use_multimodal = self.agent_config.get("use_multimodal", False)
|
66 |
-
|
67 |
-
# setup cost
|
68 |
-
model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
|
69 |
-
if model_cost and "api" in self.agent_config["base_url"]:
|
70 |
-
self.input_token_cost = model_cost["input_token_cost"]
|
71 |
-
self.output_token_cost = model_cost["output_token_cost"]
|
72 |
-
else:
|
73 |
-
self.input_token_cost = 0.0
|
74 |
-
self.output_token_cost = 0.0
|
75 |
-
|
76 |
-
def generate_with_retry(self, model_input, constraint_str_list: list = None):
|
77 |
-
total_input_tokens = 0
|
78 |
-
total_output_tokens = 0
|
79 |
-
if self.temperature == 0:
|
80 |
-
response = self.llm.invoke(model_input)
|
81 |
-
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
|
82 |
-
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
|
83 |
-
else:
|
84 |
-
for i in range(MAX_RETRY):
|
85 |
-
try:
|
86 |
-
response = self.llm.invoke(model_input)
|
87 |
-
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
|
88 |
-
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
|
89 |
-
if constraint_str_list:
|
90 |
-
pass_constraint_num = 0
|
91 |
-
for constraint_str in constraint_str_list:
|
92 |
-
if constraint_str in response.content:
|
93 |
-
pass_constraint_num += 1
|
94 |
-
if pass_constraint_num == len(constraint_str_list):
|
95 |
-
break
|
96 |
-
else:
|
97 |
-
print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
|
98 |
-
print(response.content)
|
99 |
-
else:
|
100 |
-
break
|
101 |
-
except Exception as e:
|
102 |
-
print(f"Agent returned an Error: {e}")
|
103 |
-
response = None
|
104 |
-
time.sleep(RETRY_SLEEP)
|
105 |
-
|
106 |
-
cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
|
107 |
-
|
108 |
-
if response is None:
|
109 |
-
return "", cost
|
110 |
-
else:
|
111 |
-
return response.content, cost
|
112 |
-
|
113 |
-
def prepare_message(self, model_input: dict, prompt_type: str):
|
114 |
-
message = []
|
115 |
-
return message
|
116 |
-
|
117 |
-
def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
|
118 |
-
total_cost = 0
|
119 |
-
response_list = []
|
120 |
-
# prepare message
|
121 |
-
message = self.prepare_message(model_input, prompt_type)
|
122 |
-
# print(message)
|
123 |
-
|
124 |
-
# n sampling
|
125 |
-
for i in range(self.num_generate):
|
126 |
-
response, cost = self.generate_with_retry(message, constraint_str_list)
|
127 |
-
response_list.append(response)
|
128 |
-
total_cost += cost
|
129 |
-
|
130 |
-
return response_list, total_cost
|
131 |
-
|
132 |
-
|
133 |
-
class GroundingJudgeAgent(BaseAgent):
|
134 |
-
def __init__(self, agent_config: dict):
|
135 |
-
super().__init__(agent_config)
|
136 |
-
self._setup()
|
137 |
-
|
138 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
139 |
-
message = get_messages(
|
140 |
-
input_info=model_input,
|
141 |
-
inference_mode="judge_grounding",
|
142 |
-
prompt_type=prompt_type,
|
143 |
-
use_multimodal=self.use_multimodal,
|
144 |
-
text_obs=self.agent_config["text_obs_type"],
|
145 |
-
image_obs=self.agent_config["image_obs_type"]
|
146 |
-
)
|
147 |
-
return message
|
148 |
-
|
149 |
-
|
150 |
-
class ProgressJudgeAgent(BaseAgent):
|
151 |
-
def __init__(self, agent_config: dict):
|
152 |
-
super().__init__(agent_config)
|
153 |
-
self._setup()
|
154 |
-
|
155 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
156 |
-
if self.agent_config["input_type"]=="text_only":
|
157 |
-
use_multimodal = False
|
158 |
-
text_obs = self.agent_config["text_obs_type"]
|
159 |
-
image_obs = None
|
160 |
-
elif self.agent_config["input_type"]=="image_only":
|
161 |
-
use_multimodal = True
|
162 |
-
text_obs = None
|
163 |
-
image_obs = self.agent_config["image_obs_type"]
|
164 |
-
elif self.agent_config["input_type"]=="text_image":
|
165 |
-
use_multimodal = True
|
166 |
-
text_obs = self.agent_config["text_obs_type"]
|
167 |
-
image_obs = self.agent_config["image_obs_type"]
|
168 |
-
else:
|
169 |
-
raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
|
170 |
-
|
171 |
-
if self.agent_config["use_in_progress"]:
|
172 |
-
use_in_progress = True
|
173 |
-
else:
|
174 |
-
use_in_progress = False
|
175 |
-
|
176 |
-
message = get_messages(
|
177 |
-
input_info=model_input,
|
178 |
-
inference_mode="judge_progress",
|
179 |
-
prompt_type=prompt_type,
|
180 |
-
use_checklist=self.use_checklist,
|
181 |
-
use_multimodal=use_multimodal,
|
182 |
-
text_obs=text_obs,
|
183 |
-
image_obs=image_obs,
|
184 |
-
use_in_progress=use_in_progress
|
185 |
-
)
|
186 |
-
return message
|
187 |
-
|
188 |
-
def add_logprob(self, ori_logprob: float, add_logprob: float):
|
189 |
-
if ori_logprob is None:
|
190 |
-
return add_logprob
|
191 |
-
else:
|
192 |
-
ori_prob = math.exp(ori_logprob)
|
193 |
-
add_prob = math.exp(add_logprob)
|
194 |
-
return math.log(ori_prob + add_prob)
|
195 |
-
|
196 |
-
def get_judge_probs(self, logprobs: list):
|
197 |
-
# target_judge = {
|
198 |
-
# "yes": [" Yes", "Yes"],
|
199 |
-
# "no": [" No", "No"],
|
200 |
-
# "in": [" In", "In"]
|
201 |
-
# }
|
202 |
-
target_judge = {
|
203 |
-
"yes": [
|
204 |
-
" Yes", "ĠYes", "Yes", "ĊYes",
|
205 |
-
"Ġyes", "yes", "Ċyes",
|
206 |
-
"ĠYES", "YES", "ĊYES",
|
207 |
-
"ĠDone", "Done", "ĊDone",
|
208 |
-
"ĠCompleted", "Completed", "ĊCompleted",
|
209 |
-
"ĠCorrect", "Correct", "ĊCorrect"
|
210 |
-
],
|
211 |
-
"no": [
|
212 |
-
" No", "ĠNo", "No", "ĊNo",
|
213 |
-
"ĠNO", "NO", "ĊNO",
|
214 |
-
"ĠNot", "Not", "ĊNot",
|
215 |
-
"ĠNone", "None", "ĊNone",
|
216 |
-
"ĠNope", "Nope", "ĊNope",
|
217 |
-
"ĠUn", "Un", "ĊUn",
|
218 |
-
"ĠWrong", "Wrong", "ĊWrong"
|
219 |
-
],
|
220 |
-
"in": [
|
221 |
-
" In", "ĠIn", "In", "ĊIn",
|
222 |
-
"ĠPending", "Pending", "ĊPending",
|
223 |
-
"ĠPart", "Part", "ĊPart",
|
224 |
-
"ĠPartial", "Partial", "ĊPartial",
|
225 |
-
"ĠInProgress", "InProgress", "ĊInProgress"
|
226 |
-
]
|
227 |
-
}
|
228 |
-
response_str = ""
|
229 |
-
judge_probs_list = []
|
230 |
-
# print(logprobs)
|
231 |
-
for i, log_prob in enumerate(logprobs):
|
232 |
-
# Start to find judge string
|
233 |
-
if "<answer>" in response_str:
|
234 |
-
find_judge_str = None
|
235 |
-
for judge_type in target_judge:
|
236 |
-
if log_prob["token"] in target_judge[judge_type]:
|
237 |
-
# print(log_prob)
|
238 |
-
find_judge_str = judge_type
|
239 |
-
break
|
240 |
-
if find_judge_str:
|
241 |
-
# print("find judge str")
|
242 |
-
token_judge_dict = {
|
243 |
-
"yes": None,
|
244 |
-
"no": None,
|
245 |
-
"in": None
|
246 |
-
}
|
247 |
-
if "top_logprobs" in log_prob:
|
248 |
-
for token_info in log_prob["top_logprobs"]:
|
249 |
-
for judge_type in target_judge:
|
250 |
-
for judge_str in target_judge[judge_type]:
|
251 |
-
# if judge_str in token_info["token"] and token_info["logprob"] > token_judge_dict[judge_type]:
|
252 |
-
# token_judge_dict[judge_type] = token_info["logprob"]
|
253 |
-
if judge_str in token_info["token"]:
|
254 |
-
# print(token_info["logprob"])
|
255 |
-
token_judge_dict[judge_type] = self.add_logprob(token_judge_dict[judge_type], token_info["logprob"])
|
256 |
-
# for None case
|
257 |
-
for judge_type in token_judge_dict:
|
258 |
-
if token_judge_dict[judge_type] is None:
|
259 |
-
token_judge_dict[judge_type] = float("-inf")
|
260 |
-
judge_probs_list.append(token_judge_dict)
|
261 |
-
else:
|
262 |
-
# for vllm bugs : no top_logprobs
|
263 |
-
for judge_type in token_judge_dict:
|
264 |
-
if judge_type == find_judge_str:
|
265 |
-
token_judge_dict[judge_type] = log_prob["logprob"]
|
266 |
-
else:
|
267 |
-
token_judge_dict[judge_type] = float("-inf")
|
268 |
-
judge_probs_list.append(token_judge_dict)
|
269 |
-
# print(token_judge_dict)
|
270 |
-
|
271 |
-
if "</answer>" in response_str:
|
272 |
-
break
|
273 |
-
|
274 |
-
response_str += log_prob["token"]
|
275 |
-
# print(response_str.replace("Ġ", " ").replace("Ċ", "\n"))
|
276 |
-
# print(judge_probs_list)
|
277 |
-
if len(judge_probs_list) == 0:
|
278 |
-
return [{
|
279 |
-
"yes": 0.0,
|
280 |
-
"no": 0.0,
|
281 |
-
"in": 0.0
|
282 |
-
}]
|
283 |
-
else:
|
284 |
-
# convert with softmax
|
285 |
-
final_judge_probs_list = []
|
286 |
-
for judge_probs in judge_probs_list:
|
287 |
-
exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
|
288 |
-
sum_exp_logprobs = sum(exp_logprobs)
|
289 |
-
softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
|
290 |
-
final_judge_probs_list.append({
|
291 |
-
"yes": softmax_probs[0],
|
292 |
-
"no": softmax_probs[1],
|
293 |
-
"in": softmax_probs[2]
|
294 |
-
})
|
295 |
-
return final_judge_probs_list
|
296 |
-
|
297 |
-
def generate_probs(self, model_input: dict, prompt_type: str):
|
298 |
-
total_cost = 0
|
299 |
-
response_list = []
|
300 |
-
# prepare message
|
301 |
-
message = self.prepare_message(model_input, prompt_type)
|
302 |
-
# print(message)
|
303 |
-
|
304 |
-
for i in range(self.num_generate):
|
305 |
-
try:
|
306 |
-
response = self.llm.invoke(message)
|
307 |
-
total_input_tokens = response.response_metadata["token_usage"]["prompt_tokens"]
|
308 |
-
total_output_tokens = response.response_metadata["token_usage"]["completion_tokens"]
|
309 |
-
total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
|
310 |
-
logprobs = response.response_metadata["logprobs"]["content"]
|
311 |
-
response_list.append(
|
312 |
-
{
|
313 |
-
"response": response.content,
|
314 |
-
"judge_probs": self.get_judge_probs(logprobs)
|
315 |
-
}
|
316 |
-
)
|
317 |
-
except Exception as e:
|
318 |
-
print(f"Error: {e}")
|
319 |
-
# print(response.response_metadata["logprobs"])
|
320 |
-
response_list.append(
|
321 |
-
{
|
322 |
-
"response": response.content,
|
323 |
-
"judge_probs": []
|
324 |
-
}
|
325 |
-
)
|
326 |
-
return response_list, total_cost
|
327 |
-
|
328 |
-
|
329 |
-
class ChecklistGenerationAgent(BaseAgent):
|
330 |
-
def __init__(self, agent_config: dict):
|
331 |
-
super().__init__(agent_config)
|
332 |
-
self._setup()
|
333 |
-
|
334 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
335 |
-
message = get_messages(
|
336 |
-
input_info=model_input,
|
337 |
-
inference_mode="checklist_generation",
|
338 |
-
prompt_type=prompt_type
|
339 |
-
)
|
340 |
-
return message
|
341 |
-
|
342 |
-
|
343 |
-
class ClassifierRewardAgent(Agent):
|
344 |
-
def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
|
345 |
-
self.url = url
|
346 |
-
self.use_checklist = use_checklist
|
347 |
-
self.use_multimodal = use_multimodal
|
348 |
-
|
349 |
-
def _process_multimodal_message(self, prompt: str, image_list: list[str]):
|
350 |
-
multimodal_message = []
|
351 |
-
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
|
352 |
-
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
|
353 |
-
multimodal_message = [
|
354 |
-
{"type": "text", "text": text_prompt_prefix},
|
355 |
-
# {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
|
356 |
-
{"type": "image", "image": image_to_base64_url(image_list[0])},
|
357 |
-
{"type": "text", "text": text_prompt_suffix}
|
358 |
-
]
|
359 |
-
return multimodal_message
|
360 |
-
|
361 |
-
def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
|
362 |
-
if self.use_multimodal:
|
363 |
-
tmp_user_prompt = user_prompt_template["user"].format(
|
364 |
-
**model_input
|
365 |
-
)
|
366 |
-
user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
|
367 |
-
else:
|
368 |
-
user_prompt = user_prompt_template["user"].format(
|
369 |
-
**model_input
|
370 |
-
)
|
371 |
-
assistant_prompt = user_prompt_template["assistant"].format(
|
372 |
-
**model_input
|
373 |
-
)
|
374 |
-
query = [
|
375 |
-
{"role": "user", "content": user_prompt},
|
376 |
-
{"role": "assistant", "content": assistant_prompt}
|
377 |
-
]
|
378 |
-
return query
|
379 |
-
|
380 |
-
def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
|
381 |
-
if self.use_checklist:
|
382 |
-
if self.use_multimodal:
|
383 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
|
384 |
-
else:
|
385 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
|
386 |
-
else:
|
387 |
-
if self.use_multimodal:
|
388 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
|
389 |
-
else:
|
390 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
|
391 |
-
|
392 |
-
if self.use_multimodal:
|
393 |
-
if batch:
|
394 |
-
message = [self._make_query(user_prompt_template, input) for input in model_input]
|
395 |
-
else:
|
396 |
-
message = [self._make_query(user_prompt_template, model_input)]
|
397 |
-
else:
|
398 |
-
if batch:
|
399 |
-
message = {
|
400 |
-
"query": [self._make_query(user_prompt_template, input) for input in model_input],
|
401 |
-
"promptts": []
|
402 |
-
}
|
403 |
-
else:
|
404 |
-
message = {
|
405 |
-
"query": self._make_query(user_prompt_template, model_input),
|
406 |
-
"prompts": []
|
407 |
-
}
|
408 |
-
|
409 |
-
return message
|
410 |
-
|
411 |
-
def get_rm_scroe(self, message: dict | list):
|
412 |
-
headers = {"Content-Type": "application/json"}
|
413 |
-
|
414 |
-
try:
|
415 |
-
if self.use_multimodal:
|
416 |
-
response = requests.post(
|
417 |
-
self.url,
|
418 |
-
json={"messages": message},
|
419 |
-
timeout=600
|
420 |
-
)
|
421 |
-
else:
|
422 |
-
response = requests.post(
|
423 |
-
self.url,
|
424 |
-
headers=headers,
|
425 |
-
data=json.dumps(message),
|
426 |
-
timeout=300
|
427 |
-
)
|
428 |
-
response.raise_for_status()
|
429 |
-
|
430 |
-
response_json = response.json()
|
431 |
-
|
432 |
-
if "rewards" not in response_json:
|
433 |
-
print(f"Error: 'rewards' key not found in API response: {response_json}")
|
434 |
-
return []
|
435 |
-
|
436 |
-
if "get_reward" in self.url:
|
437 |
-
# use openrlhf
|
438 |
-
return response_json["rewards"]
|
439 |
-
elif "pooling" in self.url:
|
440 |
-
# use vllm server
|
441 |
-
return response_json["reward"]
|
442 |
-
else:
|
443 |
-
# error
|
444 |
-
raise ValueError(f"Invalid URL: {self.url}")
|
445 |
-
|
446 |
-
except requests.exceptions.Timeout:
|
447 |
-
print(f"Error: Request timed out to {self.url}")
|
448 |
-
return []
|
449 |
-
except requests.exceptions.RequestException as e:
|
450 |
-
print(f"Error during request to {self.url}: {e}")
|
451 |
-
return []
|
452 |
-
except json.JSONDecodeError:
|
453 |
-
print(f"Error: Failed to decode JSON response from {self.url}")
|
454 |
-
return []
|
455 |
-
except KeyError as e:
|
456 |
-
print(f"Error: Missing key {e} in response from {self.url}")
|
457 |
-
return []
|
458 |
-
|
459 |
-
|
460 |
-
def generate_response(self, model_input: dict | list[dict], batch: bool = False):
|
461 |
-
if batch:
|
462 |
-
message = self.prepare_message(model_input, batch=True)
|
463 |
-
else:
|
464 |
-
message = self.prepare_message(model_input)
|
465 |
-
rewards = self.get_rm_scroe(message)
|
466 |
-
|
467 |
-
return rewards, 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/checklist_eval.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
from langchain_openai import ChatOpenAI
|
4 |
-
|
5 |
-
from .agent import BaseAgent
|
6 |
-
|
7 |
-
SYSTEM_PROMPT = "You are an expert evaluator. Your task is to assess how well a Web Agent’s generated checklist aligns with the reference checklist for a given user instruction."
|
8 |
-
|
9 |
-
USER_PROMPT = """# Task Description
|
10 |
-
Use the provided task description, evaluation criteria, and both checklists to assign a score from 1 to 5. Justify your rating with a brief explanation that considers both content overlap and logical structure.
|
11 |
-
|
12 |
-
## Score Criteria
|
13 |
-
- 5: Checklist covers all subgoals, is correct and clearly expressed
|
14 |
-
- 4: Minor omissions or phrasing issues but mostly accurate and complete
|
15 |
-
- 3: Partially matches, but with noticeable gaps or errors
|
16 |
-
- 2: Incomplete or includes incorrect steps
|
17 |
-
- 1: Mostly irrelevant, incorrect, or missing the task goal
|
18 |
-
|
19 |
-
## User Instruction:
|
20 |
-
{intent}
|
21 |
-
|
22 |
-
## Reference Checklist:
|
23 |
-
{gt_checklist}
|
24 |
-
|
25 |
-
## Agent’s Generated Checklist:
|
26 |
-
{generated_checklist}
|
27 |
-
|
28 |
-
# Output Format
|
29 |
-
Your response should be in the following format:
|
30 |
-
REASON: [Write 2–4 sentences explaining how well the generated checklist matches the reference. Mention specific matches, omissions, errors, or strengths.]
|
31 |
-
SCORE: [1–5]
|
32 |
-
"""
|
33 |
-
|
34 |
-
|
35 |
-
class ChecklistEvalAgent(BaseAgent):
|
36 |
-
def __init__(self, agent_config: dict):
|
37 |
-
super().__init__(agent_config)
|
38 |
-
self._setup()
|
39 |
-
|
40 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
41 |
-
message = [
|
42 |
-
{
|
43 |
-
"role": "system",
|
44 |
-
"content": SYSTEM_PROMPT
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"role": "user",
|
48 |
-
"content": USER_PROMPT.format(
|
49 |
-
intent=model_input["intent"],
|
50 |
-
gt_checklist=model_input["gt_checklist"],
|
51 |
-
generated_checklist=model_input["generated_checklist"]
|
52 |
-
)
|
53 |
-
}
|
54 |
-
]
|
55 |
-
return message
|
56 |
-
|
57 |
-
def generate_response(self, model_input: dict):
|
58 |
-
total_cost = 0
|
59 |
-
response_list = []
|
60 |
-
# prepare message
|
61 |
-
message = self.prepare_message(model_input)
|
62 |
-
|
63 |
-
# n sampling
|
64 |
-
for _ in range(self.num_generate):
|
65 |
-
response, cost = self.generate_with_retry(message, ["SCORE"])
|
66 |
-
response_list.append(response)
|
67 |
-
total_cost += cost
|
68 |
-
|
69 |
-
return response_list, total_cost
|
70 |
-
|
71 |
-
def parsing_score(response: str):
|
72 |
-
score = response.split("SCORE:")[-1].split("\n")[0].strip()
|
73 |
-
match = re.search(r'\d+', score)
|
74 |
-
|
75 |
-
if match:
|
76 |
-
return int(match.group())
|
77 |
-
else:
|
78 |
-
return None
|
79 |
-
|
80 |
-
def average_score(scores: list[int]):
|
81 |
-
if len(scores) == 0:
|
82 |
-
return 0
|
83 |
-
return sum(scores) / len(scores)
|
84 |
-
|
85 |
-
def get_score(results: list[dict]):
|
86 |
-
score_list = []
|
87 |
-
for result in results:
|
88 |
-
tmp_scores = [parsing_score(response) for response in result["response"]]
|
89 |
-
scores = [score for score in tmp_scores if score is not None]
|
90 |
-
result["score_list"] = scores
|
91 |
-
final_score = average_score(scores)
|
92 |
-
result["score"] = final_score
|
93 |
-
score_list.append(result)
|
94 |
-
|
95 |
-
return results, score_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/eval_utils.py
DELETED
@@ -1,309 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import random
|
3 |
-
from collections import Counter
|
4 |
-
|
5 |
-
from .utils import load_json, save_json, create_html_report
|
6 |
-
|
7 |
-
random.seed(42)
|
8 |
-
def get_score(response_list: list, indicator: str) -> int:
|
9 |
-
if len(response_list) == 0:
|
10 |
-
return [-100]
|
11 |
-
|
12 |
-
if isinstance(response_list[0], float):
|
13 |
-
return response_list
|
14 |
-
|
15 |
-
if indicator == "prob":
|
16 |
-
score_list = []
|
17 |
-
for response in response_list:
|
18 |
-
total_score = 0
|
19 |
-
for judge_probs in response:
|
20 |
-
yes_prob = judge_probs.get("yes", 0)
|
21 |
-
in_progress_prob = judge_probs.get("in", 0)
|
22 |
-
total_score += yes_prob + in_progress_prob * 0.5
|
23 |
-
if len(response) > 0:
|
24 |
-
score_list.append(total_score / len(response))
|
25 |
-
else:
|
26 |
-
score_list.append(0)
|
27 |
-
return score_list
|
28 |
-
else:
|
29 |
-
score_list = []
|
30 |
-
for response in response_list:
|
31 |
-
if indicator == "SCORE":
|
32 |
-
if "SCORE" in response:
|
33 |
-
try:
|
34 |
-
score_str = response.split("SCORE:")[1].split("\n")[0].strip()
|
35 |
-
except:
|
36 |
-
score_str = response.split("SCORE:")[-1].strip()
|
37 |
-
# find first integer
|
38 |
-
try:
|
39 |
-
score = re.search(r'-?\d+', score_str).group()
|
40 |
-
score_list.append(int(score))
|
41 |
-
except:
|
42 |
-
score_list.append(0)
|
43 |
-
else:
|
44 |
-
try:
|
45 |
-
score_str = response.split("<answer>")[1].split("</answer>")[0].strip()
|
46 |
-
except:
|
47 |
-
score_str = response.split("<answer>")[-1].split("</answer>")[0].strip()
|
48 |
-
# find "Yes" or "No"
|
49 |
-
if "Yes" in score_str:
|
50 |
-
score_list.append(1)
|
51 |
-
elif "In Progress" in score_str:
|
52 |
-
score_list.append(0.5)
|
53 |
-
elif "No" in score_str:
|
54 |
-
score_list.append(0)
|
55 |
-
else:
|
56 |
-
score_list.append(0)
|
57 |
-
elif indicator == "JUDGE":
|
58 |
-
try:
|
59 |
-
judge_str = response.split("JUDGE:")[1].split("\n")[0].strip()
|
60 |
-
except:
|
61 |
-
judge_str = response.split("JUDGE:")[-1].strip()
|
62 |
-
if "Yes" in judge_str:
|
63 |
-
score_list.append(1)
|
64 |
-
elif "No" in judge_str:
|
65 |
-
score_list.append(0)
|
66 |
-
else:
|
67 |
-
score_list.append(0)
|
68 |
-
elif indicator == "CHECKLIST EVALUATION":
|
69 |
-
if "<answer>" in response:
|
70 |
-
try:
|
71 |
-
checklist_str = response.split("<answer>")[1].split("</answer>")[0].strip()
|
72 |
-
except:
|
73 |
-
checklist_str = response.split("<answer>")[-1].split("</answer>")[0].strip()
|
74 |
-
else:
|
75 |
-
checklist_str = response.split("CHECKLIST EVALUATION:")[-1].strip()
|
76 |
-
|
77 |
-
count_yes = checklist_str.count("Yes")
|
78 |
-
count_no = checklist_str.count("No")
|
79 |
-
count_in_progress = checklist_str.count("In Progress")
|
80 |
-
try:
|
81 |
-
total_score = (count_yes + count_in_progress*0.5) / (count_yes + count_no + count_in_progress)
|
82 |
-
except:
|
83 |
-
total_score = 0
|
84 |
-
score_list.append(total_score)
|
85 |
-
else:
|
86 |
-
raise ValueError(f"Invalid indicator: {indicator}")
|
87 |
-
return score_list
|
88 |
-
|
89 |
-
def get_acc_and_mrr(chosen_score, rejected_scores):
|
90 |
-
if len(rejected_scores) == 0:
|
91 |
-
return 0, False
|
92 |
-
|
93 |
-
same_score_num = rejected_scores.count(chosen_score)
|
94 |
-
all_scores = rejected_scores + [chosen_score]
|
95 |
-
sorted_scores = sorted(all_scores, reverse=True)
|
96 |
-
rank = sorted_scores.index(chosen_score) + 1 + same_score_num # draw penalty
|
97 |
-
if all(chosen_score > r for r in rejected_scores):
|
98 |
-
accuracy = True
|
99 |
-
else:
|
100 |
-
accuracy = False
|
101 |
-
return 1 / rank, accuracy
|
102 |
-
|
103 |
-
def average_score(score_list: list[float]):
|
104 |
-
if len(score_list) == 0:
|
105 |
-
return -100
|
106 |
-
return sum(score_list) / len(score_list)
|
107 |
-
|
108 |
-
def self_consistency_score(score_list: list[float]):
|
109 |
-
if len(score_list) == 0:
|
110 |
-
return -100
|
111 |
-
counter = Counter(score_list)
|
112 |
-
return max(counter.values()) / len(score_list)
|
113 |
-
|
114 |
-
def get_chosen_rejected_scores(data: dict, agg_func: str):
|
115 |
-
if len(data["chosen"]) == 0:
|
116 |
-
data["chosen"] = [{"score": [-100]}]
|
117 |
-
if len(data["rejected"]) == 0:
|
118 |
-
data["rejected"] = [{"score": [-100]}]
|
119 |
-
if not isinstance(data["chosen"][0], dict):
|
120 |
-
data["chosen"][0]["score"] = [-100]
|
121 |
-
if not isinstance(data["rejected"][0], dict):
|
122 |
-
data["rejected"][0]["score"] = [-100]
|
123 |
-
|
124 |
-
if agg_func == "average":
|
125 |
-
chosen_score = average_score(data["chosen"][0]["score"])
|
126 |
-
rejected_scores = [average_score(rejected_score["score"]) for rejected_score in data["rejected"]]
|
127 |
-
elif agg_func == "self_consistency":
|
128 |
-
chosen_score = self_consistency_score(data["chosen"][0]["score"])
|
129 |
-
rejected_scores = [self_consistency_score(rejected_score["score"]) for rejected_score in data["rejected"]]
|
130 |
-
else:
|
131 |
-
raise ValueError(f"Invalid agg_func: {agg_func}")
|
132 |
-
return chosen_score, rejected_scores
|
133 |
-
|
134 |
-
def get_score_results(results, agg_func):
|
135 |
-
score_dict = {"mrr": [], "accuracy": [], "traj_accuracy": []}
|
136 |
-
task_accuracy = {}
|
137 |
-
for result in results:
|
138 |
-
chosen_score, rejected_scores = get_chosen_rejected_scores(result, agg_func)
|
139 |
-
mrr, accuracy = get_acc_and_mrr(chosen_score, rejected_scores)
|
140 |
-
score_dict["mrr"].append(mrr)
|
141 |
-
score_dict["accuracy"].append(accuracy)
|
142 |
-
if result["task_id"] not in task_accuracy:
|
143 |
-
task_accuracy[result["task_id"]] = []
|
144 |
-
task_accuracy[result["task_id"]].append(accuracy)
|
145 |
-
|
146 |
-
for task_id in task_accuracy:
|
147 |
-
if sum(task_accuracy[task_id]) == len(task_accuracy[task_id]):
|
148 |
-
score_dict["traj_accuracy"].append(True)
|
149 |
-
else:
|
150 |
-
score_dict["traj_accuracy"].append(False)
|
151 |
-
|
152 |
-
return score_dict
|
153 |
-
|
154 |
-
def calculate_stats(results, agg_func: str="average"):
|
155 |
-
if len(results) == 0:
|
156 |
-
return {
|
157 |
-
"MRR": 0,
|
158 |
-
"Accuracy": 0,
|
159 |
-
"Traj_Accuracy": 0,
|
160 |
-
}
|
161 |
-
total_score = get_score_results(results, agg_func)
|
162 |
-
stats = {
|
163 |
-
"MRR": sum(total_score["mrr"]) / len(total_score["mrr"]),
|
164 |
-
"Accuracy": sum(total_score["accuracy"]) / len(total_score["accuracy"]),
|
165 |
-
"Traj_Accuracy": sum(total_score["traj_accuracy"]) / len(total_score["traj_accuracy"]),
|
166 |
-
}
|
167 |
-
|
168 |
-
return stats
|
169 |
-
|
170 |
-
def group_by_task(results, split_indicator: str):
|
171 |
-
# sort results by task_id and step_id
|
172 |
-
results.sort(key=lambda x: (x["task_id"], x["step_id"]))
|
173 |
-
# group by task_name
|
174 |
-
grouped_task_dict = {}
|
175 |
-
for result in results:
|
176 |
-
task_name = "task_" + str(result["task_id"]) + "_step_" + str(result["step_id"])
|
177 |
-
if task_name not in grouped_task_dict:
|
178 |
-
grouped_task_dict[task_name] = {
|
179 |
-
"task_id": result["task_id"],
|
180 |
-
"step_id": result["step_id"],
|
181 |
-
"intent": result["intent"],
|
182 |
-
"start_url": result["start_url"],
|
183 |
-
"gt_checklist": result["gt_checklist"],
|
184 |
-
"generated_checklist": result.get("generated_checklist", None) ,
|
185 |
-
"trajectory": result["trajectory"],
|
186 |
-
"current_url": result["current_url"],
|
187 |
-
"text_observation": result["text_observation"],
|
188 |
-
# "image_list": result["image_list"],
|
189 |
-
"chosen": [],
|
190 |
-
"rejected": [],
|
191 |
-
"source_name": result["source_name"],
|
192 |
-
}
|
193 |
-
|
194 |
-
response = result["response"] if "response" in result else []
|
195 |
-
type_data = {
|
196 |
-
"thought": result["thought"],
|
197 |
-
"action": result["action"],
|
198 |
-
"response": response,
|
199 |
-
"score": get_score(response, split_indicator) if split_indicator != "prob" else get_score(result["judge_probs"], split_indicator),
|
200 |
-
}
|
201 |
-
if split_indicator == "prob":
|
202 |
-
type_data["judge_probs"] = result["judge_probs"]
|
203 |
-
if result["type"] == "chosen":
|
204 |
-
grouped_task_dict[task_name]["chosen"].append(type_data)
|
205 |
-
elif result["type"] == "rejected":
|
206 |
-
grouped_task_dict[task_name]["rejected"].append(type_data)
|
207 |
-
|
208 |
-
return list(grouped_task_dict.values())
|
209 |
-
|
210 |
-
|
211 |
-
def processing_results(results, evaluation_mode: str, num_generate: int, use_batch: bool=False):
|
212 |
-
if "judge_probs" in results[0]:
|
213 |
-
split_indicator = "prob"
|
214 |
-
else:
|
215 |
-
if evaluation_mode == "judge_with_checklist_generation" or evaluation_mode == "judge_with_gt_checklist":
|
216 |
-
split_indicator = "CHECKLIST EVALUATION"
|
217 |
-
else:
|
218 |
-
split_indicator = "SCORE"
|
219 |
-
|
220 |
-
# if use_batch is True, make it flattened
|
221 |
-
if use_batch:
|
222 |
-
tmp_results = []
|
223 |
-
for result in results:
|
224 |
-
for d in result:
|
225 |
-
tmp_results.append(d)
|
226 |
-
grouped_results = group_by_task(tmp_results, split_indicator)
|
227 |
-
else:
|
228 |
-
grouped_results = group_by_task(results, split_indicator)
|
229 |
-
|
230 |
-
mind2web_results = []
|
231 |
-
webarena_results = []
|
232 |
-
mind2web_task_results = []
|
233 |
-
mind2web_website_results = []
|
234 |
-
mind2web_domain_results = []
|
235 |
-
|
236 |
-
for grouped_result in grouped_results:
|
237 |
-
if "mind2web" in grouped_result["source_name"]:
|
238 |
-
mind2web_results.append(grouped_result)
|
239 |
-
if grouped_result["source_name"] == "mind2web_test_task":
|
240 |
-
mind2web_task_results.append(grouped_result)
|
241 |
-
elif grouped_result["source_name"] == "mind2web_test_website":
|
242 |
-
mind2web_website_results.append(grouped_result)
|
243 |
-
elif grouped_result["source_name"] == "mind2web_test_domain":
|
244 |
-
mind2web_domain_results.append(grouped_result)
|
245 |
-
elif "webarena" in grouped_result["source_name"]:
|
246 |
-
webarena_results.append(grouped_result)
|
247 |
-
|
248 |
-
try:
|
249 |
-
final_stats = {
|
250 |
-
"mind2web": {
|
251 |
-
"MRR": {},
|
252 |
-
"Accuracy": {},
|
253 |
-
"Traj_Accuracy": {},
|
254 |
-
},
|
255 |
-
"webarena": {
|
256 |
-
"MRR": {},
|
257 |
-
"Accuracy": {},
|
258 |
-
"Traj_Accuracy": {},
|
259 |
-
},
|
260 |
-
"mind2web_task": {
|
261 |
-
"MRR": {},
|
262 |
-
"Accuracy": {},
|
263 |
-
"Traj_Accuracy": {},
|
264 |
-
},
|
265 |
-
"mind2web_website": {
|
266 |
-
"MRR": {},
|
267 |
-
"Accuracy": {},
|
268 |
-
"Traj_Accuracy": {},
|
269 |
-
},
|
270 |
-
"mind2web_domain": {
|
271 |
-
"MRR": {},
|
272 |
-
"Accuracy": {},
|
273 |
-
"Traj_Accuracy": {},
|
274 |
-
},
|
275 |
-
}
|
276 |
-
for source_results in [
|
277 |
-
("mind2web", mind2web_results),
|
278 |
-
("webarena", webarena_results),
|
279 |
-
("mind2web_task", mind2web_task_results),
|
280 |
-
("mind2web_website", mind2web_website_results),
|
281 |
-
("mind2web_domain", mind2web_domain_results)
|
282 |
-
]:
|
283 |
-
average_stats = calculate_stats(source_results[1], "average")
|
284 |
-
self_consistency_stats = calculate_stats(source_results[1], "self_consistency")
|
285 |
-
for metric in average_stats:
|
286 |
-
final_stats[source_results[0]][metric]["Average"] = average_stats[metric]
|
287 |
-
for metric in self_consistency_stats:
|
288 |
-
final_stats[source_results[0]][metric]["Self_Consistency"] = self_consistency_stats[metric]
|
289 |
-
|
290 |
-
if num_generate == 1:
|
291 |
-
for source_name in final_stats:
|
292 |
-
for metric in final_stats[source_name]:
|
293 |
-
print(f"{round(100 * final_stats[source_name][metric]['Average'], 2)}", end=", ")
|
294 |
-
print()
|
295 |
-
else:
|
296 |
-
for agg_func in ["Average", "Self_Consistency"]:
|
297 |
-
print(f"{agg_func}")
|
298 |
-
for source_name in final_stats:
|
299 |
-
for metric in final_stats[source_name]:
|
300 |
-
print(f"{round(100 * final_stats[source_name][metric][agg_func], 2)}", end=", ")
|
301 |
-
print()
|
302 |
-
except Exception as e:
|
303 |
-
print(e)
|
304 |
-
return grouped_results, None
|
305 |
-
|
306 |
-
# add function to convert json format results to html format results
|
307 |
-
# TODO: implement this function
|
308 |
-
# create_html_report(results, "results.html")
|
309 |
-
return grouped_results, final_stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/inference_utils.py
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
|
3 |
-
from multiprocessing import Process, Manager
|
4 |
-
from tqdm import tqdm
|
5 |
-
|
6 |
-
|
7 |
-
def worker_main(work_queue, result_queue, process_func, config):
|
8 |
-
while True:
|
9 |
-
item = work_queue.get()
|
10 |
-
if item is None:
|
11 |
-
result_queue.put(None)
|
12 |
-
break
|
13 |
-
try:
|
14 |
-
results, cost = process_func(config, item)
|
15 |
-
result_queue.put((results, cost))
|
16 |
-
except Exception as e:
|
17 |
-
item_info = item.get('idx', item.get('id', 'unknown item'))
|
18 |
-
print(f"Error processing item {item_info}: {e}")
|
19 |
-
result_queue.put(None)
|
20 |
-
finally:
|
21 |
-
work_queue.task_done()
|
22 |
-
|
23 |
-
def run_parallel_evaluation(dataset, process_func, config, num_workers, description):
|
24 |
-
"""
|
25 |
-
Runs parallel evaluation on the given dataset and returns the results.
|
26 |
-
|
27 |
-
Args:
|
28 |
-
dataset (list or datasets.Dataset): Data to evaluate.
|
29 |
-
process_func (callable): Function to process each data item.
|
30 |
-
config (dict): Configuration for the process_func.
|
31 |
-
num_workers (int): Number of worker processes to use.
|
32 |
-
description (str): Description to display on the tqdm progress bar.
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
tuple: (list of evaluation results, total cost)
|
36 |
-
"""
|
37 |
-
manager = Manager()
|
38 |
-
work_queue = manager.Queue()
|
39 |
-
result_queue = manager.Queue()
|
40 |
-
|
41 |
-
# Add data to the work queue
|
42 |
-
dataset_list = list(dataset) if not isinstance(dataset, list) else dataset
|
43 |
-
for data in dataset_list:
|
44 |
-
work_queue.put(data)
|
45 |
-
|
46 |
-
# Add termination signals for workers
|
47 |
-
for _ in range(num_workers):
|
48 |
-
work_queue.put(None)
|
49 |
-
|
50 |
-
# Start parallel processing
|
51 |
-
processes = []
|
52 |
-
for _ in range(num_workers):
|
53 |
-
p = Process(target=worker_main, args=(work_queue, result_queue, process_func, config))
|
54 |
-
p.start()
|
55 |
-
processes.append(p)
|
56 |
-
|
57 |
-
# Show progress bar and collect results
|
58 |
-
process_results = []
|
59 |
-
process_cost = 0
|
60 |
-
completed_workers = 0
|
61 |
-
|
62 |
-
with tqdm(total=len(dataset_list), desc=description) as pbar:
|
63 |
-
while completed_workers < num_workers:
|
64 |
-
result_item = result_queue.get()
|
65 |
-
if result_item is None:
|
66 |
-
completed_workers += 1
|
67 |
-
else:
|
68 |
-
results, cost = result_item
|
69 |
-
if results is not None:
|
70 |
-
process_results.append(results)
|
71 |
-
process_cost += cost if cost is not None else 0
|
72 |
-
pbar.update(1)
|
73 |
-
|
74 |
-
# Wait for all processes to finish
|
75 |
-
for p in processes:
|
76 |
-
p.join()
|
77 |
-
|
78 |
-
# Collect remaining results
|
79 |
-
while not result_queue.empty():
|
80 |
-
result_item = result_queue.get_nowait()
|
81 |
-
if result_item is not None:
|
82 |
-
results, cost = result_item
|
83 |
-
if results is not None:
|
84 |
-
process_results.append(results)
|
85 |
-
process_cost += cost if cost is not None else 0
|
86 |
-
|
87 |
-
return process_results, process_cost
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .construct_messages import get_messages
|
|
|
|
agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (263 Bytes)
|
|
agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc
DELETED
Binary file (2.85 kB)
|
|
agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc
DELETED
Binary file (3.11 kB)
|
|
agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc
DELETED
Binary file (15 kB)
|
|
agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc
DELETED
Binary file (5.46 kB)
|
|
agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc
DELETED
Binary file (1.71 kB)
|
|
agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc
DELETED
Binary file (1.03 kB)
|
|
agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc
DELETED
Binary file (5.64 kB)
|
|
agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc
DELETED
Binary file (1.19 kB)
|
|
agent/mini_bench/prompts/action.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
ACTION_SPACE_PROMPT = """Note: This action set allows you to interact with your environment. Most of them are python function executing playwright code. The primary way of referring to elements in the page is through bid which are specified in your observations.
|
2 |
-
|
3 |
-
15 different types of actions are available.
|
4 |
-
|
5 |
-
noop(wait_ms: float = 1000)
|
6 |
-
Examples:
|
7 |
-
noop()
|
8 |
-
|
9 |
-
noop(500)
|
10 |
-
|
11 |
-
scroll(delta_x: float, delta_y: float)
|
12 |
-
Examples:
|
13 |
-
scroll(0, 200)
|
14 |
-
|
15 |
-
scroll(-50.2, -100.5)
|
16 |
-
|
17 |
-
keyboard_press(key: str)
|
18 |
-
Examples:
|
19 |
-
keyboard_press('Backspace')
|
20 |
-
|
21 |
-
keyboard_press('ControlOrMeta+a')
|
22 |
-
|
23 |
-
keyboard_press('Meta+Shift+t')
|
24 |
-
|
25 |
-
click(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
|
26 |
-
Examples:
|
27 |
-
click('a51')
|
28 |
-
|
29 |
-
click('b22', button='right')
|
30 |
-
|
31 |
-
click('48', button='middle', modifiers=['Shift'])
|
32 |
-
|
33 |
-
fill(bid: str, value: str)
|
34 |
-
Examples:
|
35 |
-
fill('237', 'example value')
|
36 |
-
|
37 |
-
fill('45', 'multi-line\nexample')
|
38 |
-
|
39 |
-
fill('a12', 'example with "quotes"')
|
40 |
-
|
41 |
-
hover(bid: str)
|
42 |
-
Examples:
|
43 |
-
hover('b8')
|
44 |
-
|
45 |
-
tab_focus(index: int)
|
46 |
-
Examples:
|
47 |
-
tab_focus(2)
|
48 |
-
|
49 |
-
new_tab()
|
50 |
-
Examples:
|
51 |
-
new_tab()
|
52 |
-
|
53 |
-
go_back()
|
54 |
-
Examples:
|
55 |
-
go_back()
|
56 |
-
|
57 |
-
go_forward()
|
58 |
-
Examples:
|
59 |
-
go_forward()
|
60 |
-
|
61 |
-
goto(url: str)
|
62 |
-
Examples:
|
63 |
-
goto('http://www.example.com')
|
64 |
-
|
65 |
-
tab_close()
|
66 |
-
Examples:
|
67 |
-
tab_close()
|
68 |
-
|
69 |
-
select_option(bid: str, options: str | list[str])
|
70 |
-
Examples:
|
71 |
-
select_option('a48', 'blue')
|
72 |
-
|
73 |
-
select_option('c48', ['red', 'green', 'blue'])
|
74 |
-
|
75 |
-
send_msg_to_user(text: str)
|
76 |
-
Examples:
|
77 |
-
send_msg_to_user('Based on the results of my search, the city was built in 1751.')
|
78 |
-
|
79 |
-
report_infeasible(reason: str)
|
80 |
-
Examples:
|
81 |
-
report_infeasible('I cannot follow these instructions because there is no email field in this form.')
|
82 |
-
|
83 |
-
Only a single action can be provided at once. Example:
|
84 |
-
fill('a12', 'example with "quotes"')
|
85 |
-
|
86 |
-
Note:
|
87 |
-
* Some tasks may be game like and may require to interact with the mouse position in x, y coordinates.
|
88 |
-
* Some text field might have auto completion. To see it, you have to type a few characters and wait until next step.
|
89 |
-
* If you have to cut and paste, don't forget to select the text first.
|
90 |
-
* Coordinate inside an SVG are relative to it's top left corner.
|
91 |
-
* Make sure to use bid to identify elements when using commands.
|
92 |
-
* Interacting with combobox, dropdowns and auto-complete fields can be tricky, sometimes you need to use select_option, while other times you need to use fill or click and wait for the reaction of the page.
|
93 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/checklist_prompt.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
CHECKLIST_SYSTEM_PROMPT = "You are an AI assistant tasked with generating structured checklists that highlight key subgoals necessary to complete a task."
|
2 |
-
|
3 |
-
CHECKLIST_USER_PROMPT = """## Task Description
|
4 |
-
User Instruction (Goal): "{intent}"
|
5 |
-
Start Website URL: {start_url}
|
6 |
-
|
7 |
-
## Guidelines for Checklist Generation
|
8 |
-
1. Identify Essential High-Level Subgoals:
|
9 |
-
- A subgoal should represent a significant step involving user interaction that leads to noticeable page transitions or meaningful changes in system state.
|
10 |
-
- Consolidate closely related user actions (such as applying multiple filters or selecting several options) into a single subgoal, rather than separate checklist items for each action.
|
11 |
-
- Prioritize only the most critical interactions necessary for meaningful progression, avoiding the inclusion of minor or unnecessary steps (e.g., scroll, hover).
|
12 |
-
2. Provide a Concise Subgoal Analysis:
|
13 |
-
- Before creating the checklist, offer a brief paragraph summarizing the main subgoals, emphasizing significant transitions or page-level interactions.
|
14 |
-
3. Ensure Clear Goal:
|
15 |
-
- If multiple related interactions occur (e.g., setting filters 1, 2, and 3), combine them into one subgoal with clear criteria verifying all required conditions.
|
16 |
-
- The checklist should contain only essential steps, explicitly excluding unnecessary actions, and should not exceed five critical subgoals. It is not necessary to use all five checklist items if fewer steps adequately represent the essential subgoals.
|
17 |
-
|
18 |
-
### Output Format
|
19 |
-
Before generating the checklist, first produce a concise subgoal analysis in a single paragraph summarizing the required interactions. Then, based on this, generate the checklist following the format below:
|
20 |
-
[SUBGOAL ANALYSIS]
|
21 |
-
[One-paragraph summary explaining the key subgoals and their logical sequence in task completion.]
|
22 |
-
|
23 |
-
[CHECKLISTS]
|
24 |
-
Checklist X: [Short title of the action/goal]
|
25 |
-
- Goal: [Brief description of the subgoal at this stage, emphasizing the purpose of the action.]
|
26 |
-
"""
|
27 |
-
|
28 |
-
# TODO: implement ours
|
29 |
-
CHECKLIST_OURS_SYSTEM_PROMPT = ""
|
30 |
-
|
31 |
-
CHECKLIST_OURS_USER_PROMPT = """You are an AI assistant tasked with generating structured checklists that highlight key subgoals necessary to complete a task.
|
32 |
-
|
33 |
-
# Task Description
|
34 |
-
Generate a checklist which are key milestones for achieving the given instruction. Frist, provide a concise
|
35 |
-
subgoal analysis in a single paragraph summarizing the required interactions. Then, based on this, generate the checklist with breif description.
|
36 |
-
|
37 |
-
Note: If the target website requires login, assume the user is already logged in and starts from an authenticated session.
|
38 |
-
|
39 |
-
# Given Information
|
40 |
-
## User Instruction
|
41 |
-
{intent}
|
42 |
-
|
43 |
-
## Current State
|
44 |
-
### Current URL
|
45 |
-
{start_url}
|
46 |
-
|
47 |
-
### AXTREE
|
48 |
-
Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
|
49 |
-
{text_observation}
|
50 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/construct_messages.py
DELETED
@@ -1,309 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
|
3 |
-
from .action import ACTION_SPACE_PROMPT
|
4 |
-
from .eval_type import (
|
5 |
-
GROUNDING,
|
6 |
-
PROGRESS_LIKERT_SCALE,
|
7 |
-
PROGRESS_THREE_CLASS,
|
8 |
-
PROGRESS_WITH_CHECKLIST,
|
9 |
-
PROGRESS_WITH_CHECKLIST_IN_PROGRESS,
|
10 |
-
PROGRESS_OURS
|
11 |
-
)
|
12 |
-
from .input_information import (
|
13 |
-
USER_INSTRUCTION,
|
14 |
-
TRAJECTORY,
|
15 |
-
AGENT_RESPONSE,
|
16 |
-
CHECKLIST,
|
17 |
-
CURRENT_URL,
|
18 |
-
TEXT_OBSERVATION,
|
19 |
-
SOM_IMAGE_OBSERVATION,
|
20 |
-
COORD_IMAGE_OBSERVATION
|
21 |
-
)
|
22 |
-
from .judge_prompt import (
|
23 |
-
JUDGE_GROUNDING_PROMPT_TEMPLATE,
|
24 |
-
JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE,
|
25 |
-
JUDGE_THREE_CLASS_PROMPT_TEMPLATE,
|
26 |
-
JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE,
|
27 |
-
JUDGE_OURS_PROMPT_TEMPLATE,
|
28 |
-
JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE
|
29 |
-
)
|
30 |
-
from .checklist_prompt import (
|
31 |
-
CHECKLIST_SYSTEM_PROMPT,
|
32 |
-
CHECKLIST_USER_PROMPT,
|
33 |
-
CHECKLIST_OURS_SYSTEM_PROMPT,
|
34 |
-
CHECKLIST_OURS_USER_PROMPT
|
35 |
-
)
|
36 |
-
from .image_utils import image_to_base64_url
|
37 |
-
|
38 |
-
|
39 |
-
class Message(ABC):
|
40 |
-
@abstractmethod
|
41 |
-
def get_messages(self):
|
42 |
-
pass
|
43 |
-
|
44 |
-
class BaseMessage(Message):
|
45 |
-
def __init__(self, input_info:dict, use_multimodal:bool=False):
|
46 |
-
self.input_info = input_info
|
47 |
-
self.use_multimodal = use_multimodal
|
48 |
-
|
49 |
-
def _get_system_message(self):
|
50 |
-
system_message = {"role": "system", "content": "You are a helpful assistant."}
|
51 |
-
return system_message
|
52 |
-
|
53 |
-
def _process_multimodal_message(self, prompt: str, image_list: list[str]):
|
54 |
-
multimodal_message = []
|
55 |
-
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
|
56 |
-
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
|
57 |
-
multimodal_message.append({"type": "text", "text": text_prompt_prefix})
|
58 |
-
for i, image in enumerate(image_list):
|
59 |
-
# TODO: text prompt for multiple images
|
60 |
-
# multimodal_message.append({"type": "text", "text": f"IMAGE {i+1}\n"})
|
61 |
-
multimodal_message.append({"type": "image_url", "image_url": {"url": image_to_base64_url(image), "detail": "low"}})
|
62 |
-
multimodal_message.append({"type": "text", "text": text_prompt_suffix})
|
63 |
-
return {"role": "user", "content": multimodal_message}
|
64 |
-
|
65 |
-
def _get_user_message(self):
|
66 |
-
user_prompt = "What is the capital of France?"
|
67 |
-
if self.use_multimodal:
|
68 |
-
image_list = self.input_info.get("image_list", [])
|
69 |
-
user_message = self._process_multimodal_message(user_prompt, image_list)
|
70 |
-
else:
|
71 |
-
user_message = {"role": "user", "content": user_prompt}
|
72 |
-
return user_message
|
73 |
-
|
74 |
-
def get_messages(self):
|
75 |
-
message = []
|
76 |
-
system_message = self._get_system_message()
|
77 |
-
user_message = self._get_user_message()
|
78 |
-
|
79 |
-
message.append(system_message)
|
80 |
-
# message.append({"role": "system", "content": ""})
|
81 |
-
message.append(user_message)
|
82 |
-
return message
|
83 |
-
|
84 |
-
|
85 |
-
class ProgressMessage(BaseMessage):
|
86 |
-
'''
|
87 |
-
Progress Judge Message
|
88 |
-
'''
|
89 |
-
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str, use_checklist:bool, use_in_progress:bool):
|
90 |
-
super().__init__(input_info, use_multimodal)
|
91 |
-
self.prompt_type = prompt_type
|
92 |
-
self.text_obs = text_obs
|
93 |
-
self.image_obs = image_obs
|
94 |
-
self.use_checklist = use_checklist
|
95 |
-
self.use_in_progress = use_in_progress
|
96 |
-
|
97 |
-
def _get_system_message(self):
|
98 |
-
if self.prompt_type == "likert_scale":
|
99 |
-
system_message = {"role": "system", "content": JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["system"]}
|
100 |
-
elif self.prompt_type == "three_class":
|
101 |
-
system_message = {"role": "system", "content": JUDGE_THREE_CLASS_PROMPT_TEMPLATE["system"]}
|
102 |
-
elif self.prompt_type == "with_checklist":
|
103 |
-
system_message = {"role": "system", "content": JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["system"]}
|
104 |
-
elif self.prompt_type == "ours":
|
105 |
-
system_message = {"role": "system", "content": JUDGE_OURS_PROMPT_TEMPLATE["system"]}
|
106 |
-
else:
|
107 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
108 |
-
return system_message
|
109 |
-
|
110 |
-
def _setup_input_information(self):
|
111 |
-
observation = "## Current State\n"
|
112 |
-
|
113 |
-
observation += CURRENT_URL
|
114 |
-
|
115 |
-
# text observation
|
116 |
-
if self.text_obs:
|
117 |
-
observation += TEXT_OBSERVATION
|
118 |
-
|
119 |
-
# image observation (som, coord, none)
|
120 |
-
if self.image_obs == "som":
|
121 |
-
observation += SOM_IMAGE_OBSERVATION
|
122 |
-
elif self.image_obs == "coord":
|
123 |
-
observation += COORD_IMAGE_OBSERVATION
|
124 |
-
|
125 |
-
|
126 |
-
if self.use_checklist:
|
127 |
-
input_information = USER_INSTRUCTION + TRAJECTORY + observation + CHECKLIST + AGENT_RESPONSE
|
128 |
-
else:
|
129 |
-
input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE
|
130 |
-
|
131 |
-
return input_information
|
132 |
-
|
133 |
-
def _setup_task_info(self):
|
134 |
-
if self.prompt_type == "likert_scale":
|
135 |
-
task_description = PROGRESS_LIKERT_SCALE["task_description"]
|
136 |
-
output_format = PROGRESS_LIKERT_SCALE["output_format"]
|
137 |
-
elif self.prompt_type == "three_class":
|
138 |
-
task_description = PROGRESS_THREE_CLASS["task_description"]
|
139 |
-
output_format = PROGRESS_THREE_CLASS["output_format"]
|
140 |
-
elif self.prompt_type == "with_checklist":
|
141 |
-
if self.use_in_progress:
|
142 |
-
task_description = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["task_description"]
|
143 |
-
output_format = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["output_format"]
|
144 |
-
else:
|
145 |
-
task_description = PROGRESS_WITH_CHECKLIST["task_description"]
|
146 |
-
output_format = PROGRESS_WITH_CHECKLIST["output_format"]
|
147 |
-
else:
|
148 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
149 |
-
return task_description, output_format
|
150 |
-
|
151 |
-
def _get_user_prompt_template(self):
|
152 |
-
if self.prompt_type == "likert_scale":
|
153 |
-
user_prompt = JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["user"]
|
154 |
-
elif self.prompt_type == "three_class":
|
155 |
-
user_prompt = JUDGE_THREE_CLASS_PROMPT_TEMPLATE["user"]
|
156 |
-
elif self.prompt_type == "with_checklist":
|
157 |
-
user_prompt = JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["user"]
|
158 |
-
else:
|
159 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
160 |
-
return user_prompt
|
161 |
-
|
162 |
-
def _get_user_message(self):
|
163 |
-
# setup input information (user_instruction, trajectory, current_state, agent_response, checklist)
|
164 |
-
input_information_template = self._setup_input_information()
|
165 |
-
input_information = input_information_template.format(**self.input_info)
|
166 |
-
|
167 |
-
if self.prompt_type == "ours":
|
168 |
-
if self.use_checklist:
|
169 |
-
user_prompt = JUDGE_OURS_PROMPT_TEMPLATE["user"].format(
|
170 |
-
input_information=input_information,
|
171 |
-
)
|
172 |
-
else:
|
173 |
-
user_prompt = JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE["user"].format(
|
174 |
-
input_information=input_information,
|
175 |
-
)
|
176 |
-
else:
|
177 |
-
task_description, output_format = self._setup_task_info()
|
178 |
-
# get user prompt template by prompt type
|
179 |
-
user_prompt_template = self._get_user_prompt_template()
|
180 |
-
user_prompt = user_prompt_template.format(
|
181 |
-
action_space=ACTION_SPACE_PROMPT,
|
182 |
-
task_description=task_description,
|
183 |
-
input_information=input_information,
|
184 |
-
output_format=output_format
|
185 |
-
)
|
186 |
-
|
187 |
-
# process multimodal message
|
188 |
-
if self.use_multimodal:
|
189 |
-
image_list = self.input_info.get("image_list", [])
|
190 |
-
user_message = self._process_multimodal_message(user_prompt, image_list)
|
191 |
-
else:
|
192 |
-
user_message = {"role": "user", "content": user_prompt}
|
193 |
-
|
194 |
-
return user_message
|
195 |
-
|
196 |
-
|
197 |
-
class GroundingMessage(BaseMessage):
|
198 |
-
'''
|
199 |
-
Grounding Judge Message
|
200 |
-
'''
|
201 |
-
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str):
|
202 |
-
super().__init__(input_info, use_multimodal)
|
203 |
-
self.prompt_type = prompt_type
|
204 |
-
self.text_obs = text_obs
|
205 |
-
self.image_obs = image_obs
|
206 |
-
|
207 |
-
def _get_system_message(self):
|
208 |
-
if self.prompt_type == "ours":
|
209 |
-
# TODO: implement ours
|
210 |
-
system_message = {"role": "system", "content": "You are a helpful assistant."}
|
211 |
-
elif self.prompt_type == "default":
|
212 |
-
system_message = {"role": "system", "content": JUDGE_GROUNDING_PROMPT_TEMPLATE["system"]}
|
213 |
-
else:
|
214 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
215 |
-
return system_message
|
216 |
-
|
217 |
-
def _setup_input_information(self):
|
218 |
-
observation = "## Current State\n"
|
219 |
-
|
220 |
-
observation += CURRENT_URL
|
221 |
-
|
222 |
-
# text observation
|
223 |
-
if self.text_obs:
|
224 |
-
observation += TEXT_OBSERVATION
|
225 |
-
|
226 |
-
# image observation (som, coord, none)
|
227 |
-
if self.image_obs == "som":
|
228 |
-
observation += SOM_IMAGE_OBSERVATION
|
229 |
-
elif self.image_obs == "coord":
|
230 |
-
observation += COORD_IMAGE_OBSERVATION
|
231 |
-
|
232 |
-
# input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE # with trajectory
|
233 |
-
input_information = USER_INSTRUCTION + observation + AGENT_RESPONSE # without trajectory
|
234 |
-
|
235 |
-
return input_information
|
236 |
-
|
237 |
-
def _get_user_message(self):
|
238 |
-
if self.prompt_type == "ours":
|
239 |
-
# TODO: implement ours
|
240 |
-
user_message = {"role": "user", "content": "TODO"}
|
241 |
-
elif self.prompt_type == "default":
|
242 |
-
action_space = ACTION_SPACE_PROMPT
|
243 |
-
task_description = GROUNDING["task_description"]
|
244 |
-
output_format = GROUNDING["output_format"]
|
245 |
-
input_information_template = self._setup_input_information()
|
246 |
-
input_information = input_information_template.format(**self.input_info)
|
247 |
-
|
248 |
-
user_prompt = JUDGE_GROUNDING_PROMPT_TEMPLATE["user"].format(
|
249 |
-
action_space=action_space,
|
250 |
-
task_description=task_description,
|
251 |
-
input_information=input_information,
|
252 |
-
output_format=output_format
|
253 |
-
)
|
254 |
-
|
255 |
-
# process multimodal message
|
256 |
-
if self.use_multimodal:
|
257 |
-
image_list = self.input_info.get("image_list", [])
|
258 |
-
user_message = self._process_multimodal_message(user_prompt, image_list)
|
259 |
-
else:
|
260 |
-
user_message = {"role": "user", "content": user_prompt}
|
261 |
-
else:
|
262 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
263 |
-
return user_message
|
264 |
-
|
265 |
-
|
266 |
-
class ChecklistMessage(BaseMessage):
|
267 |
-
'''
|
268 |
-
Checklist Message
|
269 |
-
'''
|
270 |
-
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str):
|
271 |
-
super().__init__(input_info, use_multimodal)
|
272 |
-
self.prompt_type = prompt_type
|
273 |
-
|
274 |
-
def _get_system_message(self):
|
275 |
-
if self.prompt_type == "ours":
|
276 |
-
# TODO: implement ours
|
277 |
-
system_message = {"role": "system", "content": CHECKLIST_OURS_SYSTEM_PROMPT}
|
278 |
-
elif self.prompt_type == "default":
|
279 |
-
system_message = {"role": "system", "content": CHECKLIST_SYSTEM_PROMPT}
|
280 |
-
else:
|
281 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
282 |
-
return system_message
|
283 |
-
|
284 |
-
def _get_user_message(self):
|
285 |
-
if self.prompt_type == "ours":
|
286 |
-
user_message = {"role": "user", "content": CHECKLIST_OURS_USER_PROMPT.format(**self.input_info)}
|
287 |
-
elif self.prompt_type == "default":
|
288 |
-
user_message = {"role": "user", "content": CHECKLIST_USER_PROMPT.format(**self.input_info)}
|
289 |
-
else:
|
290 |
-
raise ValueError(f"Invalid prompt type: {self.prompt_type}")
|
291 |
-
return user_message
|
292 |
-
|
293 |
-
|
294 |
-
def get_messages(input_info:dict, inference_mode:str, prompt_type:str, text_obs:str=None, image_obs:str=None, use_multimodal:bool=False, use_checklist:bool=False, use_in_progress:bool=False):
|
295 |
-
message_list = []
|
296 |
-
if inference_mode == "judge_grounding":
|
297 |
-
message = GroundingMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs)
|
298 |
-
elif inference_mode == "judge_progress":
|
299 |
-
message = ProgressMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs, use_checklist=use_checklist, use_in_progress=use_in_progress)
|
300 |
-
elif inference_mode == "checklist_generation":
|
301 |
-
message = ChecklistMessage(input_info, use_multimodal=False, prompt_type=prompt_type)
|
302 |
-
else:
|
303 |
-
raise ValueError(f"Invalid inference mode: {inference_mode}")
|
304 |
-
|
305 |
-
system_message, user_message = message.get_messages()
|
306 |
-
|
307 |
-
message_list.append(system_message)
|
308 |
-
message_list.append(user_message)
|
309 |
-
return message_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/eval_type.py
DELETED
@@ -1,107 +0,0 @@
|
|
1 |
-
# Task Description & Output Format
|
2 |
-
GROUNDING_TASK = """Your task is to evaluate whether the agent's ACTION is properly grounded in its THOUGHT, considering the current state of the webpage.
|
3 |
-
Use the user instruction, the current webpage state, and the agent's thought and action as evidence for your judgment. Your evaluation should assess whether the ACTION logically follows from the THOUGHT and is feasible and appropriate in the given environment.
|
4 |
-
Mark the action as 'Yes' only if it is clearly and fully grounded in the thought and current webpage state. If there is any inconsistency, ambiguity, irrelevance, or if the action is not supported by the current page state, mark it as 'No'."""
|
5 |
-
|
6 |
-
GROUNDING_FORMAT = """Please return your response in the following format:
|
7 |
-
REASON: [Your explanation for whether the action is properly grounded]
|
8 |
-
JUDGE: [Yes / No]"""
|
9 |
-
|
10 |
-
|
11 |
-
PROGRESS_LIKERT_SCALE_TASK = """Evaluate how helpful the given thought and action is for achieving the goal. Use the following scale:
|
12 |
-
**Scoring Criteria (1 to 5):**
|
13 |
-
- **5 (Very Helpful)**: The action directly and effectively moves toward fulfilling a key part of the goal.
|
14 |
-
- **4 (Helpful)**: The action contributes meaningfully to progress, though it may require follow-up actions.
|
15 |
-
- **3 (Somewhat Helpful)**: The action is partially relevant or a preparatory step, but doesn’t make immediate progress.
|
16 |
-
- **2 (Slightly Helpful)**: The action is weakly related to the goal or might only indirectly help.
|
17 |
-
- **1 (Not Helpful)**: The action is unrelated, redundant, or distracts from the goal."""
|
18 |
-
|
19 |
-
PROGRESS_LIKERT_SCALE_FORMAT = """Please return your response in the following format:
|
20 |
-
REASON: [Your explanation for the score]
|
21 |
-
SCORE: [1-5]"""
|
22 |
-
|
23 |
-
|
24 |
-
PROGRESS_THREE_CLASS_TASK = """Evaluate how helpful the given thought and action is for achieving the goal. Use the following scale:
|
25 |
-
**Scoring Criteria:**
|
26 |
-
- **1 (Helpful)**: The action clearly contributes to achieving the goal. It takes a necessary or productive step toward completing the task.
|
27 |
-
- **0 (Neutral)**: The action is neither helpful nor harmful. It may be a placeholder, irrelevant at the current step, or too ambiguous to evaluate.
|
28 |
-
- **-1 (Not Helpful)**: The action works against the goal, causes confusion, repeats a previous step unnecessarily, or leads the agent off track."""
|
29 |
-
|
30 |
-
PROGRESS_THREE_CLASS_FORMAT = """Please return your response in the following format:
|
31 |
-
REASON: [Your explanation for the score]
|
32 |
-
SCORE: [-1 / 0 / 1]"""
|
33 |
-
|
34 |
-
|
35 |
-
PROGRESS_WITH_CHECKLIST_TASK = """Your task is to evaluate how well the agent's THOUGHT and ACTION satisfy each item in the checklist.
|
36 |
-
Use the task instruction, trajectory (including previously completed steps from history), current webpage state, and the agent's current response as evidence for your evaluation.
|
37 |
-
For each checklist item:
|
38 |
-
- Mark it as 'Yes' if it is clearly and fully satisfied either in the current response or already completed in the history.
|
39 |
-
- Mark it as 'No' if there is ambiguity, insufficient evidence, or the step is incomplete or not yet started."""
|
40 |
-
|
41 |
-
PROGRESS_WITH_CHECKLIST_FORMAT = """Please return your response in the following format:
|
42 |
-
REASON: [Write a single, coherent paragraph explaining how well the agent's response satisfies the checklist overall. Use both the history and the agent's current thought/action as evidence. Mention specific strengths or missing elements that influence your decision.]
|
43 |
-
CHECKLIST EVALUATION:
|
44 |
-
Checklist X: [Yes / No]
|
45 |
-
"""
|
46 |
-
|
47 |
-
PROGRESS_WITH_CHECKLIST_IN_PROGRESS_TASK = """Your task is to evaluate how well the agent's THOUGHT and ACTION satisfy each item in the checklist.
|
48 |
-
Use the task instruction, trajectory (including previously completed steps from history), current webpage state, and the agent's current response as evidence for your evaluation. Clearly consider any items already successfully completed or currently in progress according to the provided trajectory.
|
49 |
-
For each checklist item:
|
50 |
-
- Mark it as 'Yes' if it is clearly and fully satisfied either in the current response or already completed in the history.
|
51 |
-
- Mark it as 'In Progress' if the agent has made partial but meaningful progress toward completing the item.
|
52 |
-
- Mark it as 'No' if there is ambiguity, insufficient evidence, or the step is incomplete or not yet started."""
|
53 |
-
|
54 |
-
PROGRESS_WITH_CHECKLIST_IN_PROGRESS_FORMAT = """Please return your response in the following format:
|
55 |
-
REASON: [Write a single, coherent paragraph explaining how well the agent's response satisfies the checklist overall. Use both the history and the agent's current thought/action as evidence. Mention specific strengths or missing elements that influence your decision.]
|
56 |
-
CHECKLIST EVALUATION:
|
57 |
-
Checklist X: [Yes / In Progress / No]
|
58 |
-
"""
|
59 |
-
|
60 |
-
|
61 |
-
GROUNDING_OURS_TASK = """
|
62 |
-
"""
|
63 |
-
|
64 |
-
GROUNDING_OURS_FORMAT = """
|
65 |
-
"""
|
66 |
-
|
67 |
-
PROGRESS_OURS_TASK = """
|
68 |
-
"""
|
69 |
-
|
70 |
-
PROGRESS_OURS_FORMAT = """
|
71 |
-
"""
|
72 |
-
|
73 |
-
## EVALUATION TYPE
|
74 |
-
GROUNDING = {
|
75 |
-
"task_description": GROUNDING_TASK,
|
76 |
-
"output_format": GROUNDING_FORMAT,
|
77 |
-
}
|
78 |
-
|
79 |
-
GROUNDING_OURS = {
|
80 |
-
"task_description": GROUNDING_OURS_TASK,
|
81 |
-
"output_format": GROUNDING_OURS_FORMAT,
|
82 |
-
}
|
83 |
-
|
84 |
-
PROGRESS_LIKERT_SCALE = {
|
85 |
-
"task_description": PROGRESS_LIKERT_SCALE_TASK,
|
86 |
-
"output_format": PROGRESS_LIKERT_SCALE_FORMAT,
|
87 |
-
}
|
88 |
-
|
89 |
-
PROGRESS_THREE_CLASS = {
|
90 |
-
"task_description": PROGRESS_THREE_CLASS_TASK,
|
91 |
-
"output_format": PROGRESS_THREE_CLASS_FORMAT,
|
92 |
-
}
|
93 |
-
|
94 |
-
PROGRESS_WITH_CHECKLIST = {
|
95 |
-
"task_description": PROGRESS_WITH_CHECKLIST_TASK,
|
96 |
-
"output_format": PROGRESS_WITH_CHECKLIST_FORMAT,
|
97 |
-
}
|
98 |
-
|
99 |
-
PROGRESS_WITH_CHECKLIST_IN_PROGRESS = {
|
100 |
-
"task_description": PROGRESS_WITH_CHECKLIST_IN_PROGRESS_TASK,
|
101 |
-
"output_format": PROGRESS_WITH_CHECKLIST_IN_PROGRESS_FORMAT,
|
102 |
-
}
|
103 |
-
|
104 |
-
PROGRESS_OURS = {
|
105 |
-
"task_description": PROGRESS_OURS_TASK,
|
106 |
-
"output_format": PROGRESS_OURS_FORMAT,
|
107 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/image_utils.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
import base64
|
2 |
-
import io
|
3 |
-
from PIL import Image
|
4 |
-
|
5 |
-
|
6 |
-
def image_to_base64_url(image: str | Image.Image):
|
7 |
-
if isinstance(image, str):
|
8 |
-
with open(image, "rb") as f:
|
9 |
-
image = f.read()
|
10 |
-
elif isinstance(image, Image.Image):
|
11 |
-
if image.mode in ("RGBA", "LA"):
|
12 |
-
image = image.convert("RGB")
|
13 |
-
with io.BytesIO() as buffer:
|
14 |
-
image.save(buffer, format="PNG")
|
15 |
-
image = buffer.getvalue()
|
16 |
-
else:
|
17 |
-
raise ValueError(f"Invalid image type: {type(image)}")
|
18 |
-
|
19 |
-
return "data:image/png;base64," + base64.b64encode(image).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/input_information.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
USER_INSTRUCTION = """## User Instruction
|
2 |
-
{intent}
|
3 |
-
"""
|
4 |
-
|
5 |
-
TRAJECTORY = """## Trajectory
|
6 |
-
{trajectory}"""
|
7 |
-
|
8 |
-
AGENT_RESPONSE = """## Agent's Response
|
9 |
-
THOUGHT: {thought}
|
10 |
-
ACTION: {action}
|
11 |
-
"""
|
12 |
-
|
13 |
-
CHECKLIST = """## Checklist
|
14 |
-
{checklist}
|
15 |
-
"""
|
16 |
-
|
17 |
-
|
18 |
-
# Observation
|
19 |
-
CURRENT_URL = """### Current URL
|
20 |
-
{current_url}
|
21 |
-
"""
|
22 |
-
|
23 |
-
TEXT_OBSERVATION = """### AXTREE
|
24 |
-
Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
|
25 |
-
{text_observation}
|
26 |
-
"""
|
27 |
-
|
28 |
-
SOM_IMAGE_OBSERVATION = """### SOM Image Screenshot
|
29 |
-
Here is a current image screenshot of the page, it is annotated with bounding boxes and corresponding bids:
|
30 |
-
<IMAGE_PLACEHOLDER>
|
31 |
-
"""
|
32 |
-
|
33 |
-
COORD_IMAGE_OBSERVATION = """### Raw Image Screenshot
|
34 |
-
Here is a screenshot of the page:
|
35 |
-
<IMAGE_PLACEHOLDER>
|
36 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/judge_prompt.py
DELETED
@@ -1,159 +0,0 @@
|
|
1 |
-
# SYSTEM PROMPT
|
2 |
-
DEFAULT_SYSTEM_PROMPT_FORMAT = "You are an expert evaluator of web agent. {role_description}"
|
3 |
-
|
4 |
-
PROGRESS_WITHOUT_CHECKLIST_ROLE = "Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage."
|
5 |
-
PROGRESS_WITH_CHECKLIST_ROLE = "Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage."
|
6 |
-
|
7 |
-
GROUNDING_ROLE = "Your task is to assess whether the ACTION taken by the agent is properly grounded, based on agent's THOUGHT and the current state of the webpage."
|
8 |
-
|
9 |
-
# USER PROMPT
|
10 |
-
DEFAULT_USER_PROMPT_FORMAT = """# Action space:
|
11 |
-
{action_space}
|
12 |
-
|
13 |
-
# Task Description
|
14 |
-
{task_description}
|
15 |
-
|
16 |
-
# Given Information
|
17 |
-
{input_information}
|
18 |
-
|
19 |
-
# Output Format
|
20 |
-
{output_format}
|
21 |
-
"""
|
22 |
-
|
23 |
-
|
24 |
-
JUDGE_OURS_WO_CHECKLIST_USER_PROMPT_FORMAT = """You are an expert evaluator of web agent. Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage.
|
25 |
-
|
26 |
-
# Task Description
|
27 |
-
Evaluate how well the agent’s THOUGHT and ACTION satisfy each item in the checklist using the task instruction, trajectory (including previously completed steps), current webpage state, and the agent’s latest response. Start by writing a concise paragraph summarizing the agent’s overall performance. Refer to the reasoning provided in the trajectory, and discuss whether the THOUGHT is appropriate and the ACTION moves the task forward.
|
28 |
-
|
29 |
-
# Given Information
|
30 |
-
{input_information}
|
31 |
-
"""
|
32 |
-
|
33 |
-
|
34 |
-
JUDGE_OURS_USER_PROMPT_FORMAT = """You are an expert evaluator of web agent. Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage.
|
35 |
-
|
36 |
-
# Task Description
|
37 |
-
Evaluate how well the agent’s THOUGHT and ACTION satisfy each item in the checklist using the task instruction, trajectory (including previously completed steps), current webpage state, and the agent’s latest response. Start by writing a concise paragraph summarizing the agent’s overall performance. Refer to the reasoning provided in the trajectory, and discuss whether the THOUGHT is appropriate and the ACTION moves the task forward.
|
38 |
-
Then, assess each checklist item individually using the following labels:
|
39 |
-
- Yes: The item is fully and clearly satisfied, either in the current response or previously completed.
|
40 |
-
- In Progress: There is meaningful partial progress toward completing the item.
|
41 |
-
- No: The item is not satisfied due to ambiguity, insufficient evidence, or lack of progress.
|
42 |
-
|
43 |
-
# Given Information
|
44 |
-
{input_information}
|
45 |
-
"""
|
46 |
-
|
47 |
-
|
48 |
-
JUDGE_OURS_BT_MODELING_USER_PROMPT_FORMAT = """You are an expert web agent that browses internet via GUI actions. Your task is to achieve the user's goal described in the user instruction.
|
49 |
-
|
50 |
-
# Task Description
|
51 |
-
Generate the most appropriate GUI action to achieve the user's goal. When choosing your action, consider the current webpage state and the checklist which can be interpreted as subtasks.
|
52 |
-
|
53 |
-
# Given Information
|
54 |
-
## User Instruction
|
55 |
-
{intent}
|
56 |
-
|
57 |
-
## Trajectory
|
58 |
-
{trajectory}
|
59 |
-
|
60 |
-
## Current State
|
61 |
-
### Current URL
|
62 |
-
{current_url}
|
63 |
-
|
64 |
-
### AXTREE
|
65 |
-
Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
|
66 |
-
{text_observation}
|
67 |
-
|
68 |
-
## Checklist
|
69 |
-
{checklist}
|
70 |
-
|
71 |
-
## Agent's Response
|
72 |
-
"""
|
73 |
-
|
74 |
-
JUDGE_OURS_BT_MODELING_BASE_PROMPT = """You are an expert web agent that browses internet via GUI actions. Your task is to achieve the user's goal described in the user instruction.
|
75 |
-
|
76 |
-
# Task Description
|
77 |
-
Generate the most appropriate GUI action to achieve the user's goal. When choosing your action, consider the current webpage state and the checklist which can be interpreted as subtasks.
|
78 |
-
|
79 |
-
# Given Information
|
80 |
-
## User Instruction
|
81 |
-
{intent}
|
82 |
-
|
83 |
-
## Trajectory
|
84 |
-
{trajectory}
|
85 |
-
|
86 |
-
## Current State
|
87 |
-
### Current URL
|
88 |
-
{current_url}
|
89 |
-
|
90 |
-
### AXTREE
|
91 |
-
Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
|
92 |
-
{text_observation}
|
93 |
-
"""
|
94 |
-
|
95 |
-
JUDGE_OURS_IMAGE_INPUT = """
|
96 |
-
### Image Screenshot
|
97 |
-
<IMAGE_PLACEHOLDER>
|
98 |
-
"""
|
99 |
-
|
100 |
-
JUDGE_OURS_WITH_CHECKLIST = """
|
101 |
-
## Checklist
|
102 |
-
{checklist}
|
103 |
-
"""
|
104 |
-
|
105 |
-
BT_MODELING_RESPONSE_FORMAT = """
|
106 |
-
THOUGHT: {thought}
|
107 |
-
ACTION: {action}
|
108 |
-
"""
|
109 |
-
|
110 |
-
## PROMPT TEMPLATE
|
111 |
-
JUDGE_GROUNDING_PROMPT_TEMPLATE = {
|
112 |
-
"system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=GROUNDING_ROLE),
|
113 |
-
"user": DEFAULT_USER_PROMPT_FORMAT,
|
114 |
-
}
|
115 |
-
|
116 |
-
JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE = {
|
117 |
-
"system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITHOUT_CHECKLIST_ROLE),
|
118 |
-
"user": DEFAULT_USER_PROMPT_FORMAT
|
119 |
-
}
|
120 |
-
|
121 |
-
JUDGE_THREE_CLASS_PROMPT_TEMPLATE = {
|
122 |
-
"system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITHOUT_CHECKLIST_ROLE),
|
123 |
-
"user": DEFAULT_USER_PROMPT_FORMAT
|
124 |
-
}
|
125 |
-
|
126 |
-
JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE = {
|
127 |
-
"system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITH_CHECKLIST_ROLE),
|
128 |
-
"user": DEFAULT_USER_PROMPT_FORMAT
|
129 |
-
}
|
130 |
-
|
131 |
-
JUDGE_OURS_PROMPT_TEMPLATE = {
|
132 |
-
"system": "",
|
133 |
-
"user": JUDGE_OURS_USER_PROMPT_FORMAT,
|
134 |
-
}
|
135 |
-
|
136 |
-
JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE = {
|
137 |
-
"system": "",
|
138 |
-
"user": JUDGE_OURS_WO_CHECKLIST_USER_PROMPT_FORMAT,
|
139 |
-
}
|
140 |
-
|
141 |
-
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE = {
|
142 |
-
"user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_WITH_CHECKLIST+"\n## Agent's Response\n",
|
143 |
-
"assistant": BT_MODELING_RESPONSE_FORMAT,
|
144 |
-
}
|
145 |
-
|
146 |
-
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE = {
|
147 |
-
"user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_IMAGE_INPUT+JUDGE_OURS_WITH_CHECKLIST+"\n## Agent's Response\n",
|
148 |
-
"assistant": BT_MODELING_RESPONSE_FORMAT,
|
149 |
-
}
|
150 |
-
|
151 |
-
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE = {
|
152 |
-
"user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+"\n## Agent's Response\n",
|
153 |
-
"assistant": BT_MODELING_RESPONSE_FORMAT,
|
154 |
-
}
|
155 |
-
|
156 |
-
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE = {
|
157 |
-
"user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_IMAGE_INPUT+"\n## Agent's Response\n",
|
158 |
-
"assistant": BT_MODELING_RESPONSE_FORMAT,
|
159 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/prompts/utils.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
from langchain.schema import HumanMessage, AIMessage, SystemMessage
|
2 |
-
|
3 |
-
def convert_dict_messages(dict_messages):
|
4 |
-
message_objs = []
|
5 |
-
for msg in dict_messages:
|
6 |
-
role = msg.get("role")
|
7 |
-
content = msg.get("content", "")
|
8 |
-
|
9 |
-
if role == "user":
|
10 |
-
message_objs.append(HumanMessage(content=content))
|
11 |
-
elif role == "assistant":
|
12 |
-
message_objs.append(AIMessage(content=content))
|
13 |
-
elif role == "system":
|
14 |
-
message_objs.append(SystemMessage(content=content))
|
15 |
-
else:
|
16 |
-
raise ValueError(f"Unknown role: {role}")
|
17 |
-
|
18 |
-
return message_objs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/reward_agent.py
DELETED
@@ -1,465 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
import time
|
3 |
-
import requests
|
4 |
-
import json
|
5 |
-
import math
|
6 |
-
from langsmith import Client
|
7 |
-
import numpy as np
|
8 |
-
from langchain_openai import ChatOpenAI
|
9 |
-
|
10 |
-
from .prompts import get_messages
|
11 |
-
from .prompts.judge_prompt import (
|
12 |
-
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
|
13 |
-
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
|
14 |
-
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
|
15 |
-
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
|
16 |
-
)
|
17 |
-
from .prompts.image_utils import image_to_base64_url
|
18 |
-
from .prompts.utils import convert_dict_messages
|
19 |
-
|
20 |
-
MAX_RETRY = 3
|
21 |
-
RETRY_SLEEP = 5
|
22 |
-
MODEL_COST_MAPPING = {
|
23 |
-
"gpt-4o-mini": {
|
24 |
-
"input_token_cost": 0.15,
|
25 |
-
"output_token_cost": 0.6
|
26 |
-
},
|
27 |
-
"gpt-4o": {
|
28 |
-
"input_token_cost": 2.5,
|
29 |
-
"output_token_cost": 10
|
30 |
-
},
|
31 |
-
}
|
32 |
-
|
33 |
-
|
34 |
-
class Agent(ABC):
|
35 |
-
@abstractmethod
|
36 |
-
def generate_response(self, inputs: dict) -> str:
|
37 |
-
pass
|
38 |
-
|
39 |
-
class BaseAgent(Agent):
|
40 |
-
def __init__(self, agent_config: dict):
|
41 |
-
self.agent_config = agent_config
|
42 |
-
self._setup()
|
43 |
-
|
44 |
-
def _init_llm_object(self, **extra_kwargs):
|
45 |
-
config = self.agent_config
|
46 |
-
config.update(extra_kwargs)
|
47 |
-
|
48 |
-
use_log_probs = config.get("use_log_probs", False)
|
49 |
-
if use_log_probs:
|
50 |
-
self.llm = ChatOpenAI(
|
51 |
-
model=config["model_name"],
|
52 |
-
base_url=config["base_url"],
|
53 |
-
api_key=config["api_key"],
|
54 |
-
temperature=config["temperature"],
|
55 |
-
timeout=300,
|
56 |
-
logprobs=True,
|
57 |
-
top_logprobs=10,
|
58 |
-
n=config.get('n', None)
|
59 |
-
)
|
60 |
-
else:
|
61 |
-
self.llm = ChatOpenAI(
|
62 |
-
model=config["model_name"],
|
63 |
-
base_url=config["base_url"],
|
64 |
-
api_key=config["api_key"],
|
65 |
-
temperature=config["temperature"],
|
66 |
-
timeout=300,
|
67 |
-
n=config.get('n', None)
|
68 |
-
)
|
69 |
-
|
70 |
-
def _setup(self):
|
71 |
-
self._init_llm_object()
|
72 |
-
|
73 |
-
self.temperature = self.agent_config["temperature"]
|
74 |
-
self.num_generate = self.agent_config["num_generate"]
|
75 |
-
self.use_checklist = self.agent_config.get("use_checklist", False)
|
76 |
-
self.use_multimodal = self.agent_config.get("use_multimodal", False)
|
77 |
-
|
78 |
-
# setup cost
|
79 |
-
model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
|
80 |
-
if model_cost and "api" in self.agent_config["base_url"]:
|
81 |
-
self.input_token_cost = model_cost["input_token_cost"]
|
82 |
-
self.output_token_cost = model_cost["output_token_cost"]
|
83 |
-
else:
|
84 |
-
self.input_token_cost = 0.0
|
85 |
-
self.output_token_cost = 0.0
|
86 |
-
|
87 |
-
def generate_with_retry(self, model_input, constraint_str_list: list = None):
|
88 |
-
total_input_tokens = 0
|
89 |
-
total_output_tokens = 0
|
90 |
-
if self.temperature == 0:
|
91 |
-
response = self.llm.invoke(model_input)
|
92 |
-
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
|
93 |
-
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
|
94 |
-
else:
|
95 |
-
for i in range(MAX_RETRY):
|
96 |
-
try:
|
97 |
-
response = self.llm.invoke(model_input)
|
98 |
-
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
|
99 |
-
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
|
100 |
-
if constraint_str_list:
|
101 |
-
pass_constraint_num = 0
|
102 |
-
for constraint_str in constraint_str_list:
|
103 |
-
if constraint_str in response.content:
|
104 |
-
pass_constraint_num += 1
|
105 |
-
if pass_constraint_num == len(constraint_str_list):
|
106 |
-
break
|
107 |
-
else:
|
108 |
-
print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
|
109 |
-
else:
|
110 |
-
break
|
111 |
-
except Exception as e:
|
112 |
-
print(f"Agent returned an Error: {e}")
|
113 |
-
response = None
|
114 |
-
time.sleep(RETRY_SLEEP)
|
115 |
-
|
116 |
-
cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
|
117 |
-
|
118 |
-
if response is None:
|
119 |
-
return "", cost
|
120 |
-
else:
|
121 |
-
return response.content, cost
|
122 |
-
|
123 |
-
def prepare_message(self, model_input: dict, prompt_type: str):
|
124 |
-
message = []
|
125 |
-
return message
|
126 |
-
|
127 |
-
def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
|
128 |
-
total_cost = 0
|
129 |
-
response_list = []
|
130 |
-
# prepare message
|
131 |
-
message = self.prepare_message(model_input, prompt_type)
|
132 |
-
|
133 |
-
# n sampling
|
134 |
-
for i in range(self.num_generate):
|
135 |
-
response, cost = self.generate_with_retry(message, constraint_str_list)
|
136 |
-
response_list.append(response)
|
137 |
-
total_cost += cost
|
138 |
-
|
139 |
-
return response_list, total_cost
|
140 |
-
|
141 |
-
|
142 |
-
class GroundingJudgeAgent(BaseAgent):
|
143 |
-
def __init__(self, agent_config: dict):
|
144 |
-
super().__init__(agent_config)
|
145 |
-
self._setup()
|
146 |
-
|
147 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
148 |
-
message = get_messages(
|
149 |
-
input_info=model_input,
|
150 |
-
inference_mode="judge_grounding",
|
151 |
-
prompt_type=prompt_type,
|
152 |
-
use_multimodal=self.use_multimodal,
|
153 |
-
text_obs=self.agent_config["text_obs_type"],
|
154 |
-
image_obs=self.agent_config["image_obs_type"]
|
155 |
-
)
|
156 |
-
return message
|
157 |
-
|
158 |
-
|
159 |
-
class ProgressJudgeAgent(BaseAgent):
|
160 |
-
def __init__(self, agent_config: dict):
|
161 |
-
super().__init__(agent_config)
|
162 |
-
self._setup()
|
163 |
-
|
164 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
165 |
-
if self.agent_config["input_type"]=="text_only":
|
166 |
-
use_multimodal = False
|
167 |
-
text_obs = self.agent_config["text_obs_type"]
|
168 |
-
image_obs = None
|
169 |
-
elif self.agent_config["input_type"]=="image_only":
|
170 |
-
use_multimodal = True
|
171 |
-
text_obs = None
|
172 |
-
image_obs = self.agent_config["image_obs_type"]
|
173 |
-
elif self.agent_config["input_type"]=="text_image":
|
174 |
-
use_multimodal = True
|
175 |
-
text_obs = self.agent_config["text_obs_type"]
|
176 |
-
image_obs = self.agent_config["image_obs_type"]
|
177 |
-
else:
|
178 |
-
raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
|
179 |
-
|
180 |
-
if self.agent_config["use_in_progress"]:
|
181 |
-
use_in_progress = True
|
182 |
-
else:
|
183 |
-
use_in_progress = False
|
184 |
-
|
185 |
-
message = get_messages(
|
186 |
-
input_info=model_input,
|
187 |
-
inference_mode="judge_progress",
|
188 |
-
prompt_type=prompt_type,
|
189 |
-
use_checklist=self.use_checklist,
|
190 |
-
use_multimodal=use_multimodal,
|
191 |
-
text_obs=text_obs,
|
192 |
-
image_obs=image_obs,
|
193 |
-
use_in_progress=use_in_progress
|
194 |
-
)
|
195 |
-
return message
|
196 |
-
|
197 |
-
def get_judge_probs(self, logprobs: list):
|
198 |
-
# target_judge = {
|
199 |
-
# "yes": [" Yes", "Yes", "ĠYes", "ĊYes"],
|
200 |
-
# "no": [" No", "No", "ĠNo", "ĊNo"],
|
201 |
-
# "in": [" In", "In", "ĠIn", "ĊIn"]
|
202 |
-
# }
|
203 |
-
target_judge = {
|
204 |
-
"yes": [
|
205 |
-
"ĠYes", "Yes", "ĊYes",
|
206 |
-
"Ġyes", "yes", "Ċyes",
|
207 |
-
"ĠYES", "YES", "ĊYES",
|
208 |
-
"ĠDone", "Done", "ĊDone",
|
209 |
-
"ĠCompleted", "Completed", "ĊCompleted",
|
210 |
-
"ĠCorrect", "Correct", "ĊCorrect"
|
211 |
-
],
|
212 |
-
"no": [
|
213 |
-
"ĠNo", "No", "ĊNo",
|
214 |
-
"ĠNO", "NO", "ĊNO",
|
215 |
-
"ĠNot", "Not", "ĊNot",
|
216 |
-
"ĠNone", "None", "ĊNone",
|
217 |
-
"ĠNope", "Nope", "ĊNope",
|
218 |
-
"ĠUn", "Un", "ĊUn",
|
219 |
-
"ĠWrong", "Wrong", "ĊWrong"
|
220 |
-
],
|
221 |
-
"in": [
|
222 |
-
"ĠIn", "In", "ĊIn",
|
223 |
-
"ĠPending", "Pending", "ĊPending",
|
224 |
-
"ĠPart", "Part", "ĊPart",
|
225 |
-
"ĠPartial", "Partial", "ĊPartial",
|
226 |
-
"ĠInProgress", "InProgress", "ĊInProgress"
|
227 |
-
]
|
228 |
-
}
|
229 |
-
response_str = ""
|
230 |
-
judge_probs_list = []
|
231 |
-
for i, log_prob in enumerate(logprobs):
|
232 |
-
# Start to find judge string
|
233 |
-
if "<answer>" in response_str:
|
234 |
-
find_judge_str = False
|
235 |
-
for judge_type in target_judge:
|
236 |
-
if log_prob["token"] in target_judge[judge_type]:
|
237 |
-
# print(log_prob)
|
238 |
-
find_judge_str = True
|
239 |
-
break
|
240 |
-
if find_judge_str:
|
241 |
-
token_judge_dict = {
|
242 |
-
"yes": None,
|
243 |
-
"no": None,
|
244 |
-
"in": None
|
245 |
-
}
|
246 |
-
for token_info in log_prob["top_logprobs"]:
|
247 |
-
for judge_type in target_judge:
|
248 |
-
for judge_str in target_judge[judge_type]:
|
249 |
-
if judge_str in token_info["token"] :
|
250 |
-
if token_judge_dict[judge_type] is None:
|
251 |
-
token_judge_dict[judge_type] = math.exp(token_info["logprob"])
|
252 |
-
else:
|
253 |
-
token_judge_dict[judge_type] += math.exp(token_info["logprob"])
|
254 |
-
|
255 |
-
token_judge_dict = {
|
256 |
-
"yes": math.log(token_judge_dict["yes"]) if token_judge_dict["yes"] is not None else -float('inf'),
|
257 |
-
"no": math.log(token_judge_dict["no"]) if token_judge_dict["no"] is not None else -float('inf'),
|
258 |
-
"in": math.log(token_judge_dict["in"]) if token_judge_dict["in"] is not None else -float('inf')
|
259 |
-
}
|
260 |
-
judge_probs_list.append(token_judge_dict)
|
261 |
-
|
262 |
-
if "</answer>" in response_str:
|
263 |
-
break
|
264 |
-
|
265 |
-
response_str += log_prob["token"]
|
266 |
-
|
267 |
-
if len(judge_probs_list) == 0:
|
268 |
-
return [{
|
269 |
-
"yes": 0.0,
|
270 |
-
"no": 0.0,
|
271 |
-
"in": 0.0
|
272 |
-
}]
|
273 |
-
else:
|
274 |
-
# convert with softmax
|
275 |
-
final_judge_probs_list = []
|
276 |
-
max_in_prob = -float('inf')
|
277 |
-
for idx, judge_probs in enumerate(judge_probs_list):
|
278 |
-
exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
|
279 |
-
sum_exp_logprobs = sum(exp_logprobs)
|
280 |
-
softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
|
281 |
-
if softmax_probs[2] > max_in_prob:
|
282 |
-
max_in_prob = softmax_probs[2]
|
283 |
-
final_judge_probs_list.append({
|
284 |
-
"yes": softmax_probs[0],
|
285 |
-
"no": softmax_probs[1],
|
286 |
-
"in": softmax_probs[2]
|
287 |
-
})
|
288 |
-
return final_judge_probs_list
|
289 |
-
|
290 |
-
def generate_probs(self, model_input: dict, prompt_type: str, n=1, temperature=None):
|
291 |
-
total_cost = 0
|
292 |
-
# prepare message
|
293 |
-
message = self.prepare_message(model_input, prompt_type)
|
294 |
-
messages = convert_dict_messages(message)
|
295 |
-
|
296 |
-
kwargs = {'n': n}
|
297 |
-
if temperature is not None:
|
298 |
-
kwargs['temperature'] = temperature
|
299 |
-
self._init_llm_object(**kwargs)
|
300 |
-
|
301 |
-
try:
|
302 |
-
response = self.llm.generate([messages]) # assume single batch
|
303 |
-
finally:
|
304 |
-
print('request url: ', self.agent_config['base_url'])
|
305 |
-
|
306 |
-
|
307 |
-
# parse responses
|
308 |
-
response_list = []
|
309 |
-
for generation in response.generations[0]: # assume singel batch
|
310 |
-
# parse logprobs
|
311 |
-
logprobs = generation.message.response_metadata["logprobs"]["content"]
|
312 |
-
response_list.append(
|
313 |
-
{
|
314 |
-
"response": generation.message.content,
|
315 |
-
"judge_probs": self.get_judge_probs(logprobs)
|
316 |
-
}
|
317 |
-
)
|
318 |
-
|
319 |
-
# calculate cost
|
320 |
-
total_input_tokens = response.llm_output["token_usage"]["prompt_tokens"]
|
321 |
-
total_output_tokens = response.llm_output["token_usage"]["completion_tokens"]
|
322 |
-
total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
|
323 |
-
|
324 |
-
return response_list, total_cost
|
325 |
-
|
326 |
-
|
327 |
-
class ChecklistGenerationAgent(BaseAgent):
|
328 |
-
def __init__(self, agent_config: dict):
|
329 |
-
super().__init__(agent_config)
|
330 |
-
self._setup()
|
331 |
-
|
332 |
-
def prepare_message(self, model_input: dict, prompt_type):
|
333 |
-
message = get_messages(
|
334 |
-
input_info=model_input,
|
335 |
-
inference_mode="checklist_generation",
|
336 |
-
prompt_type=prompt_type
|
337 |
-
)
|
338 |
-
return message
|
339 |
-
|
340 |
-
|
341 |
-
class ClassifierRewardAgent(Agent):
|
342 |
-
def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
|
343 |
-
self.url = url
|
344 |
-
self.use_checklist = use_checklist
|
345 |
-
self.use_multimodal = use_multimodal
|
346 |
-
|
347 |
-
def _process_multimodal_message(self, prompt: str, image_list: list[str]):
|
348 |
-
multimodal_message = []
|
349 |
-
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
|
350 |
-
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
|
351 |
-
multimodal_message = [
|
352 |
-
{"type": "text", "text": text_prompt_prefix},
|
353 |
-
# {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
|
354 |
-
{"type": "image", "image": image_to_base64_url(image_list[0])},
|
355 |
-
{"type": "text", "text": text_prompt_suffix}
|
356 |
-
]
|
357 |
-
return multimodal_message
|
358 |
-
|
359 |
-
def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
|
360 |
-
if self.use_multimodal:
|
361 |
-
tmp_user_prompt = user_prompt_template["user"].format(
|
362 |
-
**model_input
|
363 |
-
)
|
364 |
-
user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
|
365 |
-
else:
|
366 |
-
user_prompt = user_prompt_template["user"].format(
|
367 |
-
**model_input
|
368 |
-
)
|
369 |
-
assistant_prompt = user_prompt_template["assistant"].format(
|
370 |
-
**model_input
|
371 |
-
)
|
372 |
-
query = [
|
373 |
-
{"role": "user", "content": user_prompt},
|
374 |
-
{"role": "assistant", "content": assistant_prompt}
|
375 |
-
]
|
376 |
-
return query
|
377 |
-
|
378 |
-
def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
|
379 |
-
if self.use_checklist:
|
380 |
-
if self.use_multimodal:
|
381 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
|
382 |
-
else:
|
383 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
|
384 |
-
else:
|
385 |
-
if self.use_multimodal:
|
386 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
|
387 |
-
else:
|
388 |
-
user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
|
389 |
-
|
390 |
-
if self.use_multimodal:
|
391 |
-
if batch:
|
392 |
-
message = [self._make_query(user_prompt_template, input) for input in model_input]
|
393 |
-
else:
|
394 |
-
message = [self._make_query(user_prompt_template, model_input)]
|
395 |
-
else:
|
396 |
-
if batch:
|
397 |
-
message = {
|
398 |
-
"query": [self._make_query(user_prompt_template, input) for input in model_input],
|
399 |
-
"promptts": []
|
400 |
-
}
|
401 |
-
else:
|
402 |
-
message = {
|
403 |
-
"query": self._make_query(user_prompt_template, model_input),
|
404 |
-
"prompts": []
|
405 |
-
}
|
406 |
-
|
407 |
-
return message
|
408 |
-
|
409 |
-
def get_rm_scroe(self, message: dict | list):
|
410 |
-
headers = {"Content-Type": "application/json"}
|
411 |
-
|
412 |
-
try:
|
413 |
-
if self.use_multimodal:
|
414 |
-
response = requests.post(
|
415 |
-
self.url,
|
416 |
-
json={"messages": message},
|
417 |
-
timeout=600
|
418 |
-
)
|
419 |
-
else:
|
420 |
-
response = requests.post(
|
421 |
-
self.url,
|
422 |
-
headers=headers,
|
423 |
-
data=json.dumps(message),
|
424 |
-
timeout=300
|
425 |
-
)
|
426 |
-
response.raise_for_status()
|
427 |
-
|
428 |
-
response_json = response.json()
|
429 |
-
|
430 |
-
if "rewards" not in response_json:
|
431 |
-
print(f"Error: 'rewards' key not found in API response: {response_json}")
|
432 |
-
return []
|
433 |
-
|
434 |
-
if "get_reward" in self.url:
|
435 |
-
# use openrlhf
|
436 |
-
return response_json["rewards"]
|
437 |
-
elif "pooling" in self.url:
|
438 |
-
# use vllm server
|
439 |
-
return response_json["reward"]
|
440 |
-
else:
|
441 |
-
# error
|
442 |
-
raise ValueError(f"Invalid URL: {self.url}")
|
443 |
-
|
444 |
-
except requests.exceptions.Timeout:
|
445 |
-
print(f"Error: Request timed out to {self.url}")
|
446 |
-
return []
|
447 |
-
except requests.exceptions.RequestException as e:
|
448 |
-
print(f"Error during request to {self.url}: {e}")
|
449 |
-
return []
|
450 |
-
except json.JSONDecodeError:
|
451 |
-
print(f"Error: Failed to decode JSON response from {self.url}")
|
452 |
-
return []
|
453 |
-
except KeyError as e:
|
454 |
-
print(f"Error: Missing key {e} in response from {self.url}")
|
455 |
-
return []
|
456 |
-
|
457 |
-
|
458 |
-
def generate_response(self, model_input: dict | list[dict], batch: bool = False):
|
459 |
-
if batch:
|
460 |
-
message = self.prepare_message(model_input, batch=True)
|
461 |
-
else:
|
462 |
-
message = self.prepare_message(model_input)
|
463 |
-
rewards = self.get_rm_scroe(message)
|
464 |
-
|
465 |
-
return rewards, 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/mini_bench/utils.py
DELETED
@@ -1,269 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import base64
|
3 |
-
import io
|
4 |
-
import html
|
5 |
-
from PIL import Image
|
6 |
-
|
7 |
-
|
8 |
-
def image_to_base64_url(image: str | Image.Image):
|
9 |
-
if isinstance(image, str):
|
10 |
-
with open(image, "rb") as f:
|
11 |
-
image = f.read()
|
12 |
-
elif isinstance(image, Image.Image):
|
13 |
-
if image.mode in ("RGBA", "LA"):
|
14 |
-
image = image.convert("RGB")
|
15 |
-
with io.BytesIO() as buffer:
|
16 |
-
image.save(buffer, format="PNG")
|
17 |
-
image = buffer.getvalue()
|
18 |
-
else:
|
19 |
-
raise ValueError(f"Invalid image type: {type(image)}")
|
20 |
-
|
21 |
-
return "data:image/png;base64," + base64.b64encode(image).decode("utf-8")
|
22 |
-
|
23 |
-
|
24 |
-
def load_json(file_path: str) -> dict:
|
25 |
-
with open(file_path, "r") as f:
|
26 |
-
return json.load(f)
|
27 |
-
|
28 |
-
def save_json(data: dict, file_path: str):
|
29 |
-
with open(file_path, "w") as f:
|
30 |
-
json.dump(data, f, indent=4)
|
31 |
-
|
32 |
-
def str_to_bool(s: str) -> bool:
|
33 |
-
if s.lower() in ["true", "1", "yes", "y"]:
|
34 |
-
return True
|
35 |
-
elif s.lower() in ["false", "0", "no", "n"]:
|
36 |
-
return False
|
37 |
-
else:
|
38 |
-
raise ValueError(f"Invalid boolean string: {s}")
|
39 |
-
|
40 |
-
|
41 |
-
def create_html_report(json_path, html_path, checklist_generation=False):
|
42 |
-
"""
|
43 |
-
Reads the given JSON result file and generates a filterable HTML report.
|
44 |
-
|
45 |
-
Args:
|
46 |
-
json_path (str): Path to the input JSON file.
|
47 |
-
html_path (str): Path to the output HTML file.
|
48 |
-
"""
|
49 |
-
try:
|
50 |
-
with open(json_path, 'r', encoding='utf-8') as f:
|
51 |
-
data = json.load(f)
|
52 |
-
except FileNotFoundError:
|
53 |
-
print(f"Error: JSON file not found - {json_path}") # Error message in English
|
54 |
-
return
|
55 |
-
except json.JSONDecodeError:
|
56 |
-
print(f"Error: JSON file parsing error - {json_path}") # Error message in English
|
57 |
-
return
|
58 |
-
except Exception as e:
|
59 |
-
print(f"Unexpected error during data loading: {e}") # Error message in English
|
60 |
-
return
|
61 |
-
|
62 |
-
# Extract unique Task IDs and sort them
|
63 |
-
task_ids = sorted(list(set(item.get("task_id") for item in data if item.get("task_id") is not None)))
|
64 |
-
|
65 |
-
html_content = """
|
66 |
-
<!DOCTYPE html>
|
67 |
-
<html lang="en">
|
68 |
-
<head>
|
69 |
-
<meta charset="UTF-8">
|
70 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
71 |
-
<title>Benchmark Results Report</title>
|
72 |
-
<style>
|
73 |
-
body { font-family: sans-serif; line-height: 1.6; padding: 20px; }
|
74 |
-
.task-step { border: 1px solid #ccc; margin-bottom: 20px; padding: 15px; border-radius: 5px; background-color: #f9f9f9; }
|
75 |
-
.task-step h2 { margin-top: 0; color: #333; border-bottom: 1px solid #eee; padding-bottom: 5px;}
|
76 |
-
.task-step h3 { color: #555; margin-top: 15px; margin-bottom: 5px; }
|
77 |
-
.task-step h4 { color: #777; margin-top: 10px; margin-bottom: 5px; font-style: italic;}
|
78 |
-
pre { background-color: #eee; padding: 10px; border-radius: 3px; white-space: pre-wrap; word-wrap: break-word; font-size: 0.9em; margin-top: 5px; }
|
79 |
-
details { margin-top: 10px; border: 1px solid #ddd; border-radius: 3px; background-color: #fff; }
|
80 |
-
summary { cursor: pointer; padding: 8px; background-color: #f8f9fa; font-weight: bold; border-bottom: 1px solid #ddd; }
|
81 |
-
details[open] summary { border-bottom: 1px solid #ddd; }
|
82 |
-
details > pre { border: none; background-color: #fff; padding: 10px 8px; }
|
83 |
-
.response-item-toggle { margin-top: 10px; }
|
84 |
-
.chosen-section { border-left: 5px solid #4CAF50; padding-left: 10px; margin-top: 15px; }
|
85 |
-
.rejected-section { border-left: 5px solid #f44336; padding-left: 10px; margin-top: 15px; }
|
86 |
-
hr { border: 0; border-top: 1px solid #eee; margin: 15px 0; }
|
87 |
-
.thought-action { background-color: #f0f0f0; padding: 10px; border-radius: 3px; margin-bottom: 10px; border: 1px solid #e0e0e0;}
|
88 |
-
.thought-action h4 { margin-top: 0; color: #666; }
|
89 |
-
.task-container { display: none; }
|
90 |
-
.filter-controls { margin-bottom: 20px; padding: 10px; background-color: #e9ecef; border-radius: 5px; }
|
91 |
-
.filter-controls label { margin-right: 10px; font-weight: bold; }
|
92 |
-
.filter-controls select { padding: 5px; border-radius: 3px; border: 1px solid #ced4da; }
|
93 |
-
</style>
|
94 |
-
</head>
|
95 |
-
<body>
|
96 |
-
<h1>Benchmark Results Report</h1>
|
97 |
-
|
98 |
-
<!-- Task ID Filter Dropdown -->
|
99 |
-
<div class="filter-controls">
|
100 |
-
<label for="taskSelector">Select Task ID:</label>
|
101 |
-
<select id="taskSelector">
|
102 |
-
<option value="">-- Show All --</option>
|
103 |
-
"""
|
104 |
-
# Add dropdown options
|
105 |
-
for tid in task_ids:
|
106 |
-
html_content += f' <option value="{html.escape(str(tid))}">{html.escape(str(tid))}</option>\n'
|
107 |
-
|
108 |
-
html_content += """
|
109 |
-
</select>
|
110 |
-
</div>
|
111 |
-
|
112 |
-
<!-- Results Display Area -->
|
113 |
-
<div id="resultsArea">
|
114 |
-
"""
|
115 |
-
|
116 |
-
# Process each Task/Step data
|
117 |
-
for i, step_data in enumerate(data):
|
118 |
-
task_id = step_data.get("task_id", "N/A")
|
119 |
-
step_id = step_data.get("step_id", "N/A")
|
120 |
-
intent = step_data.get("intent", "N/A")
|
121 |
-
start_url = step_data.get("start_url", "N/A")
|
122 |
-
gt_checklist = step_data.get("gt_checklist", "N/A")
|
123 |
-
generated_checklist = step_data.get("generated_checklist", None)
|
124 |
-
trajectory = step_data.get("trajectory", "N/A")
|
125 |
-
text_observation = step_data.get("text_observation", "N/A")
|
126 |
-
source_name = step_data.get("source_name", "")
|
127 |
-
|
128 |
-
# Wrap each Task/Step in a container with a unique ID (hidden initially)
|
129 |
-
html_content += f"""
|
130 |
-
<div class="task-container" data-task-id="{html.escape(str(task_id))}">
|
131 |
-
<div class="task-step">
|
132 |
-
<h2>Task ID: {html.escape(str(task_id))} | Step ID: {html.escape(str(step_id))} {f'({html.escape(source_name)})' if source_name else ''}</h2>
|
133 |
-
<h3>Intent:</h3>
|
134 |
-
<p>{html.escape(intent)}</p>
|
135 |
-
<p><strong>Start URL:</strong> <a href="{html.escape(start_url)}" target="_blank">{html.escape(start_url)}</a></p>
|
136 |
-
|
137 |
-
<h3>Ground Truth Checklist:</h3>
|
138 |
-
<pre>{html.escape(gt_checklist)}</pre>
|
139 |
-
"""
|
140 |
-
if checklist_generation and generated_checklist is not None:
|
141 |
-
html_content += f"""
|
142 |
-
<details>
|
143 |
-
<summary>Generated Checklist (Click to expand/collapse)</summary>
|
144 |
-
<pre>{html.escape(str(generated_checklist))}</pre>
|
145 |
-
</details>
|
146 |
-
"""
|
147 |
-
|
148 |
-
html_content += f"""
|
149 |
-
<details>
|
150 |
-
<summary>Trajectory (Click to expand/collapse)</summary>
|
151 |
-
<pre>{html.escape(trajectory)}</pre>
|
152 |
-
</details>
|
153 |
-
|
154 |
-
<details>
|
155 |
-
<summary>Text Observation (Click to expand/collapse)</summary>
|
156 |
-
<pre>{html.escape(text_observation)}</pre>
|
157 |
-
</details>
|
158 |
-
<hr>
|
159 |
-
"""
|
160 |
-
|
161 |
-
# Chosen Responses
|
162 |
-
if 'chosen' in step_data and step_data['chosen']:
|
163 |
-
html_content += '<div class="chosen-section"><h3>Chosen Responses:</h3>'
|
164 |
-
for choice_block in step_data['chosen']:
|
165 |
-
thought = choice_block.get('thought', 'N/A')
|
166 |
-
action = choice_block.get('action', 'N/A')
|
167 |
-
responses = choice_block.get('response', [])
|
168 |
-
scores = choice_block.get('score', [])
|
169 |
-
|
170 |
-
# Add Thought and Action information
|
171 |
-
html_content += f"""
|
172 |
-
<div class="thought-action">
|
173 |
-
<h4>Thought:</h4>
|
174 |
-
<pre>{html.escape(thought)}</pre>
|
175 |
-
<h4>Action:</h4>
|
176 |
-
<pre>{html.escape(action)}</pre>
|
177 |
-
</div>"""
|
178 |
-
|
179 |
-
# Loop through responses and create toggles
|
180 |
-
for idx, (response, score) in enumerate(zip(responses, scores)):
|
181 |
-
html_content += f"""
|
182 |
-
<details class="response-item-toggle">
|
183 |
-
<summary>Judge Response {idx + 1}: {html.escape(str(score))}</summary>
|
184 |
-
<pre>{html.escape(str(response))}</pre>
|
185 |
-
</details>"""
|
186 |
-
html_content += '</div>' # End chosen-section
|
187 |
-
|
188 |
-
# Rejected Responses
|
189 |
-
if 'rejected' in step_data and step_data['rejected']:
|
190 |
-
html_content += '<div class="rejected-section"><h3>Rejected Responses:</h3>'
|
191 |
-
for rejection_block in step_data['rejected']:
|
192 |
-
thought = rejection_block.get('thought', 'N/A')
|
193 |
-
action = rejection_block.get('action', 'N/A')
|
194 |
-
responses = rejection_block.get('response', [])
|
195 |
-
scores = rejection_block.get('score', [])
|
196 |
-
|
197 |
-
# Add Thought and Action information
|
198 |
-
html_content += f"""
|
199 |
-
<div class="thought-action">
|
200 |
-
<h4>Thought:</h4>
|
201 |
-
<pre>{html.escape(thought)}</pre>
|
202 |
-
<h4>Action:</h4>
|
203 |
-
<pre>{html.escape(action)}</pre>
|
204 |
-
</div>"""
|
205 |
-
|
206 |
-
# Loop through responses and create toggles
|
207 |
-
for idx, (response, score) in enumerate(zip(responses, scores)):
|
208 |
-
html_content += f"""
|
209 |
-
<details class="response-item-toggle">
|
210 |
-
<summary>Judge Response {idx + 1}: {html.escape(str(score))}</summary>
|
211 |
-
<pre>{html.escape(str(response))}</pre>
|
212 |
-
</details>"""
|
213 |
-
html_content += '</div>' # End rejected-section
|
214 |
-
|
215 |
-
html_content += """
|
216 |
-
</div> <!-- End task-step -->
|
217 |
-
</div> <!-- End task-container -->
|
218 |
-
"""
|
219 |
-
|
220 |
-
# Finalize HTML and add JavaScript
|
221 |
-
html_content += """
|
222 |
-
</div> <!-- End resultsArea -->
|
223 |
-
|
224 |
-
<script>
|
225 |
-
document.addEventListener('DOMContentLoaded', function() {
|
226 |
-
const taskSelector = document.getElementById('taskSelector');
|
227 |
-
const taskContainers = document.querySelectorAll('.task-container');
|
228 |
-
|
229 |
-
function filterTasks() {
|
230 |
-
const selectedTaskId = taskSelector.value;
|
231 |
-
|
232 |
-
taskContainers.forEach(container => {
|
233 |
-
const containerTaskId = container.getAttribute('data-task-id');
|
234 |
-
// Show if no Task ID is selected (Show All) or if the container's Task ID matches
|
235 |
-
if (selectedTaskId === "" || containerTaskId === selectedTaskId) {
|
236 |
-
container.style.display = 'block';
|
237 |
-
} else {
|
238 |
-
// Otherwise, hide it
|
239 |
-
container.style.display = 'none';
|
240 |
-
}
|
241 |
-
});
|
242 |
-
}
|
243 |
-
|
244 |
-
// Run filter function on dropdown change
|
245 |
-
taskSelector.addEventListener('change', filterTasks);
|
246 |
-
|
247 |
-
// Run initial filtering on page load (default: Show All)
|
248 |
-
filterTasks();
|
249 |
-
});
|
250 |
-
</script>
|
251 |
-
|
252 |
-
</body>
|
253 |
-
</html>
|
254 |
-
"""
|
255 |
-
|
256 |
-
# Save the HTML file
|
257 |
-
try:
|
258 |
-
with open(html_path, 'w', encoding='utf-8') as f:
|
259 |
-
f.write(html_content)
|
260 |
-
print(f"Completed: HTML report created at {html_path}")
|
261 |
-
except IOError:
|
262 |
-
print(f"Error: Failed to write HTML file - {html_path}")
|
263 |
-
except Exception as e:
|
264 |
-
print(f"Unexpected error during HTML file saving: {e}")
|
265 |
-
|
266 |
-
# --- Example Usage ---
|
267 |
-
# input_json_file = 'path/to/your/results.json'
|
268 |
-
# output_html_file = 'trajectory_report.html'
|
269 |
-
# create_html_report(input_json_file, output_html_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/reward.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
from typing import List, Dict, Any, Optional, Union
|
3 |
-
import numpy as np
|
4 |
-
from .mini_bench.reward_agent import ProgressJudgeAgent
|
5 |
-
from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash
|
6 |
-
import json
|
7 |
-
import os
|
8 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
-
|
10 |
-
def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5):
|
11 |
-
"""하나의 unit을 처리해 (idx, reward, thought)를 돌려준다."""
|
12 |
-
agent = ProgressJudgeAgent(configs)
|
13 |
-
current_temperature = configs["temperature"]
|
14 |
-
|
15 |
-
rewards = []
|
16 |
-
n_err = 0
|
17 |
-
retry_count = 0
|
18 |
-
judge_hash_count_thought = {}
|
19 |
-
|
20 |
-
while len(rewards) < n_samples and retry_count < max_retries:
|
21 |
-
# 외부 API 호출
|
22 |
-
responses, _ = agent.generate_probs(
|
23 |
-
unit, "ours", n=n_samples - len(rewards), temperature=current_temperature
|
24 |
-
)
|
25 |
-
|
26 |
-
for response in responses:
|
27 |
-
content = response["response"]
|
28 |
-
thought = content # 전체를 로그로 저장
|
29 |
-
reward = REWARD_PROCESSORS[reward_processor](response)
|
30 |
-
rewards.append(reward)
|
31 |
-
|
32 |
-
if np.isnan(reward) or reward is None:
|
33 |
-
n_err += 1
|
34 |
-
else:
|
35 |
-
judge_hash = extract_judge_hash(response)
|
36 |
-
judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought)
|
37 |
-
|
38 |
-
if n_err > 0:
|
39 |
-
# 실패 시 온도를 높여 재시도
|
40 |
-
if n_samples == 1:
|
41 |
-
current_temperature = 0.5
|
42 |
-
retry_count += 1
|
43 |
-
|
44 |
-
reward = np.nanmean(rewards)
|
45 |
-
if np.isnan(reward):
|
46 |
-
print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0")
|
47 |
-
reward = 0.0
|
48 |
-
print(judge_hash_count_thought)
|
49 |
-
thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1]
|
50 |
-
|
51 |
-
return idx, reward, thought
|
52 |
-
|
53 |
-
|
54 |
-
def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8):
|
55 |
-
"""원본 get_ar_reward를 스레드 버전으로 교체."""
|
56 |
-
n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor]
|
57 |
-
|
58 |
-
temperature = 0.5 if n_samples > 1 else 0.0
|
59 |
-
|
60 |
-
configs = {
|
61 |
-
"model_name": model_name,
|
62 |
-
"base_url": base_url,
|
63 |
-
"api_key": "empty",
|
64 |
-
"temperature": temperature,
|
65 |
-
"num_generate": 1,
|
66 |
-
"use_checklist": True,
|
67 |
-
"input_type": "text_only",
|
68 |
-
"text_obs_type": "axtree",
|
69 |
-
"image_obs_type": "som",
|
70 |
-
"use_in_progress": True,
|
71 |
-
"use_multimodal": False,
|
72 |
-
"use_log_probs": True,
|
73 |
-
}
|
74 |
-
|
75 |
-
t_start = time.time()
|
76 |
-
results = [None] * len(dataset)
|
77 |
-
|
78 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
79 |
-
futures = [
|
80 |
-
executor.submit(
|
81 |
-
_process_unit, idx, unit, configs, n_samples, reward_processor
|
82 |
-
)
|
83 |
-
for idx, unit in enumerate(dataset)
|
84 |
-
]
|
85 |
-
|
86 |
-
for fut in as_completed(futures):
|
87 |
-
idx, reward, thought = fut.result()
|
88 |
-
results[idx] = (reward, thought)
|
89 |
-
|
90 |
-
# 순서 보존된 리스트로 분리
|
91 |
-
final_rewards = [float(r) for r, _ in results]
|
92 |
-
thoughts = [t for _, t in results]
|
93 |
-
|
94 |
-
print(f"Time taken (threaded): {time.time() - t_start:.2f} s")
|
95 |
-
return final_rewards, thoughts
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/reward_postprocessor.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import re
|
3 |
-
|
4 |
-
|
5 |
-
def extract_judge_hash(response):
|
6 |
-
"""
|
7 |
-
checklist 별로 yes, in, no를 판단한 정보를 hash 형태로 변환하여 반환
|
8 |
-
"""
|
9 |
-
content = response['response']
|
10 |
-
|
11 |
-
try:
|
12 |
-
judge_content = content.lower().replace(' ', '').split('<answer>')[1].split('</answer>')[0]
|
13 |
-
except:
|
14 |
-
import traceback
|
15 |
-
traceback.print_exc()
|
16 |
-
return None
|
17 |
-
pattern = r":yes|:inprogress|:no"
|
18 |
-
matches = re.findall(pattern, judge_content)
|
19 |
-
matches = [{':yes': 'y', ':inprogress': 'i', ':no': 'n'}[match] for match in matches]
|
20 |
-
return ''.join(matches)
|
21 |
-
|
22 |
-
def average_logits(response):
|
23 |
-
"""
|
24 |
-
yes, in, no를 logits 레벨에서 계산.
|
25 |
-
"""
|
26 |
-
judge_probs = response['judge_probs']
|
27 |
-
|
28 |
-
yes_ = np.mean([r['yes'] for r in judge_probs])
|
29 |
-
in_ = np.mean([r['in'] for r in judge_probs])
|
30 |
-
|
31 |
-
reward = yes_ + 0.5 * in_
|
32 |
-
return reward
|
33 |
-
|
34 |
-
|
35 |
-
REWARD_PROCESSORS = {
|
36 |
-
'avg_logits': average_logits
|
37 |
-
}
|
38 |
-
|
39 |
-
REWARD_PROCESSOR_N_SAMPLES = {
|
40 |
-
'avg_logits': 5
|
41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|