File size: 2,307 Bytes
dd39c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import re
import tempfile
import logging
import dataclasses

from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.agent import Agent
from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result
from browsergym.utils.obs import flatten_axtree_to_str


class MiniwobTestAgent(Agent):

    action_set = HighLevelActionSet(subsets="bid")

    def obs_preprocessor(self, obs: dict):
        return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])}

    def get_action(self, obs: dict) -> tuple[str, dict]:
        match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE)

        if match:
            bid = match.group(1)
            action = f'click("{bid}")'
        else:
            raise Exception("Can't find the button's bid")

        return action, dict(think="I'm clicking the button as requested.")


@dataclasses.dataclass
class MiniwobTestAgentArgs(AbstractAgentArgs):
    def make_agent(self):
        return MiniwobTestAgent()


def test_run_exp():
    exp_args = ExpArgs(
        agent_args=MiniwobTestAgentArgs(),
        env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42),
    )

    with tempfile.TemporaryDirectory() as tmp_dir:
        exp_args.prepare(tmp_dir)
        exp_args.run()
        exp_result = get_exp_result(exp_args.exp_dir)
        exp_record = exp_result.get_exp_record()

        target = {
            "env_args.task_name": "miniwob.click-test",
            "env_args.task_seed": 42,
            "env_args.headless": True,
            "env_args.record_video": False,
            "n_steps": 1,
            "cum_reward": 1.0,
            "terminated": True,
            "truncated": False,
        }

        assert len(exp_result.steps_info) == 2

        for key, target_val in target.items():
            assert key in exp_record
            assert exp_record[key] == target_val

        # TODO investigate why it's taking almost 5 seconds to solve
        assert exp_record["stats.cum_step_elapsed"] < 5
        if exp_record["stats.cum_step_elapsed"] > 3:
            t = exp_record["stats.cum_step_elapsed"]
            logging.warning(
                f"miniwob.click-test is taking {t:.2f}s (> 3s) to solve with an oracle."
            )