Spaces:
Running
Running
File size: 2,307 Bytes
dd39c08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import re
import tempfile
import logging
import dataclasses
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.agent import Agent
from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result
from browsergym.utils.obs import flatten_axtree_to_str
class MiniwobTestAgent(Agent):
action_set = HighLevelActionSet(subsets="bid")
def obs_preprocessor(self, obs: dict):
return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])}
def get_action(self, obs: dict) -> tuple[str, dict]:
match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE)
if match:
bid = match.group(1)
action = f'click("{bid}")'
else:
raise Exception("Can't find the button's bid")
return action, dict(think="I'm clicking the button as requested.")
@dataclasses.dataclass
class MiniwobTestAgentArgs(AbstractAgentArgs):
def make_agent(self):
return MiniwobTestAgent()
def test_run_exp():
exp_args = ExpArgs(
agent_args=MiniwobTestAgentArgs(),
env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42),
)
with tempfile.TemporaryDirectory() as tmp_dir:
exp_args.prepare(tmp_dir)
exp_args.run()
exp_result = get_exp_result(exp_args.exp_dir)
exp_record = exp_result.get_exp_record()
target = {
"env_args.task_name": "miniwob.click-test",
"env_args.task_seed": 42,
"env_args.headless": True,
"env_args.record_video": False,
"n_steps": 1,
"cum_reward": 1.0,
"terminated": True,
"truncated": False,
}
assert len(exp_result.steps_info) == 2
for key, target_val in target.items():
assert key in exp_record
assert exp_record[key] == target_val
# TODO investigate why it's taking almost 5 seconds to solve
assert exp_record["stats.cum_step_elapsed"] < 5
if exp_record["stats.cum_step_elapsed"] > 3:
t = exp_record["stats.cum_step_elapsed"]
logging.warning(
f"miniwob.click-test is taking {t:.2f}s (> 3s) to solve with an oracle."
)
|