`_.
+
+.. note::
+
+ This project is under active development.
+
+Contents
+--------
+
+.. toctree::
+ :maxdepth: 2
+
+ usage
+ api
+ tutorials
diff --git a/BrowserGym/docs/src/tutorials.rst b/BrowserGym/docs/src/tutorials.rst
new file mode 100644
index 0000000000000000000000000000000000000000..01f5bcdb67769cd04b4d1e7c01e11fa9d0741bf6
--- /dev/null
+++ b/BrowserGym/docs/src/tutorials.rst
@@ -0,0 +1,26 @@
+Tutorials
+=========
+
+This section provides tutorials to help build new environments and tasks.
+
+.. grid:: 2
+ :gutter: 2
+
+ .. grid-item-card:: Walkthrough
+ :link: examples/walkthrough.html
+
+ :bdg-primary:`Getting started`
+
+ .. grid-item-card:: Create a custom task
+ :link: examples/create_custom_task.html
+
+ :bdg-primary:`Custom task`
+
+
+
+.. toctree::
+ :maxdepth: 1
+ :hidden:
+
+ examples/walkthrough.rst
+ examples/create_custom_task.rst
diff --git a/BrowserGym/docs/src/usage.rst b/BrowserGym/docs/src/usage.rst
new file mode 100644
index 0000000000000000000000000000000000000000..038ca6d93508754bea95f4f7d8c7ddabbfc8e3f6
--- /dev/null
+++ b/BrowserGym/docs/src/usage.rst
@@ -0,0 +1,42 @@
+Usage
+=====
+
+.. _installation:
+
+Installation
+------------
+
+To use BrowserGym, first install it using pip:
+
+.. code-block:: console
+
+ pip install browsergym
+
+Then, a required step is to setup playwright by running
+
+.. code-block:: console
+
+ playwright install chromium
+
+Example code
+------------
+
+Boilerplate code to run an agent on an interactive, open-ended task:
+
+.. code-block:: python
+
+ import gymnasium as gym
+ import browsergym.core # register the openended task as a gym environment
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": "https://www.google.com/"}, # starting URL
+ wait_for_user_message=True, # wait for a user message after each agent message sent to the chat
+ )
+
+ obs, info = env.reset()
+ done = False
+ while not done:
+ action = ... # implement your agent here
+ obs, reward, terminated, truncated, info = env.step(action)
+ done = terminated or truncated
diff --git a/BrowserGym/pyproject.toml b/BrowserGym/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..68b016a511ba078a6a0614b47af3d6026d04f83c
--- /dev/null
+++ b/BrowserGym/pyproject.toml
@@ -0,0 +1,33 @@
+[project]
+name = "browsergym-meta"
+description = "BrowserGym: a gym environment for web task automation in the Chromium browser"
+dynamic = ["version"]
+[tool.setuptools]
+packages = [] # meta distribution, packages are included as dependencies
+[tool.black]
+line-length = 100
+include = '\.pyi?$'
+exclude = '''
+/(
+ \.eggs
+ | \.git
+ | \.hg
+ | \.mypy_cache
+ | \.nox
+ | \.tox
+ | \.venv
+ | _build
+ | buck-out
+ | build
+ | dist
+)/
+'''
+
+[tool.pytest.ini_options]
+filterwarnings = [
+ 'ignore::UserWarning:gymnasium.*:', # too many "The obs is not within the observation space." warnings.
+]
+markers = [
+ "slow: marks tests as slow (deselect with '-m \"not slow\"')",
+ "serial: mark test to be run sequantially (deselect with '-m \"not serial\"')"
+]
diff --git a/BrowserGym/tests/__init__.py b/BrowserGym/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/BrowserGym/tests/assistantbench/__init__.py b/BrowserGym/tests/assistantbench/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/BrowserGym/tests/assistantbench/data/fallback_gpt4_seeplanact_predictions.jsonl b/BrowserGym/tests/assistantbench/data/fallback_gpt4_seeplanact_predictions.jsonl
new file mode 100644
index 0000000000000000000000000000000000000000..04c1bf8931e9aca6493316f7f4703ba2f4c95d92
--- /dev/null
+++ b/BrowserGym/tests/assistantbench/data/fallback_gpt4_seeplanact_predictions.jsonl
@@ -0,0 +1,33 @@
+{"id": "2aa5dd83fbcd0dce9a3dd4592106e5b5edf738008d932e357d477bba80e59ccf", "answer": "\\( \\frac{2}{7} \\times 100 \\approx 28.57 \\)", "gold_answer": "14.2", "score": 0, "has_ans": 1.0}
+{"id": "2ddae3b7a208e3c25f14d82d7a1faaaa1832fbf950b4dac345e755c4c361f294", "answer": 800000.0, "gold_answer": "1010000", "score": 0.7669061178326222, "has_ans": 1.0}
+{"id": "4e615af6f0348597b4133cc1ec5418bb3f35328e3d95e23a275027cee97b5e09", "answer": [], "gold_answer": "Adrenalinpark K\u00f6ln", "score": 0.0, "has_ans": 0}
+{"id": "c7afe00869f98cf363fd83677ac41757ed5e57f03eacc3d1304feb0a92084bd1", "answer": "Knives Out", "gold_answer": "Glass Onion: A Knives Out Mystery", "score": 0.5714285714285715, "has_ans": 1.0}
+{"id": "57d9dc6935e8a40b02e7f8ec81768fe70e68a0c05f6866927c9fda38db38a486", "answer": "-$108", "gold_answer": "45", "score": 0, "has_ans": 1.0}
+{"id": "748899d9d70c09beb3bd48ac8a3658bdcfd2f9114fe6dc4c4b8d2f9541ef4607", "answer": [{"sender": "dhl", "price (usd)": "50"}, {"sender": "fedex", "price (usd)": "60"}], "gold_answer": "{\"sender\": \"DHL\", \"price (usd)\": \"55-70\"}\n{\"sender\": \"Fedex\", \"price (usd)\": \"62-95\"}\n{\"sender\": \"USPS\", \"price (usd)\": \"73.4-78.15\"}", "score": 0.3333333333333333, "has_ans": 1.0}
+{"id": "9e31099fffa6a3891c94934fd4fc2f3f522d51c1904ff3561f3a10e4bf245821", "answer": "oshrat binyamin", "gold_answer": "Shiran Nawi, Yoni Osherov, Daniel Lereya", "score": 0.0, "has_ans": 1.0}
+{"id": "291b53e665b4dd4365cde995042db4a6f6fecef3fe3a6f4482f23d61bd673918", "answer": "ftp://ftp.ncbi.nlm.nih.gov/genomes/all/gcf_002288925.1_asm228892v2/gcf_002288925.1_asm228892v2_genomic.gff.gz", "gold_answer": "https://ftp.ensembl.org/pub/release-101/gff3/delphinapterus_leucas/Delphinapterus_leucas.ASM228892v3.101.gff3.gz", "score": 0.0, "has_ans": 1.0}
+{"id": "8fa42360185068216f2919935148d4e1ad28ddc18da0abd0f4bb0b6b6f84b127", "answer": "vgt", "gold_answer": "VGT", "score": 1.0, "has_ans": 1.0}
+{"id": "3af8028c2a59e28ca88baff0e6d91f2a9f170c5ef91003f1c8406755a2760ad4", "answer": "Oko, Thief of Crowns", "gold_answer": "Oko, Thief of Crowns", "score": 1.0, "has_ans": 1.0}
+{"id": "6b06d186921b8b390c65aebd0d16f09f60a47d2f1288ebe36953f734e84c0a3c", "answer": "", "gold_answer": "1148 sqft", "score": 0.0, "has_ans": 0.0}
+{"id": "9bdca8677af1e25cb7b0c7992dc62670c3e58e4afcd5ae60bcaa2483556bba00", "answer": ["'{\"sender\": \"usps\", \"price (usd)\": 25}'"], "gold_answer": "{\"sender\": \"USPS\", \"price (usd)\": \"41.75\"}", "score": 0, "has_ans": 1.0}
+{"id": "557e78eceec08ca8b0da5f9fdaca6e1c7ec6140a8ce600983ee716327dab005e", "answer": "Wolly Mammoth", "gold_answer": "For Pete's Sake", "score": 0.0, "has_ans": 1.0}
+{"id": "fb9ba3ab6a13d0adc677f993e90d54914a5cdf211305a1bba6bf16ec4ccb9b7c", "answer": "Instagram", "gold_answer": "Linkedin", "score": 0.0, "has_ans": 1.0}
+{"id": "52f7224e9c79431e7926afe317782711a0028750693e7456cde22ef6f4bd8bd5", "answer": "Nosferatu the Vampyre", "gold_answer": "Nosferatu the Vampyre", "score": 1.0, "has_ans": 1.0}
+{"id": "0ec4371851b96837b0a81b3dd3df401415061bb532fbafeb4609f3337c358508", "answer": ["anytime fitness", "point pleasant wellness center"], "gold_answer": "The Root Sports & Fitness Center\nMuscle Headz Gym", "score": 0.16666666666666666, "has_ans": 1.0}
+{"id": "6f224e7730ed027cbac73aebb1aea7f954053082041b02b19f4ff126a0a8a208", "answer": "Gina DiGioia", "gold_answer": "Gina DiGioia", "score": 1.0, "has_ans": 1.0}
+{"id": "99da66d8af02491f98b98c56b26c709e773b5a2ad945fb280375951ba600de09", "answer": 250.0, "gold_answer": "395", "score": 0.5425751529611245, "has_ans": 1.0}
+{"id": "ccec2229ced20a4b0cb4897e3a99120a3017ea030903e01c9bda6b13d40b0b14", "answer": "", "gold_answer": "McDonald's", "score": 0.0, "has_ans": 0.0}
+{"id": "9baaa267c95f9d8b75741ee9169c50563d297cfa592c20deaffd30dbc5984c74", "answer": 16.67, "gold_answer": "31.67", "score": 0.3582408362121543, "has_ans": 1.0}
+{"id": "6e3be83d1949fa52cba03fb1ce4b5b3bf7e37a83fd7d67694b10b2e439d90cf8", "answer": "wall street boxing & fitness", "gold_answer": "Renzo Gracie Jiu-Jitsu Wall Street", "score": 0.4, "has_ans": 1.0}
+{"id": "e2dc3a6b10b762e8aba7fa4d4e70f757f6d04dcbc8b56c48fc53fd9928d31d07", "answer": 40.0, "gold_answer": "30", "score": 0.7123179275482192, "has_ans": 1.0}
+{"id": "f88066d274e265edd6cd9d61cd80a41accb3a14baf2297652fdd05cdf716d455", "answer": "lower yosemite fall trail", "gold_answer": "Yosemite Falls\nBridalveil Fall", "score": 0.16666666666666666, "has_ans": 1.0}
+{"id": "e6bc98089608217e45b6956a46518fe3cce64a799b3ac43c6974c449ae14c408", "answer": 2140000.0, "gold_answer": "3080000", "score": 0.635876232048277, "has_ans": 1.0}
+{"id": "8ad84bd6fe38481ba49e7ad1f6fbd43219a999074e5c6fc940003281f55ec65b", "answer": ["trader joe's", "whole foods market", "aldi"], "gold_answer": "Potash Markets - Clark Street", "score": 0.0, "has_ans": 1.0}
+{"id": "55f4258484c5b398956133128a50462a767da211f8f72aa5ac5bbffb9bcbba1a", "answer": "Becker", "gold_answer": "CSI: Cyber", "score": 0.0, "has_ans": 1.0}
+{"id": "4dbedc5e1a0205e14b7ff3ba89bce3060dab15d0ada3b7e1351a6f2aa8287aec", "answer": 95.0, "gold_answer": "$55", "score": 0.4534562936319301, "has_ans": 1.0}
+{"id": "929b45f34805280d77c61d1e093e3d4e551d77ddb6ecd73552b12b1af286388d", "answer": "http://hgdownload.soe.ucsc.edu/goldenpath/canfam3/bigzips/", "gold_answer": "ftp://ftp.broadinstitute.org/distribution/assemblies/mammals/dog/canFam3.1/", "score": 0.0, "has_ans": 1.0}
+{"id": "cca4776df3c73e7f9430a2e624aafad056b14322a0b7ca6c0c22b7e7f3f0890a", "answer": "monica c. lozano", "gold_answer": "Wanda Austin\nRonald D. Sugar\nSue Wagner", "score": 0.0, "has_ans": 1.0}
+{"id": "efc0f3a47e9ed2ecdbcc037c2093865fe6e39f4d413a5d1ccdc7357160a4606b", "answer": "fidelity emerging asia fund (fseax)", "gold_answer": "Fidelity\u00ae Emerging Markets Index Fund (FPADX)", "score": 0.3636363636363636, "has_ans": 1.0}
+{"id": "b36ef2d8f2643b80e74a44ce3403f674ecb2aed7fd36afeaa289061a59feef92", "answer": "crunch fitness - east village", "gold_answer": "CrossFit East River\nAvea Pilates", "score": 0.14285714285714288, "has_ans": 1.0}
+{"id": "a9074997e698f912b9e751779ea19c1e92fa148404e90e0ae997acea3f9559b0", "answer": ["uncle tom's trail", "mount washburn", "fairy falls"], "gold_answer": "Trout lake trail\nArtist Point\nFountain Paint Pot\nLone Star Geyser\nStorm Point Trail", "score": 0.06666666666666667, "has_ans": 1.0}
+{"id": "797f7a5b65ca28b7e7156e7db1e9f117bd4a021de0cd512bfdbb0be897d89eab", "answer": ["red bamboo", "quantum leap"], "gold_answer": "Shanghai villa", "score": 0.0, "has_ans": 1.0}
\ No newline at end of file
diff --git a/BrowserGym/tests/assistantbench/test_env_general.py b/BrowserGym/tests/assistantbench/test_env_general.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a2c2e44f3849ad2b49e36893f6c50a705a34595
--- /dev/null
+++ b/BrowserGym/tests/assistantbench/test_env_general.py
@@ -0,0 +1,49 @@
+import logging
+import os
+import random
+
+import gymnasium as gym
+import playwright.sync_api
+import pytest
+from tenacity import retry, retry_if_exception_type, stop_after_attempt
+
+# register gym environments
+import browsergym.assistantbench
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+from browsergym.assistantbench import TEST_AB_TASK_IDS, VALID_AB_TASK_IDS
+
+rng = random.Random(1)
+valid_task_ids = rng.sample(VALID_AB_TASK_IDS, 10)
+test_task_ids = rng.sample(TEST_AB_TASK_IDS, 10)
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
+ reraise=True,
+ before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
+)
+@pytest.mark.parametrize("task_id", valid_task_ids + test_task_ids)
+@pytest.mark.slow
+def test_valid_env(task_id):
+ env = gym.make(
+ f"browsergym/{task_id}",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+
+ obs, reward, terminated, truncated, info = env.step("noop(0)")
+ assert not obs["last_action_error"]
+ assert not (terminated or truncated)
+
+ obs, reward, terminated, truncated, info = env.step('send_msg_to_user("something")')
+ assert not obs["last_action_error"]
+ assert terminated
+
+ env.close()
diff --git a/BrowserGym/tests/assistantbench/test_evaluation.py b/BrowserGym/tests/assistantbench/test_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4973d7158f780b0397ff669f7051e44fcfd8d0a5
--- /dev/null
+++ b/BrowserGym/tests/assistantbench/test_evaluation.py
@@ -0,0 +1,77 @@
+import json
+import pathlib
+
+import gymnasium as gym
+import pytest
+
+from browsergym.assistantbench.evaluation.evaluator import question_scorer
+from browsergym.experiments.benchmark.metadata.utils import (
+ task_list_from_metadata,
+ task_metadata,
+)
+
+__DATA_DIR = pathlib.Path(__file__).resolve().parent / "data"
+
+metadata = task_metadata("assistantbench")
+file_path = pathlib.Path(__DATA_DIR) / "fallback_gpt4_seeplanact_predictions.jsonl"
+
+data_points = {}
+
+# Open the JSONL file and read each line as a JSON object
+with open(file_path, "r") as f:
+ for line in f:
+ data_point = json.loads(line)
+
+ original_id = data_point["id"]
+ answer = data_point["answer"]
+ gold_answer = data_point["gold_answer"]
+ score = data_point["score"]
+ has_ans = data_point["has_ans"]
+
+ data_points[original_id] = {
+ "task_id": task_list_from_metadata(metadata, {"original_id": original_id})[0],
+ "answer": answer,
+ "gold_answer": gold_answer,
+ "score": score,
+ "has_ans": has_ans,
+ }
+
+
+@pytest.mark.parametrize("original_id", list(data_points.keys()))
+def test_evaluate(original_id: str):
+
+ answer = data_points[original_id]["answer"]
+ gold_answer = data_points[original_id]["gold_answer"]
+ expected_score = data_points[original_id]["score"]
+ expected_has_ans = data_points[original_id]["has_ans"]
+
+ score, has_ans = question_scorer(answer, gold_answer)
+
+ # Assert if the expected results doesn't match
+ assert score == expected_score
+ assert has_ans == expected_has_ans
+
+
+@pytest.mark.parametrize(
+ "original_id",
+ [id for id in data_points.keys() if isinstance(data_points[id]["answer"], (str, float, int))],
+)
+@pytest.mark.slow
+def test_evaluate_within_env(original_id: str):
+
+ task_id = data_points[original_id]["task_id"]
+ answer = data_points[original_id]["answer"]
+ expected_score = data_points[original_id]["score"]
+
+ env = gym.make(
+ f"browsergym/{task_id}",
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+
+ obs, reward, terminated, truncated, info = env.step(f"send_msg_to_user({repr(str(answer))})")
+ assert not obs["last_action_error"]
+ assert terminated
+ assert reward == expected_score
+
+ env.close()
diff --git a/BrowserGym/tests/core/__init__.py b/BrowserGym/tests/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f09d6fbde51609da41e1041eb3fb8125d808cb
--- /dev/null
+++ b/BrowserGym/tests/core/__init__.py
@@ -0,0 +1,2 @@
+# bugfix: use same playwright instance in browsergym and pytest
+from ..utils import setup_playwright
diff --git a/BrowserGym/tests/core/data/basic_iframe_site/basic_iframe.html b/BrowserGym/tests/core/data/basic_iframe_site/basic_iframe.html
new file mode 100644
index 0000000000000000000000000000000000000000..e2e61c694f20f358274a32f62c0cb74b6a63286b
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_iframe_site/basic_iframe.html
@@ -0,0 +1,37 @@
+
+
+
+ Iframe Example
+
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_iframe_site/basic_iframe_2.html b/BrowserGym/tests/core/data/basic_iframe_site/basic_iframe_2.html
new file mode 100644
index 0000000000000000000000000000000000000000..d8e51b6ce1a4b8deebfd02868dd44e42e3a12158
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_iframe_site/basic_iframe_2.html
@@ -0,0 +1,12 @@
+
+
+
+ Simple Website
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_iframe_site/inner-iframe.html b/BrowserGym/tests/core/data/basic_iframe_site/inner-iframe.html
new file mode 100644
index 0000000000000000000000000000000000000000..6cb49db9ca79b79111698aa23d975a1900296298
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_iframe_site/inner-iframe.html
@@ -0,0 +1,23 @@
+
+
+
+
+ Inner Iframe
+
+
+
+
+ Iframe Level 2
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_iframe_site/outer-iframe.html b/BrowserGym/tests/core/data/basic_iframe_site/outer-iframe.html
new file mode 100644
index 0000000000000000000000000000000000000000..b71a077f2b374005894c2804aa9bf827e139d213
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_iframe_site/outer-iframe.html
@@ -0,0 +1,30 @@
+
+
+
+ Shadow DOM Example
+
+
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_shadow_dom_site/basic_shadow_dom.html b/BrowserGym/tests/core/data/basic_shadow_dom_site/basic_shadow_dom.html
new file mode 100644
index 0000000000000000000000000000000000000000..242678f9696f448afffe5e5523aa36704fe6ec95
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_shadow_dom_site/basic_shadow_dom.html
@@ -0,0 +1,52 @@
+
+
+
+ Unit Test with Complex Nested Shadow DOM
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_shadow_dom_site/simple_shadow_dom.html b/BrowserGym/tests/core/data/basic_shadow_dom_site/simple_shadow_dom.html
new file mode 100644
index 0000000000000000000000000000000000000000..fdcc8ceca07f897be41996144dd2a895d1a02229
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_shadow_dom_site/simple_shadow_dom.html
@@ -0,0 +1,22 @@
+
+
+
+ Unit Test with Complex Nested Shadow DOM
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_shadow_iframe_site/basic_iframe.html b/BrowserGym/tests/core/data/basic_shadow_iframe_site/basic_iframe.html
new file mode 100644
index 0000000000000000000000000000000000000000..e2e61c694f20f358274a32f62c0cb74b6a63286b
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_shadow_iframe_site/basic_iframe.html
@@ -0,0 +1,37 @@
+
+
+
+ Iframe Example
+
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_shadow_iframe_site/basic_iframe_2.html b/BrowserGym/tests/core/data/basic_shadow_iframe_site/basic_iframe_2.html
new file mode 100644
index 0000000000000000000000000000000000000000..dbcd6756822e81b68bdee21ec36944613b682826
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_shadow_iframe_site/basic_iframe_2.html
@@ -0,0 +1,12 @@
+
+
+
+ Simple Website
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_shadow_iframe_site/inner-iframe.html b/BrowserGym/tests/core/data/basic_shadow_iframe_site/inner-iframe.html
new file mode 100644
index 0000000000000000000000000000000000000000..0d480d6701adc7d034f3e05c03b899b206b9f949
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_shadow_iframe_site/inner-iframe.html
@@ -0,0 +1,12 @@
+
+
+
+ Inner Iframe
+
+
+ Iframe Level 2
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/basic_shadow_iframe_site/outer-iframe.html b/BrowserGym/tests/core/data/basic_shadow_iframe_site/outer-iframe.html
new file mode 100644
index 0000000000000000000000000000000000000000..eed22ca03938bded8c1408df0a698515fa5068e9
--- /dev/null
+++ b/BrowserGym/tests/core/data/basic_shadow_iframe_site/outer-iframe.html
@@ -0,0 +1,40 @@
+
+
+
+ Shadow DOM Example
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/example.html b/BrowserGym/tests/core/data/example.html
new file mode 100644
index 0000000000000000000000000000000000000000..13552a70b0edc84663a94433a7da6ed525561e65
--- /dev/null
+++ b/BrowserGym/tests/core/data/example.html
@@ -0,0 +1,52 @@
+
+
+
+
+ Example Domain
+
+
+
+
+
+
+
+
+
+
Example Domain
+
This domain is for use in illustrative examples in documents. You may use this
+ domain in literature without prior coordination or asking for permission.
+
More information...
+
+
+
+
diff --git a/BrowserGym/tests/core/data/hover.html b/BrowserGym/tests/core/data/hover.html
new file mode 100644
index 0000000000000000000000000000000000000000..385bf2dc97085ea2a06eefdc22aa1af159bbe077
--- /dev/null
+++ b/BrowserGym/tests/core/data/hover.html
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/button_input.html b/BrowserGym/tests/core/data/input_type/button_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..9d6e6493c7594a0a9cd86cbd3f04fcfbea415c93
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/button_input.html
@@ -0,0 +1,10 @@
+
+
+
+
+Input Button
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/checkbox_input.html b/BrowserGym/tests/core/data/input_type/checkbox_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..ada1f2ff25cc66ed14281a96ca60021da9d173c4
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/checkbox_input.html
@@ -0,0 +1,19 @@
+
+
+
+
+Checkboxes
+The input type="checkbox" defines a checkbox:
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/color_picker_input.html b/BrowserGym/tests/core/data/input_type/color_picker_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..e33b957dc62cb351ad6f2af5e4b2b55af5967acf
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/color_picker_input.html
@@ -0,0 +1,18 @@
+
+
+
+
+Show a Color Picker
+
+The input type="color" is used for input fields that should contain a color.
+
+
+
+Note: type="color" is not supported in Internet Explorer 11 or Safari 9.1 (or earlier).
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/date_input.html b/BrowserGym/tests/core/data/input_type/date_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..0e2d6a3fe1155b35651896483e9a072685f2c34d
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/date_input.html
@@ -0,0 +1,18 @@
+
+
+
+
+Date Field
+
+The input type="date" is used for input fields that should contain a date.
+
+
+
+Note: type="date" is not supported in Internet Explorer 11 or prior Safari 14.1.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/date_min_max_input.html b/BrowserGym/tests/core/data/input_type/date_min_max_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..f519df9c130708a26e71c474496538c60d9930f4
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/date_min_max_input.html
@@ -0,0 +1,22 @@
+
+
+
+
+Date Field Restrictions
+
+Use the min and max attributes to add restrictions to dates:
+
+
+
+Note: type="date" is not supported in Internet Explorer 11 or prior Safari 14.1.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/date_time_local_input.html b/BrowserGym/tests/core/data/input_type/date_time_local_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..cc34237bebfa0704a8cc6d1553d5b490fed9dd58
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/date_time_local_input.html
@@ -0,0 +1,18 @@
+
+
+
+
+Local Date Field
+
+The input type="datetime-local" specifies a date and time input field, with no time zone.
+
+
+
+Note: type="datetime-local" is not supported in Internet Explorer 11 or prior Safari 14.1.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/email_input.html b/BrowserGym/tests/core/data/input_type/email_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..0e2f6c3b5db4022e32eef4eb8ac5c0aa79a8ba40
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/email_input.html
@@ -0,0 +1,16 @@
+
+
+
+
+Email Field
+
+The input type="email" is used for input fields that should contain an e-mail address:
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/file_input.html b/BrowserGym/tests/core/data/input_type/file_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..5a026e729276c425c23e546cee19a7900cbae84d
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/file_input.html
@@ -0,0 +1,15 @@
+
+
+
+
+File upload
+
+Show a file-select field which allows a file to be chosen for upload:
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/hidden_field_input.html b/BrowserGym/tests/core/data/input_type/hidden_field_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..af16596e12dfde14ecc4f2d3daac9006f0cfb26a
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/hidden_field_input.html
@@ -0,0 +1,17 @@
+
+
+
+
+A Hidden Field (look in source code)
+
+
+
+Note: The hidden field is not shown to the user, but the data is sent when the form is submitted.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/image_input.html b/BrowserGym/tests/core/data/input_type/image_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..502fd2990500a8a21b2ef030db2017a90e0bd02f
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/image_input.html
@@ -0,0 +1,18 @@
+
+
+
+
+Display an Image as the Submit button
+
+
+
+Note: The input type="image" sends the X and Y coordinates of the click that activated the image button.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/number_input.html b/BrowserGym/tests/core/data/input_type/number_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..1158e2baaab595f7ba1f8381fd95811b2fdf9be8
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/number_input.html
@@ -0,0 +1,18 @@
+
+
+
+
+Number Field
+
+The input type="number" defines a numeric input field.
+
+You can use the min and max attributes to add numeric restrictions in the input field:
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/number_step_input.html b/BrowserGym/tests/core/data/input_type/number_step_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..8d68505c3c8005bbc6a055bca0210db276120050
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/number_step_input.html
@@ -0,0 +1,16 @@
+
+
+
+
+Numeric Steps
+
+Depending on browser support: Fixed steps will apply in the input field.
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/password_input.html b/BrowserGym/tests/core/data/input_type/password_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..66eb78622aec4c406d77b49bb61e8f7e99503e41
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/password_input.html
@@ -0,0 +1,20 @@
+
+
+
+
+Password field
+
+The input type="password" defines a password field:
+
+
+
+The characters in a password field are masked (shown as asterisks or circles).
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/radio_input.html b/BrowserGym/tests/core/data/input_type/radio_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..125d68f4df1bd3c28d1954519c01b40ff24daa68
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/radio_input.html
@@ -0,0 +1,19 @@
+
+
+
+
+Radio Buttons
+
+The input type="radio" defines a radio button:
+
+Choose your favorite Web language:
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/range_input.html b/BrowserGym/tests/core/data/input_type/range_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..d96b9791994a7546b71c4808ffd1b111546d3323
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/range_input.html
@@ -0,0 +1,16 @@
+
+
+
+
+Range Field
+
+Depending on browser support: The input type "range" can be displayed as a slider control.
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/reset_input.html b/BrowserGym/tests/core/data/input_type/reset_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..d7710c1a8078eceef1f0c3b70503c7aed35ac2da
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/reset_input.html
@@ -0,0 +1,21 @@
+
+
+
+
+Reset Button
+
+The input type="reset" defines a reset button that resets all form values to their default values:
+
+
+
+If you change the input values and then click the "Reset" button, the form-data will be reset to the default values.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/search_input.html b/BrowserGym/tests/core/data/input_type/search_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..db8ab66ed15a758c76d9f9fca7344c91b378dd10
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/search_input.html
@@ -0,0 +1,15 @@
+
+
+
+
+Search Field
+The input type="search" is used for search fields (behaves like a regular text field):
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/submit_input.html b/BrowserGym/tests/core/data/input_type/submit_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..257ebdda7f922079c4cd2648564e6adcd8be4c58
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/submit_input.html
@@ -0,0 +1,20 @@
+
+
+
+
+Submit Button
+
+The input type="submit" defines a button for submitting form data to a form-handler:
+
+
+
+If you click "Submit", the form-data will be sent to a page called "https://www.w3schools.com/action_page.php".
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/submit_nn_input.html b/BrowserGym/tests/core/data/input_type/submit_nn_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..da04e9d3a5413e6d5ca8f08c0c705d3156161569
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/submit_nn_input.html
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/telephone_input.html b/BrowserGym/tests/core/data/input_type/telephone_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..12a0c8a59a1da62b4e578dca976d167dbf1d8dfa
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/telephone_input.html
@@ -0,0 +1,17 @@
+
+
+
+
+Telephone Field
+
+The input type="tel" is used for input fields that should contain a telephone number:
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/text_input.html b/BrowserGym/tests/core/data/input_type/text_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..811753a26fd5325d481b2246051dc1d2d153a540
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/text_input.html
@@ -0,0 +1,20 @@
+
+
+
+
+Text field
+The input type="text" defines a one-line text input field:
+
+
+
+Note that the form itself is not visible.
+Also note that the default width of a text field is 20 characters.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/time_input.html b/BrowserGym/tests/core/data/input_type/time_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..8ca605580af8ee58f86aecb20acf5d2d8fa9c263
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/time_input.html
@@ -0,0 +1,20 @@
+
+
+
+
+Show a Time Input Control
+
+The input type="time" allows the user to select a time (no time zone):
+
+If the browser supports it, a time picker pops up when entering the input field.
+
+
+
+Note: type="time" is not supported in Internet Explorer 11 or prior Safari 14.1.
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/url_input.html b/BrowserGym/tests/core/data/input_type/url_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..1f6bdf641d746e582d0ce3dad5f04e483ab7bef4
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/url_input.html
@@ -0,0 +1,16 @@
+
+
+
+
+Display a URL Input Field
+
+The input type="url" is used for input fields that should contain a URL address:
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/input_type/week_input.html b/BrowserGym/tests/core/data/input_type/week_input.html
new file mode 100644
index 0000000000000000000000000000000000000000..1f6bdf641d746e582d0ce3dad5f04e483ab7bef4
--- /dev/null
+++ b/BrowserGym/tests/core/data/input_type/week_input.html
@@ -0,0 +1,16 @@
+
+
+
+
+Display a URL Input Field
+
+The input type="url" is used for input fields that should contain a URL address:
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/long_page.html b/BrowserGym/tests/core/data/long_page.html
new file mode 100644
index 0000000000000000000000000000000000000000..8fd6ca357e35581ea09f2c905b36ee9df439f92f
--- /dev/null
+++ b/BrowserGym/tests/core/data/long_page.html
@@ -0,0 +1,211 @@
+
+
+
+
+
+ This is the top
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ This is the bottom
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/lots_of_iframes.html b/BrowserGym/tests/core/data/lots_of_iframes.html
new file mode 100644
index 0000000000000000000000000000000000000000..ba342a9ced3d48364816b1fcb7888f5518a69001
--- /dev/null
+++ b/BrowserGym/tests/core/data/lots_of_iframes.html
@@ -0,0 +1,21 @@
+
+
+
+
+ Lots of Iframes
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/obstructed_checkbox_page.html b/BrowserGym/tests/core/data/obstructed_checkbox_page.html
new file mode 100644
index 0000000000000000000000000000000000000000..a3f9ec1f23c7dad374236eb1a6e19e52ceb56cb5
--- /dev/null
+++ b/BrowserGym/tests/core/data/obstructed_checkbox_page.html
@@ -0,0 +1,93 @@
+
+
+
+
+
+ Checkbox with Label Interception
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/data/test_page.html b/BrowserGym/tests/core/data/test_page.html
new file mode 100644
index 0000000000000000000000000000000000000000..cdb46c801b32395364831e3c0a6dc32149bda067
--- /dev/null
+++ b/BrowserGym/tests/core/data/test_page.html
@@ -0,0 +1,29 @@
+
+
+
+ Simple Form
+
+
+ Simple Form
+
+
+
+
diff --git a/BrowserGym/tests/core/data/test_page_2.html b/BrowserGym/tests/core/data/test_page_2.html
new file mode 100644
index 0000000000000000000000000000000000000000..b3b2a5d69c83f74229fb2589f7c0798f960db9bb
--- /dev/null
+++ b/BrowserGym/tests/core/data/test_page_2.html
@@ -0,0 +1,63 @@
+
+
+
+
+ Simple Form
+
+
+
+ Simple Form
+
+
+
+
+ Text within a non-html tag
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Text that should not be visible
+
+
+
diff --git a/BrowserGym/tests/core/data/textbox.html b/BrowserGym/tests/core/data/textbox.html
new file mode 100644
index 0000000000000000000000000000000000000000..c93bd6f7835a9f11860ce6cd2406794c3376a26b
--- /dev/null
+++ b/BrowserGym/tests/core/data/textbox.html
@@ -0,0 +1,13 @@
+
+
+
+
+ Simple HTML Page
+
+
+
+
+
+
+
+
diff --git a/BrowserGym/tests/core/test_actions_highlevel.py b/BrowserGym/tests/core/test_actions_highlevel.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a4f56c6f9f7d579cccee40f2d747c8c42cbdc9
--- /dev/null
+++ b/BrowserGym/tests/core/test_actions_highlevel.py
@@ -0,0 +1,1256 @@
+import ast
+import os
+import pathlib
+import platform
+import re
+
+import bs4
+import gymnasium as gym
+import pytest
+from pyparsing.exceptions import ParseException
+
+# register openended gym environments
+import browsergym.core
+from browsergym.core.action.highlevel import HighLevelActionSet
+from browsergym.core.action.parsers import NamedArgument, highlevel_action_parser
+from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR
+from browsergym.utils.obs import flatten_dom_to_str
+
+_IS_MAC_OS = platform.system() == "Darwin"
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+__TIMEOUT = 500
+
+__DATA_DIR = pathlib.Path(__file__).resolve().parent / "data"
+
+TEXTBOX_URL = f"file://{__DATA_DIR}/textbox.html"
+EXAMPLE_URL = f"file://{__DATA_DIR}/example.html"
+HOVER_URL = f"file://{__DATA_DIR}/hover.html"
+INEXISTANT_FILE_URL = f"file://{__DATA_DIR}/no_file_here.html"
+LONG_PAGE_URL = f"file://{__DATA_DIR}/long_page.html"
+TEXT_INPUT_URL = f"file://{__DATA_DIR}/input_type/text_input.html"
+URL_INPUT_URL = f"file://{__DATA_DIR}/input_type/url_input.html"
+CHECKBOX_URL = f"file://{__DATA_DIR}/input_type/checkbox_input.html"
+MULTI_IFRAME_URL = f"file://{__DATA_DIR}/basic_iframe_site/basic_iframe_2.html"
+OBSTRUCTED_CHECKBOX_URL = f"file://{__DATA_DIR}/obstructed_checkbox_page.html"
+LOTS_OF_IFRAMES_URL = f"file://{__DATA_DIR}/lots_of_iframes.html"
+
+
+def test_action_parser():
+ parser = highlevel_action_parser
+
+ with pytest.raises(ParseException):
+ function_calls = parser.parse_string("", parseAll=True)
+ assert not function_calls
+
+ function_calls = parser.parse_string("a()", parseAll=True)
+ assert len(function_calls) == 1
+
+ function_calls = parser.parse_string(" a ( ) \n\n\t", parseAll=True)
+ assert len(function_calls) == 1
+
+ function_calls = parser.parse_string(" a ( ) b() \n \tc()", parseAll=True)
+ assert [function_name for function_name, _ in function_calls] == ["a", "b", "c"]
+
+ function_calls = parser.parse_string('a(12, 12.2, "text", (1, 2, 3), ["a", 23])', parseAll=True)
+ _, function_args = function_calls[0]
+ assert function_args == [12, 12.2, "text", (1, 2, 3), ["a", 23]]
+
+ function_calls = parser.parse_string('a(x=12, y = 12.2, other = "text")', parseAll=True)
+ _, function_args = function_calls[0]
+ assert function_args == [
+ NamedArgument(name="x", value=12),
+ NamedArgument(name="y", value=12.2),
+ NamedArgument(name="other", value="text"),
+ ]
+
+ function_calls = parser.parse_string('a(12, y = 12.2, other = "text")', parseAll=True)
+ _, function_args = function_calls[0]
+ assert function_args == [
+ 12,
+ NamedArgument(name="y", value=12.2),
+ NamedArgument(name="other", value="text"),
+ ]
+
+ with pytest.raises(ParseException):
+ function_calls = parser.parse_string('a(x = 12, 12.2, other = "text")', parseAll=True)
+
+ with pytest.raises(ParseException):
+ function_calls = parser.parse_string('a(12, 12.2, 1 = "text")', parseAll=True)
+
+ with pytest.raises(ParseException):
+ function_calls = parser.parse_string("a(1-)", parseAll=True)
+
+ with pytest.raises(ParseException):
+ function_calls = parser.parse_string("a(1/2)", parseAll=True)
+
+ function_calls = parser.parse_string('a("""\nsome\ntext\\"\\"""")', parseAll=True)
+ _, function_args = function_calls[0]
+ assert function_args == ['\nsome\ntext""']
+
+ function_calls = parser.parse_string("a('\"some\\ntext\"')", parseAll=True)
+ _, function_args = function_calls[0]
+ assert function_args == ['"some\ntext"']
+
+ function_calls = parser.parse_string('#comment\na("# not comment") #comment \n ', parseAll=True)
+ assert len(function_calls) == 1
+ function_name, function_args = function_calls[0]
+ assert function_name == "a"
+ assert function_args == ["# not comment"]
+
+ function_calls = parser.parse_string('fun(12, x="val", y={"aaa": 23})', parseAll=True)
+ function_name, function_args = function_calls[0]
+ assert function_name == "fun"
+ assert function_args == [
+ 12,
+ NamedArgument(name="x", value="val"),
+ NamedArgument(name="y", value={"aaa": 23}),
+ ]
+
+
+def test_valid_action():
+ action_set = HighLevelActionSet()
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": CHECKBOX_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ def get_checkbox_elem(obs):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"type": "checkbox", "id": "vehicle1"})
+ return checkbox
+
+ obs, info = env.reset()
+ checkbox = get_checkbox_elem(obs)
+
+ # box not checked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # typo in action (unescaped double quotes)
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))}, "17" screen") # typo here
+"""
+ with pytest.raises(ValueError):
+ python_action = action_set.to_python_code(action)
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # error and box not checked
+ assert "Received an empty action." in obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # click box 1 time
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 1
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box checked
+ assert not obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ # click box 2 times
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 2
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box still checked
+ assert not obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ # click box 3 times
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+click({repr(checkbox.get(BID_ATTR))})
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 3
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box unchecked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # click box 3 times, same line ops
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))}) click({repr(checkbox.get(BID_ATTR))}) click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 3
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box checked
+ assert not obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ # click box 3 times, multi line ops, whitespace, tab, comma in-between args
+ action = f"""\
+ click( {repr(checkbox.get(BID_ATTR))} ) click({repr(checkbox.get(BID_ATTR))})\t
+ noop() noop () noop( )
+ # THIS IS A COMMENT
+ noop() # this is a noop() call
+click({repr(checkbox.get(BID_ATTR))}, )
+#click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 3
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box unchecked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # click box 3 times, multi line ops, whitespace, tab, comma in-between args, markdown code block
+ action = f"""\
+Below is code
+ ```python
+ click( {repr(checkbox.get(BID_ATTR))} ) click({repr(checkbox.get(BID_ATTR))})\t
+ noop() noop () noop( )
+ # THIS IS A COMMENT
+ noop() # this is a noop() call
+click({repr(checkbox.get(BID_ATTR))}, )
+#click({repr(checkbox.get(BID_ATTR))})
+```
+This is not code, just an explanation
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 3
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box checked
+ assert not obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ # multiple markdown code blocks
+ action = f"""\
+Below is code
+ ```python
+ noop() noop () noop( )
+ # THIS IS A COMMENT
+ noop() # this is a noop() call
+click({repr(checkbox.get(BID_ATTR))}, )
+#click({repr(checkbox.get(BID_ATTR))})
+```
+This is not code, just an explanation
+Below is more code
+ ```python
+ click( {repr(checkbox.get(BID_ATTR))} ) click({repr(checkbox.get(BID_ATTR))})\t
+ noop() noop () noop( )
+ # THIS IS A COMMENT
+ noop() # this is a noop() call
+#click({repr(checkbox.get(BID_ATTR))})
+```
+This is not code, just an explanation
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 3
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box unchecked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # multiple function calls in the middle of text
+ action = f"""\
+Let's do a noop(), then noop () noop( ) then click({repr(checkbox.get(BID_ATTR))}, )
+#click({repr(checkbox.get(BID_ATTR))})
+Now let's do two more
+ click( {repr(checkbox.get(BID_ATTR))} ) click({repr(checkbox.get(BID_ATTR))})\t
+ noop() noop () noop( )
+ # THIS IS A COMMENT
+ noop() # this is a noop() call
+#click({repr(checkbox.get(BID_ATTR))})
+```
+This is not code, just an explanation
+This is garbage
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nclick(") == 3
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box checked
+ assert not obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ env.close()
+
+
+def test_invalid_action():
+ action_set = HighLevelActionSet()
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": CHECKBOX_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+ obs, info = env.reset()
+
+ # click inexistant bid
+ action = f"""\
+click("INVALID_BID")
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert "ValueError" in obs["last_action_error"]
+
+ # invalid bid value type
+ action = f"""\
+click(None)
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "ValueError: expected a string, got None"
+
+ # invalid bid value type
+ action = f"""\
+click(42.7)
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "ValueError: expected a string, got 42.7"
+
+ # invalid bid value type
+ action = f"""\
+click([])
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "ValueError: expected a string, got []"
+
+ # invalid bid value type
+ action = f"""\
+click([42, "a", True, None])
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "ValueError: expected a string, got [42, 'a', True, None]"
+
+ # invalid bid value type
+ action = f"""\
+click({{}})
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "ValueError: expected a string, got {}"
+
+ # invalid bid value type
+ action = f"""\
+click({{"k": "aaa"}})
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "ValueError: expected a string, got {'k': 'aaa'}"
+
+ # invalid action args (too many)
+ action = f"""\
+click("4", "aa", "bb")
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert obs["last_action_error"] == "Error: Locator.click: modifiers: expected array, got string"
+
+ # invalid action args (not enough)
+ action = f"""\
+click()
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert (
+ obs["last_action_error"]
+ == "TypeError: click() missing 1 required positional argument: 'bid'"
+ )
+
+ # invalid action args (not enough)
+ action = f"""\
+click()
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # error
+ assert (
+ obs["last_action_error"]
+ == "TypeError: click() missing 1 required positional argument: 'bid'"
+ )
+
+ # invalid action name
+ with pytest.raises(NameError):
+ action_set.to_python_code(
+ f"""\
+not_a_valid_action()
+"""
+ )
+
+ # forbidden fill action
+ with pytest.raises(NameError):
+ HighLevelActionSet(subsets=["coord"]).to_python_code(
+ f"""\
+fill("INVALID_BID", "some text")
+"""
+ )
+
+ # forbidden import
+ with pytest.raises(ValueError):
+ action_set.to_python_code(
+ f"""\
+import numpy as np
+"""
+ )
+
+ # invalid expression, results in empty action
+ with pytest.raises(ValueError):
+ action_set.to_python_code(
+ f"""\
+[
+"""
+ )
+
+ # invalid expression, results in empty action
+ with pytest.raises(ValueError):
+ action_set.to_python_code(
+ f"""\
+click
+"""
+ )
+
+ env.close()
+
+
+def test_click_through_frames():
+ action_set = HighLevelActionSet()
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": MULTI_IFRAME_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ obs, info = env.reset()
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"type": "checkbox", "id": "checkbox_2"})
+
+ # box checked
+ assert checkbox.has_attr("checked")
+
+ # click box
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"type": "checkbox", "id": "checkbox_2"})
+
+ # box not checked
+ assert not checkbox.has_attr("checked")
+
+ env.close()
+
+
+def test_fill_through_iframe():
+ action_set = HighLevelActionSet()
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": MULTI_IFRAME_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ obs, info = env.reset()
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ text_input = soup.find(
+ "input", attrs={"type": "text", "placeholder": "Enter text here in iframe"}
+ )
+
+ # empty input
+ assert text_input.get("value") == ""
+
+ # fill with some text
+ action = f"""\
+fill({repr(text_input.get(BID_ATTR))}, "This is a test value.")
+"""
+ python_action = action_set.to_python_code(action)
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ text_input = soup.find(
+ "input", attrs={"type": "text", "placeholder": "Enter text here in iframe"}
+ )
+
+ # input filled to desired value
+ assert text_input.get("value") == "This is a test value."
+
+ env.close()
+
+
+def test_click():
+ action_set = HighLevelActionSet()
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": CHECKBOX_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ def get_checkbox_elem(obs):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"type": "checkbox", "id": "vehicle1"})
+ return checkbox
+
+ obs, info = env.reset()
+ checkbox = get_checkbox_elem(obs)
+
+ # box not checked
+ assert not checkbox.has_attr("checked")
+
+ # click box
+ action = f"""
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ # box checked
+ assert checkbox.has_attr("checked")
+
+ # click box
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ # box unchecked
+ assert not checkbox.has_attr("checked")
+
+ # click box twice
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ python_action = action_set.to_python_code(action)
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ # box still unchecked
+ assert not checkbox.has_attr("checked")
+
+ env.close()
+
+
+def test_hover():
+ action_set = HighLevelActionSet(subsets=["bid", "coord"])
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": HOVER_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ def get_button_elem(obs):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ button = soup.find("input", attrs={"type": "button"})
+ return button
+
+ obs, info = env.reset()
+ button = get_button_elem(obs)
+
+ assert not obs["last_action_error"]
+ assert button.get("value") == "Hover me"
+
+ action = f"""
+hover({repr(button.get(BID_ATTR))})
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ button = get_button_elem(obs)
+
+ assert not obs["last_action_error"]
+ assert button.get("value") == "Hello world!"
+
+ action = f"""
+mouse_move(0, 0)
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ button = get_button_elem(obs)
+
+ assert not obs["last_action_error"]
+ assert button.get("value") == "Hover me"
+
+ env.close()
+
+
+def test_fill_type_press():
+ action_set = HighLevelActionSet(subsets=["bid", "coord"])
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEXT_INPUT_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ def get_fname_lname_elems(obs):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ fname = soup.find("input", attrs={"id": "fname"})
+ lname = soup.find("input", attrs={"id": "lname"})
+ return fname, lname
+
+ obs, info = env.reset()
+ fname, lname = get_fname_lname_elems(obs)
+
+ # type using bid
+ action = f"""
+fill({repr(fname.get(BID_ATTR))}, 'Christian')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "Christian"
+ assert lname.get("value") == ""
+
+ # type using bid
+ action = f"""
+fill({repr(lname.get(BID_ATTR))}, 'Clavier')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "Christian"
+ assert lname.get("value") == "Clavier"
+
+ # type using focus and keyboard_type
+ action = f"""
+focus({repr(fname.get(BID_ATTR))}) keyboard_type('Gérard')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "ChristianGérard"
+ assert lname.get("value") == "Clavier"
+
+ # type using click and keyboard_insert_text
+ action = f"""
+click({repr(lname.get(BID_ATTR))}) keyboard_insert_text('Jugnot')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "ChristianGérard"
+ assert lname.get("value") == "ClavierJugnot"
+
+ # type using clear and keyboard_insert_text
+ action = f"""
+clear({repr(lname.get(BID_ATTR))}) keyboard_insert_text('Jugnot')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "ChristianGérard"
+ assert lname.get("value") == "Jugnot"
+
+ # type using click, manual clear and keyboard_insert_text
+ action = f"""
+click({repr(fname.get(BID_ATTR))})
+# clear the field
+keyboard_press('End')
+keyboard_down('Shift')
+keyboard_press('Home')
+keyboard_up('Shift')
+keyboard_press('Backspace')
+# insert text
+keyboard_insert_text('Gérard')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "Gérard"
+ assert lname.get("value") == "Jugnot"
+
+ # fill empty text
+ action = f"""
+fill({repr(fname.get(BID_ATTR))}, '')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == ""
+ assert lname.get("value") == "Jugnot"
+
+ # type in currently focused element
+ action = f"""
+keyboard_type('Jean')
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "Jean"
+ assert lname.get("value") == "Jugnot"
+
+ # de-focus (click 0, 0), then type text
+ action = f"""
+mouse_click(0, 0)
+"""
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "Jean"
+ assert lname.get("value") == "Jugnot"
+
+ action = f"""
+keyboard_type('Reno')
+"""
+ obs, reward, terminated, truncated, info = env.step(action)
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert not obs["last_action_error"]
+ assert fname.get("value") == "Jean"
+ assert lname.get("value") == "Jugnot"
+
+ env.close()
+
+
+@pytest.mark.skip(reason="Not implemented yet")
+def test_dblclick():
+ pass
+
+
+# copy/paste text using a sequence of keyboard_press actions
+def test_key_press():
+ action_set = HighLevelActionSet(subsets=["bid", "coord"])
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEXT_INPUT_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ obs, info = env.reset()
+
+ def get_fname_lname_elems(obs):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ fname = soup.find("input", attrs={"id": "fname"})
+ lname = soup.find("input", attrs={"id": "lname"})
+ return fname, lname
+
+ fname, lname = get_fname_lname_elems(obs)
+
+ action = f"""
+ fill({repr(fname.get(BID_ATTR))}, "Christian")
+ keyboard_press({repr("Meta+a" if _IS_MAC_OS else "Control+a")})
+ keyboard_press({repr("Meta+c" if _IS_MAC_OS else "Control+c")})
+ click({repr(lname.get(BID_ATTR))})
+ keyboard_press({repr("Meta+v" if _IS_MAC_OS else "Control+v")})
+ """
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ assert not obs["last_action_error"]
+
+ fname, lname = get_fname_lname_elems(obs)
+
+ assert lname.get("value") == "Christian"
+
+ env.close()
+
+
+def test_goto():
+ url1 = URL_INPUT_URL
+ url2 = TEXT_INPUT_URL
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": url1},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ )
+
+ obs, info = env.reset()
+
+ assert obs["url"] == url1
+
+ action = f"""
+goto({repr(url2)})
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ assert not obs["last_action_error"]
+
+ assert obs["url"] == url2
+
+ action = """
+go_back()
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ assert not obs["last_action_error"]
+
+ assert obs["url"] == url1
+
+ action = """
+go_forward()
+"""
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ assert not obs["last_action_error"]
+
+ assert obs["url"] == url2
+
+ env.close()
+
+
+def test_scroll():
+ action_set = HighLevelActionSet(subsets=["coord"])
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": LONG_PAGE_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ def extract_coords_from_elem(elem):
+ return ast.literal_eval(elem.get("center"))
+
+ def get_top_bottom_elems(obs):
+ soup = bs4.BeautifulSoup(
+ flatten_dom_to_str(
+ obs["dom_object"], obs["extra_element_properties"], with_center_coords=True
+ ),
+ "lxml",
+ )
+ top = soup.find("input", attrs={"type": "checkbox", "id": "top"})
+ bottom = soup.find("input", attrs={"type": "checkbox", "id": "bottom"})
+ return top, bottom
+
+ obs, info = env.reset()
+ top, bottom = get_top_bottom_elems(obs)
+ top_x, top_y = extract_coords_from_elem(top)
+ bottom_x, bottom_y = extract_coords_from_elem(bottom)
+
+ # top not checked
+ assert not top.has_attr("checked")
+ # bottom not checked
+ assert not bottom.has_attr("checked")
+
+ # click top
+ action = f"mouse_click({repr(top_x)}, {repr(top_y)})"
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ top, bottom = get_top_bottom_elems(obs)
+ top_x, top_y = extract_coords_from_elem(top)
+ bottom_x, bottom_y = extract_coords_from_elem(bottom)
+
+ # no error
+ assert not obs["last_action_error"]
+ # top checked
+ assert top.has_attr("checked")
+ # bottom not checked
+ assert not bottom.has_attr("checked")
+
+ top, bottom = get_top_bottom_elems(obs)
+ top_x, top_y = extract_coords_from_elem(top)
+ bottom_x, bottom_y = extract_coords_from_elem(bottom)
+
+ # click bottom
+ action = f"mouse_click({repr(bottom_x)}, {repr(bottom_y)})"
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ top, bottom = get_top_bottom_elems(obs)
+ top_x, top_y = extract_coords_from_elem(top)
+ bottom_x, bottom_y = extract_coords_from_elem(bottom)
+
+ # no error (click coordinates out of viewport is a silent fail in playwright)
+ assert not obs["last_action_error"]
+ # top checked
+ assert top.has_attr("checked")
+ # bottom not checked (click didn't go through)
+ assert not bottom.has_attr("checked")
+
+ # scroll up
+ action = f"scroll(0, -500)"
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ top, bottom = get_top_bottom_elems(obs)
+ prev_top_x, prev_top_y = top_x, top_y
+ top_x, top_y = extract_coords_from_elem(top)
+ prev_bottom_x, prev_bottom_y = bottom_x, bottom_y
+ bottom_x, bottom_y = extract_coords_from_elem(bottom)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ # no movement
+ assert prev_top_x == top_x and prev_top_y == top_y
+ assert prev_bottom_x == bottom_x and prev_bottom_y == bottom_y
+
+ # scroll down
+ action = f"scroll(0, 500)"
+
+ obs, reward, terminated, truncated, info = env.step(action)
+
+ top, bottom = get_top_bottom_elems(obs)
+ prev_top_x, prev_top_y = top_x, top_y
+ top_x, top_y = extract_coords_from_elem(top)
+ prev_bottom_x, prev_bottom_y = bottom_x, bottom_y
+ bottom_x, bottom_y = extract_coords_from_elem(bottom)
+
+ # no error
+ assert not obs["last_action_error"]
+
+ # movement
+ assert prev_top_x == top_x and prev_top_y > top_y
+ assert prev_bottom_x == bottom_x and prev_bottom_y > bottom_y
+
+ env.close()
+
+
+def test_tab_actions():
+ action_set = HighLevelActionSet(subsets=["tab", "nav"])
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": CHECKBOX_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+ assert len(obs["open_pages_urls"]) == 1
+ assert len(obs["open_pages_titles"]) == 1
+ assert obs["active_page_index"] == 0
+ assert obs["open_pages_urls"][obs["active_page_index"][0]] == obs["url"]
+
+ obs, reward, terminated, truncated, info = env.step("new_tab()")
+ assert not obs["last_action_error"]
+ assert len(obs["open_pages_urls"]) == 2
+ assert len(obs["open_pages_titles"]) == 2
+ assert obs["active_page_index"] == 1
+ assert obs["open_pages_urls"][obs["active_page_index"][0]] == obs["url"]
+
+ obs, reward, terminated, truncated, info = env.step(f"goto({repr(TEXTBOX_URL)})")
+ assert not obs["last_action_error"]
+ assert len(obs["open_pages_urls"]) == 2
+ assert len(obs["open_pages_titles"]) == 2
+ assert obs["active_page_index"] == 1
+ assert obs["open_pages_urls"][obs["active_page_index"][0]] == obs["url"]
+
+ obs, reward, terminated, truncated, info = env.step("tab_focus(0)")
+ assert not obs["last_action_error"]
+ assert len(obs["open_pages_urls"]) == 2
+ assert len(obs["open_pages_titles"]) == 2
+ assert obs["active_page_index"] == 0
+ assert obs["open_pages_urls"][obs["active_page_index"][0]] == obs["url"]
+
+ obs, reward, terminated, truncated, info = env.step("tab_close()")
+ assert not obs["last_action_error"]
+ assert len(obs["open_pages_urls"]) == 1
+ assert len(obs["open_pages_titles"]) == 1
+ assert obs["active_page_index"] == 0
+ assert obs["open_pages_urls"][obs["active_page_index"][0]] == obs["url"]
+
+ env.close()
+
+
+def test_mouse_down_up():
+ action_set = HighLevelActionSet(subsets=["bid", "coord"])
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": CHECKBOX_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ def get_checkbox_elem(obs):
+ soup = bs4.BeautifulSoup(
+ flatten_dom_to_str(
+ obs["dom_object"], obs["extra_element_properties"], with_center_coords=True
+ ),
+ "lxml",
+ )
+ checkbox = soup.find("input", attrs={"type": "checkbox", "id": "vehicle1"})
+ return checkbox
+
+ obs, info = env.reset()
+ checkbox = get_checkbox_elem(obs)
+
+ # box not checked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # click box 1 time
+ x, y = ast.literal_eval(checkbox.get("center"))
+ action = f"""\
+mouse_click({repr(x)}, {repr(y)})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nmouse_") == 1
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box checked
+ assert not obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ # click box 1 time
+ x, y = ast.literal_eval(checkbox.get("center"))
+ action = f"""\
+mouse_move(0, 0)
+mouse_move({repr(x)}, {repr(y)})
+mouse_down({repr(x)}, {repr(y)})
+mouse_up({repr(x)}, {repr(y)})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nmouse_") == 4
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box not checked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+ # click box 2 times
+ x, y = ast.literal_eval(checkbox.get("center"))
+ action = f"""\
+mouse_move(0, 0)
+mouse_move({repr(x)}, {repr(y)})
+mouse_down({repr(x)}, {repr(y)}, button="left")
+mouse_up({repr(x)}, {repr(y)}, "left")
+mouse_down({repr(x)}, {repr(y)})
+mouse_up({repr(x)}, {repr(y)})
+"""
+ python_action = action_set.to_python_code(action)
+
+ assert python_action.count("\nmouse_") == 6
+
+ obs, reward, term, trunc, info = env.step(action)
+ checkbox = get_checkbox_elem(obs)
+
+ # box not checked
+ assert not obs["last_action_error"]
+ assert not checkbox.has_attr("checked")
+
+
+# test that forced action can click an obstructed element
+@pytest.mark.parametrize("retry_with_force", [True, False])
+def test_forced_actions(retry_with_force):
+ action_set = HighLevelActionSet(subsets=["bid"], retry_with_force=retry_with_force)
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": OBSTRUCTED_CHECKBOX_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ obs, info = env.reset()
+
+ def get_checkbox(obs):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"id": "hobbies-checkbox-1"})
+ return checkbox
+
+ checkbox = get_checkbox(obs)
+
+ action = f"""
+ click({repr(checkbox.get(BID_ATTR))})
+ """
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ checkbox = get_checkbox(obs)
+ if retry_with_force:
+ assert not obs["last_action_error"]
+ assert checkbox.get("checked", False) == False
+ else:
+ assert obs["last_action_error"]
+ assert checkbox.has_attr("checked")
+
+ env.close()
+
+
+# TODO investigate why it takes ~1sec to mark each frame, although they are very small, and if we can do something about it
+@pytest.mark.slow
+def test_iframe_bid():
+ action_set = HighLevelActionSet(subsets=["bid"])
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": LOTS_OF_IFRAMES_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+
+ obs, info = env.reset()
+
+ def get_checkbox(obs, i):
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"id": f"checkbox{i}"})
+ return checkbox
+
+ # try to click on checkboxes
+ checkboxes = [
+ (0, "a"),
+ # (5, "f"),
+ # (26, "aA"),
+ (29, "aD"),
+ ]
+ for id, iframe_bid in checkboxes:
+
+ # try to click on checkbox
+ checkbox = get_checkbox(obs, id)
+ bid = checkbox.get(BID_ATTR)
+
+ # iframe bid should match
+ assert re.match(f"^{iframe_bid}[0-9]+$", bid)
+
+ action = f"""
+ click({repr(bid)})
+ """
+
+ obs, reward, terminated, truncated, info = env.step(action)
+ assert not obs["last_action_error"]
+
+ # checkbox should get checked
+ checkbox = get_checkbox(obs, id)
+ assert checkbox.has_attr("checked")
+
+ env.close()
diff --git a/BrowserGym/tests/core/test_actions_python.py b/BrowserGym/tests/core/test_actions_python.py
new file mode 100644
index 0000000000000000000000000000000000000000..69cc6237bb1f128709578f7ea84a969cfb33adf8
--- /dev/null
+++ b/BrowserGym/tests/core/test_actions_python.py
@@ -0,0 +1,60 @@
+import pytest
+
+from browsergym.core.action.python import PythonActionSet
+
+
+ACTIONS_TO_TEST = [
+ (
+ """\
+a = 0
+""",
+ """\
+a = 0
+""",
+ ),
+ (
+ """\
+```
+a = 0
+```
+""",
+ """\
+a = 0
+""",
+ ),
+ (
+ """\
+```python
+a = 0
+```
+""",
+ """\
+a = 0
+""",
+ ),
+ (
+ """\
+```python
+a = 0
+```
+This is an explanation
+```python
+b = 3
+```
+More explanations
+""",
+ """\
+a = 0
+
+b = 3
+""",
+ ),
+]
+
+
+@pytest.mark.parametrize("action,expected_code", ACTIONS_TO_TEST)
+def test_action_cleaning(action, expected_code):
+ action_set = PythonActionSet()
+ code = action_set.to_python_code(action)
+
+ assert code == expected_code
diff --git a/BrowserGym/tests/core/test_gym_envs.py b/BrowserGym/tests/core/test_gym_envs.py
new file mode 100644
index 0000000000000000000000000000000000000000..48fca3a32e119ee0a1e58440565c8e10d1fca2d7
--- /dev/null
+++ b/BrowserGym/tests/core/test_gym_envs.py
@@ -0,0 +1,313 @@
+import os
+import pathlib
+from time import time
+
+import bs4
+import gymnasium as gym
+import pytest
+
+# register openended gym environments
+import browsergym.core
+import browsergym.core.action
+from browsergym.core.action.highlevel import HighLevelActionSet
+from browsergym.core.action.python import PythonActionSet
+from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR
+from browsergym.utils.obs import flatten_dom_to_str
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+__TIMEOUT = 500
+
+__DATA_DIR = pathlib.Path(__file__).resolve().parent / "data"
+TEST_PAGE = f"file://{__DATA_DIR}/test_page.html"
+BASIC_IFRAME_PAGE = f"file://{__DATA_DIR}/basic_iframe_site/basic_iframe_2.html"
+
+
+def test_gym_env():
+ action_set = PythonActionSet()
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+ obs, info = env.reset()
+
+ assert not obs["last_action_error"]
+
+ obs, reward, term, trunc, info = env.step(
+ f"""\
+page.get_by_label("Name:").click()
+page.get_by_label("Name:").fill("Janice")
+page.get_by_label("Name:").press("Tab")
+page.get_by_label("Email:").fill("janice@mail.com")
+page.get_by_label("Email:").press("Tab")
+page.get_by_label("Age:", exact=True).fill("21")
+page.get_by_label("Age:", exact=True).press("Tab")
+"""
+ )
+
+ assert obs["last_action_error"] == ""
+ assert reward == 0
+ assert term == False
+ assert trunc == False
+
+ obs, reward, term, trunc, info = env.step(
+ f"""\
+page.get_by_label("Message:").fill("Hello")
+page.get_by_label("Message:").press("Tab")
+page.get_by_label("Subscribe to newsletter").check()
+page.get_by_label("Subscribe to newsletter").press("Tab")
+page.get_by_role("button", name="Submit").press("Enter")
+"""
+ )
+
+ assert obs["last_action_error"] == ""
+ assert reward == 0
+ assert term == False
+ assert trunc == False
+
+ obs, reward, term, trunc, info = env.step(
+ f"""\
+page.get_by_label("LABEL DOES NOT EXIST:").fill("Hello")
+page.get_by_role("button", name="Submit").press("Enter")
+"""
+ )
+
+ assert obs["last_action_error"] != ""
+ assert reward == 0
+ assert term == False
+ assert trunc == False
+
+ env.close()
+
+
+def test_max_episode_steps():
+ # no max_steps
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ obs, reward, term, trunc, info = env.step("")
+
+ assert term == False
+ assert trunc == False
+
+ obs, reward, term, trunc, info = env.step("")
+
+ assert term == False
+ assert trunc == False
+
+ # max_steps = 2
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ max_episode_steps=2,
+ )
+ obs, info = env.reset()
+
+ obs, reward, term, trunc, info = env.step("")
+
+ assert term == False
+ assert trunc == False
+
+ obs, reward, term, trunc, info = env.step("")
+
+ assert term == False
+ assert trunc == True
+
+ env.close()
+
+
+def test_active_page():
+ action_set = PythonActionSet()
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+ obs, info = env.reset()
+
+ assert len(obs["open_pages_urls"]) == 1
+ assert obs["active_page_index"] == 0
+
+ obs, reward, term, trunc, info = env.step("page.context.new_page()")
+
+ assert len(obs["open_pages_urls"]) == 2
+ assert obs["active_page_index"] == 1
+
+ obs, reward, term, trunc, info = env.step("page.context.pages[0].mouse.click(5, 5)")
+
+ assert len(obs["open_pages_urls"]) == 2
+ assert obs["active_page_index"] == 0
+
+ obs, reward, term, trunc, info = env.step("page.context.pages[1].mouse.click(5, 5)")
+
+ assert len(obs["open_pages_urls"]) == 2
+ assert obs["active_page_index"] == 1
+
+ obs, reward, term, trunc, info = env.step("page.context.pages[1].close()")
+
+ assert len(obs["open_pages_urls"]) == 1
+ assert obs["active_page_index"] == 0
+
+ obs, reward, term, trunc, info = env.step("page.close()")
+
+ assert len(obs["open_pages_urls"]) == 1
+ assert obs["active_page_index"] == 0
+
+ obs, reward, term, trunc, info = env.step("page.context.new_page()")
+
+ assert len(obs["open_pages_urls"]) == 2
+ assert obs["active_page_index"] == 1
+
+ obs, reward, term, trunc, info = env.step("page.close()")
+
+ assert len(obs["open_pages_urls"]) == 1
+ assert obs["active_page_index"] == 0
+
+ env.close()
+
+
+def test_nested_iframes_default_demo_mode():
+ demo_mode = "default"
+ action_set = HighLevelActionSet(demo_mode=demo_mode)
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": BASIC_IFRAME_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ inner_checkbox = soup.find("input", attrs={"id": "checkbox_2"})
+
+ assert inner_checkbox.has_attr("checked")
+ # click box
+ action = f"""\
+click({repr(inner_checkbox.get(BID_ATTR))})
+"""
+ click_start = time()
+ obs, _, _, _, _ = env.step(action)
+ click_end = time()
+ # clicking should be slow in demo mode
+ assert click_end - click_start > 1
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ inner_checkbox = soup.find("input", attrs={"id": "checkbox_2"})
+ # box is not checked; meaning it was clicked by the previous action
+ assert not inner_checkbox.has_attr("checked")
+
+ env.close()
+
+
+@pytest.mark.parametrize("global_demo_mode", [True, False])
+@pytest.mark.parametrize("demo_mode", [None, "off", "default", "only_visible_elements", "all_blue"])
+def test_demo_mode(global_demo_mode: bool, demo_mode: str):
+ action_set = HighLevelActionSet(demo_mode=demo_mode)
+ browsergym.core.action.set_global_demo_mode(global_demo_mode)
+
+ demo_mode_active = (global_demo_mode and demo_mode is None) or (
+ demo_mode is not None and demo_mode != "off"
+ )
+
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=action_set.to_python_code,
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ email_field = soup.find("input", attrs={"id": "email"})
+ checkbox = soup.find("input", attrs={"id": "subscribe"})
+
+ # check that the email field is empty
+ assert email_field.get("value") == ""
+
+ # check that the box is not checked
+ assert not checkbox.has_attr("checked")
+
+ # click box
+ action = f"""\
+click({repr(checkbox.get(BID_ATTR))})
+"""
+ obs, reward, terminated, truncated, info = env.step(action)
+ assert not obs["last_action_error"]
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+ checkbox = soup.find("input", attrs={"type": "checkbox", "id": "subscribe"})
+
+ # check that the box is checked
+ assert checkbox.has_attr("checked")
+
+ # clicking should be slow (only in demo mode)
+ action_time = info["action_exec_stop"] - info["action_exec_start"]
+ if demo_mode_active:
+ assert action_time > 2
+ else:
+ assert action_time <= 1.5
+
+ # fill box
+ action = f"""\
+fill({repr(email_field.get(BID_ATTR))}, "test@test")
+"""
+ obs, reward, terminated, truncated, info = env.step(action)
+ assert not obs["last_action_error"]
+
+ soup = bs4.BeautifulSoup(flatten_dom_to_str(obs["dom_object"]), "lxml")
+
+ # email field has been filled correctly
+ email_field = soup.find("input", attrs={"id": "email"})
+ assert email_field.get("value") == "test@test"
+
+ # typing should be slow (only in demo mode)
+ action_time = info["action_exec_stop"] - info["action_exec_start"]
+ if demo_mode_active:
+ assert action_time > 2
+ else:
+ assert action_time <= 1.5
+
+ env.close()
+
+
+@pytest.mark.parametrize("resizeable_window", (True, False))
+@pytest.mark.parametrize("size", ((1600, 1200), (800, 800)))
+def test_resizeable_window(resizeable_window, size):
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ viewport={"width": size[0], "height": size[1]},
+ resizeable_window=resizeable_window,
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+
+ assert (obs["screenshot"].shape[1], obs["screenshot"].shape[0]) == size
+
+ env.close()
diff --git a/BrowserGym/tests/core/test_observation.py b/BrowserGym/tests/core/test_observation.py
new file mode 100644
index 0000000000000000000000000000000000000000..36bb341937b265d8792199ac4722d404b3199049
--- /dev/null
+++ b/BrowserGym/tests/core/test_observation.py
@@ -0,0 +1,819 @@
+import ast
+import os
+from pathlib import Path
+
+import bs4
+import gymnasium as gym
+import numpy as np
+import pytest
+import regex as re
+
+# register gym environments
+import browsergym.core
+from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR
+from browsergym.core.observation import (
+ _post_extract,
+ _pre_extract,
+ extract_all_frame_axtrees,
+ extract_dom_snapshot,
+ extract_merged_axtree,
+ extract_screenshot,
+)
+from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+__TIMEOUT = 500
+__VIEWPORT = {"width": 800, "height": 600}
+
+__DATA_DIR = Path(__file__).resolve().parent / "data"
+
+TEST_PAGE = f"file://{__DATA_DIR}/test_page.html"
+TEST_PAGE_2 = f"file://{__DATA_DIR}/test_page_2.html"
+MULTI_IFRAME_URL = f"file://{__DATA_DIR}/basic_iframe_site/basic_iframe_2.html"
+SHADOW_DOM_URL = f"file://{__DATA_DIR}/basic_shadow_dom_site/basic_shadow_dom.html"
+SIMPLE_SHADOW_DOM_URL = f"file://{__DATA_DIR}/basic_shadow_dom_site/simple_shadow_dom.html"
+BASIC_IFRAME_URL = f"file://{__DATA_DIR}/basic_shadow_iframe_site/basic_iframe.html"
+BASIC_IFRAME_2_URL = f"file://{__DATA_DIR}/basic_shadow_iframe_site/basic_iframe_2.html"
+INNER_IFRAME_URL = f"file://{__DATA_DIR}/basic_shadow_iframe_site/inner-iframe.html"
+OUTER_IFRAME_URL = f"file://{__DATA_DIR}/basic_shadow_iframe_site/outer-iframe.html"
+CUSTOM_PAGE_URL = f"file://{__DATA_DIR}/custom_page/basic_iframe.html"
+MULTI_IFRAME_URL = f"file://{__DATA_DIR}/basic_iframe_site/basic_iframe_2.html"
+
+
+@pytest.mark.skip(reason="TODO: how to get the final viewport size right?")
+def test_extract_screenshot():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ _pre_extract(env.unwrapped.page)
+ screenshot = extract_screenshot(env.unwrapped.page)
+ _post_extract(env.unwrapped.page)
+
+ # 3D array (height, width, rgb) of unsigned bytes (between 0 and 255)
+ assert isinstance(screenshot, np.ndarray)
+ assert len(screenshot.shape) == 3
+ assert screenshot.shape[0] == __VIEWPORT["height"]
+ assert screenshot.shape[1] == __VIEWPORT["width"]
+ assert screenshot.shape[2] == 3 # RGB
+ assert screenshot.dtype == np.uint8
+
+ env.close()
+
+
+def test_extract_axtree_simple():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ _pre_extract(env.unwrapped.page)
+ all_frame_axtrees = extract_all_frame_axtrees(env.unwrapped.page)
+ merged_axtree = extract_merged_axtree(env.unwrapped.page)
+ _post_extract(env.unwrapped.page)
+
+ # single frame
+ assert len(all_frame_axtrees) == 1
+ assert len(next(iter(all_frame_axtrees.values()))["nodes"]) == len(merged_axtree["nodes"])
+
+ env.close()
+
+
+def test_extract_axtree_multi_iframe():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": MULTI_IFRAME_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ _pre_extract(env.unwrapped.page)
+ all_frame_axtrees = extract_all_frame_axtrees(env.unwrapped.page)
+ merged_axtree = extract_merged_axtree(env.unwrapped.page)
+ _post_extract(env.unwrapped.page)
+
+ # multiple frames
+ assert len(all_frame_axtrees) == 3
+
+ # total number of nodes in merged and individual frame axtrees should be equal
+ n_nodes = 0
+ for frame_id, frame_axtree in all_frame_axtrees.items():
+ n_nodes += len(frame_axtree["nodes"])
+
+ assert n_nodes == len(merged_axtree["nodes"])
+
+ env.close()
+
+
+def test_extract_dom_simple():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ _pre_extract(env.unwrapped.page)
+ dom_snapshot = extract_dom_snapshot(env.unwrapped.page)
+ _post_extract(env.unwrapped.page)
+
+ # single frame
+ assert len(dom_snapshot["documents"]) == 1
+
+ env.close()
+
+
+def test_extract_dom_multi_iframe():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": MULTI_IFRAME_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ _pre_extract(env.unwrapped.page)
+ dom_snapshot = extract_dom_snapshot(env.unwrapped.page)
+ _post_extract(env.unwrapped.page)
+
+ # multiple frames
+ assert len(dom_snapshot["documents"]) == 3
+
+ env.close()
+
+
+def test_simple_shadowdom():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": SIMPLE_SHADOW_DOM_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ # retrieve an input element inside the shadowDOM
+ elem = env.unwrapped.page.get_by_placeholder("Level 1.1 Text Field 1")
+ assert elem.count() == 1
+
+ # elem should have a browsergym_id in its BID_ATTR attribute
+ elem_id = elem.get_attribute(BID_ATTR)
+ assert elem_id is not None
+
+ # elem should not have an aria-description (it should have been cleaned)
+ aria_description = elem.get_attribute("aria-description")
+ assert aria_description is None
+
+ # elem should not have an aria-roledescription (it should have been cleaned)
+ aria_roledescription = elem.get_attribute("aria-roledescription")
+ assert aria_roledescription is None
+
+ # check that elem can be retrieved correctly using its browsergym_id
+ elem2 = env.unwrapped.page.get_by_test_id(elem_id)
+ assert elem2.count() == 1
+ assert env.unwrapped.page.evaluate(
+ "([node1, node2]) => {return node1.isEqualNode(node2);}",
+ [elem.element_handle(), elem2.element_handle()],
+ )
+
+ env.close()
+
+
+def test_nested_shadowdom():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": SHADOW_DOM_URL},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ # retrieve an input element inside the nested shadowDOM
+ elem = env.unwrapped.page.get_by_placeholder("Level 2.4 Text Field 2")
+ assert elem.count() == 1
+
+ # elem should have a browsergym_id in its BID_ATTR attribute
+ elem_id = elem.get_attribute(BID_ATTR)
+ assert elem_id is not None
+
+ # elem should not have an aria-description (it should have been cleaned)
+ aria_description = elem.get_attribute("aria-description")
+ assert aria_description is None
+
+ # elem should not have an aria-roledescription (it should have been cleaned)
+ aria_roledescription = elem.get_attribute("aria-roledescription")
+ assert aria_roledescription is None
+
+ # check that elem can be retrieved correctly using its browsergym_id
+ elem2 = env.unwrapped.page.get_by_test_id(elem_id)
+ assert elem2.count() == 1
+ assert env.unwrapped.page.evaluate(
+ "([node1, node2]) => {return node1.isEqualNode(node2);}",
+ [elem.element_handle(), elem2.element_handle()],
+ )
+
+ env.close()
+
+
+@pytest.mark.parametrize(
+ "url",
+ [
+ TEST_PAGE,
+ MULTI_IFRAME_URL,
+ SIMPLE_SHADOW_DOM_URL,
+ BASIC_IFRAME_URL,
+ BASIC_IFRAME_2_URL,
+ INNER_IFRAME_URL,
+ OUTER_IFRAME_URL,
+ ],
+)
+def test_dom_has_bids_no_aria(url):
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": url},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ viewport=__VIEWPORT,
+ timeout=__TIMEOUT,
+ )
+ obs, info = env.reset()
+
+ # exceptions
+ dom_node_names_without_bid = ["html", "#text", "#document", "#comment"]
+ axtree_roles_without_bid = ["RootWebArea", "none", "generic", "StaticText", "InlineTextBox"]
+
+ # 1. test the DOM snapshot for BID_ATTR, "aria-description" and "aria-roledescription"
+
+ # check all HTML elements in the DOM for unique browsergym id
+ dom = obs["dom_object"]
+ bids = []
+ for doc in dom["documents"]:
+ for node_name_id, attributes in zip(doc["nodes"]["nodeName"], doc["nodes"]["attributes"]):
+ node_name = dom["strings"][node_name_id]
+ # read the node's attributes
+ j = 0
+ bid = None
+ while j < len(attributes):
+ attr_name = dom["strings"][attributes[j]]
+ attr_value = dom["strings"][attributes[j + 1]]
+
+ # print(f"{node_name} {attr_name}: {attr_value}")
+
+ # check that the "aria-roledescription" attribute is absent (this is specific to this test page)
+ assert attr_name != "aria-roledescription"
+
+ # check that the "aria-description" attribute is absent (this is specific to this test page)
+ assert attr_name != "aria-description"
+
+ # extract the browsergym id from the BID_ATTR attribute
+ if attr_name == BID_ATTR:
+ bid = attr_value
+ j += 2
+
+ # check that all elements (with exceptions) have a browsergym id
+ if node_name not in dom_node_names_without_bid:
+ assert bid is not None
+
+ if bid is not None:
+ bids.append(bid)
+
+ # check that all browsergym ids are unique
+ assert len(bids) == len(set(bids))
+
+ # 2. test the AXTree for "browsergym_id" and "description" properties
+ axtree = obs["axtree_object"]
+ bids = []
+ for node in axtree["nodes"]:
+ bid = node.get("browsergym_id", None)
+
+ # check that the "aria-roledescription" attribute is absent (this is specific to this test page)
+ for property in node.get("properties", []):
+ assert property["name"] != "roledescription"
+
+ # check that the "aria-description" attribute is absent (this is specific to this test page)
+ assert "description" not in node
+
+ # check that all elements (with exceptions) have a browsergym id
+ if node["role"]["value"] not in axtree_roles_without_bid:
+ assert bid is not None
+
+ if bid is not None:
+ bids.append(bid)
+
+ # check that all browsergym ids are unique
+ assert len(bids) == len(set(bids))
+
+ env.close()
+
+
+def test_dom_to_text():
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={"start_url": TEST_PAGE_2},
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ timeout=__TIMEOUT,
+ action_mapping=None,
+ )
+ obs, info = env.reset()
+
+ dom = flatten_dom_to_str(obs["dom_object"])
+ assert isinstance(dom, str)
+ assert "Subscribe to newsletter" in dom
+ assert "Janice" not in dom
+
+ obs, reward, term, trunc, info = env.step(
+ f"""\
+page.get_by_label("Name:").click()
+page.get_by_label("Name:").fill("Janice")
+page.get_by_label("Name:").press("Tab")
+page.get_by_label("Email:").fill("janice@mail.com")
+page.get_by_label("Email:").press("Tab")
+page.get_by_label("Age:", exact=True).fill("21")
+page.get_by_label("Age:", exact=True).press("Tab")
+"""
+ )
+
+ dom = flatten_dom_to_str(obs["dom_object"])
+ assert "Janice" in dom
+ assert "janice@mail.com" in dom
+
+ dom = flatten_dom_to_str(
+ obs["dom_object"],
+ extra_properties=obs["extra_element_properties"],
+ with_visible=True,
+ with_clickable=True,
+ with_center_coords=True,
+ with_bounding_box_coords=True,
+ with_som=True,
+ )
+ assert 'box="(' in dom
+ assert 'center="(' in dom
+ assert 'clickable="" som="" type="submit" value="Submit" visible=""' in dom
+ assert 'head bid="1">' in dom
+ assert 'clickable="" for="email" visible=""' in dom
+ assert "Text within a non-html tag" in dom
+ assert "Text that should not be visible" in dom
+
+ dom = flatten_dom_to_str(
+ obs["dom_object"], extra_properties=obs["extra_element_properties"], filter_som_only=True
+ )
+ assert 'for="email"' not in dom
+ assert 'type="submit" value="Submit"' in dom
+ assert "Text within a non-html tag" not in dom
+ assert "Text that should not be visible" not in dom
+
+ dom = flatten_dom_to_str(
+ obs["dom_object"],
+ extra_properties=obs["extra_element_properties"],
+ filter_visible_only=True,
+ )
+ assert " None:
+ """
+ Args:
+ seed: random seed.
+ start_url: str, the url for the starting page.
+ goal: str, the initial goal.
+
+ """
+ super().__init__(seed)
+ self.start_url = start_url
+ self.goal = [
+ {"type": "text", "text": "This is a mock task with an image goal."},
+ {
+ "type": "image_url",
+ "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=",
+ },
+ ]
+
+ def setup(self, page: playwright.sync_api.Page) -> tuple[str, dict]:
+ page.goto(self.start_url, timeout=10000)
+ return self.goal, {}
+
+ def teardown(self) -> None:
+ pass
+
+ def validate(
+ self, page: playwright.sync_api.Page, chat_messages: list[str]
+ ) -> Tuple[float, bool, str, dict]:
+ reward, done, msg, info = 0, False, "", {}
+
+ for message in chat_messages:
+ if message["role"] == "user" and message["message"] == "exit":
+ done = True
+ break
+
+ return reward, done, msg, info
+
+
+def test_mock_image_goal_task():
+ env = BrowserEnv(MockImageGoalTask)
+ obs, _ = env.reset()
+
+ assert "goal_object" in obs
+ assert len(obs["goal_object"]) == 2
+ assert obs["goal_object"][0]["type"] == "text"
+ assert obs["goal_object"][0]["text"] == "This is a mock task with an image goal."
+ assert obs["goal_object"][1]["type"] == "image_url"
+
+ env.chat.add_message("user", "exit")
+ obs, reward, terminated, _, _ = env.step("send_msg_to_user('bye')")
+
+ assert reward == 0
+ assert terminated is True
+
+ env.close()
+
+
+if __name__ == "__main__":
+ test_mock_image_goal_task()
diff --git a/BrowserGym/tests/experiments/__init__.py b/BrowserGym/tests/experiments/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f09d6fbde51609da41e1041eb3fb8125d808cb
--- /dev/null
+++ b/BrowserGym/tests/experiments/__init__.py
@@ -0,0 +1,2 @@
+# bugfix: use same playwright instance in browsergym and pytest
+from ..utils import setup_playwright
diff --git a/BrowserGym/tests/experiments/test_benchmark.py b/BrowserGym/tests/experiments/test_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..9222be11c98c628551499aab0d43cf218a0fcc30
--- /dev/null
+++ b/BrowserGym/tests/experiments/test_benchmark.py
@@ -0,0 +1,351 @@
+import dataclasses
+import os
+import random
+import re
+import tempfile
+
+import numpy as np
+import pytest
+
+from browsergym.core.action.base import AbstractActionSet
+from browsergym.experiments.agent import Agent
+from browsergym.experiments.benchmark import Benchmark, HighLevelActionSetArgs
+from browsergym.experiments.benchmark.configs import DEFAULT_BENCHMARKS
+from browsergym.experiments.benchmark.utils import make_env_args_list_from_fixed_seeds
+from browsergym.experiments.loop import AbstractAgentArgs, ExpArgs, get_exp_result
+from browsergym.utils.obs import flatten_axtree_to_str
+
+
+class MiniwobTestAgent(Agent):
+
+ def __init__(self, action_set: AbstractActionSet):
+ self.action_set = action_set
+
+ def obs_preprocessor(self, obs: dict):
+ return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])}
+
+ def get_action(self, obs: dict) -> tuple[str, dict]:
+ match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE)
+
+ if match:
+ bid = match.group(1)
+ action = f'click("{bid}")'
+ else:
+ raise Exception("Can't find the button's bid")
+
+ return action, dict(think="I'm clicking the button as requested.")
+
+
+@dataclasses.dataclass
+class MiniwobTestAgentArgs(AbstractAgentArgs):
+ high_level_action_set: HighLevelActionSetArgs = None
+
+ def make_agent(self):
+ return MiniwobTestAgent(action_set=self.high_level_action_set.make_action_set())
+
+
+def test_build_benchmarks():
+ expected_bench_size = {
+ "miniwob": 125 * 5,
+ "miniwob_tiny_test": 2 * 2,
+ "webarena": 812,
+ "webarena_tiny": 6,
+ "visualwebarena": 910,
+ "visualwebarena_tiny": 4,
+ "workarena_l1": 33 * 10,
+ "workarena_l2_agent_curriculum_eval": 235,
+ "workarena_l3_agent_curriculum_eval": 235,
+ "assistantbench": 214,
+ "weblinx": 31586,
+ }
+ for name, benchmark_builder in DEFAULT_BENCHMARKS.items():
+ benchmark = benchmark_builder()
+ assert name == benchmark.name
+ assert benchmark.env_args_list # non-empty
+ assert benchmark.task_metadata is not None
+ assert len(benchmark.env_args_list) == expected_bench_size[name]
+ benchmark_bis = Benchmark.from_json(benchmark.to_json())
+ assert benchmark.to_dict() == benchmark_bis.to_dict()
+
+
+def test_benchmark_subset():
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["miniwob"]()
+
+ benchmark_subset = benchmark.subset_from_regexp(column="task_name", regexp="click")
+ assert len(benchmark_subset.env_args_list) == 31 * 5
+ assert benchmark_subset.name == "miniwob[task_name=/click/]"
+
+ benchmark_subset_1 = benchmark_subset.subset_from_regexp(
+ column="miniwob_category", regexp="original"
+ )
+ benchmark_subset_2 = benchmark_subset.subset_from_glob(
+ column="miniwob_category", glob="original"
+ )
+
+ assert benchmark_subset_1.name == "miniwob[task_name=/click/][miniwob_category=/original/]"
+ assert benchmark_subset_2.name == "miniwob[task_name=/click/][miniwob_category=original]"
+
+ dict_1 = benchmark_subset_1.to_dict()
+ dict_1.pop("name")
+ dict_2 = benchmark_subset_2.to_dict()
+ dict_2.pop("name")
+
+ assert dict_1 == dict_2
+
+
+def test_benchmark_subset_from_task_ratio():
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["webarena"]()
+
+ # Store initial random state
+ initial_state = random.getstate()
+
+ benchmark_subset = benchmark.subset_from_task_ratio(ratio=0.5, seed=1)
+ assert len(benchmark_subset.env_args_list) == 812 // 2
+ assert benchmark_subset.name == "webarena[ratio=0.5, seed=1]"
+
+ # Verify global random state hasn't changed
+ assert random.getstate() == initial_state
+
+ benchmark_subset_1 = benchmark_subset.subset_from_task_ratio(ratio=0.5, seed=1)
+ benchmark_subset_2 = benchmark_subset.subset_from_task_ratio(ratio=0.5, seed=2)
+
+ # Verify global random state still hasn't changed
+ assert random.getstate() == initial_state
+
+ # Check the task lists are different
+ assert not np.all(
+ [
+ env_args.task_name == env_args_2.task_name
+ for env_args, env_args_2 in zip(
+ benchmark_subset_1.env_args_list, benchmark_subset_2.env_args_list
+ )
+ ]
+ )
+
+ dict_1 = benchmark_subset_1.to_dict()
+ dict_1.pop("name")
+ dict_2 = benchmark_subset_2.to_dict()
+ dict_2.pop("name")
+ assert len(dict_1["env_args_list"]) == len(dict_2["env_args_list"])
+ assert dict_1 != dict_2
+
+
+def test_prepare_backend_miniwob():
+ MINIWOB_URL = os.environ["MINIWOB_URL"]
+ try:
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["miniwob"]()
+
+ benchmark.prepare_backends()
+
+ del os.environ["MINIWOB_URL"]
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+
+ os.environ["MINIWOB_URL"] = ""
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+ finally:
+ os.environ["MINIWOB_URL"] = MINIWOB_URL
+
+
+def test_prepare_backend_assistantbench():
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["assistantbench"]()
+ benchmark.prepare_backends()
+
+
+@pytest.mark.skip
+def test_prepare_backend_webarena():
+ WA_FULL_RESET = os.environ["WA_FULL_RESET"]
+ try:
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["webarena"]()
+
+ benchmark.prepare_backends()
+
+ del os.environ["WA_FULL_RESET"]
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+
+ os.environ["WA_FULL_RESET"] = "http://localhost:12345/reset"
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+ finally:
+ os.environ["WA_FULL_RESET"] = WA_FULL_RESET
+
+
+@pytest.mark.skip
+def test_prepare_backend_visualwebarena():
+ VWA_FULL_RESET = os.environ["VWA_FULL_RESET"]
+ try:
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["visualwebarena"]()
+
+ benchmark.prepare_backends()
+
+ del os.environ["VWA_FULL_RESET"]
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+
+ os.environ["VWA_FULL_RESET"] = "http://localhost:12345/reset"
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+ finally:
+ os.environ["VWA_FULL_RESET"] = VWA_FULL_RESET
+
+
+@pytest.mark.skip
+def test_prepare_backend_weblinx():
+ BROWSERGYM_WEBLINX_CACHE_DIR = os.environ["BROWSERGYM_WEBLINX_CACHE_DIR"]
+ try:
+ benchmark: Benchmark = DEFAULT_BENCHMARKS["weblinx"]()
+
+ benchmark.prepare_backends()
+
+ del os.environ["BROWSERGYM_WEBLINX_CACHE_DIR"]
+ with pytest.raises(Exception):
+ benchmark.prepare_backends()
+
+ finally:
+ os.environ["BROWSERGYM_WEBLINX_CACHE_DIR"] = BROWSERGYM_WEBLINX_CACHE_DIR
+
+
+def test_run_mock_benchmark():
+ benchmark = Benchmark(
+ name="miniwob_click_test",
+ high_level_action_set_args=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ strict=False,
+ retry_with_force=True,
+ demo_mode="off",
+ ),
+ is_multi_tab=False,
+ supports_parallel_seeds=True,
+ backends=["miniwob"],
+ env_args_list=make_env_args_list_from_fixed_seeds(
+ task_list=["miniwob.click-test"],
+ max_steps=5,
+ fixed_seeds=[0, 1],
+ ),
+ )
+
+ for env_args in benchmark.env_args_list:
+ agent_args = MiniwobTestAgentArgs(
+ high_level_action_set=benchmark.high_level_action_set_args
+ )
+ exp_args = ExpArgs(
+ agent_args=agent_args,
+ env_args=env_args,
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ exp_args.prepare(tmp_dir)
+ exp_args.run()
+ exp_result = get_exp_result(exp_args.exp_dir)
+ exp_record = exp_result.get_exp_record()
+
+ target = {
+ "env_args.task_name": "miniwob.click-test",
+ "env_args.headless": True,
+ "env_args.record_video": False,
+ "n_steps": 1,
+ "cum_reward": 1.0,
+ "terminated": True,
+ "truncated": False,
+ }
+
+ assert len(exp_result.steps_info) == 2
+
+ for key, target_val in target.items():
+ assert key in exp_record
+ assert exp_record[key] == target_val
+
+
+def test_dependency_graphs():
+ benchmark = Benchmark(
+ name="my_bench",
+ high_level_action_set_args=HighLevelActionSetArgs(
+ subsets=["bid"],
+ multiaction=False,
+ strict=False,
+ retry_with_force=True,
+ demo_mode="off",
+ ),
+ is_multi_tab=False,
+ supports_parallel_seeds=True,
+ backends=["miniwob"],
+ env_args_list=make_env_args_list_from_fixed_seeds(
+ task_list=["miniwob.click-test"],
+ max_steps=5,
+ fixed_seeds=[0, 1],
+ ),
+ )
+
+ # one task, two seeds
+ task_dependencies = benchmark.dependency_graph_over_tasks()
+ assert task_dependencies == {"miniwob.click-test": []}
+
+ env_args_dependencies = benchmark.dependency_graphs_over_env_args()
+ assert env_args_dependencies == [{0: [], 1: []}]
+
+ # change to no parallel seed support
+ benchmark.supports_parallel_seeds = False
+ env_args_dependencies = benchmark.dependency_graphs_over_env_args()
+ assert env_args_dependencies == [{0: []}, {1: []}]
+
+ # webarena, 3 tasks x 1 seed
+ benchmark = DEFAULT_BENCHMARKS["webarena"]().subset_from_regexp(
+ column="task_name", regexp=r"^webarena\.[012]$"
+ )
+
+ task_dependencies = benchmark.dependency_graph_over_tasks()
+ assert task_dependencies == {
+ "webarena.0": [],
+ "webarena.1": ["webarena.0"],
+ "webarena.2": ["webarena.1"],
+ }
+
+ env_args_dependencies = benchmark.dependency_graphs_over_env_args()
+ assert env_args_dependencies == [{0: [], 1: [0], 2: [1]}]
+
+ # workarena L2, 2 task x (2 seeds, 1 seed)
+ benchmark = DEFAULT_BENCHMARKS["workarena_l2_agent_curriculum_eval"]().subset_from_regexp(
+ column="task_name",
+ regexp=r"^workarena\.servicenow\.workload-balancing-small-l2$|^workarena\.servicenow\.easy-expense-management-small-l2$",
+ )
+
+ task_dependencies = benchmark.dependency_graph_over_tasks()
+ assert task_dependencies == {
+ "workarena.servicenow.workload-balancing-small-l2": [],
+ "workarena.servicenow.easy-expense-management-small-l2": [],
+ }
+
+ env_args_dependencies = benchmark.dependency_graphs_over_env_args()
+ assert env_args_dependencies == [{0: [], 1: [], 2: []}]
+
+ # change to no parallel seed support
+ benchmark.supports_parallel_seeds = False
+ env_args_dependencies = benchmark.dependency_graphs_over_env_args()
+ assert env_args_dependencies == [{0: [], 2: []}, {1: []}]
+
+ # webarena, 6 dependent tasks x 1 seed
+ benchmark = DEFAULT_BENCHMARKS["webarena"]().subset_from_regexp(
+ column="task_name",
+ regexp=r"^webarena\.533$|^webarena\.537$|^webarena\.552$|^webarena\.410$|^webarena\.561$|^webarena\.562$",
+ )
+
+ task_dependencies = benchmark.dependency_graph_over_tasks()
+ assert {k: set(v) for k, v in task_dependencies.items()} == {
+ k: set(v)
+ for k, v in {
+ "webarena.410": [],
+ "webarena.533": [],
+ "webarena.537": ["webarena.533"],
+ "webarena.552": ["webarena.410", "webarena.537"],
+ "webarena.561": ["webarena.552"],
+ "webarena.562": ["webarena.552", "webarena.561"],
+ }.items()
+ }
+
+ env_args_dependencies = benchmark.dependency_graphs_over_env_args()
+ assert [{k: set(v) for k, v in deps.items()} for deps in env_args_dependencies] == [
+ {k: set(v) for k, v in {0: [], 1: [], 2: [1], 3: [0, 2], 4: [3], 5: [3, 4]}.items()}
+ ]
diff --git a/BrowserGym/tests/experiments/test_bgym.py b/BrowserGym/tests/experiments/test_bgym.py
new file mode 100644
index 0000000000000000000000000000000000000000..193822caa4dd95bc7774b96ab390514c96a30f21
--- /dev/null
+++ b/BrowserGym/tests/experiments/test_bgym.py
@@ -0,0 +1,9 @@
+import bgym
+import pytest
+
+
+def test_classes():
+ bgym.EnvArgs(task_name="something")
+ bgym.HighLevelActionSet()
+ with pytest.raises(TypeError):
+ bgym.Agent()
diff --git a/BrowserGym/tests/experiments/test_exp_loop.py b/BrowserGym/tests/experiments/test_exp_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..a954f9b7f5e1fedcfd413c4490c762ff23d4aa9a
--- /dev/null
+++ b/BrowserGym/tests/experiments/test_exp_loop.py
@@ -0,0 +1,72 @@
+import re
+import tempfile
+import logging
+import dataclasses
+
+from browsergym.core.action.highlevel import HighLevelActionSet
+from browsergym.experiments.agent import Agent
+from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result
+from browsergym.utils.obs import flatten_axtree_to_str
+
+
+class MiniwobTestAgent(Agent):
+
+ action_set = HighLevelActionSet(subsets="bid")
+
+ def obs_preprocessor(self, obs: dict):
+ return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])}
+
+ def get_action(self, obs: dict) -> tuple[str, dict]:
+ match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE)
+
+ if match:
+ bid = match.group(1)
+ action = f'click("{bid}")'
+ else:
+ raise Exception("Can't find the button's bid")
+
+ return action, dict(think="I'm clicking the button as requested.")
+
+
+@dataclasses.dataclass
+class MiniwobTestAgentArgs(AbstractAgentArgs):
+ def make_agent(self):
+ return MiniwobTestAgent()
+
+
+def test_run_exp():
+ exp_args = ExpArgs(
+ agent_args=MiniwobTestAgentArgs(),
+ env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42),
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ exp_args.prepare(tmp_dir)
+ exp_args.run()
+ exp_result = get_exp_result(exp_args.exp_dir)
+ exp_record = exp_result.get_exp_record()
+
+ target = {
+ "env_args.task_name": "miniwob.click-test",
+ "env_args.task_seed": 42,
+ "env_args.headless": True,
+ "env_args.record_video": False,
+ "n_steps": 1,
+ "cum_reward": 1.0,
+ "terminated": True,
+ "truncated": False,
+ }
+
+ assert len(exp_result.steps_info) == 2
+
+ for key, target_val in target.items():
+ assert key in exp_record
+ assert exp_record[key] == target_val
+
+ # TODO investigate why it's taking almost 5 seconds to solve
+ assert exp_record["stats.cum_step_elapsed"] < 5
+ if exp_record["stats.cum_step_elapsed"] > 3:
+ t = exp_record["stats.cum_step_elapsed"]
+ logging.warning(
+ f"miniwob.click-test is taking {t:.2f}s (> 3s) to solve with an oracle."
+ )
diff --git a/BrowserGym/tests/miniwob/__init__.py b/BrowserGym/tests/miniwob/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f09d6fbde51609da41e1041eb3fb8125d808cb
--- /dev/null
+++ b/BrowserGym/tests/miniwob/__init__.py
@@ -0,0 +1,2 @@
+# bugfix: use same playwright instance in browsergym and pytest
+from ..utils import setup_playwright
diff --git a/BrowserGym/tests/miniwob/test_base.py b/BrowserGym/tests/miniwob/test_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0fdf330f7e08c148206ecbed9ad8692135ea2f
--- /dev/null
+++ b/BrowserGym/tests/miniwob/test_base.py
@@ -0,0 +1,196 @@
+import os
+import pytest
+import time
+import gymnasium as gym
+
+# register gym environments
+import browsergym.miniwob
+
+from browsergym.miniwob.all import (
+ ClickButtonTask,
+ ClickOptionTask,
+ DrawLineTask,
+ LoginUserTask,
+)
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+TASKS = [ClickButtonTask, ClickOptionTask, DrawLineTask, LoginUserTask]
+
+
+@pytest.mark.parametrize("task_cls", TASKS)
+def test_validate_teardown(task_cls):
+ pw = browsergym.core._get_global_playwright()
+
+ browser = pw.chromium.launch(headless=__HEADLESS, slow_mo=__SLOW_MO)
+ context = browser.new_context()
+ page = context.new_page()
+
+ task = task_cls(seed=42)
+ task.setup(page=page)
+
+ reward, done, msg, info = task.validate(page, [])
+
+ assert done is False
+
+ task.teardown()
+
+ context.close()
+ browser.close()
+
+
+@pytest.mark.parametrize("task_cls", TASKS)
+def test_episode_max_time(task_cls):
+ pw = browsergym.core._get_global_playwright()
+
+ browser = pw.chromium.launch(headless=__HEADLESS, slow_mo=__SLOW_MO)
+ context = browser.new_context()
+ page = context.new_page()
+
+ task = task_cls(seed=42, episode_max_time=0.2)
+ task.setup(page=page)
+
+ time.sleep(0.5)
+
+ reward, done, msg, info = task.validate(page, [])
+
+ assert done is True
+ assert reward == 0
+
+ task.teardown()
+
+ context.close()
+ browser.close()
+
+
+@pytest.mark.parametrize("task_cls", TASKS)
+def test_remove_human_display(task_cls):
+ pw = browsergym.core._get_global_playwright()
+
+ browser = pw.chromium.launch(headless=__HEADLESS, slow_mo=__SLOW_MO)
+
+ # remove display
+
+ context = browser.new_context()
+ page = context.new_page()
+
+ task = task_cls(seed=42, remove_human_display=True)
+ task.setup(page=page)
+
+ for element_id in ["reward-display", "click-canvas", "sync-task-cover"]:
+ element_in_dom = page.evaluate(f"!!document.getElementById('{element_id}')")
+ assert not element_in_dom
+
+ assert page.evaluate(f"document.getElementById('query').innerHTML") == ""
+
+ for element_id in ["wrap", "area"]:
+ element_in_dom = page.evaluate(f"!!document.getElementById('{element_id}')")
+ assert element_in_dom
+
+ task.teardown()
+
+ context.close()
+
+ # keep display
+
+ context = browser.new_context()
+ page = context.new_page()
+
+ task = task_cls(seed=42, remove_human_display=False)
+ task.setup(page=page)
+
+ for element_id in ["reward-display", "click-canvas", "sync-task-cover"]:
+ element_in_dom = page.evaluate(f"!!document.getElementById('{element_id}')")
+ assert element_in_dom
+
+ assert page.evaluate(f"document.getElementById('query').innerHTML") != ""
+
+ for element_id in ["wrap", "area"]:
+ element_in_dom = page.evaluate(f"!!document.getElementById('{element_id}')")
+ assert element_in_dom
+
+ task.teardown()
+
+ context.close()
+ browser.close()
+
+
+@pytest.mark.skip(reason="TODO: how to get the final viewport size right?")
+@pytest.mark.parametrize("task_cls", TASKS)
+def test_viewport(task_cls):
+ env = gym.make(
+ f"browsergym/{task_cls.get_task_id()}",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset(seed=42)
+
+ screenshot = obs["screenshot"]
+
+ # 3D array (height, width, rgb) of unsigned bytes (between 0 and 255)
+ # Miniwob viewport should be (320x500)
+ assert screenshot.shape[0] == 320
+ assert screenshot.shape[1] == 500
+ assert screenshot.shape[2] == 3 # RGB
+
+ env.close()
+
+
+@pytest.mark.parametrize("task_cls", TASKS)
+def test_forbidden_navigation(task_cls):
+ pw = browsergym.core._get_global_playwright()
+
+ browser = pw.chromium.launch(headless=__HEADLESS, slow_mo=__SLOW_MO)
+ context = browser.new_context()
+ page = context.new_page()
+
+ task = task_cls(seed=42)
+ task.setup(page=page)
+
+ reward, done, msg, info = task.validate(page, [])
+
+ assert reward == 0.0 and done == False
+
+ page.goto("http://www.google.com")
+
+ reward, done, msg, info = task.validate(page, [])
+
+ assert reward == 0.0 and done == True
+
+ task.teardown()
+
+ context.close()
+ browser.close()
+
+
+@pytest.mark.parametrize("task_cls", TASKS)
+def test_forbidden_navigation_2(task_cls):
+ pw = browsergym.core._get_global_playwright()
+
+ browser = pw.chromium.launch(headless=__HEADLESS, slow_mo=__SLOW_MO)
+ context = browser.new_context()
+ page = context.new_page()
+
+ task = task_cls(seed=42)
+ task.setup(page=page)
+
+ reward, done, msg, info = task.validate(page, [])
+
+ assert reward == 0.0 and done == False
+
+ page2 = context.new_page()
+ page2.goto("http://www.google.com")
+
+ reward, done, msg, info = task.validate(page, [])
+
+ assert reward == 0.0 and done == False
+
+ reward, done, msg, info = task.validate(page2, [])
+
+ assert reward == 0.0 and done == True
+
+ task.teardown()
+
+ context.close()
+ browser.close()
diff --git a/BrowserGym/tests/miniwob/test_click-menu-2.py b/BrowserGym/tests/miniwob/test_click-menu-2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8296da1cecff67321892a9b94aca1ba58febbd12
--- /dev/null
+++ b/BrowserGym/tests/miniwob/test_click-menu-2.py
@@ -0,0 +1,81 @@
+import os
+import gymnasium as gym
+import re
+import pytest
+
+# register gym environments
+import browsergym.miniwob
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+@pytest.mark.parametrize("seed", range(5))
+def test_cheat(seed):
+ env = gym.make(
+ "browsergym/miniwob.click-menu-2",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ action_mapping=None,
+ )
+ obs, info = env.reset(seed=seed)
+
+ assert obs["last_action_error"] == ""
+
+ match1 = re.match(
+ 'Click the "Menu" button, and then find and click on the item labeled "(.+)".', obs["goal"]
+ )
+ match2 = re.match(
+ 'Click the "Menu" button, and then find and click on the item with the "(.+)" icon.',
+ obs["goal"],
+ )
+
+ assert match1 or match2
+
+ if match1:
+ item_label = match1.groups()[0]
+ item_classname = {
+ "Save": "ui-icon-disk",
+ "Prev": "ui-icon-seek-start",
+ "Stop": "ui-icon-stop",
+ "Play": "ui-icon-play",
+ "Next": "ui-icon-seek-end",
+ "Zoom In": "ui-icon-zoomin",
+ "Zoom Out": "ui-icon-zoomout",
+ }[item_label]
+ else:
+ item_classname = match2.groups()[0]
+
+ action = f"""\
+page.get_by_text("Menu").click()
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ assert obs["last_action_error"] == ""
+ assert reward == 0
+ assert term == False
+
+ if item_classname in ("ui-icon-seek-start", "ui-icon-stop", "ui-icon-play", "ui-icon-seek-end"):
+
+ action = f"""\
+page.get_by_text("Playback").click()
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ assert obs["last_action_error"] == ""
+ assert reward == 0
+ assert term == False
+
+ action = f"""\
+page.locator(".{item_classname}").click()
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ assert obs["last_action_error"] == ""
+ assert reward == 1
+ assert term == True
+
+ env.close()
diff --git a/BrowserGym/tests/miniwob/test_click-scroll-list.py b/BrowserGym/tests/miniwob/test_click-scroll-list.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f16cd7c2fd14fcf364abc5025d10067689ec8ee
--- /dev/null
+++ b/BrowserGym/tests/miniwob/test_click-scroll-list.py
@@ -0,0 +1,42 @@
+import os
+import gymnasium as gym
+import re
+import pytest
+
+# register gym environments
+import browsergym.miniwob
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+@pytest.mark.parametrize("seed", range(5))
+def test_cheat(seed):
+ env = gym.make(
+ "browsergym/miniwob.click-scroll-list",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ action_mapping=None,
+ )
+ obs, info = env.reset(seed=seed)
+
+ assert obs["last_action_error"] == ""
+
+ match = re.match("Select (.+) from the scroll list and click Submit.", obs["goal"])
+
+ assert match
+
+ options = match.groups()[0].split(", ")
+ options = '", "'.join(options)
+ action = f"""\
+page.locator("#options").select_option(["{options}"])
+page.get_by_role("button", name="Submit").click()
+"""
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ assert obs["last_action_error"] == ""
+ assert reward == 1
+ assert term == True
+
+ env.close()
diff --git a/BrowserGym/tests/miniwob/test_use-colorwheel-2.py b/BrowserGym/tests/miniwob/test_use-colorwheel-2.py
new file mode 100644
index 0000000000000000000000000000000000000000..45d660d431bc97aa152f658f63306a8e14f611b7
--- /dev/null
+++ b/BrowserGym/tests/miniwob/test_use-colorwheel-2.py
@@ -0,0 +1,44 @@
+import os
+import gymnasium as gym
+import re
+import pytest
+
+# register gym environments
+import browsergym.miniwob
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+@pytest.mark.parametrize("seed", range(5))
+def test_cheat(seed):
+ env = gym.make(
+ "browsergym/miniwob.use-colorwheel-2",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ action_mapping=None,
+ )
+ obs, info = env.reset(seed=42)
+
+ assert obs["last_action_error"] == ""
+
+ match = re.match(
+ "Select the following color #(.+) with the color picker and hit Submit.", obs["goal"]
+ )
+
+ assert match
+
+ color = match.groups()[0].upper()
+
+ obs, reward, term, trunc, info = env.step(
+ f"""\
+page.locator("#col").fill("{color}")
+page.get_by_role("button", name="Submit").click()
+"""
+ )
+
+ assert obs["last_action_error"] == ""
+ assert reward == 1
+ assert term == True
+
+ env.close()
diff --git a/BrowserGym/tests/utils.py b/BrowserGym/tests/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..48595751f2a9aa8070bfebe6ae57142ff3d0d653
--- /dev/null
+++ b/BrowserGym/tests/utils.py
@@ -0,0 +1,13 @@
+import browsergym.core
+import logging
+import playwright.sync_api
+import pytest
+
+
+# setup code, executed ahead of first test
+@pytest.fixture(scope="session", autouse=True)
+def setup_playwright(playwright: playwright.sync_api.Playwright):
+ # bugfix: re-use pytest-playwright's playwright instance in browsergym
+ # https://github.com/microsoft/playwright-python/issues/2053
+ browsergym.core._set_global_playwright(playwright)
+ logging.info("Browsergym is using the playwright instance provided by pytest-playwright.")
diff --git a/BrowserGym/tests/visualwebarena/__init__.py b/BrowserGym/tests/visualwebarena/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f09d6fbde51609da41e1041eb3fb8125d808cb
--- /dev/null
+++ b/BrowserGym/tests/visualwebarena/__init__.py
@@ -0,0 +1,2 @@
+# bugfix: use same playwright instance in browsergym and pytest
+from ..utils import setup_playwright
diff --git a/BrowserGym/tests/visualwebarena/test_vwa_domains.py b/BrowserGym/tests/visualwebarena/test_vwa_domains.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d4a4256e8bb523dc447904fd8561b748d75b0e
--- /dev/null
+++ b/BrowserGym/tests/visualwebarena/test_vwa_domains.py
@@ -0,0 +1,25 @@
+import pytest
+import playwright.sync_api
+
+from browsergym.visualwebarena.instance import VisualWebArenaInstance
+
+
+def test_is_reachable():
+ # default URLs
+ instance = VisualWebArenaInstance()
+ instance.check_status()
+
+ # unreacheable URL
+ with pytest.raises(RuntimeError):
+ instance = VisualWebArenaInstance()
+ instance.urls["reddit"] = "https://invalid.url"
+ instance.check_status()
+
+
+@pytest.mark.parametrize("site", ["reddit", "shopping", "wikipedia", "classifieds"])
+def test_credentials(page: playwright.sync_api.Page, site: str):
+ # default URLs and credentials
+ instance = VisualWebArenaInstance()
+ instance.ui_login(site=site, page=page)
+
+ # TODO: test this more thoroughly
diff --git a/BrowserGym/tests/visualwebarena/test_vwa_tasks_with_reset.py b/BrowserGym/tests/visualwebarena/test_vwa_tasks_with_reset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e586d2a777934063c196903e30c7beeb503cb6fb
--- /dev/null
+++ b/BrowserGym/tests/visualwebarena/test_vwa_tasks_with_reset.py
@@ -0,0 +1,40 @@
+import logging
+import os
+import random
+
+import gymnasium as gym
+import playwright.sync_api
+import pytest
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+# register gym environments
+import browsergym.visualwebarena
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+from browsergym.visualwebarena import VISUALWEBARENA_TASK_IDS_WITH_RESET
+
+rng = random.Random(1)
+task_ids = rng.sample(VISUALWEBARENA_TASK_IDS_WITH_RESET, 10)
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
+ wait=wait_fixed(2),
+ reraise=True,
+ before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
+)
+@pytest.mark.parametrize("task_id", task_ids)
+@pytest.mark.slow
+@pytest.mark.serial
+def test_env_generic(task_id):
+ env = gym.make(
+ f"browsergym/{task_id}",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset()
+ env.close()
diff --git a/BrowserGym/tests/visualwebarena/test_vwa_tasks_without_reset.py b/BrowserGym/tests/visualwebarena/test_vwa_tasks_without_reset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3fad322381d8ceed88dae3ae3449d6da02fb197
--- /dev/null
+++ b/BrowserGym/tests/visualwebarena/test_vwa_tasks_without_reset.py
@@ -0,0 +1,74 @@
+import logging
+import os
+import random
+
+import gymnasium as gym
+import playwright.sync_api
+import pytest
+from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
+
+# register gym environments
+import browsergym.visualwebarena
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+from browsergym.visualwebarena import VISUALWEBARENA_TASK_IDS_WITHOUT_RESET
+
+rng = random.Random(1)
+task_ids = rng.sample(VISUALWEBARENA_TASK_IDS_WITHOUT_RESET, 25)
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
+ wait=wait_fixed(2),
+ reraise=True,
+ before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
+)
+@pytest.mark.parametrize("task_id", task_ids)
+@pytest.mark.slow
+def test_env_generic(task_id):
+ env = gym.make(
+ f"browsergym/{task_id}",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset()
+ env.close()
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
+ wait=wait_fixed(2),
+ reraise=True,
+ before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
+)
+def test_domain_safeguard():
+ env = gym.make(
+ f"browsergym/visualwebarena.398",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset()
+ assert not obs["last_action_error"]
+
+ obs, reward, terminated, truncated, info = env.step("new_tab()")
+ assert not obs["last_action_error"]
+ assert not (terminated or truncated)
+
+ obs, reward, terminated, truncated, info = env.step("tab_close()")
+ assert not obs["last_action_error"]
+ assert not (terminated or truncated)
+
+ obs, reward, terminated, truncated, info = env.step("tab_focus(0)")
+ assert not obs["last_action_error"]
+ assert not (terminated or truncated)
+
+ obs, reward, terminated, truncated, info = env.step('goto("http://www.google.com")')
+ assert not obs["last_action_error"]
+ assert terminated
+
+ env.close()
diff --git a/BrowserGym/tests/webarena/__init__.py b/BrowserGym/tests/webarena/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f09d6fbde51609da41e1041eb3fb8125d808cb
--- /dev/null
+++ b/BrowserGym/tests/webarena/__init__.py
@@ -0,0 +1,2 @@
+# bugfix: use same playwright instance in browsergym and pytest
+from ..utils import setup_playwright
diff --git a/BrowserGym/tests/webarena/test_env_general.py b/BrowserGym/tests/webarena/test_env_general.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4a81b23a7a34d376ed8048e71b31606d91f589d
--- /dev/null
+++ b/BrowserGym/tests/webarena/test_env_general.py
@@ -0,0 +1,40 @@
+import gymnasium as gym
+import logging
+import os
+import playwright.sync_api
+import pytest
+import random
+
+from tenacity import retry, stop_after_attempt, retry_if_exception_type
+
+# register gym environments
+import browsergym.webarena
+
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+
+from browsergym.webarena import ALL_WEBARENA_TASK_IDS
+
+rng = random.Random(1)
+task_ids = rng.sample(ALL_WEBARENA_TASK_IDS, 25)
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
+ reraise=True,
+ before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
+)
+@pytest.mark.parametrize("task_id", task_ids)
+@pytest.mark.slow
+def test_env_generic(task_id):
+ env = gym.make(
+ f"browsergym/{task_id}",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset()
+
+ env.close()
diff --git a/BrowserGym/tests/webarena/test_infeasible.py b/BrowserGym/tests/webarena/test_infeasible.py
new file mode 100644
index 0000000000000000000000000000000000000000..044b5c404558739529e159d9dd5c357156c90ec8
--- /dev/null
+++ b/BrowserGym/tests/webarena/test_infeasible.py
@@ -0,0 +1,50 @@
+import gymnasium as gym
+import logging
+import os
+import playwright.sync_api
+import pytest
+
+from tenacity import retry, stop_after_attempt, retry_if_exception_type
+
+# register gym environments
+import browsergym.webarena
+
+
+__SLOW_MO = 1000 if "DISPLAY_BROWSER" in os.environ else None
+__HEADLESS = False if "DISPLAY_BROWSER" in os.environ else True
+
+INFEAS_TASK_IDS = [101, 115, 166]
+FEAS_TASK_IDS = [165, 187, 199]
+
+
+@retry(
+ stop=stop_after_attempt(5),
+ retry=retry_if_exception_type(playwright.sync_api.TimeoutError),
+ reraise=True,
+ before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
+)
+@pytest.mark.parametrize(
+ "task_id,infeasible",
+ [(task_id, True) for task_id in INFEAS_TASK_IDS]
+ + [(task_id, False) for task_id in FEAS_TASK_IDS],
+)
+@pytest.mark.slow
+def test_infeasible(task_id, infeasible):
+ env = gym.make(
+ f"browsergym/webarena.{task_id}",
+ headless=__HEADLESS,
+ slow_mo=__SLOW_MO,
+ )
+ obs, info = env.reset()
+
+ action = 'report_infeasible("Unachievable task.")'
+
+ obs, reward, term, trunc, info = env.step(action)
+
+ if infeasible:
+ assert term == True and reward == 1.0
+
+ else:
+ assert term == True and reward == 0.0
+
+ env.close()
diff --git a/BrowserGym/tests/webarena/test_instance.py b/BrowserGym/tests/webarena/test_instance.py
new file mode 100644
index 0000000000000000000000000000000000000000..a538a53f97c7372f72a99445b62843ce30d0c9e7
--- /dev/null
+++ b/BrowserGym/tests/webarena/test_instance.py
@@ -0,0 +1,27 @@
+import pytest
+import playwright.sync_api
+
+from browsergym.webarena.instance import WebArenaInstance
+
+
+def test_is_reachable():
+ # default URLs
+ instance = WebArenaInstance()
+ instance.check_status()
+
+ # unreacheable URL
+ with pytest.raises(RuntimeError):
+ instance = WebArenaInstance()
+ instance.urls["reddit"] = "https://invalid.url"
+ instance.check_status()
+
+
+@pytest.mark.parametrize(
+ "site", ["reddit", "shopping", "shopping_admin", "gitlab", "wikipedia", "map"]
+)
+def test_credentials(page: playwright.sync_api.Page, site: str):
+ # default URLs and credentials
+ instance = WebArenaInstance()
+ instance.ui_login(site=site, page=page)
+
+ # TODO: test this more thoroughly
diff --git a/README.md b/README.md
index e2672957ef900d4f71497b5fb123dc0b6f70aaa1..74b5140c9663e42f6ae57269cb14513994784aaa 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@ title: Web Shepherd Demo
emoji: 😻
colorFrom: yellow
colorTo: blue
-sdk: gradio
+sdk: docker
pinned: false
---
diff --git a/agent/__init__.py b/agent/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/agent/checklist.py b/agent/checklist.py
new file mode 100644
index 0000000000000000000000000000000000000000..b469ddafb2a0215960c9af75d865f3ee385c583e
--- /dev/null
+++ b/agent/checklist.py
@@ -0,0 +1,18 @@
+from .mini_bench.agent import ChecklistGenerationAgent
+
+def generate_checklist(**data):
+ # data: 'intent', 'start_url', 'text_observation'
+ agent_config = {
+ 'model_name': 'WPRM/qwen-3b-ar-reward-cot-mtl-checklist-enhanced',
+ 'base_url': 'http://165.132.144.84:7701/v1',
+ 'api_key': 'empty',
+ 'temperature': 0.7,
+ 'use_log_probs': True,
+ 'use_checklist': True,
+ 'use_multimodal': False,
+ 'num_generate': 1,
+ }
+ checklist_generation_agent = ChecklistGenerationAgent(agent_config)
+ response_list, cost = checklist_generation_agent.generate_response(data, prompt_type='ours', constraint_str_list=["", "", "", ""])
+ response = response_list[0]
+ return response.split("")[-1].split("")[0].strip()
diff --git a/agent/mini_bench/__init__.py b/agent/mini_bench/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/agent/mini_bench/__pycache__/__init__.cpython-311.pyc b/agent/mini_bench/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7905063310818854817c8a42f6668ac29652395
Binary files /dev/null and b/agent/mini_bench/__pycache__/__init__.cpython-311.pyc differ
diff --git a/agent/mini_bench/__pycache__/agent.cpython-311.pyc b/agent/mini_bench/__pycache__/agent.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06d33c0998bb81fc3f60971b223755623745ad00
Binary files /dev/null and b/agent/mini_bench/__pycache__/agent.cpython-311.pyc differ
diff --git a/agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc b/agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..122657f6884e90f99ed6a82a6a3b0eca24c76fea
Binary files /dev/null and b/agent/mini_bench/__pycache__/reward_agent.cpython-311.pyc differ
diff --git a/agent/mini_bench/agent.py b/agent/mini_bench/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..5354363bbf776d794febd5c5a6953ed48d7e73ea
--- /dev/null
+++ b/agent/mini_bench/agent.py
@@ -0,0 +1,467 @@
+from abc import ABC, abstractmethod
+import time
+import requests
+import json
+import math
+from langsmith import Client
+from langchain_openai import ChatOpenAI
+
+from .prompts import get_messages
+from .prompts.judge_prompt import (
+ JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
+ JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
+ JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
+ JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
+)
+from .prompts.image_utils import image_to_base64_url
+
+MAX_RETRY = 3
+RETRY_SLEEP = 5
+MODEL_COST_MAPPING = {
+ "gpt-4o-mini": {
+ "input_token_cost": 0.15,
+ "output_token_cost": 0.6
+ },
+ "gpt-4o": {
+ "input_token_cost": 2.5,
+ "output_token_cost": 10
+ },
+}
+
+
+class Agent(ABC):
+ @abstractmethod
+ def generate_response(self, inputs: dict) -> str:
+ pass
+
+class BaseAgent(Agent):
+ def __init__(self, agent_config: dict):
+ self.agent_config = agent_config
+ self._setup()
+
+ def _setup(self):
+ use_log_probs = self.agent_config.get("use_log_probs", False)
+ if use_log_probs:
+ self.llm = ChatOpenAI(
+ model=self.agent_config["model_name"],
+ base_url=self.agent_config["base_url"],
+ api_key=self.agent_config["api_key"],
+ temperature=self.agent_config["temperature"],
+ timeout=300,
+ logprobs=True,
+ top_logprobs=10
+ )
+ else:
+ self.llm = ChatOpenAI(
+ model=self.agent_config["model_name"],
+ base_url=self.agent_config["base_url"],
+ api_key=self.agent_config["api_key"],
+ temperature=self.agent_config["temperature"],
+ timeout=300
+ )
+ self.temperature = self.agent_config["temperature"]
+ self.num_generate = self.agent_config["num_generate"]
+ self.use_checklist = self.agent_config.get("use_checklist", False)
+ self.use_multimodal = self.agent_config.get("use_multimodal", False)
+
+ # setup cost
+ model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
+ if model_cost and "api" in self.agent_config["base_url"]:
+ self.input_token_cost = model_cost["input_token_cost"]
+ self.output_token_cost = model_cost["output_token_cost"]
+ else:
+ self.input_token_cost = 0.0
+ self.output_token_cost = 0.0
+
+ def generate_with_retry(self, model_input, constraint_str_list: list = None):
+ total_input_tokens = 0
+ total_output_tokens = 0
+ if self.temperature == 0:
+ response = self.llm.invoke(model_input)
+ total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
+ total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
+ else:
+ for i in range(MAX_RETRY):
+ try:
+ response = self.llm.invoke(model_input)
+ total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
+ total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
+ if constraint_str_list:
+ pass_constraint_num = 0
+ for constraint_str in constraint_str_list:
+ if constraint_str in response.content:
+ pass_constraint_num += 1
+ if pass_constraint_num == len(constraint_str_list):
+ break
+ else:
+ print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
+ print(response.content)
+ else:
+ break
+ except Exception as e:
+ print(f"Agent returned an Error: {e}")
+ response = None
+ time.sleep(RETRY_SLEEP)
+
+ cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
+
+ if response is None:
+ return "", cost
+ else:
+ return response.content, cost
+
+ def prepare_message(self, model_input: dict, prompt_type: str):
+ message = []
+ return message
+
+ def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
+ total_cost = 0
+ response_list = []
+ # prepare message
+ message = self.prepare_message(model_input, prompt_type)
+ # print(message)
+
+ # n sampling
+ for i in range(self.num_generate):
+ response, cost = self.generate_with_retry(message, constraint_str_list)
+ response_list.append(response)
+ total_cost += cost
+
+ return response_list, total_cost
+
+
+class GroundingJudgeAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ message = get_messages(
+ input_info=model_input,
+ inference_mode="judge_grounding",
+ prompt_type=prompt_type,
+ use_multimodal=self.use_multimodal,
+ text_obs=self.agent_config["text_obs_type"],
+ image_obs=self.agent_config["image_obs_type"]
+ )
+ return message
+
+
+class ProgressJudgeAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ if self.agent_config["input_type"]=="text_only":
+ use_multimodal = False
+ text_obs = self.agent_config["text_obs_type"]
+ image_obs = None
+ elif self.agent_config["input_type"]=="image_only":
+ use_multimodal = True
+ text_obs = None
+ image_obs = self.agent_config["image_obs_type"]
+ elif self.agent_config["input_type"]=="text_image":
+ use_multimodal = True
+ text_obs = self.agent_config["text_obs_type"]
+ image_obs = self.agent_config["image_obs_type"]
+ else:
+ raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
+
+ if self.agent_config["use_in_progress"]:
+ use_in_progress = True
+ else:
+ use_in_progress = False
+
+ message = get_messages(
+ input_info=model_input,
+ inference_mode="judge_progress",
+ prompt_type=prompt_type,
+ use_checklist=self.use_checklist,
+ use_multimodal=use_multimodal,
+ text_obs=text_obs,
+ image_obs=image_obs,
+ use_in_progress=use_in_progress
+ )
+ return message
+
+ def add_logprob(self, ori_logprob: float, add_logprob: float):
+ if ori_logprob is None:
+ return add_logprob
+ else:
+ ori_prob = math.exp(ori_logprob)
+ add_prob = math.exp(add_logprob)
+ return math.log(ori_prob + add_prob)
+
+ def get_judge_probs(self, logprobs: list):
+ # target_judge = {
+ # "yes": [" Yes", "Yes"],
+ # "no": [" No", "No"],
+ # "in": [" In", "In"]
+ # }
+ target_judge = {
+ "yes": [
+ " Yes", "ĠYes", "Yes", "ĊYes",
+ "Ġyes", "yes", "Ċyes",
+ "ĠYES", "YES", "ĊYES",
+ "ĠDone", "Done", "ĊDone",
+ "ĠCompleted", "Completed", "ĊCompleted",
+ "ĠCorrect", "Correct", "ĊCorrect"
+ ],
+ "no": [
+ " No", "ĠNo", "No", "ĊNo",
+ "ĠNO", "NO", "ĊNO",
+ "ĠNot", "Not", "ĊNot",
+ "ĠNone", "None", "ĊNone",
+ "ĠNope", "Nope", "ĊNope",
+ "ĠUn", "Un", "ĊUn",
+ "ĠWrong", "Wrong", "ĊWrong"
+ ],
+ "in": [
+ " In", "ĠIn", "In", "ĊIn",
+ "ĠPending", "Pending", "ĊPending",
+ "ĠPart", "Part", "ĊPart",
+ "ĠPartial", "Partial", "ĊPartial",
+ "ĠInProgress", "InProgress", "ĊInProgress"
+ ]
+ }
+ response_str = ""
+ judge_probs_list = []
+ # print(logprobs)
+ for i, log_prob in enumerate(logprobs):
+ # Start to find judge string
+ if "" in response_str:
+ find_judge_str = None
+ for judge_type in target_judge:
+ if log_prob["token"] in target_judge[judge_type]:
+ # print(log_prob)
+ find_judge_str = judge_type
+ break
+ if find_judge_str:
+ # print("find judge str")
+ token_judge_dict = {
+ "yes": None,
+ "no": None,
+ "in": None
+ }
+ if "top_logprobs" in log_prob:
+ for token_info in log_prob["top_logprobs"]:
+ for judge_type in target_judge:
+ for judge_str in target_judge[judge_type]:
+ # if judge_str in token_info["token"] and token_info["logprob"] > token_judge_dict[judge_type]:
+ # token_judge_dict[judge_type] = token_info["logprob"]
+ if judge_str in token_info["token"]:
+ # print(token_info["logprob"])
+ token_judge_dict[judge_type] = self.add_logprob(token_judge_dict[judge_type], token_info["logprob"])
+ # for None case
+ for judge_type in token_judge_dict:
+ if token_judge_dict[judge_type] is None:
+ token_judge_dict[judge_type] = float("-inf")
+ judge_probs_list.append(token_judge_dict)
+ else:
+ # for vllm bugs : no top_logprobs
+ for judge_type in token_judge_dict:
+ if judge_type == find_judge_str:
+ token_judge_dict[judge_type] = log_prob["logprob"]
+ else:
+ token_judge_dict[judge_type] = float("-inf")
+ judge_probs_list.append(token_judge_dict)
+ # print(token_judge_dict)
+
+ if "" in response_str:
+ break
+
+ response_str += log_prob["token"]
+ # print(response_str.replace("Ġ", " ").replace("Ċ", "\n"))
+ # print(judge_probs_list)
+ if len(judge_probs_list) == 0:
+ return [{
+ "yes": 0.0,
+ "no": 0.0,
+ "in": 0.0
+ }]
+ else:
+ # convert with softmax
+ final_judge_probs_list = []
+ for judge_probs in judge_probs_list:
+ exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
+ sum_exp_logprobs = sum(exp_logprobs)
+ softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
+ final_judge_probs_list.append({
+ "yes": softmax_probs[0],
+ "no": softmax_probs[1],
+ "in": softmax_probs[2]
+ })
+ return final_judge_probs_list
+
+ def generate_probs(self, model_input: dict, prompt_type: str):
+ total_cost = 0
+ response_list = []
+ # prepare message
+ message = self.prepare_message(model_input, prompt_type)
+ # print(message)
+
+ for i in range(self.num_generate):
+ try:
+ response = self.llm.invoke(message)
+ total_input_tokens = response.response_metadata["token_usage"]["prompt_tokens"]
+ total_output_tokens = response.response_metadata["token_usage"]["completion_tokens"]
+ total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
+ logprobs = response.response_metadata["logprobs"]["content"]
+ response_list.append(
+ {
+ "response": response.content,
+ "judge_probs": self.get_judge_probs(logprobs)
+ }
+ )
+ except Exception as e:
+ print(f"Error: {e}")
+ # print(response.response_metadata["logprobs"])
+ response_list.append(
+ {
+ "response": response.content,
+ "judge_probs": []
+ }
+ )
+ return response_list, total_cost
+
+
+class ChecklistGenerationAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ message = get_messages(
+ input_info=model_input,
+ inference_mode="checklist_generation",
+ prompt_type=prompt_type
+ )
+ return message
+
+
+class ClassifierRewardAgent(Agent):
+ def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
+ self.url = url
+ self.use_checklist = use_checklist
+ self.use_multimodal = use_multimodal
+
+ def _process_multimodal_message(self, prompt: str, image_list: list[str]):
+ multimodal_message = []
+ text_prompt_prefix = prompt.split("")[0]
+ text_prompt_suffix = prompt.split("")[1]
+ multimodal_message = [
+ {"type": "text", "text": text_prompt_prefix},
+ # {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
+ {"type": "image", "image": image_to_base64_url(image_list[0])},
+ {"type": "text", "text": text_prompt_suffix}
+ ]
+ return multimodal_message
+
+ def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
+ if self.use_multimodal:
+ tmp_user_prompt = user_prompt_template["user"].format(
+ **model_input
+ )
+ user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
+ else:
+ user_prompt = user_prompt_template["user"].format(
+ **model_input
+ )
+ assistant_prompt = user_prompt_template["assistant"].format(
+ **model_input
+ )
+ query = [
+ {"role": "user", "content": user_prompt},
+ {"role": "assistant", "content": assistant_prompt}
+ ]
+ return query
+
+ def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
+ if self.use_checklist:
+ if self.use_multimodal:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
+ else:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
+ else:
+ if self.use_multimodal:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
+ else:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
+
+ if self.use_multimodal:
+ if batch:
+ message = [self._make_query(user_prompt_template, input) for input in model_input]
+ else:
+ message = [self._make_query(user_prompt_template, model_input)]
+ else:
+ if batch:
+ message = {
+ "query": [self._make_query(user_prompt_template, input) for input in model_input],
+ "promptts": []
+ }
+ else:
+ message = {
+ "query": self._make_query(user_prompt_template, model_input),
+ "prompts": []
+ }
+
+ return message
+
+ def get_rm_scroe(self, message: dict | list):
+ headers = {"Content-Type": "application/json"}
+
+ try:
+ if self.use_multimodal:
+ response = requests.post(
+ self.url,
+ json={"messages": message},
+ timeout=600
+ )
+ else:
+ response = requests.post(
+ self.url,
+ headers=headers,
+ data=json.dumps(message),
+ timeout=300
+ )
+ response.raise_for_status()
+
+ response_json = response.json()
+
+ if "rewards" not in response_json:
+ print(f"Error: 'rewards' key not found in API response: {response_json}")
+ return []
+
+ if "get_reward" in self.url:
+ # use openrlhf
+ return response_json["rewards"]
+ elif "pooling" in self.url:
+ # use vllm server
+ return response_json["reward"]
+ else:
+ # error
+ raise ValueError(f"Invalid URL: {self.url}")
+
+ except requests.exceptions.Timeout:
+ print(f"Error: Request timed out to {self.url}")
+ return []
+ except requests.exceptions.RequestException as e:
+ print(f"Error during request to {self.url}: {e}")
+ return []
+ except json.JSONDecodeError:
+ print(f"Error: Failed to decode JSON response from {self.url}")
+ return []
+ except KeyError as e:
+ print(f"Error: Missing key {e} in response from {self.url}")
+ return []
+
+
+ def generate_response(self, model_input: dict | list[dict], batch: bool = False):
+ if batch:
+ message = self.prepare_message(model_input, batch=True)
+ else:
+ message = self.prepare_message(model_input)
+ rewards = self.get_rm_scroe(message)
+
+ return rewards, 0
\ No newline at end of file
diff --git a/agent/mini_bench/checklist_eval.py b/agent/mini_bench/checklist_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3b5a5ddcd7eb0c8efbdc09875391dc8751c0953
--- /dev/null
+++ b/agent/mini_bench/checklist_eval.py
@@ -0,0 +1,95 @@
+import re
+
+from langchain_openai import ChatOpenAI
+
+from .agent import BaseAgent
+
+SYSTEM_PROMPT = "You are an expert evaluator. Your task is to assess how well a Web Agent’s generated checklist aligns with the reference checklist for a given user instruction."
+
+USER_PROMPT = """# Task Description
+Use the provided task description, evaluation criteria, and both checklists to assign a score from 1 to 5. Justify your rating with a brief explanation that considers both content overlap and logical structure.
+
+## Score Criteria
+- 5: Checklist covers all subgoals, is correct and clearly expressed
+- 4: Minor omissions or phrasing issues but mostly accurate and complete
+- 3: Partially matches, but with noticeable gaps or errors
+- 2: Incomplete or includes incorrect steps
+- 1: Mostly irrelevant, incorrect, or missing the task goal
+
+## User Instruction:
+{intent}
+
+## Reference Checklist:
+{gt_checklist}
+
+## Agent’s Generated Checklist:
+{generated_checklist}
+
+# Output Format
+Your response should be in the following format:
+REASON: [Write 2–4 sentences explaining how well the generated checklist matches the reference. Mention specific matches, omissions, errors, or strengths.]
+SCORE: [1–5]
+"""
+
+
+class ChecklistEvalAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ message = [
+ {
+ "role": "system",
+ "content": SYSTEM_PROMPT
+ },
+ {
+ "role": "user",
+ "content": USER_PROMPT.format(
+ intent=model_input["intent"],
+ gt_checklist=model_input["gt_checklist"],
+ generated_checklist=model_input["generated_checklist"]
+ )
+ }
+ ]
+ return message
+
+ def generate_response(self, model_input: dict):
+ total_cost = 0
+ response_list = []
+ # prepare message
+ message = self.prepare_message(model_input)
+
+ # n sampling
+ for _ in range(self.num_generate):
+ response, cost = self.generate_with_retry(message, ["SCORE"])
+ response_list.append(response)
+ total_cost += cost
+
+ return response_list, total_cost
+
+def parsing_score(response: str):
+ score = response.split("SCORE:")[-1].split("\n")[0].strip()
+ match = re.search(r'\d+', score)
+
+ if match:
+ return int(match.group())
+ else:
+ return None
+
+def average_score(scores: list[int]):
+ if len(scores) == 0:
+ return 0
+ return sum(scores) / len(scores)
+
+def get_score(results: list[dict]):
+ score_list = []
+ for result in results:
+ tmp_scores = [parsing_score(response) for response in result["response"]]
+ scores = [score for score in tmp_scores if score is not None]
+ result["score_list"] = scores
+ final_score = average_score(scores)
+ result["score"] = final_score
+ score_list.append(result)
+
+ return results, score_list
\ No newline at end of file
diff --git a/agent/mini_bench/eval_utils.py b/agent/mini_bench/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..21b211cd239a627b3f9800ab68ed6d5a6840c3e8
--- /dev/null
+++ b/agent/mini_bench/eval_utils.py
@@ -0,0 +1,309 @@
+import re
+import random
+from collections import Counter
+
+from .utils import load_json, save_json, create_html_report
+
+random.seed(42)
+def get_score(response_list: list, indicator: str) -> int:
+ if len(response_list) == 0:
+ return [-100]
+
+ if isinstance(response_list[0], float):
+ return response_list
+
+ if indicator == "prob":
+ score_list = []
+ for response in response_list:
+ total_score = 0
+ for judge_probs in response:
+ yes_prob = judge_probs.get("yes", 0)
+ in_progress_prob = judge_probs.get("in", 0)
+ total_score += yes_prob + in_progress_prob * 0.5
+ if len(response) > 0:
+ score_list.append(total_score / len(response))
+ else:
+ score_list.append(0)
+ return score_list
+ else:
+ score_list = []
+ for response in response_list:
+ if indicator == "SCORE":
+ if "SCORE" in response:
+ try:
+ score_str = response.split("SCORE:")[1].split("\n")[0].strip()
+ except:
+ score_str = response.split("SCORE:")[-1].strip()
+ # find first integer
+ try:
+ score = re.search(r'-?\d+', score_str).group()
+ score_list.append(int(score))
+ except:
+ score_list.append(0)
+ else:
+ try:
+ score_str = response.split("")[1].split("")[0].strip()
+ except:
+ score_str = response.split("")[-1].split("")[0].strip()
+ # find "Yes" or "No"
+ if "Yes" in score_str:
+ score_list.append(1)
+ elif "In Progress" in score_str:
+ score_list.append(0.5)
+ elif "No" in score_str:
+ score_list.append(0)
+ else:
+ score_list.append(0)
+ elif indicator == "JUDGE":
+ try:
+ judge_str = response.split("JUDGE:")[1].split("\n")[0].strip()
+ except:
+ judge_str = response.split("JUDGE:")[-1].strip()
+ if "Yes" in judge_str:
+ score_list.append(1)
+ elif "No" in judge_str:
+ score_list.append(0)
+ else:
+ score_list.append(0)
+ elif indicator == "CHECKLIST EVALUATION":
+ if "" in response:
+ try:
+ checklist_str = response.split("")[1].split("")[0].strip()
+ except:
+ checklist_str = response.split("")[-1].split("")[0].strip()
+ else:
+ checklist_str = response.split("CHECKLIST EVALUATION:")[-1].strip()
+
+ count_yes = checklist_str.count("Yes")
+ count_no = checklist_str.count("No")
+ count_in_progress = checklist_str.count("In Progress")
+ try:
+ total_score = (count_yes + count_in_progress*0.5) / (count_yes + count_no + count_in_progress)
+ except:
+ total_score = 0
+ score_list.append(total_score)
+ else:
+ raise ValueError(f"Invalid indicator: {indicator}")
+ return score_list
+
+def get_acc_and_mrr(chosen_score, rejected_scores):
+ if len(rejected_scores) == 0:
+ return 0, False
+
+ same_score_num = rejected_scores.count(chosen_score)
+ all_scores = rejected_scores + [chosen_score]
+ sorted_scores = sorted(all_scores, reverse=True)
+ rank = sorted_scores.index(chosen_score) + 1 + same_score_num # draw penalty
+ if all(chosen_score > r for r in rejected_scores):
+ accuracy = True
+ else:
+ accuracy = False
+ return 1 / rank, accuracy
+
+def average_score(score_list: list[float]):
+ if len(score_list) == 0:
+ return -100
+ return sum(score_list) / len(score_list)
+
+def self_consistency_score(score_list: list[float]):
+ if len(score_list) == 0:
+ return -100
+ counter = Counter(score_list)
+ return max(counter.values()) / len(score_list)
+
+def get_chosen_rejected_scores(data: dict, agg_func: str):
+ if len(data["chosen"]) == 0:
+ data["chosen"] = [{"score": [-100]}]
+ if len(data["rejected"]) == 0:
+ data["rejected"] = [{"score": [-100]}]
+ if not isinstance(data["chosen"][0], dict):
+ data["chosen"][0]["score"] = [-100]
+ if not isinstance(data["rejected"][0], dict):
+ data["rejected"][0]["score"] = [-100]
+
+ if agg_func == "average":
+ chosen_score = average_score(data["chosen"][0]["score"])
+ rejected_scores = [average_score(rejected_score["score"]) for rejected_score in data["rejected"]]
+ elif agg_func == "self_consistency":
+ chosen_score = self_consistency_score(data["chosen"][0]["score"])
+ rejected_scores = [self_consistency_score(rejected_score["score"]) for rejected_score in data["rejected"]]
+ else:
+ raise ValueError(f"Invalid agg_func: {agg_func}")
+ return chosen_score, rejected_scores
+
+def get_score_results(results, agg_func):
+ score_dict = {"mrr": [], "accuracy": [], "traj_accuracy": []}
+ task_accuracy = {}
+ for result in results:
+ chosen_score, rejected_scores = get_chosen_rejected_scores(result, agg_func)
+ mrr, accuracy = get_acc_and_mrr(chosen_score, rejected_scores)
+ score_dict["mrr"].append(mrr)
+ score_dict["accuracy"].append(accuracy)
+ if result["task_id"] not in task_accuracy:
+ task_accuracy[result["task_id"]] = []
+ task_accuracy[result["task_id"]].append(accuracy)
+
+ for task_id in task_accuracy:
+ if sum(task_accuracy[task_id]) == len(task_accuracy[task_id]):
+ score_dict["traj_accuracy"].append(True)
+ else:
+ score_dict["traj_accuracy"].append(False)
+
+ return score_dict
+
+def calculate_stats(results, agg_func: str="average"):
+ if len(results) == 0:
+ return {
+ "MRR": 0,
+ "Accuracy": 0,
+ "Traj_Accuracy": 0,
+ }
+ total_score = get_score_results(results, agg_func)
+ stats = {
+ "MRR": sum(total_score["mrr"]) / len(total_score["mrr"]),
+ "Accuracy": sum(total_score["accuracy"]) / len(total_score["accuracy"]),
+ "Traj_Accuracy": sum(total_score["traj_accuracy"]) / len(total_score["traj_accuracy"]),
+ }
+
+ return stats
+
+def group_by_task(results, split_indicator: str):
+ # sort results by task_id and step_id
+ results.sort(key=lambda x: (x["task_id"], x["step_id"]))
+ # group by task_name
+ grouped_task_dict = {}
+ for result in results:
+ task_name = "task_" + str(result["task_id"]) + "_step_" + str(result["step_id"])
+ if task_name not in grouped_task_dict:
+ grouped_task_dict[task_name] = {
+ "task_id": result["task_id"],
+ "step_id": result["step_id"],
+ "intent": result["intent"],
+ "start_url": result["start_url"],
+ "gt_checklist": result["gt_checklist"],
+ "generated_checklist": result.get("generated_checklist", None) ,
+ "trajectory": result["trajectory"],
+ "current_url": result["current_url"],
+ "text_observation": result["text_observation"],
+ # "image_list": result["image_list"],
+ "chosen": [],
+ "rejected": [],
+ "source_name": result["source_name"],
+ }
+
+ response = result["response"] if "response" in result else []
+ type_data = {
+ "thought": result["thought"],
+ "action": result["action"],
+ "response": response,
+ "score": get_score(response, split_indicator) if split_indicator != "prob" else get_score(result["judge_probs"], split_indicator),
+ }
+ if split_indicator == "prob":
+ type_data["judge_probs"] = result["judge_probs"]
+ if result["type"] == "chosen":
+ grouped_task_dict[task_name]["chosen"].append(type_data)
+ elif result["type"] == "rejected":
+ grouped_task_dict[task_name]["rejected"].append(type_data)
+
+ return list(grouped_task_dict.values())
+
+
+def processing_results(results, evaluation_mode: str, num_generate: int, use_batch: bool=False):
+ if "judge_probs" in results[0]:
+ split_indicator = "prob"
+ else:
+ if evaluation_mode == "judge_with_checklist_generation" or evaluation_mode == "judge_with_gt_checklist":
+ split_indicator = "CHECKLIST EVALUATION"
+ else:
+ split_indicator = "SCORE"
+
+ # if use_batch is True, make it flattened
+ if use_batch:
+ tmp_results = []
+ for result in results:
+ for d in result:
+ tmp_results.append(d)
+ grouped_results = group_by_task(tmp_results, split_indicator)
+ else:
+ grouped_results = group_by_task(results, split_indicator)
+
+ mind2web_results = []
+ webarena_results = []
+ mind2web_task_results = []
+ mind2web_website_results = []
+ mind2web_domain_results = []
+
+ for grouped_result in grouped_results:
+ if "mind2web" in grouped_result["source_name"]:
+ mind2web_results.append(grouped_result)
+ if grouped_result["source_name"] == "mind2web_test_task":
+ mind2web_task_results.append(grouped_result)
+ elif grouped_result["source_name"] == "mind2web_test_website":
+ mind2web_website_results.append(grouped_result)
+ elif grouped_result["source_name"] == "mind2web_test_domain":
+ mind2web_domain_results.append(grouped_result)
+ elif "webarena" in grouped_result["source_name"]:
+ webarena_results.append(grouped_result)
+
+ try:
+ final_stats = {
+ "mind2web": {
+ "MRR": {},
+ "Accuracy": {},
+ "Traj_Accuracy": {},
+ },
+ "webarena": {
+ "MRR": {},
+ "Accuracy": {},
+ "Traj_Accuracy": {},
+ },
+ "mind2web_task": {
+ "MRR": {},
+ "Accuracy": {},
+ "Traj_Accuracy": {},
+ },
+ "mind2web_website": {
+ "MRR": {},
+ "Accuracy": {},
+ "Traj_Accuracy": {},
+ },
+ "mind2web_domain": {
+ "MRR": {},
+ "Accuracy": {},
+ "Traj_Accuracy": {},
+ },
+ }
+ for source_results in [
+ ("mind2web", mind2web_results),
+ ("webarena", webarena_results),
+ ("mind2web_task", mind2web_task_results),
+ ("mind2web_website", mind2web_website_results),
+ ("mind2web_domain", mind2web_domain_results)
+ ]:
+ average_stats = calculate_stats(source_results[1], "average")
+ self_consistency_stats = calculate_stats(source_results[1], "self_consistency")
+ for metric in average_stats:
+ final_stats[source_results[0]][metric]["Average"] = average_stats[metric]
+ for metric in self_consistency_stats:
+ final_stats[source_results[0]][metric]["Self_Consistency"] = self_consistency_stats[metric]
+
+ if num_generate == 1:
+ for source_name in final_stats:
+ for metric in final_stats[source_name]:
+ print(f"{round(100 * final_stats[source_name][metric]['Average'], 2)}", end=", ")
+ print()
+ else:
+ for agg_func in ["Average", "Self_Consistency"]:
+ print(f"{agg_func}")
+ for source_name in final_stats:
+ for metric in final_stats[source_name]:
+ print(f"{round(100 * final_stats[source_name][metric][agg_func], 2)}", end=", ")
+ print()
+ except Exception as e:
+ print(e)
+ return grouped_results, None
+
+ # add function to convert json format results to html format results
+ # TODO: implement this function
+ # create_html_report(results, "results.html")
+ return grouped_results, final_stats
\ No newline at end of file
diff --git a/agent/mini_bench/inference_utils.py b/agent/mini_bench/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..947d35c13c2245fdbba067272a70f52b72328ec8
--- /dev/null
+++ b/agent/mini_bench/inference_utils.py
@@ -0,0 +1,87 @@
+import time
+
+from multiprocessing import Process, Manager
+from tqdm import tqdm
+
+
+def worker_main(work_queue, result_queue, process_func, config):
+ while True:
+ item = work_queue.get()
+ if item is None:
+ result_queue.put(None)
+ break
+ try:
+ results, cost = process_func(config, item)
+ result_queue.put((results, cost))
+ except Exception as e:
+ item_info = item.get('idx', item.get('id', 'unknown item'))
+ print(f"Error processing item {item_info}: {e}")
+ result_queue.put(None)
+ finally:
+ work_queue.task_done()
+
+def run_parallel_evaluation(dataset, process_func, config, num_workers, description):
+ """
+ Runs parallel evaluation on the given dataset and returns the results.
+
+ Args:
+ dataset (list or datasets.Dataset): Data to evaluate.
+ process_func (callable): Function to process each data item.
+ config (dict): Configuration for the process_func.
+ num_workers (int): Number of worker processes to use.
+ description (str): Description to display on the tqdm progress bar.
+
+ Returns:
+ tuple: (list of evaluation results, total cost)
+ """
+ manager = Manager()
+ work_queue = manager.Queue()
+ result_queue = manager.Queue()
+
+ # Add data to the work queue
+ dataset_list = list(dataset) if not isinstance(dataset, list) else dataset
+ for data in dataset_list:
+ work_queue.put(data)
+
+ # Add termination signals for workers
+ for _ in range(num_workers):
+ work_queue.put(None)
+
+ # Start parallel processing
+ processes = []
+ for _ in range(num_workers):
+ p = Process(target=worker_main, args=(work_queue, result_queue, process_func, config))
+ p.start()
+ processes.append(p)
+
+ # Show progress bar and collect results
+ process_results = []
+ process_cost = 0
+ completed_workers = 0
+
+ with tqdm(total=len(dataset_list), desc=description) as pbar:
+ while completed_workers < num_workers:
+ result_item = result_queue.get()
+ if result_item is None:
+ completed_workers += 1
+ else:
+ results, cost = result_item
+ if results is not None:
+ process_results.append(results)
+ process_cost += cost if cost is not None else 0
+ pbar.update(1)
+
+ # Wait for all processes to finish
+ for p in processes:
+ p.join()
+
+ # Collect remaining results
+ while not result_queue.empty():
+ result_item = result_queue.get_nowait()
+ if result_item is not None:
+ results, cost = result_item
+ if results is not None:
+ process_results.append(results)
+ process_cost += cost if cost is not None else 0
+
+ return process_results, process_cost
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/__init__.py b/agent/mini_bench/prompts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4cbd42a34e971cd6b890e59d67d445b3e1f8030
--- /dev/null
+++ b/agent/mini_bench/prompts/__init__.py
@@ -0,0 +1 @@
+from .construct_messages import get_messages
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9be9bcaf9b2b7378f6d7e5e143c51b6dc2ae05ba
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/__init__.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc6e4d6f24370bdb633435da39ab875bbd88ad0d
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/action.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a520e7063b5b181eafad979001e8072e9a3e2cc
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/checklist_prompt.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3a5955988073f5bbcab7b7ec9b148db58448f01
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/construct_messages.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54bab0e0898c44c89d59c5ae2c91058c8877a2ae
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/eval_type.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e08ee689eb62a5f4a2af37e6334980b8d437e58c
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/image_utils.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc2ba33c854ee1f72c1b368b8d8dc4d091f8e78b
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/input_information.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dc6b3ab5f34cde0cdc1ae323c7de13237f7c2fb
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/judge_prompt.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc b/agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31f684751392709b04668c6c86eae85afe18d1e9
Binary files /dev/null and b/agent/mini_bench/prompts/__pycache__/utils.cpython-311.pyc differ
diff --git a/agent/mini_bench/prompts/action.py b/agent/mini_bench/prompts/action.py
new file mode 100644
index 0000000000000000000000000000000000000000..66ed24a7d2d36bd0cdea4088d5d60692543afc40
--- /dev/null
+++ b/agent/mini_bench/prompts/action.py
@@ -0,0 +1,93 @@
+ACTION_SPACE_PROMPT = """Note: This action set allows you to interact with your environment. Most of them are python function executing playwright code. The primary way of referring to elements in the page is through bid which are specified in your observations.
+
+15 different types of actions are available.
+
+noop(wait_ms: float = 1000)
+ Examples:
+ noop()
+
+ noop(500)
+
+scroll(delta_x: float, delta_y: float)
+ Examples:
+ scroll(0, 200)
+
+ scroll(-50.2, -100.5)
+
+keyboard_press(key: str)
+ Examples:
+ keyboard_press('Backspace')
+
+ keyboard_press('ControlOrMeta+a')
+
+ keyboard_press('Meta+Shift+t')
+
+click(bid: str, button: Literal['left', 'middle', 'right'] = 'left', modifiers: list[typing.Literal['Alt', 'Control', 'ControlOrMeta', 'Meta', 'Shift']] = [])
+ Examples:
+ click('a51')
+
+ click('b22', button='right')
+
+ click('48', button='middle', modifiers=['Shift'])
+
+fill(bid: str, value: str)
+ Examples:
+ fill('237', 'example value')
+
+ fill('45', 'multi-line\nexample')
+
+ fill('a12', 'example with "quotes"')
+
+hover(bid: str)
+ Examples:
+ hover('b8')
+
+tab_focus(index: int)
+ Examples:
+ tab_focus(2)
+
+new_tab()
+ Examples:
+ new_tab()
+
+go_back()
+ Examples:
+ go_back()
+
+go_forward()
+ Examples:
+ go_forward()
+
+goto(url: str)
+ Examples:
+ goto('http://www.example.com')
+
+tab_close()
+ Examples:
+ tab_close()
+
+select_option(bid: str, options: str | list[str])
+ Examples:
+ select_option('a48', 'blue')
+
+ select_option('c48', ['red', 'green', 'blue'])
+
+send_msg_to_user(text: str)
+ Examples:
+ send_msg_to_user('Based on the results of my search, the city was built in 1751.')
+
+report_infeasible(reason: str)
+ Examples:
+ report_infeasible('I cannot follow these instructions because there is no email field in this form.')
+
+Only a single action can be provided at once. Example:
+fill('a12', 'example with "quotes"')
+
+Note:
+* Some tasks may be game like and may require to interact with the mouse position in x, y coordinates.
+* Some text field might have auto completion. To see it, you have to type a few characters and wait until next step.
+* If you have to cut and paste, don't forget to select the text first.
+* Coordinate inside an SVG are relative to it's top left corner.
+* Make sure to use bid to identify elements when using commands.
+* Interacting with combobox, dropdowns and auto-complete fields can be tricky, sometimes you need to use select_option, while other times you need to use fill or click and wait for the reaction of the page.
+"""
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/checklist_prompt.py b/agent/mini_bench/prompts/checklist_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab78eaafa2b857a00232036bcb42ed47381e2a84
--- /dev/null
+++ b/agent/mini_bench/prompts/checklist_prompt.py
@@ -0,0 +1,50 @@
+CHECKLIST_SYSTEM_PROMPT = "You are an AI assistant tasked with generating structured checklists that highlight key subgoals necessary to complete a task."
+
+CHECKLIST_USER_PROMPT = """## Task Description
+User Instruction (Goal): "{intent}"
+Start Website URL: {start_url}
+
+## Guidelines for Checklist Generation
+1. Identify Essential High-Level Subgoals:
+- A subgoal should represent a significant step involving user interaction that leads to noticeable page transitions or meaningful changes in system state.
+- Consolidate closely related user actions (such as applying multiple filters or selecting several options) into a single subgoal, rather than separate checklist items for each action.
+- Prioritize only the most critical interactions necessary for meaningful progression, avoiding the inclusion of minor or unnecessary steps (e.g., scroll, hover).
+2. Provide a Concise Subgoal Analysis:
+- Before creating the checklist, offer a brief paragraph summarizing the main subgoals, emphasizing significant transitions or page-level interactions.
+3. Ensure Clear Goal:
+- If multiple related interactions occur (e.g., setting filters 1, 2, and 3), combine them into one subgoal with clear criteria verifying all required conditions.
+- The checklist should contain only essential steps, explicitly excluding unnecessary actions, and should not exceed five critical subgoals. It is not necessary to use all five checklist items if fewer steps adequately represent the essential subgoals.
+
+### Output Format
+Before generating the checklist, first produce a concise subgoal analysis in a single paragraph summarizing the required interactions. Then, based on this, generate the checklist following the format below:
+[SUBGOAL ANALYSIS]
+[One-paragraph summary explaining the key subgoals and their logical sequence in task completion.]
+
+[CHECKLISTS]
+Checklist X: [Short title of the action/goal]
+- Goal: [Brief description of the subgoal at this stage, emphasizing the purpose of the action.]
+"""
+
+# TODO: implement ours
+CHECKLIST_OURS_SYSTEM_PROMPT = ""
+
+CHECKLIST_OURS_USER_PROMPT = """You are an AI assistant tasked with generating structured checklists that highlight key subgoals necessary to complete a task.
+
+# Task Description
+Generate a checklist which are key milestones for achieving the given instruction. Frist, provide a concise
+subgoal analysis in a single paragraph summarizing the required interactions. Then, based on this, generate the checklist with breif description.
+
+Note: If the target website requires login, assume the user is already logged in and starts from an authenticated session.
+
+# Given Information
+## User Instruction
+{intent}
+
+## Current State
+### Current URL
+{start_url}
+
+### AXTREE
+Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
+{text_observation}
+"""
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/construct_messages.py b/agent/mini_bench/prompts/construct_messages.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9203e2a6aa3b5db48b7ebc610c9f1123a0473c9
--- /dev/null
+++ b/agent/mini_bench/prompts/construct_messages.py
@@ -0,0 +1,309 @@
+from abc import ABC, abstractmethod
+
+from .action import ACTION_SPACE_PROMPT
+from .eval_type import (
+ GROUNDING,
+ PROGRESS_LIKERT_SCALE,
+ PROGRESS_THREE_CLASS,
+ PROGRESS_WITH_CHECKLIST,
+ PROGRESS_WITH_CHECKLIST_IN_PROGRESS,
+ PROGRESS_OURS
+)
+from .input_information import (
+ USER_INSTRUCTION,
+ TRAJECTORY,
+ AGENT_RESPONSE,
+ CHECKLIST,
+ CURRENT_URL,
+ TEXT_OBSERVATION,
+ SOM_IMAGE_OBSERVATION,
+ COORD_IMAGE_OBSERVATION
+)
+from .judge_prompt import (
+ JUDGE_GROUNDING_PROMPT_TEMPLATE,
+ JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE,
+ JUDGE_THREE_CLASS_PROMPT_TEMPLATE,
+ JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE,
+ JUDGE_OURS_PROMPT_TEMPLATE,
+ JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE
+)
+from .checklist_prompt import (
+ CHECKLIST_SYSTEM_PROMPT,
+ CHECKLIST_USER_PROMPT,
+ CHECKLIST_OURS_SYSTEM_PROMPT,
+ CHECKLIST_OURS_USER_PROMPT
+)
+from .image_utils import image_to_base64_url
+
+
+class Message(ABC):
+ @abstractmethod
+ def get_messages(self):
+ pass
+
+class BaseMessage(Message):
+ def __init__(self, input_info:dict, use_multimodal:bool=False):
+ self.input_info = input_info
+ self.use_multimodal = use_multimodal
+
+ def _get_system_message(self):
+ system_message = {"role": "system", "content": "You are a helpful assistant."}
+ return system_message
+
+ def _process_multimodal_message(self, prompt: str, image_list: list[str]):
+ multimodal_message = []
+ text_prompt_prefix = prompt.split("")[0]
+ text_prompt_suffix = prompt.split("")[1]
+ multimodal_message.append({"type": "text", "text": text_prompt_prefix})
+ for i, image in enumerate(image_list):
+ # TODO: text prompt for multiple images
+ # multimodal_message.append({"type": "text", "text": f"IMAGE {i+1}\n"})
+ multimodal_message.append({"type": "image_url", "image_url": {"url": image_to_base64_url(image), "detail": "low"}})
+ multimodal_message.append({"type": "text", "text": text_prompt_suffix})
+ return {"role": "user", "content": multimodal_message}
+
+ def _get_user_message(self):
+ user_prompt = "What is the capital of France?"
+ if self.use_multimodal:
+ image_list = self.input_info.get("image_list", [])
+ user_message = self._process_multimodal_message(user_prompt, image_list)
+ else:
+ user_message = {"role": "user", "content": user_prompt}
+ return user_message
+
+ def get_messages(self):
+ message = []
+ system_message = self._get_system_message()
+ user_message = self._get_user_message()
+
+ message.append(system_message)
+ # message.append({"role": "system", "content": ""})
+ message.append(user_message)
+ return message
+
+
+class ProgressMessage(BaseMessage):
+ '''
+ Progress Judge Message
+ '''
+ def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str, use_checklist:bool, use_in_progress:bool):
+ super().__init__(input_info, use_multimodal)
+ self.prompt_type = prompt_type
+ self.text_obs = text_obs
+ self.image_obs = image_obs
+ self.use_checklist = use_checklist
+ self.use_in_progress = use_in_progress
+
+ def _get_system_message(self):
+ if self.prompt_type == "likert_scale":
+ system_message = {"role": "system", "content": JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["system"]}
+ elif self.prompt_type == "three_class":
+ system_message = {"role": "system", "content": JUDGE_THREE_CLASS_PROMPT_TEMPLATE["system"]}
+ elif self.prompt_type == "with_checklist":
+ system_message = {"role": "system", "content": JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["system"]}
+ elif self.prompt_type == "ours":
+ system_message = {"role": "system", "content": JUDGE_OURS_PROMPT_TEMPLATE["system"]}
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return system_message
+
+ def _setup_input_information(self):
+ observation = "## Current State\n"
+
+ observation += CURRENT_URL
+
+ # text observation
+ if self.text_obs:
+ observation += TEXT_OBSERVATION
+
+ # image observation (som, coord, none)
+ if self.image_obs == "som":
+ observation += SOM_IMAGE_OBSERVATION
+ elif self.image_obs == "coord":
+ observation += COORD_IMAGE_OBSERVATION
+
+
+ if self.use_checklist:
+ input_information = USER_INSTRUCTION + TRAJECTORY + observation + CHECKLIST + AGENT_RESPONSE
+ else:
+ input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE
+
+ return input_information
+
+ def _setup_task_info(self):
+ if self.prompt_type == "likert_scale":
+ task_description = PROGRESS_LIKERT_SCALE["task_description"]
+ output_format = PROGRESS_LIKERT_SCALE["output_format"]
+ elif self.prompt_type == "three_class":
+ task_description = PROGRESS_THREE_CLASS["task_description"]
+ output_format = PROGRESS_THREE_CLASS["output_format"]
+ elif self.prompt_type == "with_checklist":
+ if self.use_in_progress:
+ task_description = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["task_description"]
+ output_format = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["output_format"]
+ else:
+ task_description = PROGRESS_WITH_CHECKLIST["task_description"]
+ output_format = PROGRESS_WITH_CHECKLIST["output_format"]
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return task_description, output_format
+
+ def _get_user_prompt_template(self):
+ if self.prompt_type == "likert_scale":
+ user_prompt = JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["user"]
+ elif self.prompt_type == "three_class":
+ user_prompt = JUDGE_THREE_CLASS_PROMPT_TEMPLATE["user"]
+ elif self.prompt_type == "with_checklist":
+ user_prompt = JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["user"]
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return user_prompt
+
+ def _get_user_message(self):
+ # setup input information (user_instruction, trajectory, current_state, agent_response, checklist)
+ input_information_template = self._setup_input_information()
+ input_information = input_information_template.format(**self.input_info)
+
+ if self.prompt_type == "ours":
+ if self.use_checklist:
+ user_prompt = JUDGE_OURS_PROMPT_TEMPLATE["user"].format(
+ input_information=input_information,
+ )
+ else:
+ user_prompt = JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE["user"].format(
+ input_information=input_information,
+ )
+ else:
+ task_description, output_format = self._setup_task_info()
+ # get user prompt template by prompt type
+ user_prompt_template = self._get_user_prompt_template()
+ user_prompt = user_prompt_template.format(
+ action_space=ACTION_SPACE_PROMPT,
+ task_description=task_description,
+ input_information=input_information,
+ output_format=output_format
+ )
+
+ # process multimodal message
+ if self.use_multimodal:
+ image_list = self.input_info.get("image_list", [])
+ user_message = self._process_multimodal_message(user_prompt, image_list)
+ else:
+ user_message = {"role": "user", "content": user_prompt}
+
+ return user_message
+
+
+class GroundingMessage(BaseMessage):
+ '''
+ Grounding Judge Message
+ '''
+ def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str):
+ super().__init__(input_info, use_multimodal)
+ self.prompt_type = prompt_type
+ self.text_obs = text_obs
+ self.image_obs = image_obs
+
+ def _get_system_message(self):
+ if self.prompt_type == "ours":
+ # TODO: implement ours
+ system_message = {"role": "system", "content": "You are a helpful assistant."}
+ elif self.prompt_type == "default":
+ system_message = {"role": "system", "content": JUDGE_GROUNDING_PROMPT_TEMPLATE["system"]}
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return system_message
+
+ def _setup_input_information(self):
+ observation = "## Current State\n"
+
+ observation += CURRENT_URL
+
+ # text observation
+ if self.text_obs:
+ observation += TEXT_OBSERVATION
+
+ # image observation (som, coord, none)
+ if self.image_obs == "som":
+ observation += SOM_IMAGE_OBSERVATION
+ elif self.image_obs == "coord":
+ observation += COORD_IMAGE_OBSERVATION
+
+ # input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE # with trajectory
+ input_information = USER_INSTRUCTION + observation + AGENT_RESPONSE # without trajectory
+
+ return input_information
+
+ def _get_user_message(self):
+ if self.prompt_type == "ours":
+ # TODO: implement ours
+ user_message = {"role": "user", "content": "TODO"}
+ elif self.prompt_type == "default":
+ action_space = ACTION_SPACE_PROMPT
+ task_description = GROUNDING["task_description"]
+ output_format = GROUNDING["output_format"]
+ input_information_template = self._setup_input_information()
+ input_information = input_information_template.format(**self.input_info)
+
+ user_prompt = JUDGE_GROUNDING_PROMPT_TEMPLATE["user"].format(
+ action_space=action_space,
+ task_description=task_description,
+ input_information=input_information,
+ output_format=output_format
+ )
+
+ # process multimodal message
+ if self.use_multimodal:
+ image_list = self.input_info.get("image_list", [])
+ user_message = self._process_multimodal_message(user_prompt, image_list)
+ else:
+ user_message = {"role": "user", "content": user_prompt}
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return user_message
+
+
+class ChecklistMessage(BaseMessage):
+ '''
+ Checklist Message
+ '''
+ def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str):
+ super().__init__(input_info, use_multimodal)
+ self.prompt_type = prompt_type
+
+ def _get_system_message(self):
+ if self.prompt_type == "ours":
+ # TODO: implement ours
+ system_message = {"role": "system", "content": CHECKLIST_OURS_SYSTEM_PROMPT}
+ elif self.prompt_type == "default":
+ system_message = {"role": "system", "content": CHECKLIST_SYSTEM_PROMPT}
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return system_message
+
+ def _get_user_message(self):
+ if self.prompt_type == "ours":
+ user_message = {"role": "user", "content": CHECKLIST_OURS_USER_PROMPT.format(**self.input_info)}
+ elif self.prompt_type == "default":
+ user_message = {"role": "user", "content": CHECKLIST_USER_PROMPT.format(**self.input_info)}
+ else:
+ raise ValueError(f"Invalid prompt type: {self.prompt_type}")
+ return user_message
+
+
+def get_messages(input_info:dict, inference_mode:str, prompt_type:str, text_obs:str=None, image_obs:str=None, use_multimodal:bool=False, use_checklist:bool=False, use_in_progress:bool=False):
+ message_list = []
+ if inference_mode == "judge_grounding":
+ message = GroundingMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs)
+ elif inference_mode == "judge_progress":
+ message = ProgressMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs, use_checklist=use_checklist, use_in_progress=use_in_progress)
+ elif inference_mode == "checklist_generation":
+ message = ChecklistMessage(input_info, use_multimodal=False, prompt_type=prompt_type)
+ else:
+ raise ValueError(f"Invalid inference mode: {inference_mode}")
+
+ system_message, user_message = message.get_messages()
+
+ message_list.append(system_message)
+ message_list.append(user_message)
+ return message_list
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/eval_type.py b/agent/mini_bench/prompts/eval_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3fd156b2536b02a84bedc35212a23fdc6c2af2
--- /dev/null
+++ b/agent/mini_bench/prompts/eval_type.py
@@ -0,0 +1,107 @@
+# Task Description & Output Format
+GROUNDING_TASK = """Your task is to evaluate whether the agent's ACTION is properly grounded in its THOUGHT, considering the current state of the webpage.
+Use the user instruction, the current webpage state, and the agent's thought and action as evidence for your judgment. Your evaluation should assess whether the ACTION logically follows from the THOUGHT and is feasible and appropriate in the given environment.
+Mark the action as 'Yes' only if it is clearly and fully grounded in the thought and current webpage state. If there is any inconsistency, ambiguity, irrelevance, or if the action is not supported by the current page state, mark it as 'No'."""
+
+GROUNDING_FORMAT = """Please return your response in the following format:
+REASON: [Your explanation for whether the action is properly grounded]
+JUDGE: [Yes / No]"""
+
+
+PROGRESS_LIKERT_SCALE_TASK = """Evaluate how helpful the given thought and action is for achieving the goal. Use the following scale:
+**Scoring Criteria (1 to 5):**
+- **5 (Very Helpful)**: The action directly and effectively moves toward fulfilling a key part of the goal.
+- **4 (Helpful)**: The action contributes meaningfully to progress, though it may require follow-up actions.
+- **3 (Somewhat Helpful)**: The action is partially relevant or a preparatory step, but doesn’t make immediate progress.
+- **2 (Slightly Helpful)**: The action is weakly related to the goal or might only indirectly help.
+- **1 (Not Helpful)**: The action is unrelated, redundant, or distracts from the goal."""
+
+PROGRESS_LIKERT_SCALE_FORMAT = """Please return your response in the following format:
+REASON: [Your explanation for the score]
+SCORE: [1-5]"""
+
+
+PROGRESS_THREE_CLASS_TASK = """Evaluate how helpful the given thought and action is for achieving the goal. Use the following scale:
+**Scoring Criteria:**
+- **1 (Helpful)**: The action clearly contributes to achieving the goal. It takes a necessary or productive step toward completing the task.
+- **0 (Neutral)**: The action is neither helpful nor harmful. It may be a placeholder, irrelevant at the current step, or too ambiguous to evaluate.
+- **-1 (Not Helpful)**: The action works against the goal, causes confusion, repeats a previous step unnecessarily, or leads the agent off track."""
+
+PROGRESS_THREE_CLASS_FORMAT = """Please return your response in the following format:
+REASON: [Your explanation for the score]
+SCORE: [-1 / 0 / 1]"""
+
+
+PROGRESS_WITH_CHECKLIST_TASK = """Your task is to evaluate how well the agent's THOUGHT and ACTION satisfy each item in the checklist.
+Use the task instruction, trajectory (including previously completed steps from history), current webpage state, and the agent's current response as evidence for your evaluation.
+For each checklist item:
+- Mark it as 'Yes' if it is clearly and fully satisfied either in the current response or already completed in the history.
+- Mark it as 'No' if there is ambiguity, insufficient evidence, or the step is incomplete or not yet started."""
+
+PROGRESS_WITH_CHECKLIST_FORMAT = """Please return your response in the following format:
+REASON: [Write a single, coherent paragraph explaining how well the agent's response satisfies the checklist overall. Use both the history and the agent's current thought/action as evidence. Mention specific strengths or missing elements that influence your decision.]
+CHECKLIST EVALUATION:
+Checklist X: [Yes / No]
+"""
+
+PROGRESS_WITH_CHECKLIST_IN_PROGRESS_TASK = """Your task is to evaluate how well the agent's THOUGHT and ACTION satisfy each item in the checklist.
+Use the task instruction, trajectory (including previously completed steps from history), current webpage state, and the agent's current response as evidence for your evaluation. Clearly consider any items already successfully completed or currently in progress according to the provided trajectory.
+For each checklist item:
+- Mark it as 'Yes' if it is clearly and fully satisfied either in the current response or already completed in the history.
+- Mark it as 'In Progress' if the agent has made partial but meaningful progress toward completing the item.
+- Mark it as 'No' if there is ambiguity, insufficient evidence, or the step is incomplete or not yet started."""
+
+PROGRESS_WITH_CHECKLIST_IN_PROGRESS_FORMAT = """Please return your response in the following format:
+REASON: [Write a single, coherent paragraph explaining how well the agent's response satisfies the checklist overall. Use both the history and the agent's current thought/action as evidence. Mention specific strengths or missing elements that influence your decision.]
+CHECKLIST EVALUATION:
+Checklist X: [Yes / In Progress / No]
+"""
+
+
+GROUNDING_OURS_TASK = """
+"""
+
+GROUNDING_OURS_FORMAT = """
+"""
+
+PROGRESS_OURS_TASK = """
+"""
+
+PROGRESS_OURS_FORMAT = """
+"""
+
+## EVALUATION TYPE
+GROUNDING = {
+ "task_description": GROUNDING_TASK,
+ "output_format": GROUNDING_FORMAT,
+}
+
+GROUNDING_OURS = {
+ "task_description": GROUNDING_OURS_TASK,
+ "output_format": GROUNDING_OURS_FORMAT,
+}
+
+PROGRESS_LIKERT_SCALE = {
+ "task_description": PROGRESS_LIKERT_SCALE_TASK,
+ "output_format": PROGRESS_LIKERT_SCALE_FORMAT,
+}
+
+PROGRESS_THREE_CLASS = {
+ "task_description": PROGRESS_THREE_CLASS_TASK,
+ "output_format": PROGRESS_THREE_CLASS_FORMAT,
+}
+
+PROGRESS_WITH_CHECKLIST = {
+ "task_description": PROGRESS_WITH_CHECKLIST_TASK,
+ "output_format": PROGRESS_WITH_CHECKLIST_FORMAT,
+}
+
+PROGRESS_WITH_CHECKLIST_IN_PROGRESS = {
+ "task_description": PROGRESS_WITH_CHECKLIST_IN_PROGRESS_TASK,
+ "output_format": PROGRESS_WITH_CHECKLIST_IN_PROGRESS_FORMAT,
+}
+
+PROGRESS_OURS = {
+ "task_description": PROGRESS_OURS_TASK,
+ "output_format": PROGRESS_OURS_FORMAT,
+}
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/image_utils.py b/agent/mini_bench/prompts/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e3f487bdd2831541870ce5990366ab09270dd9d
--- /dev/null
+++ b/agent/mini_bench/prompts/image_utils.py
@@ -0,0 +1,19 @@
+import base64
+import io
+from PIL import Image
+
+
+def image_to_base64_url(image: str | Image.Image):
+ if isinstance(image, str):
+ with open(image, "rb") as f:
+ image = f.read()
+ elif isinstance(image, Image.Image):
+ if image.mode in ("RGBA", "LA"):
+ image = image.convert("RGB")
+ with io.BytesIO() as buffer:
+ image.save(buffer, format="PNG")
+ image = buffer.getvalue()
+ else:
+ raise ValueError(f"Invalid image type: {type(image)}")
+
+ return "data:image/png;base64," + base64.b64encode(image).decode("utf-8")
\ No newline at end of file
diff --git a/agent/mini_bench/prompts/input_information.py b/agent/mini_bench/prompts/input_information.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce90a7d6fd976b3244beb2a13e6800a42476ef7f
--- /dev/null
+++ b/agent/mini_bench/prompts/input_information.py
@@ -0,0 +1,36 @@
+USER_INSTRUCTION = """## User Instruction
+{intent}
+"""
+
+TRAJECTORY = """## Trajectory
+{trajectory}"""
+
+AGENT_RESPONSE = """## Agent's Response
+THOUGHT: {thought}
+ACTION: {action}
+"""
+
+CHECKLIST = """## Checklist
+{checklist}
+"""
+
+
+# Observation
+CURRENT_URL = """### Current URL
+{current_url}
+"""
+
+TEXT_OBSERVATION = """### AXTREE
+Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
+{text_observation}
+"""
+
+SOM_IMAGE_OBSERVATION = """### SOM Image Screenshot
+Here is a current image screenshot of the page, it is annotated with bounding boxes and corresponding bids:
+
+"""
+
+COORD_IMAGE_OBSERVATION = """### Raw Image Screenshot
+Here is a screenshot of the page:
+
+"""
diff --git a/agent/mini_bench/prompts/judge_prompt.py b/agent/mini_bench/prompts/judge_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..78140cfe9606c6005f949ae560bff64e59c212c4
--- /dev/null
+++ b/agent/mini_bench/prompts/judge_prompt.py
@@ -0,0 +1,159 @@
+# SYSTEM PROMPT
+DEFAULT_SYSTEM_PROMPT_FORMAT = "You are an expert evaluator of web agent. {role_description}"
+
+PROGRESS_WITHOUT_CHECKLIST_ROLE = "Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage."
+PROGRESS_WITH_CHECKLIST_ROLE = "Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage."
+
+GROUNDING_ROLE = "Your task is to assess whether the ACTION taken by the agent is properly grounded, based on agent's THOUGHT and the current state of the webpage."
+
+# USER PROMPT
+DEFAULT_USER_PROMPT_FORMAT = """# Action space:
+{action_space}
+
+# Task Description
+{task_description}
+
+# Given Information
+{input_information}
+
+# Output Format
+{output_format}
+"""
+
+
+JUDGE_OURS_WO_CHECKLIST_USER_PROMPT_FORMAT = """You are an expert evaluator of web agent. Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage.
+
+# Task Description
+Evaluate how well the agent’s THOUGHT and ACTION satisfy each item in the checklist using the task instruction, trajectory (including previously completed steps), current webpage state, and the agent’s latest response. Start by writing a concise paragraph summarizing the agent’s overall performance. Refer to the reasoning provided in the trajectory, and discuss whether the THOUGHT is appropriate and the ACTION moves the task forward.
+
+# Given Information
+{input_information}
+"""
+
+
+JUDGE_OURS_USER_PROMPT_FORMAT = """You are an expert evaluator of web agent. Your task is to assess how helpful a given agent's THOUGHT and ACTION is in making progress toward the user's goal, based on the current state of the webpage.
+
+# Task Description
+Evaluate how well the agent’s THOUGHT and ACTION satisfy each item in the checklist using the task instruction, trajectory (including previously completed steps), current webpage state, and the agent’s latest response. Start by writing a concise paragraph summarizing the agent’s overall performance. Refer to the reasoning provided in the trajectory, and discuss whether the THOUGHT is appropriate and the ACTION moves the task forward.
+Then, assess each checklist item individually using the following labels:
+- Yes: The item is fully and clearly satisfied, either in the current response or previously completed.
+- In Progress: There is meaningful partial progress toward completing the item.
+- No: The item is not satisfied due to ambiguity, insufficient evidence, or lack of progress.
+
+# Given Information
+{input_information}
+"""
+
+
+JUDGE_OURS_BT_MODELING_USER_PROMPT_FORMAT = """You are an expert web agent that browses internet via GUI actions. Your task is to achieve the user's goal described in the user instruction.
+
+# Task Description
+Generate the most appropriate GUI action to achieve the user's goal. When choosing your action, consider the current webpage state and the checklist which can be interpreted as subtasks.
+
+# Given Information
+## User Instruction
+{intent}
+
+## Trajectory
+{trajectory}
+
+## Current State
+### Current URL
+{current_url}
+
+### AXTREE
+Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
+{text_observation}
+
+## Checklist
+{checklist}
+
+## Agent's Response
+"""
+
+JUDGE_OURS_BT_MODELING_BASE_PROMPT = """You are an expert web agent that browses internet via GUI actions. Your task is to achieve the user's goal described in the user instruction.
+
+# Task Description
+Generate the most appropriate GUI action to achieve the user's goal. When choosing your action, consider the current webpage state and the checklist which can be interpreted as subtasks.
+
+# Given Information
+## User Instruction
+{intent}
+
+## Trajectory
+{trajectory}
+
+## Current State
+### Current URL
+{current_url}
+
+### AXTREE
+Note: [bid] is the unique alpha-numeric identifier at the beginning of lines for each element in the AXTree. Always use bid to refer to elements in your actions.
+{text_observation}
+"""
+
+JUDGE_OURS_IMAGE_INPUT = """
+### Image Screenshot
+
+"""
+
+JUDGE_OURS_WITH_CHECKLIST = """
+## Checklist
+{checklist}
+"""
+
+BT_MODELING_RESPONSE_FORMAT = """
+THOUGHT: {thought}
+ACTION: {action}
+"""
+
+## PROMPT TEMPLATE
+JUDGE_GROUNDING_PROMPT_TEMPLATE = {
+ "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=GROUNDING_ROLE),
+ "user": DEFAULT_USER_PROMPT_FORMAT,
+}
+
+JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE = {
+ "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITHOUT_CHECKLIST_ROLE),
+ "user": DEFAULT_USER_PROMPT_FORMAT
+}
+
+JUDGE_THREE_CLASS_PROMPT_TEMPLATE = {
+ "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITHOUT_CHECKLIST_ROLE),
+ "user": DEFAULT_USER_PROMPT_FORMAT
+}
+
+JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE = {
+ "system": DEFAULT_SYSTEM_PROMPT_FORMAT.format(role_description=PROGRESS_WITH_CHECKLIST_ROLE),
+ "user": DEFAULT_USER_PROMPT_FORMAT
+}
+
+JUDGE_OURS_PROMPT_TEMPLATE = {
+ "system": "",
+ "user": JUDGE_OURS_USER_PROMPT_FORMAT,
+}
+
+JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE = {
+ "system": "",
+ "user": JUDGE_OURS_WO_CHECKLIST_USER_PROMPT_FORMAT,
+}
+
+JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE = {
+ "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_WITH_CHECKLIST+"\n## Agent's Response\n",
+ "assistant": BT_MODELING_RESPONSE_FORMAT,
+}
+
+JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE = {
+ "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_IMAGE_INPUT+JUDGE_OURS_WITH_CHECKLIST+"\n## Agent's Response\n",
+ "assistant": BT_MODELING_RESPONSE_FORMAT,
+}
+
+JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE = {
+ "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+"\n## Agent's Response\n",
+ "assistant": BT_MODELING_RESPONSE_FORMAT,
+}
+
+JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE = {
+ "user": JUDGE_OURS_BT_MODELING_BASE_PROMPT+JUDGE_OURS_IMAGE_INPUT+"\n## Agent's Response\n",
+ "assistant": BT_MODELING_RESPONSE_FORMAT,
+}
diff --git a/agent/mini_bench/prompts/utils.py b/agent/mini_bench/prompts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8694ef624237c32d1d3a722ab12547a8216984f0
--- /dev/null
+++ b/agent/mini_bench/prompts/utils.py
@@ -0,0 +1,18 @@
+from langchain.schema import HumanMessage, AIMessage, SystemMessage
+
+def convert_dict_messages(dict_messages):
+ message_objs = []
+ for msg in dict_messages:
+ role = msg.get("role")
+ content = msg.get("content", "")
+
+ if role == "user":
+ message_objs.append(HumanMessage(content=content))
+ elif role == "assistant":
+ message_objs.append(AIMessage(content=content))
+ elif role == "system":
+ message_objs.append(SystemMessage(content=content))
+ else:
+ raise ValueError(f"Unknown role: {role}")
+
+ return message_objs
\ No newline at end of file
diff --git a/agent/mini_bench/reward_agent.py b/agent/mini_bench/reward_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..6838b76e60988b436308770defb9e6cc90c6b764
--- /dev/null
+++ b/agent/mini_bench/reward_agent.py
@@ -0,0 +1,465 @@
+from abc import ABC, abstractmethod
+import time
+import requests
+import json
+import math
+from langsmith import Client
+import numpy as np
+from langchain_openai import ChatOpenAI
+
+from .prompts import get_messages
+from .prompts.judge_prompt import (
+ JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
+ JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
+ JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
+ JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
+)
+from .prompts.image_utils import image_to_base64_url
+from .prompts.utils import convert_dict_messages
+
+MAX_RETRY = 3
+RETRY_SLEEP = 5
+MODEL_COST_MAPPING = {
+ "gpt-4o-mini": {
+ "input_token_cost": 0.15,
+ "output_token_cost": 0.6
+ },
+ "gpt-4o": {
+ "input_token_cost": 2.5,
+ "output_token_cost": 10
+ },
+}
+
+
+class Agent(ABC):
+ @abstractmethod
+ def generate_response(self, inputs: dict) -> str:
+ pass
+
+class BaseAgent(Agent):
+ def __init__(self, agent_config: dict):
+ self.agent_config = agent_config
+ self._setup()
+
+ def _init_llm_object(self, **extra_kwargs):
+ config = self.agent_config
+ config.update(extra_kwargs)
+
+ use_log_probs = config.get("use_log_probs", False)
+ if use_log_probs:
+ self.llm = ChatOpenAI(
+ model=config["model_name"],
+ base_url=config["base_url"],
+ api_key=config["api_key"],
+ temperature=config["temperature"],
+ timeout=300,
+ logprobs=True,
+ top_logprobs=10,
+ n=config.get('n', None)
+ )
+ else:
+ self.llm = ChatOpenAI(
+ model=config["model_name"],
+ base_url=config["base_url"],
+ api_key=config["api_key"],
+ temperature=config["temperature"],
+ timeout=300,
+ n=config.get('n', None)
+ )
+
+ def _setup(self):
+ self._init_llm_object()
+
+ self.temperature = self.agent_config["temperature"]
+ self.num_generate = self.agent_config["num_generate"]
+ self.use_checklist = self.agent_config.get("use_checklist", False)
+ self.use_multimodal = self.agent_config.get("use_multimodal", False)
+
+ # setup cost
+ model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
+ if model_cost and "api" in self.agent_config["base_url"]:
+ self.input_token_cost = model_cost["input_token_cost"]
+ self.output_token_cost = model_cost["output_token_cost"]
+ else:
+ self.input_token_cost = 0.0
+ self.output_token_cost = 0.0
+
+ def generate_with_retry(self, model_input, constraint_str_list: list = None):
+ total_input_tokens = 0
+ total_output_tokens = 0
+ if self.temperature == 0:
+ response = self.llm.invoke(model_input)
+ total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
+ total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
+ else:
+ for i in range(MAX_RETRY):
+ try:
+ response = self.llm.invoke(model_input)
+ total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
+ total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
+ if constraint_str_list:
+ pass_constraint_num = 0
+ for constraint_str in constraint_str_list:
+ if constraint_str in response.content:
+ pass_constraint_num += 1
+ if pass_constraint_num == len(constraint_str_list):
+ break
+ else:
+ print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
+ else:
+ break
+ except Exception as e:
+ print(f"Agent returned an Error: {e}")
+ response = None
+ time.sleep(RETRY_SLEEP)
+
+ cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
+
+ if response is None:
+ return "", cost
+ else:
+ return response.content, cost
+
+ def prepare_message(self, model_input: dict, prompt_type: str):
+ message = []
+ return message
+
+ def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
+ total_cost = 0
+ response_list = []
+ # prepare message
+ message = self.prepare_message(model_input, prompt_type)
+
+ # n sampling
+ for i in range(self.num_generate):
+ response, cost = self.generate_with_retry(message, constraint_str_list)
+ response_list.append(response)
+ total_cost += cost
+
+ return response_list, total_cost
+
+
+class GroundingJudgeAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ message = get_messages(
+ input_info=model_input,
+ inference_mode="judge_grounding",
+ prompt_type=prompt_type,
+ use_multimodal=self.use_multimodal,
+ text_obs=self.agent_config["text_obs_type"],
+ image_obs=self.agent_config["image_obs_type"]
+ )
+ return message
+
+
+class ProgressJudgeAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ if self.agent_config["input_type"]=="text_only":
+ use_multimodal = False
+ text_obs = self.agent_config["text_obs_type"]
+ image_obs = None
+ elif self.agent_config["input_type"]=="image_only":
+ use_multimodal = True
+ text_obs = None
+ image_obs = self.agent_config["image_obs_type"]
+ elif self.agent_config["input_type"]=="text_image":
+ use_multimodal = True
+ text_obs = self.agent_config["text_obs_type"]
+ image_obs = self.agent_config["image_obs_type"]
+ else:
+ raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
+
+ if self.agent_config["use_in_progress"]:
+ use_in_progress = True
+ else:
+ use_in_progress = False
+
+ message = get_messages(
+ input_info=model_input,
+ inference_mode="judge_progress",
+ prompt_type=prompt_type,
+ use_checklist=self.use_checklist,
+ use_multimodal=use_multimodal,
+ text_obs=text_obs,
+ image_obs=image_obs,
+ use_in_progress=use_in_progress
+ )
+ return message
+
+ def get_judge_probs(self, logprobs: list):
+ # target_judge = {
+ # "yes": [" Yes", "Yes", "ĠYes", "ĊYes"],
+ # "no": [" No", "No", "ĠNo", "ĊNo"],
+ # "in": [" In", "In", "ĠIn", "ĊIn"]
+ # }
+ target_judge = {
+ "yes": [
+ "ĠYes", "Yes", "ĊYes",
+ "Ġyes", "yes", "Ċyes",
+ "ĠYES", "YES", "ĊYES",
+ "ĠDone", "Done", "ĊDone",
+ "ĠCompleted", "Completed", "ĊCompleted",
+ "ĠCorrect", "Correct", "ĊCorrect"
+ ],
+ "no": [
+ "ĠNo", "No", "ĊNo",
+ "ĠNO", "NO", "ĊNO",
+ "ĠNot", "Not", "ĊNot",
+ "ĠNone", "None", "ĊNone",
+ "ĠNope", "Nope", "ĊNope",
+ "ĠUn", "Un", "ĊUn",
+ "ĠWrong", "Wrong", "ĊWrong"
+ ],
+ "in": [
+ "ĠIn", "In", "ĊIn",
+ "ĠPending", "Pending", "ĊPending",
+ "ĠPart", "Part", "ĊPart",
+ "ĠPartial", "Partial", "ĊPartial",
+ "ĠInProgress", "InProgress", "ĊInProgress"
+ ]
+ }
+ response_str = ""
+ judge_probs_list = []
+ for i, log_prob in enumerate(logprobs):
+ # Start to find judge string
+ if "" in response_str:
+ find_judge_str = False
+ for judge_type in target_judge:
+ if log_prob["token"] in target_judge[judge_type]:
+ # print(log_prob)
+ find_judge_str = True
+ break
+ if find_judge_str:
+ token_judge_dict = {
+ "yes": None,
+ "no": None,
+ "in": None
+ }
+ for token_info in log_prob["top_logprobs"]:
+ for judge_type in target_judge:
+ for judge_str in target_judge[judge_type]:
+ if judge_str in token_info["token"] :
+ if token_judge_dict[judge_type] is None:
+ token_judge_dict[judge_type] = math.exp(token_info["logprob"])
+ else:
+ token_judge_dict[judge_type] += math.exp(token_info["logprob"])
+
+ token_judge_dict = {
+ "yes": math.log(token_judge_dict["yes"]) if token_judge_dict["yes"] is not None else -float('inf'),
+ "no": math.log(token_judge_dict["no"]) if token_judge_dict["no"] is not None else -float('inf'),
+ "in": math.log(token_judge_dict["in"]) if token_judge_dict["in"] is not None else -float('inf')
+ }
+ judge_probs_list.append(token_judge_dict)
+
+ if "" in response_str:
+ break
+
+ response_str += log_prob["token"]
+
+ if len(judge_probs_list) == 0:
+ return [{
+ "yes": 0.0,
+ "no": 0.0,
+ "in": 0.0
+ }]
+ else:
+ # convert with softmax
+ final_judge_probs_list = []
+ max_in_prob = -float('inf')
+ for idx, judge_probs in enumerate(judge_probs_list):
+ exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
+ sum_exp_logprobs = sum(exp_logprobs)
+ softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
+ if softmax_probs[2] > max_in_prob:
+ max_in_prob = softmax_probs[2]
+ final_judge_probs_list.append({
+ "yes": softmax_probs[0],
+ "no": softmax_probs[1],
+ "in": softmax_probs[2]
+ })
+ return final_judge_probs_list
+
+ def generate_probs(self, model_input: dict, prompt_type: str, n=1, temperature=None):
+ total_cost = 0
+ # prepare message
+ message = self.prepare_message(model_input, prompt_type)
+ messages = convert_dict_messages(message)
+
+ kwargs = {'n': n}
+ if temperature is not None:
+ kwargs['temperature'] = temperature
+ self._init_llm_object(**kwargs)
+
+ try:
+ response = self.llm.generate([messages]) # assume single batch
+ finally:
+ print('request url: ', self.agent_config['base_url'])
+
+
+ # parse responses
+ response_list = []
+ for generation in response.generations[0]: # assume singel batch
+ # parse logprobs
+ logprobs = generation.message.response_metadata["logprobs"]["content"]
+ response_list.append(
+ {
+ "response": generation.message.content,
+ "judge_probs": self.get_judge_probs(logprobs)
+ }
+ )
+
+ # calculate cost
+ total_input_tokens = response.llm_output["token_usage"]["prompt_tokens"]
+ total_output_tokens = response.llm_output["token_usage"]["completion_tokens"]
+ total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
+
+ return response_list, total_cost
+
+
+class ChecklistGenerationAgent(BaseAgent):
+ def __init__(self, agent_config: dict):
+ super().__init__(agent_config)
+ self._setup()
+
+ def prepare_message(self, model_input: dict, prompt_type):
+ message = get_messages(
+ input_info=model_input,
+ inference_mode="checklist_generation",
+ prompt_type=prompt_type
+ )
+ return message
+
+
+class ClassifierRewardAgent(Agent):
+ def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
+ self.url = url
+ self.use_checklist = use_checklist
+ self.use_multimodal = use_multimodal
+
+ def _process_multimodal_message(self, prompt: str, image_list: list[str]):
+ multimodal_message = []
+ text_prompt_prefix = prompt.split("")[0]
+ text_prompt_suffix = prompt.split("")[1]
+ multimodal_message = [
+ {"type": "text", "text": text_prompt_prefix},
+ # {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
+ {"type": "image", "image": image_to_base64_url(image_list[0])},
+ {"type": "text", "text": text_prompt_suffix}
+ ]
+ return multimodal_message
+
+ def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
+ if self.use_multimodal:
+ tmp_user_prompt = user_prompt_template["user"].format(
+ **model_input
+ )
+ user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
+ else:
+ user_prompt = user_prompt_template["user"].format(
+ **model_input
+ )
+ assistant_prompt = user_prompt_template["assistant"].format(
+ **model_input
+ )
+ query = [
+ {"role": "user", "content": user_prompt},
+ {"role": "assistant", "content": assistant_prompt}
+ ]
+ return query
+
+ def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
+ if self.use_checklist:
+ if self.use_multimodal:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
+ else:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
+ else:
+ if self.use_multimodal:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
+ else:
+ user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
+
+ if self.use_multimodal:
+ if batch:
+ message = [self._make_query(user_prompt_template, input) for input in model_input]
+ else:
+ message = [self._make_query(user_prompt_template, model_input)]
+ else:
+ if batch:
+ message = {
+ "query": [self._make_query(user_prompt_template, input) for input in model_input],
+ "promptts": []
+ }
+ else:
+ message = {
+ "query": self._make_query(user_prompt_template, model_input),
+ "prompts": []
+ }
+
+ return message
+
+ def get_rm_scroe(self, message: dict | list):
+ headers = {"Content-Type": "application/json"}
+
+ try:
+ if self.use_multimodal:
+ response = requests.post(
+ self.url,
+ json={"messages": message},
+ timeout=600
+ )
+ else:
+ response = requests.post(
+ self.url,
+ headers=headers,
+ data=json.dumps(message),
+ timeout=300
+ )
+ response.raise_for_status()
+
+ response_json = response.json()
+
+ if "rewards" not in response_json:
+ print(f"Error: 'rewards' key not found in API response: {response_json}")
+ return []
+
+ if "get_reward" in self.url:
+ # use openrlhf
+ return response_json["rewards"]
+ elif "pooling" in self.url:
+ # use vllm server
+ return response_json["reward"]
+ else:
+ # error
+ raise ValueError(f"Invalid URL: {self.url}")
+
+ except requests.exceptions.Timeout:
+ print(f"Error: Request timed out to {self.url}")
+ return []
+ except requests.exceptions.RequestException as e:
+ print(f"Error during request to {self.url}: {e}")
+ return []
+ except json.JSONDecodeError:
+ print(f"Error: Failed to decode JSON response from {self.url}")
+ return []
+ except KeyError as e:
+ print(f"Error: Missing key {e} in response from {self.url}")
+ return []
+
+
+ def generate_response(self, model_input: dict | list[dict], batch: bool = False):
+ if batch:
+ message = self.prepare_message(model_input, batch=True)
+ else:
+ message = self.prepare_message(model_input)
+ rewards = self.get_rm_scroe(message)
+
+ return rewards, 0
\ No newline at end of file
diff --git a/agent/mini_bench/utils.py b/agent/mini_bench/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bf25e45036e7ff480d2bc88bff6d00faee5bfcb
--- /dev/null
+++ b/agent/mini_bench/utils.py
@@ -0,0 +1,269 @@
+import json
+import base64
+import io
+import html
+from PIL import Image
+
+
+def image_to_base64_url(image: str | Image.Image):
+ if isinstance(image, str):
+ with open(image, "rb") as f:
+ image = f.read()
+ elif isinstance(image, Image.Image):
+ if image.mode in ("RGBA", "LA"):
+ image = image.convert("RGB")
+ with io.BytesIO() as buffer:
+ image.save(buffer, format="PNG")
+ image = buffer.getvalue()
+ else:
+ raise ValueError(f"Invalid image type: {type(image)}")
+
+ return "data:image/png;base64," + base64.b64encode(image).decode("utf-8")
+
+
+def load_json(file_path: str) -> dict:
+ with open(file_path, "r") as f:
+ return json.load(f)
+
+def save_json(data: dict, file_path: str):
+ with open(file_path, "w") as f:
+ json.dump(data, f, indent=4)
+
+def str_to_bool(s: str) -> bool:
+ if s.lower() in ["true", "1", "yes", "y"]:
+ return True
+ elif s.lower() in ["false", "0", "no", "n"]:
+ return False
+ else:
+ raise ValueError(f"Invalid boolean string: {s}")
+
+
+def create_html_report(json_path, html_path, checklist_generation=False):
+ """
+ Reads the given JSON result file and generates a filterable HTML report.
+
+ Args:
+ json_path (str): Path to the input JSON file.
+ html_path (str): Path to the output HTML file.
+ """
+ try:
+ with open(json_path, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ except FileNotFoundError:
+ print(f"Error: JSON file not found - {json_path}") # Error message in English
+ return
+ except json.JSONDecodeError:
+ print(f"Error: JSON file parsing error - {json_path}") # Error message in English
+ return
+ except Exception as e:
+ print(f"Unexpected error during data loading: {e}") # Error message in English
+ return
+
+ # Extract unique Task IDs and sort them
+ task_ids = sorted(list(set(item.get("task_id") for item in data if item.get("task_id") is not None)))
+
+ html_content = """
+
+
+
+
+
+ Benchmark Results Report
+
+
+
+ Benchmark Results Report
+
+
+
+
+
+
+
+
+
+"""
+
+ # Process each Task/Step data
+ for i, step_data in enumerate(data):
+ task_id = step_data.get("task_id", "N/A")
+ step_id = step_data.get("step_id", "N/A")
+ intent = step_data.get("intent", "N/A")
+ start_url = step_data.get("start_url", "N/A")
+ gt_checklist = step_data.get("gt_checklist", "N/A")
+ generated_checklist = step_data.get("generated_checklist", None)
+ trajectory = step_data.get("trajectory", "N/A")
+ text_observation = step_data.get("text_observation", "N/A")
+ source_name = step_data.get("source_name", "")
+
+ # Wrap each Task/Step in a container with a unique ID (hidden initially)
+ html_content += f"""
+
+
+
Task ID: {html.escape(str(task_id))} | Step ID: {html.escape(str(step_id))} {f'({html.escape(source_name)})' if source_name else ''}
+
Intent:
+
{html.escape(intent)}
+
Start URL: {html.escape(start_url)}
+
+
Ground Truth Checklist:
+
{html.escape(gt_checklist)}
+"""
+ if checklist_generation and generated_checklist is not None:
+ html_content += f"""
+
+ Generated Checklist (Click to expand/collapse)
+ {html.escape(str(generated_checklist))}
+
+"""
+
+ html_content += f"""
+
+ Trajectory (Click to expand/collapse)
+ {html.escape(trajectory)}
+
+
+
+ Text Observation (Click to expand/collapse)
+ {html.escape(text_observation)}
+
+
+"""
+
+ # Chosen Responses
+ if 'chosen' in step_data and step_data['chosen']:
+ html_content += '
Chosen Responses:
'
+ for choice_block in step_data['chosen']:
+ thought = choice_block.get('thought', 'N/A')
+ action = choice_block.get('action', 'N/A')
+ responses = choice_block.get('response', [])
+ scores = choice_block.get('score', [])
+
+ # Add Thought and Action information
+ html_content += f"""
+
+
Thought:
+
{html.escape(thought)}
+
Action:
+
{html.escape(action)}
+
"""
+
+ # Loop through responses and create toggles
+ for idx, (response, score) in enumerate(zip(responses, scores)):
+ html_content += f"""
+
+ Judge Response {idx + 1}: {html.escape(str(score))}
+ {html.escape(str(response))}
+ """
+ html_content += '
' # End chosen-section
+
+ # Rejected Responses
+ if 'rejected' in step_data and step_data['rejected']:
+ html_content += '
Rejected Responses:
'
+ for rejection_block in step_data['rejected']:
+ thought = rejection_block.get('thought', 'N/A')
+ action = rejection_block.get('action', 'N/A')
+ responses = rejection_block.get('response', [])
+ scores = rejection_block.get('score', [])
+
+ # Add Thought and Action information
+ html_content += f"""
+
+
Thought:
+
{html.escape(thought)}
+
Action:
+
{html.escape(action)}
+
"""
+
+ # Loop through responses and create toggles
+ for idx, (response, score) in enumerate(zip(responses, scores)):
+ html_content += f"""
+
+ Judge Response {idx + 1}: {html.escape(str(score))}
+ {html.escape(str(response))}
+ """
+ html_content += '
' # End rejected-section
+
+ html_content += """
+
+
+"""
+
+ # Finalize HTML and add JavaScript
+ html_content += """
+
+
+
+
+
+
+"""
+
+ # Save the HTML file
+ try:
+ with open(html_path, 'w', encoding='utf-8') as f:
+ f.write(html_content)
+ print(f"Completed: HTML report created at {html_path}")
+ except IOError:
+ print(f"Error: Failed to write HTML file - {html_path}")
+ except Exception as e:
+ print(f"Unexpected error during HTML file saving: {e}")
+
+# --- Example Usage ---
+# input_json_file = 'path/to/your/results.json'
+# output_html_file = 'trajectory_report.html'
+# create_html_report(input_json_file, output_html_file)
\ No newline at end of file
diff --git a/agent/reward.py b/agent/reward.py
new file mode 100644
index 0000000000000000000000000000000000000000..2337bc24110a83634c0ec9d6feba3edde13b85f8
--- /dev/null
+++ b/agent/reward.py
@@ -0,0 +1,96 @@
+import time
+from typing import List, Dict, Any, Optional, Union
+import numpy as np
+from .mini_bench.reward_agent import ProgressJudgeAgent
+from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash
+import json
+import os
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5):
+ """하나의 unit을 처리해 (idx, reward, thought)를 돌려준다."""
+ agent = ProgressJudgeAgent(configs)
+ current_temperature = configs["temperature"]
+
+ rewards = []
+ n_err = 0
+ retry_count = 0
+ judge_hash_count_thought = {}
+
+ while len(rewards) < n_samples and retry_count < max_retries:
+ # 외부 API 호출
+ responses, _ = agent.generate_probs(
+ unit, "ours", n=n_samples - len(rewards), temperature=current_temperature
+ )
+
+ for response in responses:
+ content = response["response"]
+ thought = content # 전체를 로그로 저장
+ reward = REWARD_PROCESSORS[reward_processor](response)
+ rewards.append(reward)
+
+ if np.isnan(reward) or reward is None:
+ n_err += 1
+ else:
+ judge_hash = extract_judge_hash(response)
+ judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought)
+
+ if n_err > 0:
+ # 실패 시 온도를 높여 재시도
+ if n_samples == 1:
+ current_temperature = 0.5
+ retry_count += 1
+
+ reward = np.nanmean(rewards)
+ if np.isnan(reward):
+ print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0")
+ reward = 0.0
+ print(judge_hash_count_thought)
+ thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1]
+
+ return idx, reward, thought
+
+
+def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8):
+ """원본 get_ar_reward를 스레드 버전으로 교체."""
+ n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor]
+
+ temperature = 0.5 if n_samples > 1 else 0.0
+
+ configs = {
+ "model_name": model_name,
+ "base_url": base_url,
+ "api_key": "empty",
+ "temperature": temperature,
+ "num_generate": 1,
+ "use_checklist": True,
+ "input_type": "text_only",
+ "text_obs_type": "axtree",
+ "image_obs_type": "som",
+ "use_in_progress": True,
+ "use_multimodal": False,
+ "use_log_probs": True,
+ }
+
+ t_start = time.time()
+ results = [None] * len(dataset)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = [
+ executor.submit(
+ _process_unit, idx, unit, configs, n_samples, reward_processor
+ )
+ for idx, unit in enumerate(dataset)
+ ]
+
+ for fut in as_completed(futures):
+ idx, reward, thought = fut.result()
+ results[idx] = (reward, thought)
+
+ # 순서 보존된 리스트로 분리
+ final_rewards = [float(r) for r, _ in results]
+ thoughts = [t for _, t in results]
+
+ print(f"Time taken (threaded): {time.time() - t_start:.2f} s")
+ return final_rewards, thoughts
+
diff --git a/agent/reward_postprocessor.py b/agent/reward_postprocessor.py
new file mode 100644
index 0000000000000000000000000000000000000000..4124dba8cf12745a64d756002c3c5e99a17f1123
--- /dev/null
+++ b/agent/reward_postprocessor.py
@@ -0,0 +1,41 @@
+import numpy as np
+import re
+
+
+def extract_judge_hash(response):
+ """
+ checklist 별로 yes, in, no를 판단한 정보를 hash 형태로 변환하여 반환
+ """
+ content = response['response']
+
+ try:
+ judge_content = content.lower().replace(' ', '').split('')[1].split('')[0]
+ except:
+ import traceback
+ traceback.print_exc()
+ return None
+ pattern = r":yes|:inprogress|:no"
+ matches = re.findall(pattern, judge_content)
+ matches = [{':yes': 'y', ':inprogress': 'i', ':no': 'n'}[match] for match in matches]
+ return ''.join(matches)
+
+def average_logits(response):
+ """
+ yes, in, no를 logits 레벨에서 계산.
+ """
+ judge_probs = response['judge_probs']
+
+ yes_ = np.mean([r['yes'] for r in judge_probs])
+ in_ = np.mean([r['in'] for r in judge_probs])
+
+ reward = yes_ + 0.5 * in_
+ return reward
+
+
+REWARD_PROCESSORS = {
+ 'avg_logits': average_logits
+}
+
+REWARD_PROCESSOR_N_SAMPLES = {
+ 'avg_logits': 5
+}
\ No newline at end of file
diff --git a/app.py b/app.py
index 0da0319a5b670dce5025888fde58916b96f19869..58fb621ece872ea5d3624e5e1a8d2da7b42359aa 100644
--- a/app.py
+++ b/app.py
@@ -1,64 +1,134 @@
+import os
+import logging
import gradio as gr
-from huggingface_hub import InferenceClient
+import openai
+import multiprocessing
-"""
-For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
-"""
-client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
+from process_run import process_run
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(levelname)s - %(message)s',
+ handlers=[
+ logging.StreamHandler(),
+ ]
+)
+logger = logging.getLogger(__name__)
+logger.setLevel('INFO')
-def respond(
- message,
- history: list[tuple[str, str]],
- system_message,
- max_tokens,
- temperature,
- top_p,
-):
- messages = [{"role": "system", "content": system_message}]
+# Set your OpenAI API key
+openai.api_key = os.getenv("OPENAI_API_KEY")
- for val in history:
- if val[0]:
- messages.append({"role": "user", "content": val[0]})
- if val[1]:
- messages.append({"role": "assistant", "content": val[1]})
- messages.append({"role": "user", "content": message})
+# Example instructions to display
+EXAMPLES = [
+ "When did the solar system form? Find on wikipedia.",
+ "Find the rating of Monopoly (1935) on boardgamegeek.com",
+]
- response = ""
+URL_EXAMPLES = [
+ "about:blank",
+ "https://www.wikipedia.org",
+ "https://www.boardgamegeek.com"
+]
- for message in client.chat_completion(
- messages,
- max_tokens=max_tokens,
- stream=True,
- temperature=temperature,
- top_p=top_p,
- ):
- token = message.choices[0].delta.content
+def main():
+ logger.info("Starting BrowserGym web agent")
+
+ with gr.Blocks(title="WebShephered Demo") as demo:
+ # Add CSS for outlined groups
+ gr.Markdown("# WebShephered Demo")
+ with gr.Row():
+ with gr.Column(scale=2):
+ with gr.Column():
+ instruction = gr.Textbox(
+ label="Instruction",
+ placeholder="Enter your instruction here",
+ lines=2,
+ )
+ gr.Examples(
+ examples=[[e] for e in EXAMPLES],
+ inputs=instruction,
+ cache_examples=False,
+ )
+
+ gr.Markdown("\n\n")
+
+ with gr.Column():
+ start_url = gr.Textbox(
+ label="Starting URL",
+ placeholder="URL to start the browser at",
+ value="about:blank"
+ )
+ gr.Examples(
+ examples=URL_EXAMPLES,
+ inputs=start_url,
+ cache_examples=False,
+ )
- response += token
- yield response
+ gr.Markdown("\n\n")
+ model_name = gr.Dropdown(
+ label="Agent Model",
+ choices=["gpt-4o"],
+ value="gpt-4o"
+ )
+ run_btn = gr.Button("Run Demo")
-"""
-For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
-"""
-demo = gr.ChatInterface(
- respond,
- additional_inputs=[
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
- gr.Slider(
- minimum=0.1,
- maximum=1.0,
- value=0.95,
- step=0.05,
- label="Top-p (nucleus sampling)",
- ),
- ],
-)
+ gr.Markdown("---")
+
+ with gr.Column():
+ gr.Markdown("## Current State")
+ state_view = gr.Markdown()
+ browser_view = gr.Image(label="Browser View")
+
+ gr.Markdown("### Task Checklist from WebShephered")
+ checklist_view = gr.Markdown()
+
+ gr.Markdown("### Action Selection in current step")
+ with gr.Row() as rm_row:
+ rm_cards_container = gr.HTML()
+ with gr.Column(scale=2):
+ gr.Markdown("## Trajectory")
+ trajectory_container = gr.HTML() # Placeholder for our custom trajectory component
+
+
+
+ run_btn.click(
+ fn=process_run,
+ inputs=[instruction, model_name, start_url],
+ outputs=[state_view, browser_view, checklist_view, rm_cards_container, trajectory_container],
+ api_name="run_agent",
+ concurrency_limit=32,
+ show_progress=True
+ )
+ logger.info("Launching Gradio interface")
+ # Set max_threads to allow multiple concurrent requests
+ demo.launch(share=True, max_threads=32)
if __name__ == "__main__":
- demo.launch()
+ import os
+ import subprocess
+ import multiprocessing
+
+ # Install BrowserGym dependencies before running the main application
+ def install_browsergym():
+ try:
+ print("Installing BrowserGym dependencies...")
+ subprocess.run("cd BrowserGym && make install", shell=True, check=True)
+ print("BrowserGym installation completed successfully")
+ except subprocess.CalledProcessError as e:
+ print(f"Error installing BrowserGym: {e}")
+ raise
+
+ # Run the installation before starting the application
+ if os.path.exists("BrowserGym"):
+ install_browsergym()
+ else:
+ print("BrowserGym directory not found. Skipping installation.")
+
+ # Add support for multiprocessing on Windows
+ multiprocessing.freeze_support()
+ main()
diff --git a/browser_agent.py b/browser_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..389a0da2a48b912c8f6ea42f9de119eb0414a403
--- /dev/null
+++ b/browser_agent.py
@@ -0,0 +1,282 @@
+import logging
+import os
+from typing import Any, List, Tuple
+
+from browsergym.core.action.highlevel import HighLevelActionSet
+from browsergym.utils.obs import (
+ flatten_axtree_to_str,
+ flatten_dom_to_str,
+ prune_html,
+)
+from browsergym.experiments import Agent
+
+from utils import remove_inline_comments_safe, image_to_jpg_base64_url
+
+import openai
+
+
+logger = logging.getLogger(__name__)
+
+openai.api_key = os.getenv("OPENAI_API_KEY")
+
+
+
+class BrowserAgent(Agent):
+ def obs_preprocessor(self, obs: dict) -> dict:
+ return {
+ "chat_messages": obs["chat_messages"],
+ "som_screenshot": obs["som_screenshot"],
+ "goal_object": obs["goal_object"],
+ "last_action": obs["last_action"],
+ "last_action_error": obs["last_action_error"],
+ "open_pages_urls": obs["open_pages_urls"],
+ "open_pages_titles": obs["open_pages_titles"],
+ "active_page_index": obs["active_page_index"],
+ "axtree_txt": flatten_axtree_to_str(obs["axtree_object"], filter_visible_only=True, extra_properties=obs['extra_element_properties'], filter_som_only=True),
+ "pruned_html": prune_html(flatten_dom_to_str(obs["dom_object"])),
+ }
+
+ def __init__(self, model_name: str = "gpt-4o", use_html: bool = False, use_axtree: bool = True, use_screenshot: bool = False):
+ super().__init__()
+ logger.info(f"Initializing BrowserAgent with model: {model_name}")
+ logger.info(f"Observation space: HTML={use_html}, AXTree={use_axtree}, Screenshot={use_screenshot}")
+
+ self.model_name = model_name
+ self.use_html = use_html
+ self.use_axtree = use_axtree
+ self.use_screenshot = use_screenshot
+
+ if not (use_html or use_axtree):
+ raise ValueError("Either use_html or use_axtree must be set to True.")
+
+ self.openai_client = openai.OpenAI()
+
+ self.action_set = HighLevelActionSet(
+ subsets=["chat", "tab", "nav", "bid", "infeas"],
+ strict=False,
+ multiaction=False,
+ demo_mode="default"
+ )
+ self.action_history = []
+
+ def get_action(self, obs: dict) -> tuple[str, dict]:
+ logger.debug("Preparing action request")
+
+ system_msgs = [{
+ "type": "text",
+ "text": """\
+# Instructions
+
+You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can
+communicate with the user via a chat, to which the user gives you instructions and to which you
+can send back messages. You have access to a web browser that both you and the user can see,
+and with which only you can interact via specific commands.
+
+Review the instructions from the user, the current state of the page and all other information
+to find the best possible next action to accomplish your goal. Your answer will be interpreted
+and executed by a program, make sure to follow the formatting instructions.
+"""
+ }]
+
+ user_msgs = []
+
+ # Add chat messages
+ user_msgs.append({
+ "type": "text",
+ "text": "# Chat Messages\n"
+ })
+ for msg in obs["chat_messages"]:
+ if msg["role"] in ("user", "assistant", "infeasible"):
+ user_msgs.append({
+ "type": "text",
+ "text": f"- [{msg['role']}] {msg['message']}\n"
+ })
+ logger.debug(f"Added chat message: [{msg['role']}] {msg['message']}")
+ elif msg["role"] == "user_image":
+ user_msgs.append({"type": "image_url", "image_url": msg["message"]})
+ logger.debug("Added user image message")
+
+ # Add open tabs info
+ user_msgs.append({
+ "type": "text",
+ "text": "# Currently open tabs\n"
+ })
+ for page_index, (page_url, page_title) in enumerate(
+ zip(obs["open_pages_urls"], obs["open_pages_titles"])
+ ):
+ user_msgs.append({
+ "type": "text",
+ "text": f"""\
+Tab {page_index}{" (active tab)" if page_index == obs["active_page_index"] else ""}
+ Title: {page_title}
+ URL: {page_url}
+"""
+ })
+ logger.debug(f"Added tab info: {page_title} ({page_url})")
+
+ # Add accessibility tree if enabled
+ if self.use_axtree:
+ user_msgs.append({
+ "type": "text",
+ "text": f"""\
+# Current page Accessibility Tree
+
+{obs["axtree_txt"]}
+
+"""
+ })
+ logger.debug("Added accessibility tree")
+
+ # Add HTML if enabled
+ if self.use_html:
+ user_msgs.append({
+ "type": "text",
+ "text": f"""\
+# Current page DOM
+
+{obs["pruned_html"]}
+
+"""
+ })
+ logger.debug("Added HTML DOM")
+
+ # Add screenshot if enabled
+ if self.use_screenshot:
+ user_msgs.append({
+ "type": "text",
+ "text": "# Current page Screenshot\n"
+ })
+ user_msgs.append({
+ "type": "image_url",
+ "image_url": {
+ "url": image_to_jpg_base64_url(obs["som_screenshot"]),
+ "detail": "auto"
+ }
+ })
+ logger.debug("Added screenshot")
+
+ # Add action space description
+ user_msgs.append({
+ "type": "text",
+ "text": f"""\
+# Action Space
+
+{self.action_set.describe(with_long_description=False, with_examples=True)}
+
+Here are examples of actions with chain-of-thought reasoning:
+
+I now need to click on the Submit button to send the form. I will use the click action on the button, which has bid 12.
+```click("12")```
+
+I found the information requested by the user, I will send it to the chat.
+```send_msg_to_user("The price for a 15\\" laptop is 1499 USD.")```
+
+"""
+ })
+
+ # Add action history and errors
+ if self.action_history:
+ user_msgs.append({
+ "type": "text",
+ "text": "# History of past actions\n"
+ })
+ for action in self.action_history:
+ user_msgs.append({
+ "type": "text",
+ "text": f"\n{action}\n"
+ })
+ logger.debug(f"Added past action: {action}")
+
+ if obs["last_action_error"]:
+ user_msgs.append({
+ "type": "text",
+ "text": f"""\
+# Error message from last action
+
+{obs["last_action_error"]}
+
+"""
+ })
+ logger.warning(f"Last action error: {obs['last_action_error']}")
+
+ # Ask for next action
+ user_msgs.append({
+ "type": "text",
+ "text": """\
+# Next action
+
+You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, and the current state of the page before deciding on your next action.
+Note: You might use 'goto' action if you're in a blank page.
+"""
+ })
+
+ # Log the full prompt for debugging
+ prompt_text_strings = []
+ for message in system_msgs + user_msgs:
+ match message["type"]:
+ case "text":
+ prompt_text_strings.append(message["text"])
+ case "image_url":
+ image_url = message["image_url"]
+ if isinstance(message["image_url"], dict):
+ image_url = image_url["url"]
+ if image_url.startswith("data:image"):
+ prompt_text_strings.append(
+ "image_url: " + image_url[:30] + "... (truncated)"
+ )
+ else:
+ prompt_text_strings.append("image_url: " + image_url)
+ case _:
+ raise ValueError(
+ f"Unknown message type {repr(message['type'])} in the task goal."
+ )
+ full_prompt_txt = "\n".join(prompt_text_strings)
+ logger.debug(full_prompt_txt)
+
+ # Query OpenAI model
+ logger.info("Sending request to OpenAI")
+ response = self.openai_client.chat.completions.create(
+ model=self.model_name,
+ messages=[
+ {"role": "system", "content": system_msgs},
+ {"role": "user", "content": user_msgs}
+ ],
+ n=20,
+ temperature=0.8
+ )
+ parses = []
+ for i, choice in enumerate(response.choices):
+ response = choice.message.content
+ try:
+ parses.append({
+ 'response': response,
+ 'thought': response.split('```')[0].strip(),
+ 'action': remove_inline_comments_safe(response.split('```')[1].strip('`').strip().strip('`').strip()),
+ })
+ except Exception as e:
+ logger.error(f"Error parsing action: {e}")
+ logger.error(f"Response: {response}")
+ logger.error(f"Choice: {choice}")
+ logger.error(f"Index: {i}")
+ logger.error(f"Response: {response}")
+
+ candidates = self.get_top_k_actions(parses)
+ logger.info(f"Received action from OpenAI: {[cand['action'] for cand in candidates]}")
+ return candidates, {}
+
+ def get_top_k_actions(self, parses, k=3):
+ count_dict = {}
+ action_to_parsed = {}
+ for parsed in parses:
+ action = parsed["action"]
+ if action in count_dict:
+ count_dict[action] += 1
+ else:
+ count_dict[action] = 1
+ action_to_parsed[action] = parsed.copy()
+
+ # Get the top_k most frequent actions
+ sorted_actions = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
+ top_k_actions = [action_to_parsed[action] for action, _ in sorted_actions[:k]]
+
+ return top_k_actions
\ No newline at end of file
diff --git a/process_run.py b/process_run.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a2dd217ab6fcecd566df280d4ae6621e7fcd87
--- /dev/null
+++ b/process_run.py
@@ -0,0 +1,301 @@
+from pathlib import Path
+import multiprocessing
+import logging
+from PIL import Image
+import io
+import base64
+import numpy as np
+import gymnasium as gym
+import os
+
+from agent.checklist import generate_checklist
+from agent.reward import get_ar_reward
+
+from browser_agent import BrowserAgent
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel('INFO')
+
+templates_dir = Path(__file__).parent / "templates"
+CSS_RM_CARDS: str = (templates_dir / "rm_cards.css").read_text()
+CSS_TRAJECTORY: str = (templates_dir / "trajectory.css").read_text()
+CARD_HTML_TEMPLATE: str = (templates_dir / "card.html").read_text()
+
+RM_BASE_URL = os.environ['RM_BASE_URL']
+RM_MODEL_NAME = os.environ['RM_MODEL_NAME']
+
+def return_state(state, screenshot=None):
+ return state, None, None, screenshot, None
+
+def run_agent(instruction: str, model_name: str = "gpt-4o", start_url: str = "about:blank",
+ use_html: bool = False, use_axtree: bool = True, use_screenshot: bool = False, max_steps: int = 20):
+ logger.info(f"Starting agent with instruction: {instruction}")
+ logger.info(f"Configuration: model={model_name}, start_url={start_url}")
+
+ trajectory = []
+ trajectory_str = ''
+ agent = BrowserAgent(
+ model_name=model_name,
+ use_html=use_html,
+ use_axtree=use_axtree,
+ use_screenshot=use_screenshot
+ )
+
+ # Initialize BrowserGym environment
+ logger.info("Initializing BrowserGym environment")
+ yield return_state("## Initializing BrowserGym environment...", None)
+ env = gym.make(
+ "browsergym/openended",
+ task_kwargs={
+ "start_url": start_url,
+ "goal": instruction,
+ },
+ wait_for_user_message=True
+ )
+ obs, info = env.reset()
+ logger.info("Environment initialized")
+
+ # Send user instruction to the environment
+ logger.info("Sending user instruction to environment")
+ obs, reward, terminated, truncated, info = env.step({
+ "type": "send_msg_to_user",
+ "message": instruction
+ })
+ processed_obs = agent.obs_preprocessor(obs)
+ logger.info(f"Obs: {processed_obs.keys()}")
+ logger.info(f"axtree_txt: {processed_obs['axtree_txt']}")
+
+ yield return_state("## Generating checklist...", obs['som_screenshot'])
+ checklist = generate_checklist(intent=instruction, start_url=start_url, text_observation=processed_obs['axtree_txt'])
+
+ # yield initial state
+ current_screenshot = obs['som_screenshot'].copy()
+ yield "## Rollout actions from policy...", checklist, [], current_screenshot, trajectory.copy()
+
+ try:
+ step_count = 0
+ while step_count < max_steps:
+ logger.info(f"Step {step_count}: Getting next action")
+ # Get next action from agent
+ candidates, _ = agent.get_action(processed_obs)
+
+ yield return_state(f"## Rewarding actions...", current_screenshot)
+
+ total_rewards, total_thoughts = get_ar_reward(
+ dataset=[
+ {
+ 'text_observation': processed_obs['axtree_txt'],
+ 'intent': instruction,
+ 'trajectory': trajectory_str,
+ 'current_url': processed_obs['open_pages_urls'][processed_obs['active_page_index'][0]],
+ 'checklist': checklist,
+ 'thought': cand['thought'],
+ 'action': cand['action'],
+ } for cand in candidates
+ ],
+ base_url=RM_BASE_URL,
+ model_name=RM_MODEL_NAME,
+ )
+
+ # process rewards
+ diff_reward = abs(max(total_rewards) - total_rewards[0]) # reward difference between actions with the highest reward and the most frequent.
+ if diff_reward <= 0.01:
+ logger.info(f"diff_reward: {diff_reward} -> most frequent action")
+ max_index = 0 # most frequent action
+ else:
+ logger.info(f"diff_reward: {diff_reward} -> highest reward")
+ max_index = total_rewards.index(max(total_rewards)) # highest reward
+
+ # sort by reward
+ sorted_indices = sorted(list(enumerate(total_rewards)), key=lambda x: (-1 if x[0] == max_index else 0, -x[1]))
+ new_order = [idx for idx, _ in sorted_indices]
+ candidates = [candidates[idx] for idx in new_order]
+ total_rewards = [total_rewards[idx] for idx in new_order]
+ total_thoughts = [total_thoughts[idx] for idx in new_order]
+
+ best_cand = candidates[0]
+
+ agent.action_history.append(best_cand['response'])
+
+ action = best_cand['action']
+
+ # processing action
+ step_info = {
+ 'thought': best_cand['thought'],
+ 'action': action
+ }
+ current_cards = [{'thought': cand['thought'], 'action': cand['action'], 'feedback': feedback, 'reward': round(reward, 2)} for idx, (cand, reward, feedback) in enumerate(zip(candidates, total_rewards, total_thoughts))]
+
+ trajectory_str += f'THOUGHT {step_count+1}: {step_info["thought"]}\nACTION {step_count+1}: {step_info["action"]}\n\n'
+
+ # Execute action
+ logger.info(f"Step {step_count}: Executing action: {action}")
+ yield f"## Executing action: {action}", checklist, current_cards, current_screenshot, trajectory.copy()
+ if action.startswith('send_msg_to_user'):
+ terminated = True
+ truncated = False
+ else:
+ obs, reward, terminated, truncated, info = env.step(action)
+ trajectory.append((processed_obs['som_screenshot'], [{'action': cand['action'], 'reward': round(reward, 2)} for cand, reward in zip(candidates, total_rewards)]))
+ processed_obs = agent.obs_preprocessor(obs)
+ current_screenshot = processed_obs['som_screenshot'].copy()
+
+ while '\n\n' in step_info['thought']:
+ step_info['thought'] = step_info['thought'].replace('\n\n', '\n')
+
+ # trajectory에 numpy array 직접 저장
+ logger.info(f"Step {step_count}: Saved screenshot and updated trajectory")
+ step_count += 1
+
+ # yield by each step
+ yield "## Rollout actions from policy...", checklist, current_cards, current_screenshot, trajectory.copy()
+
+ if terminated or truncated:
+ logger.info(f"Episode ended: terminated={terminated}, truncated={truncated}")
+ yield return_state("## Episode ended", current_screenshot)
+ break
+
+ finally:
+ logger.info("Finished")
+
+
+def run_agent_worker(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue):
+ """Worker function that runs the agent in a separate process and puts results in a queue."""
+ try:
+ for result in run_agent(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps):
+ return_queue.put(result)
+ except Exception as e:
+ logger.error(f"Error in agent worker process: {e}")
+ return_queue.put(("Error occurred in agent process", [], None, []))
+ import traceback
+ traceback.print_exc()
+ finally:
+ # Signal that the process is done
+ return_queue.put(None)
+
+def run_agent_wrapper(instruction, model_name="gpt-4o", start_url="about:blank",
+ use_html=False, use_axtree=True, use_screenshot=False, max_steps=20):
+ """Wrapper function that runs the agent in a separate process and yields results."""
+ return_queue = multiprocessing.Queue()
+
+ # Start the agent in a separate process
+ p = multiprocessing.Process(
+ target=run_agent_worker,
+ args=(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue)
+ )
+ p.daemon = True # Ensure process terminates when parent terminates
+ p.start()
+
+ # Get results from the queue and yield them
+ while True:
+ result = return_queue.get()
+ if result is None: # End signal
+ break
+ yield result
+
+ # Clean up
+ if p.is_alive():
+ p.terminate()
+ p.join()
+
+def process_run(instruction, model_name, start_url):
+ # Use the wrapper function instead of directly calling run_agent
+ trajectory_generator = run_agent_wrapper(
+ instruction,
+ model_name,
+ start_url,
+ use_html=False,
+ use_axtree=True,
+ use_screenshot=False
+ )
+
+ all_trajectory = []
+ last_checklist_view, last_trajectory_html = None, None
+
+ for state, checklist_view, rm_cards, screenshot, trajectory in trajectory_generator:
+ if checklist_view is None:
+ yield state, screenshot, last_checklist_view, None, last_trajectory_html
+ continue
+ # Create HTML for reward model cards
+ rm_cards_html = f"""
+
+
+ """
+
+ for idx, card in enumerate(rm_cards):
+ rm_cards_html += CARD_HTML_TEMPLATE.format(
+ additional_class='top-candidate' if idx == 0 else '',
+ k=idx+1,
+ suffix='(best)' if idx == 0 else '',
+ thought=card['thought'],
+ action=card['action'],
+ reward=card['reward'],
+ feedback=card['feedback']
+ )
+
+ rm_cards_html += "
"
+ all_trajectory = trajectory
+
+ # Create HTML for trajectory display
+ trajectory_html = f"""
+
+
+ """
+
+ for idx, (after_img, cands) in enumerate(all_trajectory):
+ # Convert image to base64 if needed
+ img = all_trajectory[idx][0]
+ if isinstance(img, np.ndarray):
+ img = Image.fromarray(img)
+ if isinstance(img, Image.Image):
+ buffer = io.BytesIO()
+ img.save(buffer, format="JPEG")
+ img_str = base64.b64encode(buffer.getvalue()).decode()
+ img_src = f"data:image/jpeg;base64,{img_str}"
+ else:
+ img_src = img
+
+ trajectory_html += f"""
+
+
+
+
+

+
+
+
Action Candidates:
+
+ """
+
+ # Display all candidates for this step
+ for i, cand in enumerate(cands):
+ action = cand['action']
+ reward = cand['reward']
+
+ trajectory_html += f"""
+
+
+ Action {i+1}{' (Selected)' if i == 0 else ''}
+ Reward: {reward}
+
+
{action}
+
+ """
+
+ trajectory_html += """
+
+
+
+
+ """
+
+ trajectory_html += "
"
+
+ last_checklist_view, last_trajectory_html = checklist_view, trajectory_html
+ yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html
+ yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html
diff --git a/requirements.txt b/requirements.txt
index cfc5b09a68217c6eba8d711a8c995c765049d339..bb8e47a583bb7b74c8504053dbb375c759cbc3ea 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,7 @@
-huggingface_hub==0.25.2
\ No newline at end of file
+gradio
+openai
+numpy
+gymnasium
+langsmith
+langchain[openai]
+pillow
\ No newline at end of file
diff --git a/templates/card.html b/templates/card.html
new file mode 100644
index 0000000000000000000000000000000000000000..274a06f43b3b615e0af8f310946c603bbc44a0dd
--- /dev/null
+++ b/templates/card.html
@@ -0,0 +1,33 @@
+
+
+
+
+
+
+
+
+ Feedback (click to view)
+
+
+
+
+
\ No newline at end of file
diff --git a/templates/rm_cards.css b/templates/rm_cards.css
new file mode 100644
index 0000000000000000000000000000000000000000..e09b97dd6a01ae909031d4e476af4120da0b4d0b
--- /dev/null
+++ b/templates/rm_cards.css
@@ -0,0 +1,106 @@
+.rm-cards-container {
+ display: flex;
+ gap: 15px;
+ padding: 10px 0;
+ overflow-x: auto;
+}
+
+.rm-card {
+ min-width: 300px;
+ max-width: 400px;
+ border: 1px solid #ddd;
+ border-radius: 8px;
+ overflow: hidden;
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
+ background: rgba(255, 255, 255, 0.0);
+}
+
+.rm-card.top-candidate {
+ border: 3px solid #007bff;
+ box-shadow: 0 4px 8px rgba(0, 0, 150, 0.2);
+}
+
+.rm-card-header {
+ background: rgba(240, 240, 240, 0.3);
+ padding: 10px 15px;
+ font-weight: bold;
+ border-bottom: 1px solid #ddd;
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+}
+
+.top-candidate .rm-card-header {
+ background: rgba(230, 242, 255, 0.3);
+}
+
+.reward-badge {
+ background: #007bff;
+ color: white;
+ padding: 3px 8px;
+ border-radius: 12px;
+ font-size: 0.9em;
+}
+
+.rm-card-body {
+ padding: 15px;
+ background-color: rgba(240, 240, 240, 0.0);
+}
+
+.card-section {
+ margin-bottom: 12px;
+}
+
+.card-section-title {
+ font-weight: bold;
+ margin-bottom: 5px;
+ color: #555;
+}
+
+.thought-content {
+ background: rgba(247, 247, 255, 0.3);
+ border: 1px solid #d0d0ff;
+ border-radius: 6px;
+ padding: 8px;
+}
+
+.action-content {
+ background: rgba(240, 255, 240, 0.3);
+ border: 1px solid #d0ffd0;
+ border-radius: 6px;
+ padding: 8px;
+}
+
+.feedback-content {
+ background: rgba(255, 247, 240, 0.3);
+ border: 1px solid #ffd0a0;
+ border-radius: 6px;
+ padding: 8px;
+}
+
+.reward-content {
+ background: rgba(240, 240, 255, 0.3);
+ border: 1px solid #d0d0ff;
+ border-radius: 6px;
+ padding: 8px;
+}
+
+details {
+ margin-top: 5px;
+}
+
+summary {
+ cursor: pointer;
+ font-weight: bold;
+ color: #555;
+}
+
+summary:hover {
+ color: #007bff;
+}
+
+pre {
+ margin: 0;
+ white-space: pre-wrap;
+ word-break: break-word;
+}
\ No newline at end of file
diff --git a/templates/trajectory.css b/templates/trajectory.css
new file mode 100644
index 0000000000000000000000000000000000000000..5795985190beaec7a73168113106e8094a16c0cd
--- /dev/null
+++ b/templates/trajectory.css
@@ -0,0 +1,82 @@
+/* templates/trajectory.css */
+
+.trajectory-container {
+ display: flex;
+ flex-direction: column;
+ gap: 20px;
+}
+
+.step-container {
+ border: 1px solid #ddd;
+ border-radius: 8px;
+ overflow: hidden;
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
+}
+
+.step-header {
+ background: rgba(240, 240, 240, 0.0);
+ padding: 10px 15px;
+ font-weight: bold;
+ border-bottom: 1px solid #ddd;
+}
+
+.step-content {
+ display: flex;
+ gap: 15px;
+ flex-wrap: wrap;
+}
+
+.step-image {
+ flex: 1 1 40%;
+ padding: 10px;
+}
+
+.step-image img {
+ width: 100%;
+ border: 1px solid #eee;
+ border-radius: 4px;
+}
+
+.step-info {
+ flex: 1 1 55%;
+ display: flex;
+ flex-direction: column;
+ padding: 10px;
+}
+
+.action-candidates {
+ display: flex;
+ flex-direction: column;
+ gap: 10px;
+}
+
+.candidate-box {
+ background: rgba(245, 240, 255, 0.0);
+ border: 1px solid #d0d0ff;
+ border-radius: 6px;
+ padding: 10px;
+}
+
+.candidate-box.selected {
+ border: 2px solid #7030a0;
+ box-shadow: 0 2px 4px rgba(112, 48, 160, 0.1);
+}
+
+.box-title {
+ font-weight: bold;
+ margin-bottom: 5px;
+}
+
+.reward-text {
+ display: inline-block;
+ color: #555;
+ font-size: 0.9em;
+ margin-left: 8px;
+ font-style: italic;
+}
+
+pre {
+ margin: 0;
+ white-space: pre-wrap;
+ word-break: break-word;
+}
\ No newline at end of file
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c133195e82b192db4b825a1938ba2a1918a1457c
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,41 @@
+import tokenize
+import io
+import base64
+import numpy as np
+from PIL import Image
+
+def remove_inline_comments_safe(code: str) -> str:
+ result = []
+ tokens = tokenize.generate_tokens(io.StringIO(code).readline)
+
+ last_line = -1
+ current_line = ''
+ for tok_type, tok_string, (srow, scol), (_, _), _ in tokens:
+ if srow != last_line:
+ if current_line:
+ result.append(current_line.rstrip())
+ current_line = ''
+ last_line = srow
+
+ if tok_type == tokenize.COMMENT:
+ # 주석 무시 (아무 것도 추가하지 않음)
+ continue
+
+ current_line += tok_string
+
+ if current_line:
+ result.append(current_line.rstrip())
+
+ return '\n'.join(result)
+
+
+def image_to_jpg_base64_url(image: Image.Image | np.ndarray) -> str:
+ """Return a base64 *JPEG* data‑URL from a PIL image or NumPy array."""
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ if image.mode in {"RGBA", "LA"}:
+ image = image.convert("RGB")
+ with io.BytesIO() as buffer:
+ image.save(buffer, format="JPEG")
+ encoded: str = base64.b64encode(buffer.getvalue()).decode()
+ return f"data:image/jpeg;base64,{encoded}"