iruno commited on
Commit
cb35226
·
verified ·
1 Parent(s): e5605d9

Delete process_run.py

Browse files
Files changed (1) hide show
  1. process_run.py +0 -301
process_run.py DELETED
@@ -1,301 +0,0 @@
1
- from pathlib import Path
2
- import multiprocessing
3
- import logging
4
- from PIL import Image
5
- import io
6
- import base64
7
- import numpy as np
8
- import gymnasium as gym
9
- import os
10
-
11
- from agent.checklist import generate_checklist
12
- from agent.reward import get_ar_reward
13
-
14
- from browser_agent import BrowserAgent
15
-
16
-
17
- logger = logging.getLogger(__name__)
18
- logger.setLevel('INFO')
19
-
20
- templates_dir = Path(__file__).parent / "templates"
21
- CSS_RM_CARDS: str = (templates_dir / "rm_cards.css").read_text()
22
- CSS_TRAJECTORY: str = (templates_dir / "trajectory.css").read_text()
23
- CARD_HTML_TEMPLATE: str = (templates_dir / "card.html").read_text()
24
-
25
- RM_BASE_URL = os.environ['RM_BASE_URL']
26
- RM_MODEL_NAME = os.environ['RM_MODEL_NAME']
27
-
28
- def return_state(state, screenshot=None):
29
- return state, None, None, screenshot, None
30
-
31
- def run_agent(instruction: str, model_name: str = "gpt-4o", start_url: str = "about:blank",
32
- use_html: bool = False, use_axtree: bool = True, use_screenshot: bool = False, max_steps: int = 20):
33
- logger.info(f"Starting agent with instruction: {instruction}")
34
- logger.info(f"Configuration: model={model_name}, start_url={start_url}")
35
-
36
- trajectory = []
37
- trajectory_str = ''
38
- agent = BrowserAgent(
39
- model_name=model_name,
40
- use_html=use_html,
41
- use_axtree=use_axtree,
42
- use_screenshot=use_screenshot
43
- )
44
-
45
- # Initialize BrowserGym environment
46
- logger.info("Initializing BrowserGym environment")
47
- yield return_state("## Initializing BrowserGym environment...", None)
48
- env = gym.make(
49
- "browsergym/openended",
50
- task_kwargs={
51
- "start_url": start_url,
52
- "goal": instruction,
53
- },
54
- wait_for_user_message=True
55
- )
56
- obs, info = env.reset()
57
- logger.info("Environment initialized")
58
-
59
- # Send user instruction to the environment
60
- logger.info("Sending user instruction to environment")
61
- obs, reward, terminated, truncated, info = env.step({
62
- "type": "send_msg_to_user",
63
- "message": instruction
64
- })
65
- processed_obs = agent.obs_preprocessor(obs)
66
- logger.info(f"Obs: {processed_obs.keys()}")
67
- logger.info(f"axtree_txt: {processed_obs['axtree_txt']}")
68
-
69
- yield return_state("## Generating checklist...", obs['som_screenshot'])
70
- checklist = generate_checklist(intent=instruction, start_url=start_url, text_observation=processed_obs['axtree_txt'])
71
-
72
- # yield initial state
73
- current_screenshot = obs['som_screenshot'].copy()
74
- yield "## Rollout actions from policy...", checklist, [], current_screenshot, trajectory.copy()
75
-
76
- try:
77
- step_count = 0
78
- while step_count < max_steps:
79
- logger.info(f"Step {step_count}: Getting next action")
80
- # Get next action from agent
81
- candidates, _ = agent.get_action(processed_obs)
82
-
83
- yield return_state(f"## Rewarding actions...", current_screenshot)
84
-
85
- total_rewards, total_thoughts = get_ar_reward(
86
- dataset=[
87
- {
88
- 'text_observation': processed_obs['axtree_txt'],
89
- 'intent': instruction,
90
- 'trajectory': trajectory_str,
91
- 'current_url': processed_obs['open_pages_urls'][processed_obs['active_page_index'][0]],
92
- 'checklist': checklist,
93
- 'thought': cand['thought'],
94
- 'action': cand['action'],
95
- } for cand in candidates
96
- ],
97
- base_url=RM_BASE_URL,
98
- model_name=RM_MODEL_NAME,
99
- )
100
-
101
- # process rewards
102
- diff_reward = abs(max(total_rewards) - total_rewards[0]) # reward difference between actions with the highest reward and the most frequent.
103
- if diff_reward <= 0.01:
104
- logger.info(f"diff_reward: {diff_reward} -> most frequent action")
105
- max_index = 0 # most frequent action
106
- else:
107
- logger.info(f"diff_reward: {diff_reward} -> highest reward")
108
- max_index = total_rewards.index(max(total_rewards)) # highest reward
109
-
110
- # sort by reward
111
- sorted_indices = sorted(list(enumerate(total_rewards)), key=lambda x: (-1 if x[0] == max_index else 0, -x[1]))
112
- new_order = [idx for idx, _ in sorted_indices]
113
- candidates = [candidates[idx] for idx in new_order]
114
- total_rewards = [total_rewards[idx] for idx in new_order]
115
- total_thoughts = [total_thoughts[idx] for idx in new_order]
116
-
117
- best_cand = candidates[0]
118
-
119
- agent.action_history.append(best_cand['response'])
120
-
121
- action = best_cand['action']
122
-
123
- # processing action
124
- step_info = {
125
- 'thought': best_cand['thought'],
126
- 'action': action
127
- }
128
- current_cards = [{'thought': cand['thought'], 'action': cand['action'], 'feedback': feedback, 'reward': round(reward, 2)} for idx, (cand, reward, feedback) in enumerate(zip(candidates, total_rewards, total_thoughts))]
129
-
130
- trajectory_str += f'THOUGHT {step_count+1}: {step_info["thought"]}\nACTION {step_count+1}: {step_info["action"]}\n\n'
131
-
132
- # Execute action
133
- logger.info(f"Step {step_count}: Executing action: {action}")
134
- yield f"## Executing action: {action}", checklist, current_cards, current_screenshot, trajectory.copy()
135
- if action.startswith('send_msg_to_user'):
136
- terminated = True
137
- truncated = False
138
- else:
139
- obs, reward, terminated, truncated, info = env.step(action)
140
- trajectory.append((processed_obs['som_screenshot'], [{'action': cand['action'], 'reward': round(reward, 2)} for cand, reward in zip(candidates, total_rewards)]))
141
- processed_obs = agent.obs_preprocessor(obs)
142
- current_screenshot = processed_obs['som_screenshot'].copy()
143
-
144
- while '\n\n' in step_info['thought']:
145
- step_info['thought'] = step_info['thought'].replace('\n\n', '\n')
146
-
147
- # trajectory에 numpy array 직접 저장
148
- logger.info(f"Step {step_count}: Saved screenshot and updated trajectory")
149
- step_count += 1
150
-
151
- # yield by each step
152
- yield "## Rollout actions from policy...", checklist, current_cards, current_screenshot, trajectory.copy()
153
-
154
- if terminated or truncated:
155
- logger.info(f"Episode ended: terminated={terminated}, truncated={truncated}")
156
- yield return_state("## Episode ended", current_screenshot)
157
- break
158
-
159
- finally:
160
- logger.info("Finished")
161
-
162
-
163
- def run_agent_worker(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue):
164
- """Worker function that runs the agent in a separate process and puts results in a queue."""
165
- try:
166
- for result in run_agent(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps):
167
- return_queue.put(result)
168
- except Exception as e:
169
- logger.error(f"Error in agent worker process: {e}")
170
- return_queue.put(("Error occurred in agent process", [], None, []))
171
- import traceback
172
- traceback.print_exc()
173
- finally:
174
- # Signal that the process is done
175
- return_queue.put(None)
176
-
177
- def run_agent_wrapper(instruction, model_name="gpt-4o", start_url="about:blank",
178
- use_html=False, use_axtree=True, use_screenshot=False, max_steps=20):
179
- """Wrapper function that runs the agent in a separate process and yields results."""
180
- return_queue = multiprocessing.Queue()
181
-
182
- # Start the agent in a separate process
183
- p = multiprocessing.Process(
184
- target=run_agent_worker,
185
- args=(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue)
186
- )
187
- p.daemon = True # Ensure process terminates when parent terminates
188
- p.start()
189
-
190
- # Get results from the queue and yield them
191
- while True:
192
- result = return_queue.get()
193
- if result is None: # End signal
194
- break
195
- yield result
196
-
197
- # Clean up
198
- if p.is_alive():
199
- p.terminate()
200
- p.join()
201
-
202
- def process_run(instruction, model_name, start_url):
203
- # Use the wrapper function instead of directly calling run_agent
204
- trajectory_generator = run_agent_wrapper(
205
- instruction,
206
- model_name,
207
- start_url,
208
- use_html=False,
209
- use_axtree=True,
210
- use_screenshot=False
211
- )
212
-
213
- all_trajectory = []
214
- last_checklist_view, last_trajectory_html = None, None
215
-
216
- for state, checklist_view, rm_cards, screenshot, trajectory in trajectory_generator:
217
- if checklist_view is None:
218
- yield state, screenshot, last_checklist_view, None, last_trajectory_html
219
- continue
220
- # Create HTML for reward model cards
221
- rm_cards_html = f"""
222
- <style>
223
- {CSS_RM_CARDS}
224
- </style>
225
- <div class="rm-cards-container">
226
- """
227
-
228
- for idx, card in enumerate(rm_cards):
229
- rm_cards_html += CARD_HTML_TEMPLATE.format(
230
- additional_class='top-candidate' if idx == 0 else '',
231
- k=idx+1,
232
- suffix='(best)' if idx == 0 else '',
233
- thought=card['thought'],
234
- action=card['action'],
235
- reward=card['reward'],
236
- feedback=card['feedback']
237
- )
238
-
239
- rm_cards_html += "</div>"
240
- all_trajectory = trajectory
241
-
242
- # Create HTML for trajectory display
243
- trajectory_html = f"""
244
- <style>
245
- {CSS_TRAJECTORY}
246
- </style>
247
- <div class="trajectory-container">
248
- """
249
-
250
- for idx, (after_img, cands) in enumerate(all_trajectory):
251
- # Convert image to base64 if needed
252
- img = all_trajectory[idx][0]
253
- if isinstance(img, np.ndarray):
254
- img = Image.fromarray(img)
255
- if isinstance(img, Image.Image):
256
- buffer = io.BytesIO()
257
- img.save(buffer, format="JPEG")
258
- img_str = base64.b64encode(buffer.getvalue()).decode()
259
- img_src = f"data:image/jpeg;base64,{img_str}"
260
- else:
261
- img_src = img
262
-
263
- trajectory_html += f"""
264
- <div class="step-container">
265
- <div class="step-header">Step {idx + 1}</div>
266
- <div class="step-content">
267
- <div class="step-image">
268
- <img src="{img_src}" alt="Browser state">
269
- </div>
270
- <div class="step-info">
271
- <div class="box-title">Action Candidates:</div>
272
- <div class="action-candidates">
273
- """
274
-
275
- # Display all candidates for this step
276
- for i, cand in enumerate(cands):
277
- action = cand['action']
278
- reward = cand['reward']
279
-
280
- trajectory_html += f"""
281
- <div class="candidate-box{' selected' if i == 0 else ''}">
282
- <div class="box-title">
283
- Action {i+1}{' (Selected)' if i == 0 else ''}
284
- <span class="reward-text">Reward: {reward}</span>
285
- </div>
286
- <pre>{action}</pre>
287
- </div>
288
- """
289
-
290
- trajectory_html += """
291
- </div>
292
- </div>
293
- </div>
294
- </div>
295
- """
296
-
297
- trajectory_html += "</div>"
298
-
299
- last_checklist_view, last_trajectory_html = checklist_view, trajectory_html
300
- yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html
301
- yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html