kyle8581's picture
upload
dd39c08
raw
history blame
2.31 kB
import re
import tempfile
import logging
import dataclasses
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.agent import Agent
from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result
from browsergym.utils.obs import flatten_axtree_to_str
class MiniwobTestAgent(Agent):
action_set = HighLevelActionSet(subsets="bid")
def obs_preprocessor(self, obs: dict):
return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])}
def get_action(self, obs: dict) -> tuple[str, dict]:
match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE)
if match:
bid = match.group(1)
action = f'click("{bid}")'
else:
raise Exception("Can't find the button's bid")
return action, dict(think="I'm clicking the button as requested.")
@dataclasses.dataclass
class MiniwobTestAgentArgs(AbstractAgentArgs):
def make_agent(self):
return MiniwobTestAgent()
def test_run_exp():
exp_args = ExpArgs(
agent_args=MiniwobTestAgentArgs(),
env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42),
)
with tempfile.TemporaryDirectory() as tmp_dir:
exp_args.prepare(tmp_dir)
exp_args.run()
exp_result = get_exp_result(exp_args.exp_dir)
exp_record = exp_result.get_exp_record()
target = {
"env_args.task_name": "miniwob.click-test",
"env_args.task_seed": 42,
"env_args.headless": True,
"env_args.record_video": False,
"n_steps": 1,
"cum_reward": 1.0,
"terminated": True,
"truncated": False,
}
assert len(exp_result.steps_info) == 2
for key, target_val in target.items():
assert key in exp_record
assert exp_record[key] == target_val
# TODO investigate why it's taking almost 5 seconds to solve
assert exp_record["stats.cum_step_elapsed"] < 5
if exp_record["stats.cum_step_elapsed"] > 3:
t = exp_record["stats.cum_step_elapsed"]
logging.warning(
f"miniwob.click-test is taking {t:.2f}s (> 3s) to solve with an oracle."
)