XanderJC commited on
Commit
f0f6e5c
·
1 Parent(s): 9d0f374
.gitignore CHANGED
@@ -169,3 +169,6 @@ cython_debug/
169
 
170
  # PyPI configuration file
171
  .pypirc
 
 
 
 
169
 
170
  # PyPI configuration file
171
  .pypirc
172
+
173
+ logs/
174
+ local_trajectories/
Makefile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .PHONY: proxy
2
+
3
+ proxy:
4
+ uv venv --python 3.11 --python-preference managed
5
+ uv sync
6
+ uv pip install -e .
7
+ playwright install
README.md CHANGED
@@ -1,2 +1,75 @@
1
- # proxy-lite
2
- A mini, open-weights, version of our Proxy assistant.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Proxy-Lite Logo](assets/proxy-lite.png)
2
+
3
+ A mini, open-weights version of our Proxy assistant.
4
+
5
+ ![Proxy-Lite Demo](demo.gif)
6
+
7
+ ---
8
+
9
+ ## Getting Started
10
+
11
+ ### Installation
12
+
13
+ Clone the repository:
14
+
15
+ ```bash
16
+ git clone https://github.com/convergence-ai/proxy-lite.git
17
+ ```
18
+
19
+ Set-up the environment with:
20
+
21
+ ```bash
22
+ make proxy
23
+ ```
24
+
25
+ Or do it manually:
26
+
27
+ ```bash
28
+ uv venv --python 3.11 --python-preference managed
29
+ uv sync
30
+ uv pip install -e .
31
+ playwright install
32
+ ```
33
+
34
+
35
+ ### Usage
36
+
37
+ ```bash
38
+ proxy --help
39
+ ```
40
+
41
+ You can directly run the proxy with:
42
+
43
+ ```bash
44
+ proxy "Book a table for 2 at an Italian restaurant in Kings Cross tonight at 7pm."
45
+ ```
46
+
47
+
48
+ ### Proxy-Lite Endpoint
49
+
50
+ By default, Proxy-Lite will point to an endpoint set up on HuggingFace spaces. This is a demo endpoint and is not suitable for production use; it may be very slow when under heavy load.
51
+
52
+ We recommend hosting your own endpoint with vLLM, you can use the following command:
53
+
54
+ ```bash
55
+ vllm serve --model convergence-ai/proxy-lite-7b \
56
+ --trust-remote-code \
57
+ --enable-auto-tool-choice \
58
+ --tool-call-parser hermes \
59
+ --port 8008 \
60
+ ```
61
+
62
+ You can set the `api_base` to point to your local endpoint when calling Proxy-Lite:
63
+
64
+ ```bash
65
+ proxy --api-base http://localhost:8008/v1 "Book a table...
66
+ ```
67
+ or by setting the environment variable:
68
+
69
+ ```bash
70
+ export PROXY_LITE_API_BASE=http://localhost:8008/v1
71
+ ```
72
+
73
+
74
+
75
+
pyproject.toml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "proxy-lite"
3
+ version = "0.1.0"
4
+ description = "Proxy Lite - A mini, open-weights, version of the Convergence AI Proxy assistant."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "omegaconf>=2.3.0",
9
+ "openai>=1.61.1",
10
+ "opencv-python>=4.11.0.86",
11
+ "playwright-stealth>=1.0.6",
12
+ "playwright>=1.50.0",
13
+ "pydantic>=2.10.6",
14
+ "rich>=13.9.4",
15
+ "setuptools>=75.8.0",
16
+ "tenacity>=9.0.0",
17
+ "torch>=2.6.0",
18
+ "torchvision>=0.21.0",
19
+ ]
20
+
21
+ [project.scripts]
22
+ proxy = "proxy_lite.cli:main"
23
+
24
+ [build-system]
25
+ requires = ["setuptools"]
26
+ build-backend = "setuptools.build_meta"
27
+
28
+ [tool.setuptools]
29
+ packages = { find = { where = ["src"] } }
30
+
31
+ [tool.setuptools.package-data]
32
+ proxy_lite = ["**/*.json"]
33
+
34
+ [tool.ruff]
35
+ line-length = 120
36
+
37
+ [tool.ruff.lint]
38
+ select = ["E", "F", "B", "I", "SIM"]
39
+ ignore = [
40
+ "B028",
41
+ "E722", # ignore bare except
42
+ "B904", # ignore raise from requirement
43
+ "FA102",
44
+ ]
45
+ [tool.ruff.lint.flake8-bugbear]
46
+
47
+ extend-immutable-calls = [
48
+ "fastapi.Depends",
49
+ "fastapi.params.Depends",
50
+ "fastapi.Query",
51
+ "fastapi.params.Query",
52
+ ]
53
+
src/proxy_lite/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .runner import Runner, RunnerConfig
2
+
3
+ __all__ = ["Runner", "RunnerConfig"]
src/proxy_lite/agents/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from .agent_base import Agents, BaseAgent, BaseAgentConfig
4
+ from .browser_agent import BrowserAgent, BrowserAgentConfig
5
+ from .proxy_lite_agent import ProxyLiteAgent, ProxyLiteAgentConfig
6
+
7
+ AgentTypes = Union[*list(Agents._agent_registry.values())]
8
+ AgentConfigTypes = Union[*list(Agents._agent_config_registry.values())]
9
+
10
+
11
+ __all__ = [
12
+ "AgentConfigTypes",
13
+ "AgentTypes",
14
+ "Agents",
15
+ "BaseAgent",
16
+ "BaseAgentConfig",
17
+ "BrowserAgent",
18
+ "BrowserAgentConfig",
19
+ "ProxyLiteAgent",
20
+ "ProxyLiteAgentConfig",
21
+ ]
src/proxy_lite/agents/agent_base.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from abc import ABC, abstractmethod
4
+ from contextlib import AsyncExitStack
5
+ from functools import cached_property
6
+ from typing import Any, Optional, Type, cast
7
+
8
+ from pydantic import BaseModel, Field
9
+ from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
10
+
11
+ from proxy_lite.client import BaseClient, ClientConfigTypes, OpenAIClientConfig
12
+ from proxy_lite.history import (
13
+ AssistantMessage,
14
+ MessageHistory,
15
+ MessageLabel,
16
+ SystemMessage,
17
+ Text,
18
+ ToolCall,
19
+ ToolMessage,
20
+ UserMessage,
21
+ )
22
+ from proxy_lite.logger import logger
23
+ from proxy_lite.tools import Tool
24
+
25
+ # if TYPE_CHECKING:
26
+ # from proxy_lite.tools import Tool
27
+
28
+
29
+ class BaseAgentConfig(BaseModel):
30
+ client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig)
31
+ history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict())
32
+ history_messages_include: Optional[dict[MessageLabel, int]] = Field(
33
+ default=None,
34
+ description="If set, overrides history_messages_limit by setting all message types to 0 except those specified",
35
+ )
36
+
37
+ def model_post_init(self, __context: Any) -> None:
38
+ if self.history_messages_include is not None:
39
+ self.history_messages_limit = {label: 0 for label in MessageLabel}
40
+ self.history_messages_limit.update(self.history_messages_include)
41
+
42
+
43
+ class BaseAgent(BaseModel, ABC):
44
+ config: BaseAgentConfig
45
+ temperature: float = Field(default=0.7, ge=0, le=2)
46
+ history: MessageHistory = Field(default_factory=MessageHistory)
47
+ client: Optional[BaseClient] = None
48
+ env_tools: list[Tool] = Field(default_factory=list)
49
+ task: Optional[str] = Field(default=None)
50
+ seed: Optional[int] = Field(default=None)
51
+
52
+ class Config:
53
+ arbitrary_types_allowed = True
54
+
55
+ def __init__(self, **data) -> None:
56
+ super().__init__(**data)
57
+ self._exit_stack = AsyncExitStack()
58
+ self._tools_init_task = None
59
+
60
+ def model_post_init(self, __context: Any) -> None:
61
+ super().model_post_init(__context)
62
+ self.client = BaseClient.create(self.config.client)
63
+
64
+ @property
65
+ @abstractmethod
66
+ def system_prompt(self) -> str: ...
67
+
68
+ @cached_property
69
+ @abstractmethod
70
+ def tools(self) -> list[Tool]: ...
71
+
72
+ @cached_property
73
+ def tool_descriptions(self) -> str:
74
+ tool_descriptions = []
75
+ for tool in self.tools:
76
+ func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema)
77
+ tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else ""
78
+ tool_descriptions.append(f"{tool_title}{func_descriptions}")
79
+ return "\n\n".join(tool_descriptions)
80
+
81
+ async def get_history_view(self) -> MessageHistory:
82
+ return MessageHistory(
83
+ messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
84
+ ) + self.history.history_view(
85
+ limits=self.config.history_messages_limit,
86
+ )
87
+
88
+ @retry(
89
+ wait=wait_exponential(multiplier=1, min=4, max=10),
90
+ stop=stop_after_attempt(3),
91
+ reraise=True,
92
+ before_sleep=before_sleep_log(logger, logging.ERROR),
93
+ )
94
+ async def generate_output(
95
+ self,
96
+ use_tool: bool = False,
97
+ response_format: Optional[type[BaseModel]] = None,
98
+ append_assistant_message: bool = True,
99
+ ) -> AssistantMessage:
100
+ messages: MessageHistory = await self.get_history_view()
101
+ response_content = (
102
+ await self.client.create_completion(
103
+ messages=messages,
104
+ temperature=self.temperature,
105
+ seed=self.seed,
106
+ response_format=response_format,
107
+ tools=self.tools if use_tool else None,
108
+ )
109
+ ).model_dump()
110
+ response_content = response_content["choices"][0]["message"]
111
+ assistant_message = AssistantMessage(
112
+ role=response_content["role"],
113
+ content=[Text(text=response_content["content"])] if response_content["content"] else [],
114
+ tool_calls=response_content["tool_calls"],
115
+ )
116
+ if append_assistant_message:
117
+ self.history.append(message=assistant_message, label=self.message_label)
118
+ return assistant_message
119
+
120
+ def receive_user_message(
121
+ self,
122
+ text: Optional[str] = None,
123
+ image: list[bytes] = None,
124
+ label: MessageLabel = None,
125
+ is_base64: bool = False,
126
+ ) -> None:
127
+ message = UserMessage.from_media(
128
+ text=text,
129
+ image=image,
130
+ is_base64=is_base64,
131
+ )
132
+ self.history.append(message=message, label=label)
133
+
134
+ def receive_system_message(
135
+ self,
136
+ text: Optional[str] = None,
137
+ label: MessageLabel = None,
138
+ ) -> None:
139
+ message = SystemMessage.from_media(text=text)
140
+ self.history.append(message=message, label=label)
141
+
142
+ def receive_assistant_message(
143
+ self,
144
+ content: Optional[str] = None,
145
+ tool_calls: Optional[list[ToolCall]] = None,
146
+ label: MessageLabel = None,
147
+ ) -> None:
148
+ message = AssistantMessage(
149
+ content=[Text(text=content)] if content else [],
150
+ tool_calls=tool_calls,
151
+ )
152
+ self.history.append(message=message, label=label)
153
+
154
+ async def use_tool(self, tool_call: ToolCall):
155
+ function = tool_call.function
156
+ for tool in self.tools:
157
+ if hasattr(tool, function["name"]):
158
+ return await getattr(tool, function["name"])(
159
+ **json.loads(function["arguments"]),
160
+ )
161
+ msg = f'No tool function with name "{function["name"]}"'
162
+ raise ValueError(msg)
163
+
164
+ async def receive_tool_message(
165
+ self,
166
+ text: str,
167
+ tool_id: str,
168
+ label: MessageLabel = None,
169
+ ) -> None:
170
+ self.history.append(
171
+ message=ToolMessage(content=[Text(text=text)], tool_call_id=tool_id),
172
+ label=label,
173
+ )
174
+
175
+
176
+ class Agents:
177
+ _agent_registry: dict[str, type[BaseAgent]] = {}
178
+ _agent_config_registry: dict[str, type[BaseAgentConfig]] = {}
179
+
180
+ @classmethod
181
+ def register_agent(cls, name: str):
182
+ """
183
+ Decorator to register an Agent class under a given name.
184
+
185
+ Example:
186
+ @Agents.register_agent("browser")
187
+ class BrowserAgent(BaseAgent):
188
+ ...
189
+ """
190
+
191
+ def decorator(agent_cls: type[BaseAgent]) -> type[BaseAgent]:
192
+ cls._agent_registry[name] = agent_cls
193
+ return agent_cls
194
+
195
+ return decorator
196
+
197
+ @classmethod
198
+ def register_agent_config(cls, name: str):
199
+ """
200
+ Decorator to register a configuration class under a given name.
201
+
202
+ Example:
203
+ @Agents.register_agent_config("browser")
204
+ class BrowserAgentConfig(BaseAgentConfig):
205
+ ...
206
+ """
207
+
208
+ def decorator(config_cls: type[BaseAgentConfig]) -> type[BaseAgentConfig]:
209
+ cls._agent_config_registry[name] = config_cls
210
+ return config_cls
211
+
212
+ return decorator
213
+
214
+ @classmethod
215
+ def get(cls, name: str) -> type[BaseAgent]:
216
+ """
217
+ Retrieve a registered Agent class by its name.
218
+
219
+ Raises:
220
+ ValueError: If no such agent is found.
221
+ """
222
+ try:
223
+ return cast(Type[BaseAgent], cls._agent_registry[name])
224
+ except KeyError:
225
+ raise ValueError(f"Agent '{name}' not found.")
226
+
227
+ @classmethod
228
+ def get_config(cls, name: str) -> type[BaseAgentConfig]:
229
+ """
230
+ Retrieve a registered Agent configuration class by its name.
231
+
232
+ Raises:
233
+ ValueError: If no such config is found.
234
+ """
235
+ try:
236
+ return cast(type[BaseAgentConfig], cls._agent_config_registry[name])
237
+ except KeyError:
238
+ raise ValueError(f"Agent config for '{name}' not found.")
src/proxy_lite/agents/browser_agent.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import cached_property
3
+ from typing import Literal
4
+
5
+ from pydantic import Field
6
+
7
+ from proxy_lite.agents.agent_base import Agents, BaseAgent, BaseAgentConfig
8
+ from proxy_lite.history import MessageHistory, MessageLabel, SystemMessage, Text
9
+ from proxy_lite.tools import Tool
10
+
11
+ BROWSER_AGENT_SYSTEM_PROMPT = """ **You are Proxy Lite, the Web-Browsing Agent.** You are developed by Convergence.
12
+
13
+ **Current date:** {date_time_with_day}.
14
+
15
+ You are given:
16
+
17
+ 1. A user task that you are trying to complete.
18
+ 2. Relevant facts we have at our disposal.
19
+ 3. A high level plan to complete the task.
20
+ 4. A history of previous actions and observations.
21
+ 5. An annotated webpage screenshot and text description of what's visible in the browser before and after the last action.
22
+
23
+ ## Objective
24
+
25
+ You are an expert at controlling the web browser.
26
+ You will be assisting a user with a task they are trying to complete on the web.
27
+
28
+ ## Web Screenshots
29
+
30
+ Each iteration of your browsing loop, you'll be provided with a screenshot of the browser.
31
+
32
+ The screenshot will have red rectangular annotations. These annotations highlight the marked elements you can interact with.
33
+
34
+ ## Mark IDs
35
+
36
+ Each annotated element is labeled with a "mark id" in the top-left corner.
37
+
38
+ When using tools like typing or clicking, specify the "mark id" to indicate which element you want to interact with.
39
+
40
+ If an element is not annotated, you cannot interact with it. This is a limitation of the software. Focus on marked elements only.
41
+
42
+ ## Text Snippets
43
+
44
+ Along with the screenshot, you will receive text snippets describing each annotated element.
45
+
46
+ Here’s an example of different element types:
47
+
48
+ - [0] `<a>text</a>` → Mark 0 is a link (`<a>` tag) containing the text "text".
49
+ - [1] `<button>text</button>` → Mark 1 is a button (`<button>` tag) containing the text "text".
50
+ - [2] `<input value="text"/>` → Mark 2 is an input field (`<input>` tag) with the value "text".
51
+ - [3] `<select>text</select>` → Mark 3 is a dropdown menu (`<select>` tag) with the option "text" selected.
52
+ - [4] `<textarea>text</textarea>` → Mark 4 is a text area (`<textarea>` tag) containing the text "text".
53
+ - [5] `<li>text</li>` → Mark 5 is a list item (`<li>` tag) containing the text "text".
54
+ - [6] `<div scrollable>text</div>` → Mark 6 is a division (`<div>` tag) containing the text "text" and is scrollable.
55
+ - [7] `<td>text</td>` → Mark 7 is a table cell (`<td>` tag) containing the text "text".
56
+
57
+ Note that these text snippets may be incomplete.
58
+
59
+ ## History
60
+
61
+ You will see your past actions and observations but not old annotated webpages.
62
+
63
+ This means annotated webpages showing useful information will not be visible in future actions.
64
+
65
+ To get around this, key details from each webpage are stored in observations.
66
+
67
+ ## Web Browser Actions
68
+
69
+ You can only take the following actions with the web browser:
70
+ {tool_descriptions}
71
+
72
+ ## Important Browsing Tips
73
+
74
+ If there is a modal overlay that is unresponsive on the page try reloading the webpage.
75
+
76
+ If there is a cookie consent form covering part of the page just click accept on the form.
77
+
78
+ When typing into a text field be sure to click one of the dropdown options (when present). Not selecting a dropdown option will result in the field being cleared after the next action.
79
+
80
+ You do not have access any internet accounts (outside of those provided by the user).
81
+
82
+ The browser has a built in CAPTCHA solver, if you are asked to solve one just wait and it will be solved for you.
83
+
84
+ ## Don't Repeat the Same Actions Continuously
85
+
86
+ If you find yourself repeating an action without making progress, try another action.
87
+
88
+ ## Task
89
+
90
+ You will now be connected to the user, who will give you their task.""" # noqa: E501
91
+
92
+ MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
93
+ MessageLabel.SCREENSHOT: 1,
94
+ # MessageLabel.REASONING_INDUCTION: 1,
95
+ # MessageLabel.FORMAT_INSTRUCTIONS: 1,
96
+ # MessageLabel.ACTION: 1,
97
+ }
98
+
99
+
100
+ @Agents.register_agent_config("browser")
101
+ class BrowserAgentConfig(BaseAgentConfig):
102
+ name: Literal["browser"] = "browser"
103
+ history_messages_limit: dict[MessageLabel, int] = Field(
104
+ default_factory=lambda: MAX_MESSAGES_FOR_CONTEXT_WINDOW,
105
+ )
106
+
107
+
108
+ @Agents.register_agent("browser")
109
+ class BrowserAgent(BaseAgent):
110
+ config: BrowserAgentConfig
111
+ message_label: MessageLabel = MessageLabel.AGENT_MODEL_RESPONSE
112
+
113
+ def __init__(self, **data):
114
+ super().__init__(**data)
115
+
116
+ @property
117
+ def system_prompt(self) -> str:
118
+ return BROWSER_AGENT_SYSTEM_PROMPT.format(
119
+ date_time_with_day=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
120
+ tool_descriptions=self.tool_descriptions,
121
+ memories="",
122
+ )
123
+
124
+ @cached_property
125
+ def tools(self) -> list[Tool]:
126
+ return self.env_tools
127
+
128
+ async def get_history_view(self) -> MessageHistory:
129
+ return MessageHistory(
130
+ messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
131
+ ) + self.history.history_view(
132
+ limits=self.config.history_messages_limit,
133
+ )
src/proxy_lite/agents/proxy_lite_agent.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from typing import Literal
3
+
4
+ from pydantic import Field
5
+
6
+ from proxy_lite.history import MessageHistory, MessageLabel, SystemMessage, Text
7
+ from proxy_lite.tools import Tool
8
+
9
+ from .agent_base import Agents, BaseAgent, BaseAgentConfig
10
+
11
+ MODEL_SYSTEM_PROMPT = """You are Proxy-Lite, an AI assistant that can perform actions on a computer screen.
12
+ You were developed by Convergence AI.
13
+ The user will instuct you to perform a task.
14
+ You will be shown a screen as well as relevant interactable elements highlighted by mark_ids and you will be given a set of tools to use to perform the task.
15
+ You should make observations about the screen, putting them in <observation></observation> tags.
16
+ You should then reason about what needs to be done to complete the task, putting your thoughts in <thinking></thinking> tags.
17
+ You should then use the tools to perform the task, putting the tool calls in <tool_call></tool_call> tags.
18
+ """ # noqa: E501
19
+
20
+ MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
21
+ MessageLabel.SCREENSHOT: 1,
22
+ }
23
+
24
+
25
+ @Agents.register_agent_config("proxy_lite")
26
+ class ProxyLiteAgentConfig(BaseAgentConfig):
27
+ name: Literal["proxy_lite"] = "proxy_lite"
28
+ history_messages_limit: dict[MessageLabel, int] = Field(
29
+ default_factory=lambda: MAX_MESSAGES_FOR_CONTEXT_WINDOW,
30
+ )
31
+
32
+
33
+ @Agents.register_agent("proxy_lite")
34
+ class ProxyLiteAgent(BaseAgent):
35
+ config: ProxyLiteAgentConfig
36
+ message_label: MessageLabel = MessageLabel.AGENT_MODEL_RESPONSE
37
+
38
+ def __init__(self, **data):
39
+ super().__init__(**data)
40
+
41
+ @property
42
+ def system_prompt(self) -> str:
43
+ return MODEL_SYSTEM_PROMPT
44
+
45
+ @cached_property
46
+ def tools(self) -> list[Tool]:
47
+ return self.env_tools
48
+
49
+ async def get_history_view(self) -> MessageHistory:
50
+ return MessageHistory(
51
+ messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
52
+ ) + self.history.history_view(
53
+ limits=self.config.history_messages_limit,
54
+ )
src/proxy_lite/browser/__init__.py ADDED
File without changes
src/proxy_lite/browser/add_custom_select.js ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ handledSelectElementsConvergence = new WeakSet();
2
+
3
+ overwriteDefaultSelectConvergence = (input = null) => {
4
+ let activeSelectElement = null;
5
+
6
+ // Handle iframe input element
7
+ let rootElement = input ? input : document.documentElement;
8
+
9
+ function createCustomSelectElement() {
10
+ // Create the custom select container
11
+ const customSelect = document.createElement('div');
12
+ customSelect.id = 'convergence-custom-select-element-X2EmudtLRN';
13
+ customSelect.style.position = 'absolute'
14
+ customSelect.style.zIndex = 2147483647 - 1;
15
+ customSelect.style.display = 'none';
16
+ document.body.appendChild(customSelect);
17
+
18
+ // Create the select options list
19
+ const optionsList = document.createElement('div');
20
+ optionsList.style.border = '1px solid #ccc';
21
+ optionsList.style.backgroundColor = '#fff';
22
+ optionsList.style.color = 'black';
23
+ customSelect.appendChild(optionsList);
24
+
25
+ return customSelect;
26
+ }
27
+
28
+ function showCustomSelect(select) {
29
+ activeSelectElement = select;
30
+
31
+ // Clear previous options
32
+ const customSelect = rootElement.querySelector('#convergence-custom-select-element-X2EmudtLRN');
33
+ let optionsList = customSelect.firstChild;
34
+ optionsList.innerHTML = '';
35
+
36
+ // Populate with new options
37
+ Array.from(select.options).forEach(option => {
38
+ const customOption = document.createElement('div');
39
+ customOption.className = 'custom-option';
40
+ customOption.style.padding = '8px';
41
+ customOption.style.cursor = 'pointer';
42
+ customOption.textContent = option.text;
43
+ customOption.dataset.value = option.value;
44
+ optionsList.appendChild(customOption);
45
+
46
+ customOption.addEventListener('mouseenter', function () {
47
+ customOption.style.backgroundColor = '#f0f0f0';
48
+ });
49
+
50
+ customOption.addEventListener('mouseleave', function () {
51
+ customOption.style.backgroundColor = '';
52
+ });
53
+
54
+ customOption.addEventListener('mousedown', (e) => {
55
+ e.stopPropagation();
56
+ select.value = customOption.dataset.value;
57
+ customSelect.style.display = 'none';
58
+ activeSelectElement = null;
59
+ // ensure we trigger all potential event listeners
60
+ select.dispatchEvent(new InputEvent('focus', { bubbles: true, cancelable: true }));
61
+ select.dispatchEvent(new InputEvent('input', { bubbles: true, cancelable: true }));
62
+ select.dispatchEvent(new InputEvent('change', { bubbles: true, cancelable: true }));
63
+ select.dispatchEvent(new InputEvent('blur', { bubbles: true, cancelable: true }));
64
+ });
65
+ });
66
+
67
+ // Position and show the custom select
68
+ const selectRect = select.getBoundingClientRect();
69
+ customSelect.style.top = `${selectRect.bottom + window.scrollY}px`;
70
+ customSelect.style.left = `${selectRect.left + window.scrollX}px`;
71
+ customSelect.style.width = `${selectRect.width}px`;
72
+ customSelect.style.display = 'block';
73
+ select.focus();
74
+ select.addEventListener('blur', function (e) {
75
+ customSelect.style.display = 'none';
76
+ activeSelectElement = null;
77
+ });
78
+ select.addEventListener('change', function (e) {
79
+ customSelect.style.display = 'none';
80
+ activeSelectElement = null;
81
+ });
82
+ }
83
+
84
+ // Ensure we have a custom select element
85
+ let customSelect = rootElement.querySelector(`#convergence-custom-select-element-X2EmudtLRN`);
86
+ if (!customSelect) {
87
+ customSelect = createCustomSelectElement();
88
+ }
89
+
90
+ // Find selects in shadow DOMs
91
+ function findSelectInShadowRoot(element) {
92
+ if (element.shadowRoot) {
93
+ return element.shadowRoot.querySelectorAll('select');
94
+ }
95
+ return [];
96
+ }
97
+ let shadowSelects = [];
98
+ rootElement.querySelectorAll('*').forEach(el => {
99
+ shadowSelects.push(...findSelectInShadowRoot(el));
100
+ });
101
+
102
+ // Find selects in the regular (light) DOM
103
+ const lightSelects = Array.from(rootElement.querySelectorAll('select'));
104
+
105
+ // Add event listeners to all select elements
106
+ const allSelects = [...lightSelects, ...shadowSelects];
107
+ allSelects.forEach(select => {
108
+ if (select.hasAttribute('multiple')) {
109
+ // skip special multiple elements as our POI code already handles them
110
+ return;
111
+ }
112
+ if (!handledSelectElementsConvergence.has(select)) {
113
+ select.addEventListener('mousedown', (e) => {
114
+ // only use custom select when the default behaviour is being used
115
+ if (!e.defaultPrevented) {
116
+ showCustomSelect(select);
117
+ e.preventDefault();
118
+ }
119
+ });
120
+ handledSelectElementsConvergence.add(select);
121
+ }
122
+ });
123
+ }
src/proxy_lite/browser/bounding_boxes.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from pydantic import BaseModel, Field, field_validator
7
+
8
+
9
+ class Point(BaseModel):
10
+ x: int
11
+ y: int
12
+
13
+ def __iter__(self):
14
+ return iter((self.x, self.y))
15
+
16
+ def __getitem__(self, index) -> int:
17
+ return (self.x, self.y)[index]
18
+
19
+ def __tuple__(self) -> tuple[int, int]:
20
+ return (self.x, self.y)
21
+
22
+ def __repr__(self) -> str:
23
+ return f"Point(x={self.x}, y={self.y})"
24
+
25
+
26
+ class BoundingBox(BaseModel):
27
+ label: str = Field(..., description="The label that's given for this bounding box")
28
+ left: int = Field(..., description="Left coordinate of the bounding box")
29
+ right: int = Field(..., description="Right coordinate of the bounding box")
30
+ top: int = Field(..., description="Top coordinate of the bounding box")
31
+ bottom: int = Field(..., description="Bottom coordinate of the bounding box")
32
+
33
+ @field_validator("left", "top", mode="before")
34
+ @classmethod
35
+ def round_down(cls, v):
36
+ return math.floor(float(v))
37
+
38
+ @field_validator("right", "bottom", mode="before")
39
+ @classmethod
40
+ def round_up(cls, v):
41
+ return math.ceil(float(v))
42
+
43
+
44
+ class POI(BaseModel):
45
+ info: dict[str, Any]
46
+ element_centroid: Point
47
+ bounding_box: BoundingBox
48
+
49
+
50
+ def calculate_dash_points(start, end, dash_length, gap_length):
51
+ x1, y1 = start
52
+ x2, y2 = end
53
+ dx = x2 - x1
54
+ dy = y2 - y1
55
+ dist = np.sqrt(dx * dx + dy * dy)
56
+
57
+ if dist == 0:
58
+ return []
59
+
60
+ unit_x = dx / dist
61
+ unit_y = dy / dist
62
+
63
+ dash_points = []
64
+ current_dist = 0
65
+ while current_dist < dist:
66
+ dash_end = min(current_dist + dash_length, dist)
67
+ dash_points.extend(
68
+ [
69
+ (int(x1 + unit_x * current_dist), int(y1 + unit_y * current_dist)),
70
+ (int(x1 + unit_x * dash_end), int(y1 + unit_y * dash_end)),
71
+ ],
72
+ )
73
+ current_dist += dash_length + gap_length
74
+
75
+ return dash_points
76
+
77
+
78
+ def draw_dashed_rectangle(
79
+ img,
80
+ bbox: BoundingBox,
81
+ color,
82
+ thickness=1,
83
+ dash_length=10,
84
+ gap_length=5,
85
+ ):
86
+ # Calculate dash points for all sides
87
+ top_points = calculate_dash_points(
88
+ (bbox.left + 25, bbox.top + 25),
89
+ (bbox.right + 25, bbox.top + 25),
90
+ dash_length,
91
+ gap_length,
92
+ )
93
+ right_points = calculate_dash_points(
94
+ (bbox.right + 25, bbox.top + 25),
95
+ (bbox.right + 25, bbox.bottom + 25),
96
+ dash_length,
97
+ gap_length,
98
+ )
99
+ bottom_points = calculate_dash_points(
100
+ (bbox.right + 25, bbox.bottom + 25),
101
+ (bbox.left + 25, bbox.bottom + 25),
102
+ dash_length,
103
+ gap_length,
104
+ )
105
+ left_points = calculate_dash_points(
106
+ (bbox.left + 25, bbox.bottom + 25),
107
+ (bbox.left + 25, bbox.top + 25),
108
+ dash_length,
109
+ gap_length,
110
+ )
111
+
112
+ # Combine all points
113
+ all_points = top_points + right_points + bottom_points + left_points
114
+
115
+ # Draw all lines at once
116
+ if all_points:
117
+ all_points = np.array(all_points).reshape((-1, 2, 2))
118
+ cv2.polylines(img, all_points, False, color, thickness)
119
+
120
+
121
+ # @time_it(name='Annotate bounding box')
122
+ def annotate_bounding_box(image: bytes, bbox: BoundingBox) -> None:
123
+ # Draw dashed bounding box
124
+ draw_dashed_rectangle(
125
+ image,
126
+ bbox,
127
+ color=(0, 0, 255),
128
+ thickness=1,
129
+ dash_length=10,
130
+ gap_length=5,
131
+ )
132
+
133
+ # Prepare label
134
+ font_scale = 0.4 * 4 # Increased by 4x for the larger patch
135
+ font = cv2.FONT_HERSHEY_SIMPLEX
136
+ thickness = 3 # Increased thickness for the larger patch
137
+
138
+ # Get text size for the larger patch
139
+ (label_width, label_height), _ = cv2.getTextSize(
140
+ bbox.label,
141
+ font,
142
+ font_scale,
143
+ thickness,
144
+ )
145
+
146
+ # Create a larger patch (4x)
147
+ large_label_patch = np.zeros(
148
+ (label_height + 20, label_width + 20, 4),
149
+ dtype=np.uint8,
150
+ )
151
+ large_label_patch[:, :, 0:3] = (0, 0, 255) # BGR color format: Red background
152
+ large_label_patch[:, :, 3] = 128 # Alpha channel: 50% opacity (128/255 = 0.5)
153
+
154
+ # Draw text on the larger patch
155
+ cv2.putText(
156
+ large_label_patch,
157
+ bbox.label,
158
+ (8, label_height + 8), # Adjusted position for the larger patch
159
+ font,
160
+ font_scale,
161
+ (255, 255, 255, 128), # White text, 50% opaque (128/255 = 0.5)
162
+ thickness,
163
+ )
164
+
165
+ # Scale down the patch to improve anti-aliasing
166
+ label_patch = cv2.resize(
167
+ large_label_patch,
168
+ (label_width // 4 + 5, label_height // 4 + 5),
169
+ interpolation=cv2.INTER_AREA,
170
+ )
171
+
172
+ # Calculate position for top-left alignment
173
+ offset = 2 # Small offset to prevent touching the bounding box edge
174
+ x = min(image.shape[1], max(0, int(bbox.left + 25) - offset))
175
+ y = min(image.shape[0], max(0, int(bbox.top + 25) - label_patch.shape[0] - offset))
176
+
177
+ # Ensure we're not out of bounds
178
+ x_end = min(image.shape[1], x + label_patch.shape[1])
179
+ y_end = min(image.shape[0], y + label_patch.shape[0])
180
+ label_patch = label_patch[: (y_end - y), : (x_end - x)]
181
+
182
+ # Create a mask for the label patch
183
+ alpha_mask = label_patch[:, :, 3] / 255.0
184
+ alpha_mask = np.repeat(alpha_mask[:, :, np.newaxis], 3, axis=2)
185
+
186
+ # Blend the label patch with the image
187
+ image_section = image[y:y_end, x:x_end]
188
+ blended = (1 - alpha_mask) * image_section + alpha_mask * label_patch[:, :, 0:3]
189
+ image[y:y_end, x:x_end] = blended.astype(np.uint8)
190
+
191
+
192
+ def annotate_bounding_boxes(image: bytes, bounding_boxes: list[BoundingBox]) -> bytes:
193
+ # Read the image
194
+ nparr = np.frombuffer(image, np.uint8)
195
+ # Decode the image
196
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
197
+ padded_img = cv2.copyMakeBorder(
198
+ img,
199
+ top=25, # Value chosen based on label size
200
+ bottom=25, # Value chosen based on label size
201
+ left=25, # Value chosen based on label size
202
+ right=25, # Value chosen based on label size
203
+ borderType=cv2.BORDER_CONSTANT,
204
+ value=(255, 255, 255),
205
+ )
206
+ for bounding_box in bounding_boxes:
207
+ # Annotate the image in place with the bounding box and the bounding box label
208
+ annotate_bounding_box(padded_img, bounding_box)
209
+ _, buffer = cv2.imencode(".jpeg", padded_img)
210
+ return buffer.tobytes()
src/proxy_lite/browser/browser.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import re
4
+ from contextlib import AsyncExitStack
5
+ from pathlib import Path
6
+ from typing import Literal, Optional, Self
7
+
8
+ from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
9
+ from playwright.async_api import TimeoutError as PlaywrightTimeoutError
10
+ from playwright_stealth import stealth_async
11
+ from pydantic import Field
12
+ from tenacity import before_sleep_log, retry, stop_after_delay, wait_exponential
13
+
14
+ from proxy_lite.browser.bounding_boxes import POI, BoundingBox, Point, annotate_bounding_boxes
15
+ from proxy_lite.logger import logger
16
+
17
+ SELF_CONTAINED_TAGS = [
18
+ # many of these are non-interactive but keeping them anyway
19
+ "area",
20
+ "base",
21
+ "br",
22
+ "col",
23
+ "embed",
24
+ "hr",
25
+ "img",
26
+ "input",
27
+ "link",
28
+ "meta",
29
+ "param",
30
+ "source",
31
+ "track",
32
+ "wbr",
33
+ ]
34
+
35
+
36
+ def element_as_text(
37
+ mark_id: int,
38
+ tag: Optional[str] = None,
39
+ text: Optional[str] = None,
40
+ **raw_attributes,
41
+ ) -> str:
42
+ """Return a text representation of all elements on the page."""
43
+ attributes = []
44
+ for k, v in raw_attributes.items():
45
+ if v is None:
46
+ continue
47
+ if isinstance(v, bool):
48
+ if v:
49
+ attributes.append(k)
50
+ # we ignore False bool attributes
51
+ else:
52
+ v = str(v)
53
+ if len(v) > 2500:
54
+ v = v[: 2500 - 1] + "…"
55
+ attributes.append(f'{k}="{v}"')
56
+ attributes = " ".join(attributes)
57
+ attributes = (" " + attributes).rstrip()
58
+ tag = tag.lower()
59
+ if text is None:
60
+ text = ""
61
+ if len(text) > 2500:
62
+ text = text[: 2500 - 1] + "…"
63
+
64
+ # sub-out line breaks so elements are easier to distinguish
65
+ attributes = re.sub(r"\r\n|\r|\n", "⏎", attributes)
66
+ text = re.sub(r"\r\n|\r|\n", "⏎", text)
67
+
68
+ if tag in SELF_CONTAINED_TAGS:
69
+ if text:
70
+ logger.warning(
71
+ f"Got self-contained element '{tag}' which contained text '{text}'.",
72
+ )
73
+ else:
74
+ return f"- [{mark_id}] <{tag}{attributes}/>"
75
+ return f"- [{mark_id}] <{tag}{attributes}>{text}</{tag}>"
76
+
77
+
78
+ class BrowserSession:
79
+ def __init__(
80
+ self,
81
+ viewport_width: int = 1280,
82
+ viewport_height: int = 720,
83
+ headless: bool = True,
84
+ ):
85
+ self.viewport_width = viewport_width
86
+ self.viewport_height = viewport_height
87
+ self.headless = headless
88
+ self.playwright: Playwright | None = None
89
+ self.browser: Browser | None = None
90
+ self.context: BrowserContext | None = None
91
+ self._exit_stack: AsyncExitStack | None = None
92
+
93
+ self.poi_elements: list = Field(default_factory=list)
94
+ self.poi_centroids: list[Point] = Field(default_factory=list)
95
+ self.bounding_boxes: list[BoundingBox] = Field(default_factory=list)
96
+ self.pois: list[POI] = Field(default_factory=list)
97
+
98
+ async def __aenter__(self) -> Self:
99
+ self._exit_stack = AsyncExitStack()
100
+ self.playwright = await async_playwright().start()
101
+
102
+ self.browser = await self.playwright.chromium.launch(headless=self.headless)
103
+ self.context = await self.browser.new_context(
104
+ viewport={"width": self.viewport_width, "height": self.viewport_height},
105
+ )
106
+ await self.context.new_page()
107
+ self.context.set_default_timeout(60_000)
108
+ self.current_page.set_default_timeout(60_000)
109
+ await stealth_async(self.current_page)
110
+ await self.context.add_init_script(
111
+ path=Path(__file__).with_name("add_custom_select.js"),
112
+ )
113
+ await self.context.add_init_script(
114
+ path=Path(__file__).with_name("find_pois.js"),
115
+ )
116
+
117
+ return self
118
+
119
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
120
+ if self.browser:
121
+ await self.browser.close()
122
+ if self.playwright:
123
+ await self.playwright.stop()
124
+ if self._exit_stack:
125
+ await self._exit_stack.aclose()
126
+
127
+ @property
128
+ def current_page(self) -> Optional[Page]:
129
+ if self.context.pages:
130
+ return self.context.pages[-1]
131
+ return None
132
+
133
+ @property
134
+ def current_url(self) -> Optional[str]:
135
+ if self.current_page:
136
+ return self.current_page.url
137
+ return None
138
+
139
+ # re-run for cases of mid-run redirects
140
+ @retry(
141
+ wait=wait_exponential(multiplier=1, min=1, max=10),
142
+ stop=stop_after_delay(5),
143
+ reraise=True,
144
+ before_sleep=before_sleep_log(logger, logging.ERROR),
145
+ )
146
+ async def process_iframe(self, iframe) -> Optional[tuple[dict, dict]]:
147
+ try:
148
+ # Check iframe visibility and size
149
+ bounding_box = await iframe.bounding_box()
150
+ if not bounding_box:
151
+ return None # Skip if iframe is not visible
152
+
153
+ width, height = bounding_box["width"], bounding_box["height"]
154
+ if width < 50 or height < 50:
155
+ return None
156
+
157
+ frame = await iframe.content_frame()
158
+ if not frame:
159
+ return None
160
+
161
+ poi = await frame.evaluate(
162
+ """() => {
163
+ overwriteDefaultSelectConvergence();
164
+ return findPOIsConvergence();
165
+ }""",
166
+ )
167
+ if not poi:
168
+ return None
169
+
170
+ iframe_offset = {"x": round(bounding_box["x"]), "y": round(bounding_box["y"])}
171
+ return poi, iframe_offset
172
+ except Exception as e:
173
+ logger.error(f"Error processing iframe: {e}")
174
+ return None
175
+
176
+ # re-run for cases of mid-run redirects
177
+ @retry(
178
+ wait=wait_exponential(multiplier=1, min=1, max=10),
179
+ stop=stop_after_delay(5),
180
+ reraise=True,
181
+ before_sleep=before_sleep_log(logger, logging.ERROR),
182
+ )
183
+ async def update_poi(self) -> None:
184
+ try:
185
+ await self.current_page.wait_for_load_state(timeout=60000)
186
+ except PlaywrightTimeoutError:
187
+ logger.error(f"Timeout waiting for website load state: {self.current_url}")
188
+ await self.current_page.wait_for_selector("body", timeout=60000, state="visible")
189
+ # Run the bounding box javascript code to highlight the points of interest on the page
190
+ page_info = await self.current_page.evaluate(
191
+ """() => {
192
+ overwriteDefaultSelectConvergence();
193
+ return findPOIsConvergence();
194
+ }""",
195
+ )
196
+ # Get the points of interest on the page
197
+ self.poi_elements = page_info["element_descriptions"]
198
+ element_centroids = page_info["element_centroids"]
199
+ try:
200
+ # Select all iframes on the page
201
+ iframes = await self.current_page.query_selector_all("iframe")
202
+
203
+ max_iframes = 10
204
+
205
+ # Define an asynchronous function to process and filter each iframe
206
+
207
+ tasks = [asyncio.create_task(self.process_iframe(iframe)) for iframe in iframes[:max_iframes]]
208
+
209
+ results = await asyncio.gather(*tasks)
210
+
211
+ filtered_results = [result for result in results if result is not None]
212
+
213
+ iframes_pois = []
214
+ iframe_offsets = []
215
+
216
+ for poi, offset in filtered_results:
217
+ iframes_pois.append(poi)
218
+ iframe_offsets.append(offset)
219
+
220
+ # Combine the points of interest from the iframes with the main page and adjust the centroids
221
+ for index, iframe_poi in enumerate(iframes_pois):
222
+ self.poi_elements.extend(iframe_poi["element_descriptions"])
223
+ for centroid in iframe_poi["element_centroids"]:
224
+ centroid["x"] += iframe_offsets[index]["x"]
225
+ centroid["y"] += iframe_offsets[index]["y"]
226
+ centroid["left"] += iframe_offsets[index]["x"]
227
+ centroid["top"] += iframe_offsets[index]["y"]
228
+ centroid["right"] += iframe_offsets[index]["x"]
229
+ centroid["bottom"] += iframe_offsets[index]["y"]
230
+ element_centroids.extend(iframe_poi["element_centroids"])
231
+
232
+ except Exception as e:
233
+ logger.error(f"Error in finding iframes: {e}")
234
+
235
+ # Get the centroids of the points of interest
236
+ self.poi_centroids = [Point(x=xy["x"], y=xy["y"]) for xy in element_centroids]
237
+ self.bounding_boxes = [BoundingBox(**xy, label=str(i)) for i, xy in enumerate(element_centroids)]
238
+ self.pois = [
239
+ POI(info=info, element_centroid=centroid, bounding_box=bbox)
240
+ for info, centroid, bbox in zip(
241
+ self.poi_elements,
242
+ self.poi_centroids,
243
+ self.bounding_boxes,
244
+ strict=False,
245
+ )
246
+ ]
247
+
248
+ @property
249
+ def poi_text(self) -> str:
250
+ # Get all points of interest on the page as text
251
+ texts = [element_as_text(mark_id=i, **element) for i, element in enumerate(self.poi_elements)]
252
+ # Return formatted text of points of interest on page
253
+ return "\n".join([txt for txt in texts if txt])
254
+
255
+ async def screenshot(
256
+ self,
257
+ delay: float = 0.0,
258
+ quality: int = 70,
259
+ type: str = "jpeg",
260
+ scale: str = "css",
261
+ ) -> tuple[bytes, bytes]:
262
+ if delay > 0.0:
263
+ await asyncio.sleep(delay)
264
+ await self.update_poi()
265
+ old_poi_positions = [tuple(point) for point in self.poi_centroids]
266
+ img = await self.current_page.screenshot(type=type, quality=quality, scale=scale)
267
+ annotated_img = annotate_bounding_boxes(image=img, bounding_boxes=self.bounding_boxes)
268
+ # check page has not changed since the screenshot was taken
269
+ await self.update_poi()
270
+ new_poi_positions = [tuple(point) for point in self.poi_centroids]
271
+ if new_poi_positions != old_poi_positions:
272
+ # if it has changed, take another
273
+ img = await self.current_page.screenshot(type=type, quality=quality, scale=scale)
274
+ await self.update_poi()
275
+ annotated_img = annotate_bounding_boxes(image=img, bounding_boxes=self.bounding_boxes)
276
+ return img, annotated_img
277
+
278
+ async def goto(self, url: str) -> None:
279
+ await self.current_page.goto(url, wait_until="domcontentloaded")
280
+
281
+ async def reload(self) -> None:
282
+ await self.current_page.reload(wait_until="domcontentloaded")
283
+
284
+ async def click_tab(self, mark_id: int) -> None:
285
+ point: Point = self.poi_centroids[mark_id]
286
+ await self.hover(point)
287
+ await self.current_page.mouse.click(*point, button="middle")
288
+
289
+ async def click(self, mark_id: int) -> None:
290
+ point: Point = self.poi_centroids[mark_id]
291
+ await self.hover(point)
292
+ await self.current_page.mouse.click(*point)
293
+
294
+ async def enter_text(self, mark_id: int, text: str, submit: bool = False) -> None:
295
+ await self.clear_text_field(mark_id)
296
+ await self.click(mark_id)
297
+ await self.current_page.keyboard.type(text)
298
+
299
+ if submit:
300
+ await self.current_page.keyboard.press("Enter")
301
+
302
+ async def scroll(
303
+ self,
304
+ direction: Literal["up", "down", "left", "right"],
305
+ mark_id: Optional[int] = None,
306
+ ) -> None:
307
+ if mark_id is None:
308
+ point = Point(x=-1, y=-1)
309
+ max_scroll_x = self.viewport_width
310
+ max_scroll_y = self.viewport_height
311
+ else:
312
+ point: Point = self.poi_centroids[mark_id]
313
+ bbox: BoundingBox = self.bounding_boxes[mark_id]
314
+ max_scroll_x = bbox.right - bbox.left
315
+ max_scroll_y = bbox.bottom - bbox.top
316
+
317
+ await self.hover(point=point)
318
+ scroll_x = int(max_scroll_x * 0.8)
319
+ scroll_y = int(max_scroll_y * 0.8)
320
+ is_vertical = direction in ("up", "down")
321
+ reverse_scroll = direction in ("up", "left")
322
+ await self.current_page.mouse.wheel(
323
+ scroll_x * (-1 if reverse_scroll else 1) * (not is_vertical),
324
+ scroll_y * (-1 if reverse_scroll else 1) * is_vertical,
325
+ )
326
+
327
+ async def go_back(self) -> None:
328
+ # If there is no tab open then return
329
+ if not self.current_page:
330
+ return
331
+
332
+ await self.current_page.go_back(wait_until="domcontentloaded")
333
+ if self.current_page.url == "about:blank":
334
+ if not len(self.context.pages) > 1:
335
+ await self.current_page.go_forward(wait_until="domcontentloaded")
336
+ raise Exception("There is no previous page to go back to.")
337
+ await self.current_page.close()
338
+
339
+ async def hover(self, point: Point) -> None:
340
+ await self.current_page.mouse.move(*point)
341
+
342
+ async def focus(self, point: Point) -> None:
343
+ # Focus on the element on the page at point (x, y)
344
+ await self.current_page.evaluate(
345
+ """
346
+ ([x, y]) => {
347
+ const element = document.elementFromPoint(x, y);
348
+ if (element && element.focus) {
349
+ element.focus();
350
+ }
351
+ }""",
352
+ tuple(point),
353
+ )
354
+
355
+ async def get_text(self, mark_id: int) -> str:
356
+ return await self.current_page.evaluate(
357
+ """
358
+ (mark_id) => {
359
+ const element = marked_elements_convergence[mark_id];
360
+ if (element && (element.value !== undefined || element.textContent !== undefined)) {
361
+ return element.value || element.textContent;
362
+ }
363
+ return '';
364
+ }
365
+ """,
366
+ (mark_id,),
367
+ )
368
+
369
+ async def clear_text_field(self, mark_id: int) -> None:
370
+ existing_text = await self.get_text(mark_id)
371
+ if existing_text.strip():
372
+ # Clear existing text only if it exists
373
+ await self.click(mark_id)
374
+ await self.current_page.keyboard.press("Control+Home")
375
+ await self.current_page.keyboard.press("Control+Shift+End")
376
+ await self.current_page.keyboard.press("Backspace")
377
+
378
+
379
+ if __name__ == "__main__":
380
+ import json
381
+
382
+ test = """{"name": "return_value", "arguments": {'value': 'The most downloaded French speech recognition model on Hugging Face is DeepSeek-R1. Here are its evaluation metrics:\n\n- Claude-3.5-1022: MMLU 88.3, MMLU-Redux 88.9\n- GPT-4.0-5013: MMLU 87.2, MMLU-Redux 88.0\n- DeepSeek-01-3013: MMLU 88.5, MMLU-Redux 89.1\n- OpenAI-01-mini: MMLU 91.0, MMLU-Redux 88.7\n\nPlease see the attached screenshot for more details.'}}"""
383
+ test = json.loads(test)
384
+ print(test)
385
+ exit()
386
+
387
+ async def dummy_test():
388
+ async with BrowserSession(headless=False) as s:
389
+ page = await s.context.new_page()
390
+ await page.goto("http://google.co.uk")
391
+ await asyncio.sleep(5)
392
+ await page.screenshot(path="example.png")
393
+ await s.update_poi()
394
+ _, annotated_image = await s.screenshot()
395
+ with open("output.png", "wb") as f:
396
+ f.write(annotated_image)
397
+
398
+ asyncio.run(dummy_test())
src/proxy_lite/browser/find_pois.js ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ marked_elements_convergence = [];
2
+
3
+ const interactiveTags = new Set([
4
+ 'a', 'button', 'details', 'embed', 'input', 'label',
5
+ 'menu', 'menuitem', 'object', 'select', 'textarea', 'summary',
6
+ 'video', 'audio', 'option', 'iframe'
7
+ ]);
8
+
9
+ const interactiveRoles = new Set([
10
+ 'button', 'menu', 'menuitem', 'link', 'checkbox', 'radio',
11
+ 'slider', 'tab', 'tabpanel', 'textbox', 'combobox', 'grid',
12
+ 'listbox', 'option', 'progressbar', 'scrollbar', 'searchbox',
13
+ 'switch', 'tree', 'treeitem', 'spinbutton', 'tooltip',
14
+ 'a-button-inner', 'a-dropdown-button', 'click',
15
+ 'menuitemcheckbox', 'menuitemradio', 'a-button-text',
16
+ 'button-text', 'button-icon', 'button-icon-only',
17
+ 'button-text-icon-only', 'dropdown', 'combobox'
18
+ ]);
19
+
20
+ findPOIsConvergence = (input = null) => {
21
+
22
+ let rootElement = input ? input : document.documentElement;
23
+
24
+ function isScrollable(element) {
25
+ if ((input === null) && (element === document.documentElement)) {
26
+ // we can always scroll the full page
27
+ return false;
28
+ }
29
+
30
+ const style = window.getComputedStyle(element);
31
+
32
+ const hasScrollableYContent = element.scrollHeight > element.clientHeight
33
+ const overflowYScroll = style.overflowY === 'scroll' || style.overflowY === 'auto';
34
+
35
+ const hasScrollableXContent = element.scrollWidth > element.clientWidth;
36
+ const overflowXScroll = style.overflowX === 'scroll' || style.overflowX === 'auto';
37
+
38
+ return (hasScrollableYContent && overflowYScroll) || (hasScrollableXContent && overflowXScroll);
39
+ }
40
+
41
+ function getEventListeners(element) {
42
+ try {
43
+ return window.getEventListeners?.(element) || {};
44
+ } catch (e) {
45
+ return {};
46
+ }
47
+ }
48
+
49
+ function isInteractive(element) {
50
+ if (!element) return false;
51
+
52
+ return (hasInteractiveTag(element) ||
53
+ hasInteractiveAttributes(element) ||
54
+ hasInteractiveEventListeners(element)) ||
55
+ isScrollable(element);
56
+ }
57
+
58
+ function hasInteractiveTag(element) {
59
+ return interactiveTags.has(element.tagName.toLowerCase());
60
+ }
61
+
62
+ function hasInteractiveAttributes(element) {
63
+ const role = element.getAttribute('role');
64
+ const ariaRole = element.getAttribute('aria-role');
65
+ const tabIndex = element.getAttribute('tabindex');
66
+ const onAttribute = element.getAttribute('on');
67
+
68
+ if (element.getAttribute('contenteditable') === 'true') return true;
69
+ if ((role && interactiveRoles.has(role)) ||
70
+ (ariaRole && interactiveRoles.has(ariaRole))) return true;
71
+ if (tabIndex !== null && tabIndex !== '-1') return true;
72
+
73
+ // Add check for AMP's 'on' attribute that starts with 'tap:'
74
+ if (onAttribute && onAttribute.startsWith('tap:')) return true;
75
+
76
+ const hasAriaProps = element.hasAttribute('aria-expanded') ||
77
+ element.hasAttribute('aria-pressed') ||
78
+ element.hasAttribute('aria-selected') ||
79
+ element.hasAttribute('aria-checked');
80
+
81
+ return hasAriaProps;
82
+ }
83
+
84
+ function hasInteractiveEventListeners(element) {
85
+ const hasClickHandler = element.onclick !== null ||
86
+ element.getAttribute('onclick') !== null ||
87
+ element.hasAttribute('ng-click') ||
88
+ element.hasAttribute('@click') ||
89
+ element.hasAttribute('v-on:click');
90
+ if (hasClickHandler) return true;
91
+
92
+ const listeners = getEventListeners(element);
93
+ return listeners && (
94
+ listeners.click?.length > 0 ||
95
+ listeners.mousedown?.length > 0 ||
96
+ listeners.mouseup?.length > 0 ||
97
+ listeners.touchstart?.length > 0 ||
98
+ listeners.touchend?.length > 0
99
+ );
100
+ }
101
+
102
+ function calculateArea(rects) {
103
+ return rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
104
+ }
105
+
106
+ function getElementRects(element, context) {
107
+ const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
108
+ const vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
109
+
110
+ let rects = [...element.getClientRects()];
111
+
112
+ // If rects are empty (likely due to Shadow DOM), try to estimate position
113
+ if (rects.length === 0 && element.getBoundingClientRect) {
114
+ rects = [element.getBoundingClientRect()];
115
+ }
116
+
117
+ // Get iframe offset if element is in an iframe
118
+ let iframeOffset = { x: 0, y: 0 };
119
+ if (context !== document && context?.defaultView?.frameElement) {
120
+ const iframe = context.defaultView.frameElement;
121
+ if (iframe) {
122
+ const iframeRect = iframe.getBoundingClientRect();
123
+ iframeOffset = {
124
+ x: iframeRect.left,
125
+ y: iframeRect.top
126
+ };
127
+ }
128
+ }
129
+
130
+ return rects.filter(bb => {
131
+ const center_x = bb.left + bb.width / 2 + iframeOffset.x;
132
+ const center_y = bb.top + bb.height / 2 + iframeOffset.y;
133
+ const elAtCenter = context.elementFromPoint(center_x - iframeOffset.x, center_y - iframeOffset.y);
134
+
135
+ return elAtCenter === element || element.contains(elAtCenter);
136
+ }).map(bb => {
137
+ const rect = {
138
+ left: Math.max(0, bb.left + iframeOffset.x),
139
+ top: Math.max(0, bb.top + iframeOffset.y),
140
+ right: Math.min(vw, bb.right + iframeOffset.x),
141
+ bottom: Math.min(vh, bb.bottom + iframeOffset.y)
142
+ };
143
+ return {
144
+ ...rect,
145
+ width: rect.right - rect.left,
146
+ height: rect.bottom - rect.top
147
+ };
148
+ });
149
+ }
150
+
151
+ function isElementVisible(element) {
152
+ const style = window.getComputedStyle(element);
153
+ return element.offsetWidth > 0 &&
154
+ element.offsetHeight > 0 &&
155
+ style.visibility !== 'hidden' &&
156
+ style.display !== 'none';
157
+ }
158
+
159
+ function isTopElement(element) {
160
+ let doc = element.ownerDocument;
161
+ if (doc !== window.document) {
162
+ // If in an iframe's document, treat as top
163
+ return true;
164
+ }
165
+ const shadowRoot = element.getRootNode();
166
+ if (shadowRoot instanceof ShadowRoot) {
167
+ const rect = element.getBoundingClientRect();
168
+ const point = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
169
+ try {
170
+ const topEl = shadowRoot.elementFromPoint(point.x, point.y);
171
+ if (!topEl) return false;
172
+ let current = topEl;
173
+ while (current && current !== shadowRoot) {
174
+ if (current === element) return true;
175
+ current = current.parentElement;
176
+ }
177
+ return false;
178
+ } catch (e) {
179
+ return true;
180
+ }
181
+ }
182
+ const rect = element.getBoundingClientRect();
183
+ const point = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
184
+ try {
185
+ const topEl = document.elementFromPoint(point.x, point.y);
186
+ if (!topEl) return false;
187
+ let current = topEl;
188
+ while (current && current !== document.documentElement) {
189
+ if (current === element) return true;
190
+ current = current.parentElement;
191
+ }
192
+ return false;
193
+ } catch (e) {
194
+ return true;
195
+ }
196
+ }
197
+
198
+ function getVisibleText(element, marked_elements_convergence = []) {
199
+ const blockLikeDisplays = [
200
+ // Basic block elements
201
+ 'block', 'flow-root', 'inline-block',
202
+ // Lists
203
+ 'list-item',
204
+ // Table elements
205
+ 'table', 'inline-table', 'table-row', 'table-cell',
206
+ 'table-caption', 'table-header-group', 'table-footer-group',
207
+ 'table-row-group',
208
+ // Modern layouts
209
+ 'flex', 'inline-flex', 'grid', 'inline-grid'
210
+ ];
211
+
212
+ // Check if element is hidden
213
+ const style = window.getComputedStyle(element);
214
+ if (style.display === 'none' || style.visibility === 'hidden') {
215
+ return '';
216
+ }
217
+
218
+ let collectedText = [];
219
+
220
+ function isMarkedInteractive(el) {
221
+ return marked_elements_convergence.includes(el);
222
+ }
223
+
224
+ function traverse(node) {
225
+ if (
226
+ node.nodeType === Node.ELEMENT_NODE &&
227
+ node !== element &&
228
+ isMarkedInteractive(node)
229
+ ) {
230
+ return false;
231
+ }
232
+
233
+ if (node.nodeType === Node.TEXT_NODE) {
234
+ const trimmed = node.textContent.trim();
235
+ if (trimmed) {
236
+ collectedText.push(trimmed);
237
+ }
238
+ } else if (node.nodeType === Node.ELEMENT_NODE) {
239
+ // Skip noscript elements
240
+ if (node.tagName === 'NOSCRIPT') {
241
+ return true;
242
+ }
243
+
244
+ const nodeStyle = window.getComputedStyle(node);
245
+
246
+ // Skip hidden elements
247
+ if (nodeStyle.display === 'none' || nodeStyle.visibility === 'hidden') {
248
+ return true;
249
+ }
250
+
251
+ // Add newline before block elements if we have text
252
+ if (blockLikeDisplays.includes(nodeStyle.display) && collectedText.length > 0) {
253
+ collectedText.push('\n');
254
+ }
255
+
256
+ if (node.tagName === 'IMG') {
257
+ const textParts = [];
258
+ const alt = node.getAttribute('alt');
259
+ const title = node.getAttribute('title');
260
+ const ariaLabel = node.getAttribute('aria-label');
261
+ // Add more as needed (e.g., 'aria-describedby', 'data-caption', etc.)
262
+
263
+ if (alt) textParts.push(`alt="${alt}"`);
264
+ if (title) textParts.push(`title="${title}"`);
265
+ if (ariaLabel) textParts.push(`aria-label="${ariaLabel}"`);
266
+
267
+ if (textParts.length > 0) {
268
+ collectedText.push(`[img - ${textParts.join(' ')}]`);
269
+ }
270
+ return true;
271
+ }
272
+
273
+ for (const child of node.childNodes) {
274
+ const shouldContinue = traverse(child);
275
+ if (shouldContinue === false) {
276
+ return false;
277
+ }
278
+ }
279
+
280
+ // Add newline after block elements
281
+ if (blockLikeDisplays.includes(nodeStyle.display)) {
282
+ collectedText.push('\n');
283
+ }
284
+ }
285
+
286
+ return true;
287
+ }
288
+
289
+ traverse(element);
290
+
291
+ // Join text and normalize whitespace
292
+ return collectedText.join(' ').trim().replace(/\s{2,}/g, ' ').trim();
293
+ }
294
+
295
+ function extractInteractiveItems(rootElement) {
296
+ const items = [];
297
+
298
+ function processElement(element, context) {
299
+ if (!element) return;
300
+
301
+ // Recursively process elements
302
+ if (element.nodeType === Node.ELEMENT_NODE && isInteractive(element) && isElementVisible(element) && isTopElement(element)) {
303
+ const rects = getElementRects(element, context);
304
+ const area = calculateArea(rects);
305
+ items.push({
306
+ element: element,
307
+ area,
308
+ rects,
309
+ is_scrollable: isScrollable(element),
310
+ });
311
+ }
312
+
313
+ if (element.shadowRoot) {
314
+ // if it's shadow DOM, process elements in the shadow DOM
315
+ Array.from(element.shadowRoot.childNodes || []).forEach(child => {
316
+ processElement(child, element.shadowRoot);
317
+ });
318
+ }
319
+
320
+ if (element.tagName === 'SLOT') {
321
+ // Handle both assigned elements and nodes
322
+ const assigned = element.assignedNodes ? element.assignedNodes() : element.assignedElements();
323
+ assigned.forEach(child => {
324
+ processElement(child, context);
325
+ });
326
+ }
327
+ else if (element.tagName === 'IFRAME') {
328
+ try {
329
+ const iframeDoc = element.contentDocument || element.contentWindow?.document;
330
+ if (iframeDoc && iframeDoc.body) {
331
+ // Process elements inside iframe
332
+ processElement(iframeDoc.body, iframeDoc);
333
+ }
334
+ } catch (e) {
335
+ console.warn('Unable to access iframe contents:', e);
336
+ }
337
+ } else {
338
+ // if it's regular child elements, process regular child elements
339
+ Array.from(element.children || []).forEach(child => {
340
+ processElement(child, context);
341
+ });
342
+ }
343
+ }
344
+
345
+ processElement(rootElement, document);
346
+ return items;
347
+ }
348
+
349
+ if (marked_elements_convergence) {
350
+ marked_elements_convergence = [];
351
+ }
352
+ let mark_centres = [];
353
+ let marked_element_descriptions = [];
354
+ var items = extractInteractiveItems(rootElement);
355
+
356
+ // Lets create a floating border on top of these elements that will always be visible
357
+ let index = 0;
358
+ items.forEach(function (item) {
359
+ item.rects.forEach((bbox) => {
360
+ marked_elements_convergence.push(item.element);
361
+ mark_centres.push({
362
+ x: Math.round((bbox.left + bbox.right) / 2),
363
+ y: Math.round((bbox.top + bbox.bottom) / 2),
364
+ left: bbox.left,
365
+ top: bbox.top,
366
+ right: bbox.right,
367
+ bottom: bbox.bottom,
368
+ });
369
+ marked_element_descriptions.push({
370
+ tag: item.element.tagName,
371
+ text: getVisibleText(item.element),
372
+ // NOTE: all other attributes will be shown to the model when present
373
+ // TODO: incorperate child attributes, e.g. <img alt="..."> when img is a child of the link element
374
+ value: item.element.value,
375
+ placeholder: item.element.getAttribute("placeholder"),
376
+ element_type: item.element.getAttribute("type"),
377
+ aria_label: item.element.getAttribute("aria-label"),
378
+ name: item.element.getAttribute("name"),
379
+ required: item.element.getAttribute("required"),
380
+ disabled: item.element.getAttribute("disabled"),
381
+ pattern: item.element.getAttribute("pattern"),
382
+ checked: item.element.getAttribute("checked"),
383
+ minlength: item.element.getAttribute("minlength"),
384
+ maxlength: item.element.getAttribute("maxlength"),
385
+ role: item.element.getAttribute("role"),
386
+ title: item.element.getAttribute("title"),
387
+ scrollable: item.is_scrollable
388
+ });
389
+ index++;
390
+ });
391
+ });
392
+
393
+ return {
394
+ element_descriptions: marked_element_descriptions,
395
+ element_centroids: mark_centres
396
+ };
397
+ }
src/proxy_lite/cli.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from proxy_lite import Runner, RunnerConfig
8
+ from proxy_lite.logger import logger
9
+
10
+
11
+ def update_config_from_env(config: RunnerConfig) -> RunnerConfig:
12
+ if os.getenv("PROXY_LITE_API_BASE"):
13
+ config.solver.client.api_base = os.getenv("PROXY_LITE_API_BASE")
14
+ if os.getenv("PROXY_LITE_MODEL"):
15
+ config.solver.client.model_id = os.getenv("PROXY_LITE_MODEL")
16
+ return config
17
+
18
+
19
+ def do_command(args):
20
+ do_text = " ".join(args.task)
21
+ logger.info("🤖 Let me help you with that...")
22
+ # Take default config from YAML
23
+ config = RunnerConfig.from_yaml(args.config)
24
+ # Update config from environment variables
25
+ config = update_config_from_env(config)
26
+ # Update config from command-line arguments
27
+ if args.api_base:
28
+ config.solver.client.api_base = args.api_base
29
+ if args.model:
30
+ config.solver.client.model_id = args.model
31
+ if args.homepage:
32
+ config.homepage = args.homepage
33
+ if args.viewport_width:
34
+ config.viewport_width = args.viewport_width
35
+ if args.viewport_height:
36
+ config.viewport_height = args.viewport_height
37
+ o = Runner(config=config)
38
+ asyncio.run(o.run(do_text))
39
+
40
+
41
+ def main():
42
+ parser = argparse.ArgumentParser(description="Proxy-Lite")
43
+ parser.add_argument(
44
+ "task",
45
+ type=str,
46
+ help="The task you want to accomplish",
47
+ nargs="*",
48
+ )
49
+ parser.add_argument(
50
+ "--model",
51
+ type=Optional[str],
52
+ default=None,
53
+ help="The model to use.",
54
+ )
55
+ parser.add_argument(
56
+ "--api_base",
57
+ type=Optional[str],
58
+ default=None,
59
+ help="The API base URL to use.",
60
+ )
61
+ # New option for setting a homepage URL:
62
+ parser.add_argument(
63
+ "--homepage",
64
+ type=Optional[str],
65
+ default=None,
66
+ help="The homepage URL to use.",
67
+ )
68
+ # New viewport controls:
69
+ parser.add_argument(
70
+ "--viewport-width",
71
+ type=Optional[int],
72
+ default=None,
73
+ help="Viewport width in pixels.",
74
+ )
75
+ parser.add_argument(
76
+ "--viewport-height",
77
+ type=Optional[int],
78
+ default=None,
79
+ help="Viewport height in pixels.",
80
+ )
81
+ parser.add_argument(
82
+ "--config",
83
+ type=Path,
84
+ default=Path(__file__).parent / "configs/default.yaml",
85
+ help="Path to config file (default: configs/default.yaml)",
86
+ )
87
+
88
+ args = parser.parse_args()
89
+ do_command(args)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
src/proxy_lite/client.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from functools import cached_property
4
+ from typing import ClassVar, Literal, Optional, Union
5
+
6
+ import httpx
7
+ from httpx import Limits, Timeout
8
+ from openai import AsyncOpenAI
9
+ from openai.types.chat.chat_completion import (
10
+ ChatCompletion,
11
+ )
12
+ from pydantic import BaseModel
13
+
14
+ from proxy_lite.history import MessageHistory
15
+ from proxy_lite.logger import logger
16
+ from proxy_lite.serializer import (
17
+ BaseSerializer,
18
+ OpenAISerializer,
19
+ )
20
+ from proxy_lite.tools import Tool
21
+
22
+
23
+ class BaseClientConfig(BaseModel):
24
+ http_timeout: float = 50
25
+ http_concurrent_connections: int = 50
26
+
27
+
28
+ class BaseClient(BaseModel, ABC):
29
+ config: BaseClientConfig
30
+ serializer: ClassVar[BaseSerializer]
31
+
32
+ @abstractmethod
33
+ async def create_completion(
34
+ self,
35
+ messages: MessageHistory,
36
+ temperature: float = 0.7,
37
+ seed: Optional[int] = None,
38
+ tools: Optional[list[Tool]] = None,
39
+ response_format: Optional[type[BaseModel]] = None,
40
+ ) -> ChatCompletion: ...
41
+
42
+ """
43
+ Create completion from model.
44
+ Expect subclasses to adapt from various endpoints that will handle
45
+ requests differently, make sure to raise appropriate warnings.
46
+
47
+ Returns:
48
+ ChatCompletion: OpenAI ChatCompletion format for consistency
49
+ """
50
+
51
+ @classmethod
52
+ def create(cls, config: BaseClientConfig) -> "BaseClient":
53
+ supported_clients = {
54
+ "openai-azure": OpenAIClient,
55
+ "convergence": ConvergenceClient,
56
+ }
57
+ if config.name not in supported_clients:
58
+ error_message = f"Unsupported model: {config.name}."
59
+ raise ValueError(error_message)
60
+ return supported_clients[config.name](config=config)
61
+
62
+ @property
63
+ def http_client(self) -> httpx.AsyncClient:
64
+ return httpx.AsyncClient(
65
+ timeout=Timeout(self.config.http_timeout),
66
+ limits=Limits(
67
+ max_connections=self.config.http_concurrent_connections,
68
+ max_keepalive_connections=self.config.http_concurrent_connections,
69
+ ),
70
+ )
71
+
72
+
73
+ class OpenAIClientConfig(BaseClientConfig):
74
+ name: Literal["openai"] = "openai"
75
+ model_id: str = "gpt-4o"
76
+ api_key: str = os.environ["OPENAI_API_KEY"]
77
+
78
+
79
+ class OpenAIClient(BaseClient):
80
+ config: OpenAIClientConfig
81
+ serializer: ClassVar[OpenAISerializer] = OpenAISerializer()
82
+
83
+ @cached_property
84
+ def external_client(self) -> AsyncOpenAI:
85
+ return AsyncOpenAI(
86
+ api_key=self.config.api_key,
87
+ http_client=self.http_client,
88
+ )
89
+
90
+ async def create_completion(
91
+ self,
92
+ messages: MessageHistory,
93
+ temperature: float = 0.7,
94
+ seed: Optional[int] = None,
95
+ tools: Optional[list[Tool]] = None,
96
+ response_format: Optional[type[BaseModel]] = None,
97
+ ) -> ChatCompletion:
98
+ base_params = {
99
+ "model": self.config.model_id,
100
+ "messages": self.serializer.serialize_messages(messages),
101
+ "temperature": temperature,
102
+ }
103
+ optional_params = {
104
+ "seed": seed,
105
+ "tools": self.serializer.serialize_tools(tools) if tools else None,
106
+ "tool_choice": "required" if tools else None,
107
+ "response_format": {"type": "json_object"} if response_format else {"type": "text"},
108
+ }
109
+ base_params.update({k: v for k, v in optional_params.items() if v is not None})
110
+ return await self.external_client.chat.completions.create(**base_params)
111
+
112
+
113
+ class ConvergenceClientConfig(BaseClientConfig):
114
+ name: Literal["convergence"] = "convergence"
115
+ model_id: str = "convergence-ai/proxy-lite-7b"
116
+ api_base: str = "http://localhost:8000/v1"
117
+ api_key: str = "none"
118
+
119
+
120
+ class ConvergenceClient(OpenAIClient):
121
+ config: ConvergenceClientConfig
122
+ serializer: ClassVar[OpenAISerializer] = OpenAISerializer()
123
+ _model_validated: bool = False
124
+
125
+ async def _validate_model(self) -> None:
126
+ try:
127
+ await self.external_client.beta.chat.completions.parse(
128
+ model=self.config.model_id,
129
+ messages=[{"role": "user", "content": "Hello"}],
130
+ )
131
+ self._model_validated = True
132
+ logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
133
+ except Exception as e:
134
+ logger.error(f"Error retrieving model: {e}")
135
+ raise e
136
+
137
+ @cached_property
138
+ def external_client(self) -> AsyncOpenAI:
139
+ return AsyncOpenAI(
140
+ base_url=self.config.api_base,
141
+ http_client=self.http_client,
142
+ )
143
+
144
+ async def create_completion(
145
+ self,
146
+ messages: MessageHistory,
147
+ temperature: float = 0.7,
148
+ seed: Optional[int] = None,
149
+ tools: Optional[list[Tool]] = None,
150
+ response_format: Optional[type[BaseModel]] = None,
151
+ ) -> ChatCompletion:
152
+ if not self._model_validated:
153
+ await self._validate_model()
154
+ base_params = {
155
+ "model": self.config.model_id,
156
+ "messages": self.serializer.serialize_messages(messages),
157
+ "temperature": temperature,
158
+ }
159
+ optional_params = {
160
+ "seed": seed,
161
+ "tools": self.serializer.serialize_tools(tools) if tools else None,
162
+ "tool_choice": "auto" if tools else None, # vLLM does not support "required"
163
+ "response_format": response_format if response_format else {"type": "text"},
164
+ }
165
+ base_params.update({k: v for k, v in optional_params.items() if v is not None})
166
+ return await self.external_client.chat.completions.create(**base_params)
167
+
168
+
169
+ ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig]
170
+ ClientTypes = Union[OpenAIClient, ConvergenceClient]
src/proxy_lite/configs/default.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ environment:
2
+ name: webbrowser
3
+ annotate_image: true
4
+ screenshot_delay: 2.0
5
+ viewport_width: 1280
6
+ viewport_height: 1920
7
+ include_poi_text: true
8
+ headless: false
9
+ homepage: https://www.google.co.uk
10
+ solver:
11
+ name: simple
12
+ agent:
13
+ name: proxy_lite
14
+ client:
15
+ name: convergence
16
+ model_id: convergence-ai/subset-distill-tools-7b-15-02-2025
17
+ api_base: http://slurm1-a3nodeset-4-1:8002/v1
18
+ local_view: true
19
+ task_timeout: 1800
20
+ verbose: true
src/proxy_lite/environments/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from .environment_base import (
4
+ Action,
5
+ BaseEnvironment,
6
+ BaseEnvironmentConfig,
7
+ Environments,
8
+ Event,
9
+ EventType,
10
+ Observation,
11
+ )
12
+ from .webbrowser import (
13
+ WebBrowserEnvironment,
14
+ WebBrowserEnvironmentConfig,
15
+ )
16
+
17
+ EnvironmentConfigTypes = Union[*list(Environments._environment_config_registry.values())]
18
+ EnvironmentTypes = Union[*list(Environments._environment_registry.values())]
19
+
20
+
21
+ __all__ = [
22
+ "Action",
23
+ "BaseEnvironment",
24
+ "BaseEnvironmentConfig",
25
+ "EnvironmentConfigTypes",
26
+ "Environments",
27
+ "Event",
28
+ "EventType",
29
+ "Observation",
30
+ "WebBrowserEnvironment",
31
+ "WebBrowserEnvironmentConfig",
32
+ ]
src/proxy_lite/environments/environment_base.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from abc import ABC, abstractmethod
4
+ from enum import Enum
5
+ from functools import cached_property
6
+ from typing import Any, Literal, Optional, Self
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from proxy_lite.history import ToolCall
11
+ from proxy_lite.tools import Tool, ToolExecutionResponse
12
+
13
+
14
+ class EventType(str, Enum):
15
+ OBSERVATION = "observation"
16
+ ACTION = "action"
17
+ MESSAGE = "message"
18
+
19
+
20
+ class Event(BaseModel):
21
+ type: EventType
22
+
23
+
24
+ class State(BaseModel):
25
+ text: Optional[str] = None
26
+ image: Optional[str] = None # base64 encoded image
27
+ html: Optional[str] = None
28
+ tool_responses: Optional[list[ToolExecutionResponse]] = None
29
+
30
+
31
+ class Observation(Event):
32
+ type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION
33
+ state: State
34
+ terminated: bool
35
+ reward: Optional[float] = None
36
+ info: Optional[dict[str, Any]] = None
37
+
38
+
39
+ class Action(Event):
40
+ type: Literal[EventType.ACTION] = EventType.ACTION
41
+ text: Optional[str] = None
42
+ tool_calls: Optional[list[ToolCall]] = None
43
+ info: Optional[dict[str, Any]] = None
44
+
45
+
46
+ class BaseEnvironmentConfig(BaseModel): ...
47
+
48
+
49
+ class BaseEnvironment(BaseModel, ABC):
50
+ config: BaseEnvironmentConfig
51
+ logger: logging.Logger | None = None
52
+
53
+ class Config:
54
+ arbitrary_types_allowed = True
55
+
56
+ async def __aenter__(self) -> Self:
57
+ return self
58
+
59
+ async def __aexit__(self, exc_type, exc_value, traceback):
60
+ pass
61
+
62
+ @property
63
+ @abstractmethod
64
+ def info_for_user(self) -> str: ...
65
+
66
+ @cached_property
67
+ @abstractmethod
68
+ def tools(self) -> list[Tool]: ...
69
+
70
+ @abstractmethod
71
+ async def initialise(self) -> Observation: ...
72
+
73
+ @abstractmethod
74
+ async def execute_action(self, action: Action) -> Observation: ...
75
+
76
+ @abstractmethod
77
+ async def observe(self) -> Observation: ...
78
+
79
+ @abstractmethod
80
+ async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ...
81
+
82
+ async def execute_tool(self, tool_call: ToolCall) -> None:
83
+ function = tool_call.function
84
+ for tool in self.tools:
85
+ if hasattr(tool, function["name"]):
86
+ arguments = json.loads(function["arguments"])
87
+ if type(arguments) == str:
88
+ arguments = json.loads(arguments)
89
+ return await getattr(tool, function["name"])(
90
+ **arguments,
91
+ )
92
+ msg = f'No tool function with name "{function["name"]}"'
93
+ raise ValueError(msg)
94
+
95
+ async def get_info(self) -> dict[str, Any]:
96
+ return {}
97
+
98
+
99
+ class Environments:
100
+ _environment_registry: dict[str, type[BaseEnvironment]] = {}
101
+ _environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {}
102
+
103
+ @classmethod
104
+ def register_environment(cls, name: str):
105
+ """
106
+ Decorator to register an Environment class under a given name.
107
+
108
+ Example:
109
+ @Environments.register_environment("my_environment")
110
+ class MyEnvironment(BaseEnvironment):
111
+ ...
112
+ """
113
+
114
+ def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]:
115
+ cls._environment_registry[name] = env_cls
116
+ return env_cls
117
+
118
+ return decorator
119
+
120
+ @classmethod
121
+ def register_environment_config(cls, name: str):
122
+ """
123
+ Decorator to register an Environment configuration class under a given name.
124
+
125
+ Example:
126
+ @Environments.register_environment_config("my_environment")
127
+ class MyEnvironmentConfig(BaseEnvironmentConfig):
128
+ ...
129
+ """
130
+
131
+ def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]:
132
+ cls._environment_config_registry[name] = config_cls
133
+ return config_cls
134
+
135
+ return decorator
136
+
137
+ @classmethod
138
+ def get(cls, name: str) -> type[BaseEnvironment]:
139
+ """
140
+ Retrieve a registered Environment class by its name.
141
+
142
+ Raises:
143
+ ValueError: If no such environment is found.
144
+ """
145
+ try:
146
+ return cls._environment_registry[name]
147
+ except KeyError:
148
+ raise ValueError(f"Environment '{name}' not found.")
149
+
150
+ @classmethod
151
+ def get_config(cls, name: str) -> type[BaseEnvironmentConfig]:
152
+ """
153
+ Retrieve a registered Environment configuration class by its name.
154
+
155
+ Raises:
156
+ ValueError: If no such configuration is found.
157
+ """
158
+ try:
159
+ return cls._environment_config_registry[name]
160
+ except KeyError:
161
+ raise ValueError(f"Environment config for '{name}' not found.")
src/proxy_lite/environments/webbrowser.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from functools import cached_property
3
+ from typing import Any, Literal, Optional, Self
4
+
5
+ from proxy_lite.browser.browser import BrowserSession
6
+ from proxy_lite.environments.environment_base import (
7
+ Action,
8
+ BaseEnvironment,
9
+ BaseEnvironmentConfig,
10
+ Environments,
11
+ Observation,
12
+ State,
13
+ )
14
+ from proxy_lite.tools import BrowserTool, Tool, ToolExecutionResponse
15
+
16
+
17
+ @Environments.register_environment_config("webbrowser")
18
+ class WebBrowserEnvironmentConfig(BaseEnvironmentConfig):
19
+ name: Literal["webbrowser"] = "webbrowser"
20
+ homepage: str = "https://google.com"
21
+ annotate_image: bool = True
22
+ screenshot_delay: float = 1.0 # seconds
23
+ include_html: bool = True
24
+ include_poi_text: bool = True
25
+ record_pois: bool = True
26
+ viewport_width: int = 1280
27
+ viewport_height: int = 720
28
+ browserbase_timeout: int = 7200
29
+ headless: bool = True
30
+ keep_original_image: bool = False
31
+
32
+
33
+ @Environments.register_environment("webbrowser")
34
+ class WebBrowserEnvironment(BaseEnvironment):
35
+ config: WebBrowserEnvironmentConfig
36
+ browser: Optional[BrowserSession] = None
37
+ cancelled_last_action: bool = False
38
+
39
+ class Config:
40
+ arbitrary_types_allowed = True
41
+
42
+ async def __aenter__(self) -> Self:
43
+ # Initialize the BrowserSession
44
+ self.browser = self.browser_session(
45
+ viewport_width=self.config.viewport_width,
46
+ viewport_height=self.config.viewport_height,
47
+ headless=self.config.headless,
48
+ )
49
+ await self.browser.__aenter__()
50
+ # Initialize other resources if necessary
51
+ if self.cookies:
52
+ await self.browser.context.add_cookies(self.cookies)
53
+ self.logger.info("🌐 [bold blue]Browser session started.[/]")
54
+ return self
55
+
56
+ async def __aexit__(self, exc_type, exc_value, traceback):
57
+ # Clean up the BrowserSession
58
+ await self.browser.__aexit__(exc_type, exc_value, traceback)
59
+
60
+ @property
61
+ def info_for_user(self) -> str:
62
+ return "This is a web browser environment. You can navigate the web, search the web, and perform actions on the web." # noqa: E501
63
+
64
+ @cached_property
65
+ def tools(self) -> list[Tool]:
66
+ return [BrowserTool(session=self.browser)]
67
+
68
+ @cached_property
69
+ def browser_session(self) -> type[BrowserSession]:
70
+ return BrowserSession
71
+
72
+ @property
73
+ def cookies(self) -> list[dict]:
74
+ return []
75
+
76
+ async def initialise(self) -> Observation:
77
+ await self.browser.goto(self.config.homepage)
78
+ original_img, annotated_img = await self.browser.screenshot(
79
+ delay=self.config.screenshot_delay,
80
+ )
81
+
82
+ base64_image = base64.b64encode(annotated_img).decode("utf-8")
83
+
84
+ html_content = await self.browser.current_page.content() if self.config.include_html else None
85
+
86
+ info = {"url": self.browser.current_url}
87
+ if self.config.record_pois:
88
+ info["pois"] = self.browser.pois
89
+ if self.config.keep_original_image:
90
+ info["original_image"] = base64.b64encode(original_img).decode("utf-8")
91
+
92
+ return Observation(
93
+ state=State(
94
+ text=f"URL: {self.browser.current_url}"
95
+ + (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
96
+ image=base64_image,
97
+ html=html_content,
98
+ ),
99
+ terminated=False,
100
+ reward=None,
101
+ info=info,
102
+ )
103
+
104
+ async def should_perform_action(self) -> bool:
105
+ # if cancelled last action, run the action without updating POIs
106
+ if self.cancelled_last_action:
107
+ self.cancelled_last_action = False
108
+ return True
109
+
110
+ # check for page changes
111
+ old_points = [tuple(point) for point in self.browser.poi_centroids]
112
+ await self.browser.update_poi()
113
+ new_points = [tuple(point) for point in self.browser.poi_centroids]
114
+ page_changed_mid_action = old_points != new_points
115
+
116
+ # record if the last action was cancelled
117
+ if page_changed_mid_action:
118
+ self.cancelled_last_action = True
119
+ return False
120
+ return True
121
+
122
+ async def execute_action(self, action: Action) -> Observation:
123
+ responses = []
124
+ cancelled_tools_flag = False
125
+ if await self.should_perform_action():
126
+ for tool_call in action.tool_calls:
127
+ # Perform the chosen action
128
+ try:
129
+ tool_response: ToolExecutionResponse = await self.execute_tool(
130
+ tool_call,
131
+ )
132
+ tool_response.id = tool_call.id
133
+ responses.append(tool_response)
134
+ except Exception as e: # noqa: PERF203
135
+ self.logger.warning("🌐 An error occurred taking action: %s", str(e), exc_info=False)
136
+ tool_response = ToolExecutionResponse(content=str(e), id=tool_call.id)
137
+ responses.append(tool_response)
138
+ else:
139
+ self.logger.warning("🌐 Page changed since last observation, cancelling action.")
140
+ self.cancelled_last_action = True
141
+ for tool_call in action.tool_calls:
142
+ tool_response = ToolExecutionResponse(
143
+ content="The page changed before the action could be executed, instead of being ran it was cancelled.", # noqa: E501
144
+ id=tool_call.id,
145
+ )
146
+ responses.append(tool_response)
147
+ cancelled_tools_flag = True
148
+ original_img, annotated_img = await self.browser.screenshot(
149
+ delay=self.config.screenshot_delay,
150
+ )
151
+
152
+ base64_image = base64.b64encode(annotated_img).decode("utf-8")
153
+
154
+ info = {"url": self.browser.current_url, "cancelled_tools": cancelled_tools_flag}
155
+ if self.config.record_pois:
156
+ info["pois"] = self.browser.pois
157
+ if self.config.keep_original_image:
158
+ info["original_image"] = base64.b64encode(original_img).decode("utf-8")
159
+
160
+ html_content = await self.browser.current_page.content() if self.config.include_html else None
161
+ return Observation(
162
+ state=State(
163
+ text=f"URL: {self.browser.current_url}"
164
+ + (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
165
+ image=base64_image,
166
+ html=html_content,
167
+ tool_responses=responses,
168
+ ),
169
+ terminated=False,
170
+ reward=None,
171
+ info=info,
172
+ )
173
+
174
+ async def observe(self) -> Observation:
175
+ return await self.browser.observe()
176
+
177
+ async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
178
+ return {}
179
+
180
+ async def get_info(self) -> dict[str, Any]:
181
+ info = {}
182
+ return info
src/proxy_lite/history.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from collections.abc import Iterator
5
+ from enum import Enum
6
+ from typing import Any, Literal, Optional, Set, Union
7
+
8
+ from pydantic import BaseModel, Field, TypeAdapter, field_validator
9
+
10
+
11
+ class MessageLabel(str, Enum):
12
+ SYSTEM = "system"
13
+ USER_INPUT = "user_input"
14
+ SCREENSHOT = "screenshot"
15
+ AGENT_MODEL_RESPONSE = "agent_model_response"
16
+
17
+
18
+ MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
19
+ MessageLabel.SCREENSHOT: 1,
20
+ }
21
+
22
+
23
+ class MessageContent(BaseModel):
24
+ pass
25
+
26
+
27
+ class Text(MessageContent):
28
+ type: Literal["text"] = Field(default="text", init=False)
29
+ text: str
30
+
31
+
32
+ class ImageUrl(BaseModel):
33
+ url: str
34
+
35
+
36
+ class Image(MessageContent):
37
+ type: Literal["image_url"] = Field(default="image_url", init=False)
38
+ image_url: ImageUrl
39
+
40
+
41
+ class Message(BaseModel):
42
+ label: Optional[MessageLabel] = None
43
+ content: list[Union[Text, Image]] = Field(default_factory=list)
44
+
45
+ class Config:
46
+ use_enum_values = True
47
+
48
+ @property
49
+ def images(self) -> list[Image]:
50
+ return [content for content in self.content if isinstance(content, Image)]
51
+
52
+ @property
53
+ def texts(self) -> list[Text]:
54
+ return [content for content in self.content if isinstance(content, Text)]
55
+
56
+ @property
57
+ def first_image(self) -> Optional[Image]:
58
+ return self.images[0] if self.images else None
59
+
60
+ @property
61
+ def first_text(self) -> Optional[Text]:
62
+ return self.texts[0] if self.texts else None
63
+
64
+ def __len__(self):
65
+ return len(self.content)
66
+
67
+ @classmethod
68
+ def from_media(
69
+ cls,
70
+ text: Optional[str] = None,
71
+ image: Optional[bytes | str] = None,
72
+ is_base64: bool = False,
73
+ ) -> Message:
74
+ if text is not None:
75
+ text = Text(text=text)
76
+ if image is not None:
77
+ base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8")
78
+ data_url = f"data:image/jpeg;base64,{base64_image}"
79
+ image = Image(image_url=ImageUrl(url=data_url))
80
+ content = [text, image] if text is not None else [image]
81
+ else:
82
+ content = [text]
83
+ return cls(content=content)
84
+
85
+
86
+ class SystemMessage(Message):
87
+ role: Literal["system"] = Field(default="system", init=False)
88
+
89
+
90
+ class UserMessage(Message):
91
+ role: Literal["user"] = Field(default="user", init=False)
92
+
93
+
94
+ class ToolCall(BaseModel):
95
+ id: str
96
+ type: str
97
+ function: dict[str, Any]
98
+
99
+
100
+ class AssistantMessage(Message):
101
+ role: Literal["assistant"] = Field(default="assistant", init=False)
102
+ tool_calls: list[ToolCall] = Field(default_factory=list)
103
+
104
+ def model_dump(self, **kwargs):
105
+ data = super().model_dump(**kwargs)
106
+ if not self.tool_calls:
107
+ data.pop("tool_calls")
108
+ return data
109
+
110
+ @field_validator("tool_calls", mode="before")
111
+ @classmethod
112
+ def ensure_list(cls, v):
113
+ return [] if v is None else v
114
+
115
+
116
+ class ToolMessage(Message):
117
+ role: Literal["tool"] = Field(default="tool", init=False)
118
+ tool_call_id: str
119
+
120
+
121
+ MessageTypes = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
122
+ MessageAdapter = TypeAdapter(MessageTypes)
123
+
124
+
125
+ class MessageHistory(BaseModel):
126
+ messages: list[MessageTypes] = Field(default_factory=list)
127
+
128
+ def append(self, message: MessageTypes, label: Optional[str] = None):
129
+ if label is not None:
130
+ message.label = label
131
+ self.messages.append(message)
132
+
133
+ def pop(self) -> MessageTypes:
134
+ return self.messages.pop()
135
+
136
+ def extend(self, history: MessageHistory):
137
+ self.messages.extend(history.messages)
138
+
139
+ def __reversed__(self):
140
+ return MessageHistory(messages=self.messages[::-1])
141
+
142
+ def __getitem__(self, index):
143
+ return self.messages[index]
144
+
145
+ def __len__(self):
146
+ return len(self.messages)
147
+
148
+ def __iter__(self) -> Iterator[MessageTypes]:
149
+ return iter(self.messages)
150
+
151
+ def to_dict(self, exclude: Set[str] | None = None) -> list[dict]:
152
+ exclude = exclude or set()
153
+ return [message.model_dump(exclude=exclude) for message in self.messages]
154
+
155
+ def history_view(
156
+ self,
157
+ limits: dict = MAX_MESSAGES_FOR_CONTEXT_WINDOW,
158
+ ) -> MessageHistory:
159
+ """Context window management.
160
+
161
+ Filters messages in reverse order, retaining a limited number of recent screenshots and prompts.
162
+ """
163
+ label_counts = {label: 0 for label in limits}
164
+ filtered_messages = []
165
+ for message in reversed(self.messages):
166
+ if message.label in limits:
167
+ maximum_count = limits[message.label]
168
+ if label_counts[message.label] < maximum_count:
169
+ filtered_messages.append(message)
170
+ label_counts[message.label] += 1
171
+ else:
172
+ filtered_messages.append(message)
173
+ return MessageHistory(messages=reversed(filtered_messages))
174
+
175
+ def __add__(self, other: MessageHistory) -> MessageHistory:
176
+ new_history = MessageHistory()
177
+ new_history.extend(self)
178
+ new_history.extend(other)
179
+ return new_history
180
+
181
+ def __iadd__(self, other: MessageHistory) -> MessageHistory:
182
+ self.extend(other)
183
+ return self
src/proxy_lite/logger.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from typing import Literal
4
+ from uuid import uuid4
5
+
6
+ from rich.logging import RichHandler
7
+
8
+
9
+ class StructuredLogger(logging.Logger):
10
+ def _log(
11
+ self,
12
+ level,
13
+ msg,
14
+ args,
15
+ exc_info=None,
16
+ extra=None,
17
+ stack_info=False,
18
+ stacklevel=1,
19
+ ):
20
+ if extra is None:
21
+ extra = {}
22
+
23
+ json_fields = {
24
+ "logger_name": self.name,
25
+ "message": msg % args if args else msg,
26
+ }
27
+
28
+ exc_type, exc_value, exc_traceback = sys.exc_info()
29
+ if exc_type is not None:
30
+ json_fields["exception_class"] = exc_type.__name__
31
+ json_fields["exception_message"] = str(exc_value)
32
+
33
+ json_fields.update(extra)
34
+ super()._log(
35
+ level,
36
+ msg,
37
+ args,
38
+ exc_info,
39
+ {"json_fields": json_fields},
40
+ stack_info,
41
+ stacklevel + 1,
42
+ )
43
+
44
+
45
+ def create_logger(
46
+ name: str,
47
+ level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
48
+ detailed_name: bool = False,
49
+ ) -> logging.Logger:
50
+ unique_name = f"{name}-{str(uuid4())[:8]}"
51
+ logger = logging.getLogger(unique_name)
52
+ logger.setLevel(level)
53
+ handler = RichHandler(
54
+ rich_tracebacks=True,
55
+ markup=True,
56
+ show_path=False,
57
+ show_time=False,
58
+ log_time_format="[%s]",
59
+ )
60
+ if detailed_name:
61
+ handler.setFormatter(logging.Formatter("%(name)s:\n%(message)s\n------"))
62
+ else:
63
+ handler.setFormatter(logging.Formatter("%(message)s\n------"))
64
+ logger.addHandler(handler)
65
+ logger.propagate = False
66
+ return logger
67
+
68
+
69
+ # Set StructuredLogger as the default logger class
70
+ logging.setLoggerClass(StructuredLogger)
71
+
72
+ logger = logging.getLogger(__name__)
73
+ logger.setLevel(logging.INFO)
74
+ logger.propagate = True
75
+ handler = RichHandler(
76
+ rich_tracebacks=True,
77
+ markup=True,
78
+ show_path=False,
79
+ show_time=False,
80
+ )
81
+ logger.addHandler(handler)
src/proxy_lite/recorder.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import json
5
+ import os
6
+ import sys
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Any, Optional, Self
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+ from proxy_lite.environments import EnvironmentConfigTypes
14
+ from proxy_lite.environments.environment_base import Action, Observation
15
+ from proxy_lite.history import MessageHistory
16
+ from proxy_lite.solvers import SolverConfigTypes
17
+
18
+
19
+ class Run(BaseModel):
20
+ run_id: str # uuid.UUID
21
+ task: str
22
+ created_at: str # datetime.datetime
23
+ complete: bool = False
24
+ terminated_at: str | None = None # datetime.datetime
25
+ evaluation: dict[str, Any] | None = None
26
+ history: list[Observation | Action] = Field(default_factory=list)
27
+ solver_history: MessageHistory | None = None
28
+ result: str | None = None
29
+ env_info: dict[str, Any] = Field(default_factory=dict)
30
+ environment: Optional[EnvironmentConfigTypes] = None
31
+ solver: Optional[SolverConfigTypes] = None
32
+
33
+ @classmethod
34
+ def initialise(cls, task: str) -> Self:
35
+ run_id = str(uuid.uuid4())
36
+ return cls(
37
+ run_id=run_id,
38
+ task=task,
39
+ created_at=str(datetime.datetime.now(datetime.UTC)),
40
+ )
41
+
42
+ @property
43
+ def observations(self) -> list[Observation]:
44
+ return [h for h in self.history if isinstance(h, Observation)]
45
+
46
+ @property
47
+ def actions(self) -> list[Action]:
48
+ return [h for h in self.history if isinstance(h, Action)]
49
+
50
+ @property
51
+ def last_action(self) -> Action | None:
52
+ return self.actions[-1] if self.actions else None
53
+
54
+ @property
55
+ def last_observation(self) -> Observation | None:
56
+ return self.observations[-1] if self.observations else None
57
+
58
+ def record(
59
+ self,
60
+ observation: Optional[Observation] = None,
61
+ action: Optional[Action] = None,
62
+ solver_history: Optional[MessageHistory] = None,
63
+ ) -> None:
64
+ # expect only one of observation and action to be provided in order to handle ordering
65
+ if observation and action:
66
+ raise ValueError("Only one of observation and action can be provided")
67
+ if observation:
68
+ self.history.append(observation)
69
+ if action:
70
+ self.history.append(action)
71
+ if solver_history:
72
+ self.solver_history = solver_history
73
+
74
+ def terminate(self) -> None:
75
+ self.terminated_at = str(datetime.datetime.now(datetime.UTC))
76
+
77
+
78
+ class DataRecorder:
79
+ def __init__(self, local_folder: str | None = None):
80
+ self.local_folder = local_folder
81
+
82
+ def initialise_run(self, task: str) -> Run:
83
+ self.local_folder = Path(os.path.abspath(sys.path[0])) / "local_trajectories"
84
+ os.makedirs(self.local_folder, exist_ok=True)
85
+ return Run.initialise(task)
86
+
87
+ async def terminate(
88
+ self,
89
+ run: Run,
90
+ save: bool = True,
91
+ ) -> None:
92
+ run.terminate()
93
+ if save:
94
+ await self.save(run)
95
+
96
+ async def save(self, run: Run) -> None:
97
+ json_payload = run.model_dump()
98
+ with open(self.local_folder / f"{run.run_id}.json", "w") as f:
99
+ json.dump(json_payload, f)
src/proxy_lite/runner.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from typing import Any, Literal, Self
6
+
7
+ from omegaconf import OmegaConf
8
+ from pydantic import BaseModel
9
+
10
+ from proxy_lite.environments import (
11
+ Action,
12
+ BaseEnvironment,
13
+ EnvironmentConfigTypes,
14
+ Environments,
15
+ EventType,
16
+ Observation,
17
+ )
18
+ from proxy_lite.logger import create_logger
19
+ from proxy_lite.recorder import DataRecorder, Run
20
+ from proxy_lite.solvers import (
21
+ BaseSolver,
22
+ SolverConfigTypes,
23
+ Solvers,
24
+ )
25
+
26
+
27
+ @asynccontextmanager
28
+ async def async_timeout(timeout: float, task_name: str = "timeout"):
29
+ try:
30
+ async with asyncio.TaskGroup() as tg:
31
+
32
+ async def timeout_task():
33
+ await asyncio.sleep(timeout)
34
+ raise TimeoutError(
35
+ f"Operation {task_name} timed out after {timeout} seconds",
36
+ )
37
+
38
+ # Create the timeout task
39
+ timeout_handle = tg.create_task(timeout_task())
40
+
41
+ try:
42
+ yield
43
+ finally:
44
+ timeout_handle.cancel()
45
+ except* asyncio.TimeoutError as eg:
46
+ for e in eg.exceptions:
47
+ raise e
48
+ except* Exception as eg:
49
+ for e in eg.exceptions:
50
+ raise e
51
+
52
+
53
+ class RunnerConfig(BaseModel):
54
+ environment: EnvironmentConfigTypes
55
+ solver: SolverConfigTypes
56
+
57
+ save_every_step: bool = True
58
+ max_steps: int = 100
59
+ action_timeout: float = 60.0
60
+ environment_timeout: float = 30.0
61
+ task_timeout: float = 1800.0
62
+ logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
63
+ detailed_logger_name: bool = False
64
+
65
+ @classmethod
66
+ def from_dict(cls, config_dict: dict) -> Self:
67
+ conf = OmegaConf.create(config_dict)
68
+ config_dict = OmegaConf.to_container(conf, resolve=True)
69
+ return cls(**config_dict)
70
+
71
+ @classmethod
72
+ def from_yaml(cls, yaml_path: str) -> Self:
73
+ conf = OmegaConf.load(yaml_path)
74
+ config_dict = OmegaConf.to_container(conf, resolve=True)
75
+ return cls(**config_dict)
76
+
77
+
78
+ class Runner(BaseModel):
79
+ config: RunnerConfig
80
+ recorder: DataRecorder | None = None
81
+ environment: type[BaseEnvironment] | None = None
82
+ solver: type[BaseSolver] | None = None
83
+ logger: logging.Logger | None = None
84
+ _run: Run | None = None
85
+
86
+ class Config:
87
+ arbitrary_types_allowed = True
88
+
89
+ def model_post_init(self, __context: Any) -> None:
90
+ super().model_post_init(__context)
91
+ self.environment = Environments.get(self.config.environment.name)
92
+ self.solver = Solvers.get(self.config.solver.name)
93
+ self.recorder = DataRecorder()
94
+ self.logger = create_logger(
95
+ name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])",
96
+ level=self.config.logger_level,
97
+ detailed_name=self.config.detailed_logger_name,
98
+ )
99
+
100
+ async def run_generator(self, task: str) -> AsyncIterator[Run]:
101
+ async with (
102
+ async_timeout(self.config.task_timeout, "Task"),
103
+ ):
104
+ if self.config.logger_level is not None:
105
+ self.logger.setLevel(self.config.logger_level)
106
+ run = self.recorder.initialise_run(task)
107
+ run.environment = self.config.environment
108
+ run.solver = self.config.solver
109
+ self.logger.debug(f"Run intialised: {run.run_id}")
110
+ event_queue = asyncio.Queue()
111
+ async with (
112
+ self.environment(
113
+ config=self.config.environment,
114
+ logger=self.logger,
115
+ ) as environment,
116
+ self.solver(config=self.config.solver, logger=self.logger) as solver,
117
+ ):
118
+ run.env_info = await environment.get_info()
119
+ await solver.initialise(
120
+ task,
121
+ environment.tools,
122
+ environment.info_for_user,
123
+ )
124
+ self.logger.debug("Solver initialised.")
125
+ run.solver_history = solver.history
126
+ observation: Observation = await environment.initialise()
127
+ await event_queue.put(observation)
128
+ self.logger.debug("Environment initialised.")
129
+ step_count = 0
130
+ while step_count < self.config.max_steps:
131
+ event = await event_queue.get()
132
+ self.logger.debug(f"🤖 [bold purple]Processing event:[/] {event.type}")
133
+ match event.type:
134
+ case EventType.OBSERVATION:
135
+ observation: Observation = event
136
+ run.record(
137
+ observation=observation,
138
+ solver_history=solver.history,
139
+ )
140
+ async with async_timeout(
141
+ self.config.action_timeout,
142
+ "Action decision",
143
+ ):
144
+ action: Action = await solver.act(observation)
145
+ await event_queue.put(action)
146
+ case EventType.ACTION:
147
+ action: Action = event
148
+ self.logger.debug(f"Tool calls: {action.tool_calls}")
149
+ run.record(action=action, solver_history=solver.history)
150
+ run.complete = await solver.is_complete(observation)
151
+ if self.config.save_every_step:
152
+ await self.recorder.save(run)
153
+ if run.complete:
154
+ run.result = action.text
155
+ self.logger.info(f"🤖 [bold purple]Task complete.[/] ✨ \n{run.result}")
156
+ break
157
+ async with async_timeout(
158
+ self.config.environment_timeout,
159
+ "Environment response",
160
+ ):
161
+ observation: Observation = await environment.execute_action(action)
162
+ step_count += 1
163
+ await event_queue.put(observation)
164
+ yield run
165
+ if not run.complete:
166
+ self.logger.warning("🤖 [bold purple]Ran out of steps!")
167
+ await self.recorder.terminate(run, save=True)
168
+ yield run
169
+
170
+ async def run(self, task: str) -> Run:
171
+ async for run in self.run_generator(task): # noqa: B007
172
+ self._run = run
173
+ return run
174
+
175
+ def run_concurrent(self, tasks: list[str]) -> list[Run]:
176
+ async def gather_runs():
177
+ return await asyncio.gather(
178
+ *[self.run(task) for task in tasks],
179
+ return_exceptions=True,
180
+ )
181
+
182
+ return asyncio.run(gather_runs())
183
+
184
+ @property
185
+ def complete(self) -> bool:
186
+ if self._run is None:
187
+ raise RuntimeError("Run not initialised")
188
+ return self._run.complete
189
+
190
+ @property
191
+ def run_id(self) -> str:
192
+ if self._run is None:
193
+ raise RuntimeError("Run not initialised")
194
+ return self._run.run_id
195
+
196
+ @property
197
+ def run_result(self) -> str:
198
+ if self._run is None:
199
+ raise RuntimeError("Run not initialised")
200
+ return self._run.result
201
+
202
+
203
+ if __name__ == "__main__":
204
+ from proxy_lite.logger import logger
205
+
206
+ config = RunnerConfig.from_dict(
207
+ {
208
+ "environment": {
209
+ "name": "webbrowser",
210
+ "homepage": "https://www.google.com",
211
+ "viewport_width": 1920,
212
+ "viewport_height": 1080,
213
+ "screenshot_delay": 1,
214
+ "headless": False,
215
+ },
216
+ "solver": {
217
+ "name": "simple",
218
+ "agent": {
219
+ "name": "proxy_lite",
220
+ "client": {
221
+ "name": "convergence",
222
+ "model_id": "convergence-ai/all-distill-tools-7b-16-02-2025",
223
+ "api_base": "http://slurm1-a3nodeset-4-1:8009/v1",
224
+ # # "model_id": "Qwen/Qwen2.5-VL-3B-Instruct",
225
+ # # "api_base": "http://0.0.0.0:8000/v1",
226
+ },
227
+ },
228
+ },
229
+ "max_steps": 150,
230
+ "action_timeout": 1800,
231
+ "environment_timeout": 1800,
232
+ "task_timeout": 18000,
233
+ "logger_level": "DEBUG",
234
+ },
235
+ )
236
+ logger.info(f"🤖 [bold purple]Config:[/] {config}")
237
+
238
+ runner = Runner(config=config)
239
+ result = asyncio.run(
240
+ runner.run(
241
+ "Tell me the tesla stock price" # noqa: E501
242
+ )
243
+ )
244
+ print(runner.run_result)
245
+ print(runner.complete)
src/proxy_lite/serializer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from abc import ABC, abstractmethod
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from proxy_lite.history import MessageAdapter, MessageHistory
7
+ from proxy_lite.tools import Tool
8
+
9
+
10
+ class BaseSerializer(BaseModel, ABC):
11
+ """Base class for serializers.
12
+
13
+ Serializers are responsible for converting between the internal MessageHistory/Tool
14
+ objects and the external API format. Deserialise is not always possible, so raise
15
+ appropriate warnings.
16
+ """
17
+
18
+ @abstractmethod
19
+ def serialize_messages(self, message_history: MessageHistory) -> list[dict]: ...
20
+
21
+ @abstractmethod
22
+ def deserialize_messages(self, data: list[dict]) -> MessageHistory: ...
23
+
24
+ @abstractmethod
25
+ def serialize_tools(self, tools: list[Tool]) -> list[dict]: ...
26
+
27
+
28
+ class OpenAISerializer(BaseSerializer):
29
+ def serialize_messages(self, message_history: MessageHistory) -> list[dict]:
30
+ return message_history.to_dict(exclude={"label"})
31
+
32
+ def deserialize_messages(self, data: list[dict]) -> MessageHistory:
33
+ return MessageHistory(
34
+ messages=[MessageAdapter.validate_python(message) for message in data],
35
+ )
36
+
37
+ def serialize_tools(self, tools: list[Tool]) -> list[dict]:
38
+ tool_schemas = [[{"type": "function", "function": schema} for schema in tool.schema] for tool in tools]
39
+ return list(itertools.chain.from_iterable(tool_schemas))
src/proxy_lite/solvers/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Union
4
+
5
+ from .simple_solver import SimpleSolver, SimpleSolverConfig
6
+ from .solver_base import BaseSolver, BaseSolverConfig, Solvers
7
+ from .structured_solver import StructuredSolver, StructuredSolverConfig
8
+
9
+ SolverConfigTypes = Union[*Solvers._solver_config_registry.values()]
10
+ SolverTypes = Union[*Solvers._solver_registry.values()]
11
+
12
+
13
+ __all__ = [
14
+ "BaseSolver",
15
+ "BaseSolverConfig",
16
+ "SimpleSolver",
17
+ "SimpleSolverConfig",
18
+ "StructuredSolver",
19
+ "StructuredSolverConfig",
20
+ "SolverConfigTypes",
21
+ "SolverTypes",
22
+ "Solvers",
23
+ ]
src/proxy_lite/solvers/simple_solver.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E501
2
+ import json
3
+ import re
4
+ from functools import cached_property
5
+ from typing import Literal, Optional
6
+
7
+ from proxy_lite.agents import AgentConfigTypes, Agents, BaseAgent
8
+ from proxy_lite.environments.environment_base import Action, Observation
9
+ from proxy_lite.history import (
10
+ MessageHistory,
11
+ MessageLabel,
12
+ SystemMessage,
13
+ )
14
+ from proxy_lite.solvers.solver_base import BaseSolver, BaseSolverConfig, Solvers
15
+ from proxy_lite.tools import ReturnValueTool, Tool
16
+
17
+ WEB_TOOL_TURN = """The action has been attempted in the computer."""
18
+
19
+
20
+ @Solvers.register_solver_config("simple")
21
+ class SimpleSolverConfig(BaseSolverConfig):
22
+ name: Literal["simple"] = "simple"
23
+ agent: AgentConfigTypes
24
+
25
+
26
+ @Solvers.register_solver("simple")
27
+ class SimpleSolver(BaseSolver):
28
+ task: Optional[str] = None
29
+ complete: bool = False
30
+
31
+ @cached_property
32
+ def tools(self) -> list[Tool]:
33
+ return [ReturnValueTool()] + self.env_tools
34
+
35
+ @cached_property
36
+ def agent(self) -> BaseAgent:
37
+ self.logger.debug(f"Tools: {self.tools}")
38
+ return Agents.get(self.config.agent.name)(
39
+ config=self.config.agent,
40
+ env_tools=self.tools,
41
+ )
42
+
43
+ @property
44
+ def history(self) -> MessageHistory:
45
+ return MessageHistory(
46
+ messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
47
+ )
48
+
49
+ async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
50
+ self.env_tools = env_tools
51
+ self.task = task
52
+ self.agent.receive_user_message(
53
+ text=f"Task: {task}",
54
+ label=MessageLabel.USER_INPUT,
55
+ )
56
+ self.logger.debug(f"Initialised with task: {task}")
57
+
58
+ async def act(self, observation: Observation) -> Action:
59
+ self.agent.receive_user_message(
60
+ image=observation.state.image,
61
+ text=observation.state.text,
62
+ label=MessageLabel.SCREENSHOT,
63
+ is_base64=True,
64
+ )
65
+
66
+ message = await self.agent.generate_output(use_tool=True)
67
+
68
+ self.logger.debug(f"Assistant message generated: {message}")
69
+
70
+ # check tool calls for return_value
71
+ if any(tool_call.function["name"] == "return_value" for tool_call in message.tool_calls):
72
+ self.complete = True
73
+ arguments = json.loads(message.tool_calls[0].function["arguments"])
74
+ if isinstance(arguments, str):
75
+ arguments = json.loads(arguments)
76
+ return_value = arguments["value"]
77
+ return Action(tool_calls=[], text=return_value)
78
+
79
+ text_content = message.content[0].text
80
+
81
+ observation_match = re.search(r"<observation>(.*?)</observation>", text_content, re.DOTALL)
82
+ observation_content = observation_match.group(1).strip() if observation_match else ""
83
+
84
+ self.logger.info(f"🌐 [bold blue]Observation:[/] {observation_content}")
85
+
86
+ # Extract text between thinking tags if present
87
+ thinking_match = re.search(r"<thinking>(.*?)</thinking>", text_content, re.DOTALL)
88
+ thinking_content = thinking_match.group(1).strip() if thinking_match else text_content
89
+
90
+ self.logger.info(f"🤖 [bold purple]Action:[/] {thinking_content}")
91
+
92
+ return Action(tool_calls=message.tool_calls, text=text_content)
93
+
94
+ async def is_complete(self, observation: Observation) -> bool:
95
+ env_terminated = observation.terminated
96
+ return self.complete or env_terminated
src/proxy_lite/solvers/solver_base.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from functools import cached_property
4
+ from typing import Optional, Self, Type, cast
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+ from proxy_lite.environments.environment_base import Action, Observation
9
+ from proxy_lite.tools import Tool
10
+
11
+
12
+ class BaseSolverConfig(BaseModel):
13
+ pass
14
+
15
+
16
+ class BaseSolver(BaseModel, ABC):
17
+ task: Optional[str] = None
18
+ env_tools: list[Tool] = Field(default_factory=list)
19
+ config: BaseSolverConfig
20
+ logger: logging.Logger | None = None
21
+
22
+ class Config:
23
+ arbitrary_types_allowed = True
24
+
25
+ async def __aenter__(self) -> Self:
26
+ return self
27
+
28
+ async def __aexit__(self, exc_type, exc_value, traceback) -> None:
29
+ pass
30
+
31
+ @cached_property
32
+ @abstractmethod
33
+ def tools(self) -> list[Tool]: ...
34
+
35
+ @abstractmethod
36
+ async def initialise(
37
+ self,
38
+ task: str,
39
+ env_tools: list[Tool],
40
+ env_info: str,
41
+ ) -> None:
42
+ """
43
+ Initialise the solution with the given task.
44
+ """
45
+ ...
46
+
47
+ @abstractmethod
48
+ async def act(self, observation: Observation) -> Action:
49
+ """
50
+ Return an action for interacting with the environment.
51
+ """
52
+ ...
53
+
54
+ async def is_complete(self, observation: Observation) -> bool:
55
+ """
56
+ Return a boolean indicating if the task is complete.
57
+ """
58
+ return observation.terminated
59
+
60
+
61
+ class Solvers:
62
+ _solver_registry: dict[str, type[BaseSolver]] = {}
63
+ _solver_config_registry: dict[str, type[BaseSolverConfig]] = {}
64
+
65
+ @classmethod
66
+ def register_solver(cls, name: str):
67
+ """
68
+ Decorator to register a Solver class under a given name.
69
+
70
+ Example:
71
+ @Solvers.register_solver("my_solver")
72
+ class MySolver(BaseSolver):
73
+ ...
74
+ """
75
+
76
+ def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]:
77
+ cls._solver_registry[name] = solver_cls
78
+ return solver_cls
79
+
80
+ return decorator
81
+
82
+ @classmethod
83
+ def register_solver_config(cls, name: str):
84
+ """
85
+ Decorator to register a Solver configuration class under a given name.
86
+
87
+ Example:
88
+ @Solvers.register_solver_config("my_solver")
89
+ class MySolverConfig(BaseSolverConfig):
90
+ ...
91
+ """
92
+
93
+ def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]:
94
+ cls._solver_config_registry[name] = config_cls
95
+ return config_cls
96
+
97
+ return decorator
98
+
99
+ @classmethod
100
+ def get(cls, name: str) -> type[BaseSolver]:
101
+ """
102
+ Retrieve a registered Solver class by its name.
103
+
104
+ Raises:
105
+ ValueError: If no such solver is found.
106
+ """
107
+ try:
108
+ return cast(Type[BaseSolver], cls._solver_registry[name])
109
+ except KeyError:
110
+ raise ValueError(f"Solver '{name}' not found.")
111
+
112
+ @classmethod
113
+ def get_config(cls, name: str) -> type[BaseSolverConfig]:
114
+ """
115
+ Retrieve a registered Solver configuration class by its name.
116
+
117
+ Raises:
118
+ ValueError: If no such config is found.
119
+ """
120
+ try:
121
+ return cast(Type[BaseSolverConfig], cls._solver_config_registry[name])
122
+ except KeyError:
123
+ raise ValueError(f"Solver config for '{name}' not found.")
src/proxy_lite/solvers/structured_solver.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E501
2
+
3
+ from functools import cached_property
4
+ from typing import Literal, Optional
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+ from proxy_lite.agents import AgentConfigTypes, Agents, BaseAgent
9
+ from proxy_lite.environments.environment_base import Action, Observation
10
+ from proxy_lite.history import (
11
+ MessageHistory,
12
+ MessageLabel,
13
+ SystemMessage,
14
+ )
15
+ from proxy_lite.tools import Tool
16
+
17
+ from .solver_base import BaseSolver, BaseSolverConfig, Solvers
18
+
19
+ WEB_TOOL_TURN = """The browser action has been attempted. Please double check if the action was successful."""
20
+ PLAN_USER_PROMPT = "First create a high-level plan to help solve the task on the web."
21
+ ACTION_PROMPT = """Now take the most-promising next action in the browser.
22
+
23
+ Only refer to the latest web elements from the latest screenshot.
24
+
25
+ Using mark ids from older turns will lead to errors as they are no longer valid.
26
+
27
+ Only interact with elements visible on the current webpage. Do not make up numbers or elements."""
28
+ REASONING_PROMPT = """You will now follow these steps.
29
+
30
+ 1. **Make observations about the state of the webpage**:
31
+ - Consider the previous screenshot, your attempted previous action, and the current screenshot.
32
+ - Describe any changes you observe, and try to determine if the previous action succeeded.
33
+ - For example, if a form is being filled out, check whether the correct information is now displayed.
34
+
35
+ 2. **Write down any helpful facts you have gathered**:
36
+ - Describe any useful information on the webpage that might be helpful for completing the task.
37
+ - For example, if you are viewing a document, you may wish to note down any information you want to refer back to later.
38
+
39
+ 3. **Reason about the system's status**:
40
+ - Have you fully completed the task?
41
+
42
+ 4. **Select one of the following statuses**:
43
+ - "complete": if the task has been completed.
44
+ - "continue": if you are ready to continue without information or help.
45
+
46
+ 5. **Reason through next steps**:
47
+ - If the status is "continue", write down your reasoning for the next action you will take. You can only take one action at a time.
48
+ - If the status is not "continue", return an empty string.
49
+
50
+ 6. **Write a message to the user**:
51
+ - If the status is "complete", write a message to the user. If they asked a question in the task, make sure the answer is here. Otherwise, just provide other useful information about how the task went or if there was a problem in completing it.
52
+ - If the status is not "complete", set this to an empty string.
53
+
54
+ Tips:
55
+ - If you have already provided a response, don't provide it again.
56
+ - If you notice you are repeating previous actions, you're likely stuck. Try something different."""
57
+
58
+
59
+ class Reflection(BaseModel):
60
+ observation: str = Field(
61
+ ...,
62
+ description="Observation of the current browser state, including an assessment on the success of the last action (previous actions and observations are often wrong).",
63
+ )
64
+ fact_updates: list[str] = Field(
65
+ "",
66
+ description="List of new information relevant to the task that was found on the page, ignore input fields holding content you wrote.",
67
+ )
68
+ status_reasoning: str = Field(
69
+ ...,
70
+ description="Reasoning about the current state of the task.",
71
+ )
72
+ status: Literal["complete", "continue"] = Field(
73
+ ...,
74
+ description="Choose a system status based on your status reasoning.",
75
+ )
76
+ next_step_reasoning: str = Field(
77
+ ...,
78
+ description='If status is "continue", reason through the next action you will be taking (do not repeat actions over and over). Otherwise set to "".',
79
+ )
80
+ ending_message: str = Field(
81
+ ...,
82
+ description="If status is 'complete', write a message to the user. If they asked a question in the task, make sure the answer is here. Otherwise, just provide other useful information about how the task went or if there was a problem in completing it. If status is 'continue', set to ''.",
83
+ )
84
+
85
+
86
+ @Solvers.register_solver_config("structured")
87
+ class StructuredSolverConfig(BaseSolverConfig):
88
+ name: Literal["structured"] = "structured"
89
+ agent: AgentConfigTypes
90
+ start_with_plan: bool = True
91
+
92
+
93
+ @Solvers.register_solver("structured")
94
+ class StructuredSolver(BaseSolver):
95
+ task: Optional[str] = None
96
+ complete: bool = False
97
+
98
+ @cached_property
99
+ def tools(self) -> list[Tool]:
100
+ return self.env_tools
101
+
102
+ @cached_property
103
+ def local_tools(self) -> list[Tool]:
104
+ if self.sandbox:
105
+ return self.sandbox.tools
106
+ return []
107
+
108
+ @cached_property
109
+ def agent(self) -> BaseAgent:
110
+ self.logger.debug(f"Tools: {self.tools}")
111
+ return Agents.get(self.config.agent.name)(
112
+ config=self.config.agent,
113
+ env_tools=self.tools,
114
+ )
115
+
116
+ @property
117
+ def history(self) -> MessageHistory:
118
+ return MessageHistory(
119
+ messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
120
+ )
121
+
122
+ async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
123
+ self.env_tools = env_tools
124
+ self.agent.receive_user_message(
125
+ text=env_info,
126
+ label=MessageLabel.USER_INPUT,
127
+ )
128
+ self.task = task
129
+ self.agent.receive_user_message(
130
+ text=f"Task: {task}",
131
+ label=MessageLabel.USER_INPUT,
132
+ )
133
+ if self.config.start_with_plan:
134
+ self.agent.receive_user_message(text=PLAN_USER_PROMPT, label=MessageLabel.PLAN)
135
+ await self.agent.generate_output(use_tool=False)
136
+
137
+ async def act(self, observation: Observation) -> Action:
138
+ if observation.state.tool_responses:
139
+ for tool_response in observation.state.tool_responses:
140
+ await self.agent.receive_tool_message(
141
+ text=f"{WEB_TOOL_TURN}\n{tool_response.content}",
142
+ tool_id=tool_response.id,
143
+ label=MessageLabel.TOOL_RESULT_INDUCTION,
144
+ )
145
+
146
+ self.agent.receive_user_message(
147
+ image=observation.state.image,
148
+ text=observation.state.text,
149
+ label=MessageLabel.SCREENSHOT,
150
+ is_base64=True,
151
+ )
152
+
153
+ self.agent.receive_user_message(
154
+ text=REASONING_PROMPT,
155
+ label=MessageLabel.REASONING_INDUCTION,
156
+ )
157
+
158
+ message = await self.agent.generate_structured_output(model=Reflection)
159
+ self.logger.info(f"🌐 [bold blue]Observation:[/] {message.observation}")
160
+
161
+ if message.status == "complete":
162
+ self.complete = True
163
+ return Action(tool_calls=[], text=message.ending_message)
164
+
165
+ next_step = message.next_step_reasoning
166
+
167
+ self.agent.receive_user_message(
168
+ text=ACTION_PROMPT,
169
+ label=MessageLabel.ACTION,
170
+ is_base64=True,
171
+ )
172
+ message = await self.agent.generate_output(use_tool=True)
173
+
174
+ return Action(tool_calls=message.tool_calls, text=next_step)
175
+
176
+ async def is_complete(self, observation: Observation) -> bool:
177
+ env_terminated = observation.terminated
178
+ return self.complete or env_terminated
src/proxy_lite/tools/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .browser_tool import BrowserTool
2
+ from .return_tool import ReturnValueTool
3
+ from .tool_base import Tool, ToolExecutionResponse, attach_param_schema
4
+
5
+ __all__ = ["BrowserTool", "ReturnValueTool", "Tool", "ToolExecutionResponse", "attach_param_schema"]
src/proxy_lite/tools/browser_tool.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from contextlib import AsyncExitStack
3
+ from typing import List, Literal, Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ from proxy_lite.browser.browser import BrowserSession
8
+ from proxy_lite.logger import logger
9
+
10
+ from .tool_base import Tool, ToolExecutionResponse, attach_param_schema
11
+
12
+ SELF_CONTAINED_TAGS = [
13
+ # many of these are non-interactive but keeping them anyway
14
+ "area",
15
+ "base",
16
+ "br",
17
+ "col",
18
+ "embed",
19
+ "hr",
20
+ "img",
21
+ "input",
22
+ "link",
23
+ "meta",
24
+ "param",
25
+ "source",
26
+ "track",
27
+ "wbr",
28
+ ]
29
+
30
+
31
+ def element_as_text(
32
+ mark_id: int,
33
+ tag: Optional[str] = None,
34
+ text: Optional[str] = None,
35
+ **raw_attributes,
36
+ ) -> str:
37
+ """Return a text representation of all elements on the page"""
38
+ attributes = []
39
+ for k, v in raw_attributes.items():
40
+ if v is None:
41
+ continue
42
+ if isinstance(v, bool):
43
+ if v:
44
+ attributes.append(k)
45
+ # we ignore False bool attributes
46
+ else:
47
+ v = str(v)
48
+ if len(v) > 2500:
49
+ v = v[: 2500 - 1] + "…"
50
+ attributes.append(f'{k}="{v}"')
51
+ attributes = " ".join(attributes)
52
+ attributes = (" " + attributes).rstrip()
53
+ tag = tag.lower()
54
+ if text is None:
55
+ text = ""
56
+ if len(text) > 2500:
57
+ text = text[: 2500 - 1] + "…"
58
+ if tag in SELF_CONTAINED_TAGS:
59
+ if text:
60
+ logger.warning(
61
+ f"Got self-contained element '{tag}' which contained text '{text}'.",
62
+ )
63
+ else:
64
+ return f"<{tag} id={mark_id}{attributes}/>"
65
+ return f"<{tag} id={mark_id}{attributes}>{text}</{tag}>"
66
+
67
+
68
+ class GotoParams(BaseModel):
69
+ url: str = Field(..., description="The web address to visit. Must be a valid URL.")
70
+
71
+
72
+ class GoogleSearchParams(BaseModel):
73
+ query_plan: str = Field(
74
+ ...,
75
+ description="Plan out the query you will make. Re-write queries in a way that will yield the best results.",
76
+ )
77
+ query: str = Field(..., description="The Google search to perform.")
78
+
79
+
80
+ class ClickParams(BaseModel):
81
+ mark_id: int = Field(..., description="Element Mark ID.")
82
+
83
+
84
+ class TypeEntry(BaseModel):
85
+ mark_id: int = Field(..., description="Element Mark ID.")
86
+ content: str = Field(..., description="The text to type into the element.")
87
+
88
+
89
+ class TypeParams(BaseModel):
90
+ entries: List[TypeEntry] = Field(
91
+ ...,
92
+ description="A list of elements and contents to type.",
93
+ )
94
+ submit: bool = Field(
95
+ ...,
96
+ description='Whether to press the "Enter" key after typing in the last entry.',
97
+ )
98
+
99
+
100
+ class ScrollParams(BaseModel):
101
+ direction: Literal["up", "down", "left", "right"] = Field(
102
+ ...,
103
+ description='Direction to scroll. Must be one of "up", "down", "left" or "right".',
104
+ )
105
+ mark_id: int = Field(
106
+ ...,
107
+ description="What to scroll. Use -1 to scroll the whole page otherwise give the mark ID of an element that is `scrollable`.", # noqa: E501
108
+ )
109
+
110
+
111
+ class BackParams(BaseModel):
112
+ pass
113
+
114
+
115
+ class WaitParams(BaseModel):
116
+ pass
117
+
118
+
119
+ class ReloadParams(BaseModel):
120
+ pass
121
+
122
+
123
+ class DoNothingParams(BaseModel):
124
+ pass
125
+
126
+
127
+ class BrowserTool(Tool):
128
+ def __init__(self, session: BrowserSession) -> None:
129
+ super().__init__()
130
+ self.browser = session
131
+
132
+ async def __aenter__(self):
133
+ self._exit_stack = AsyncExitStack()
134
+ await self._exit_stack.enter_async_context(self.browser)
135
+ return self
136
+
137
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
138
+ await self._exit_stack.aclose()
139
+
140
+ @property
141
+ def poi_text(self) -> str:
142
+ # Get all points of interest on the page as text
143
+ texts = [element_as_text(mark_id=i, **element) for i, element in enumerate(self.browser.poi_elements)]
144
+ # Return formatted text of points of interest on page
145
+ return "\n".join([txt for txt in texts if txt])
146
+
147
+ @attach_param_schema(GotoParams)
148
+ async def goto(self, url: str) -> ToolExecutionResponse:
149
+ """Go directly to a specific web url. Specify the exact URL."""
150
+ await self.browser.goto(url)
151
+ return ToolExecutionResponse()
152
+
153
+ @attach_param_schema(GoogleSearchParams)
154
+ async def google_search(self, query_plan: str, query: str) -> ToolExecutionResponse:
155
+ """Perform a generic web search using Google.
156
+ Results may not be relevant. If you see poor results, you can try another query.
157
+ """
158
+ url = f"https://www.google.com/search?q={query}"
159
+ await self.browser.goto(url)
160
+ return ToolExecutionResponse()
161
+
162
+ @attach_param_schema(ClickParams)
163
+ async def click(self, mark_id: int) -> ToolExecutionResponse:
164
+ """Click on an element of the page."""
165
+ await self.browser.click(mark_id=mark_id)
166
+ return ToolExecutionResponse()
167
+
168
+ @attach_param_schema(TypeParams)
169
+ async def type(self, entries: List[dict], submit: bool) -> ToolExecutionResponse:
170
+ """Type text.
171
+ You can type into one or more elements.
172
+ Note that the text inside an element is cleared before typing.
173
+ """
174
+ for i, entry_dict in enumerate(entries):
175
+ entry = TypeEntry(**entry_dict)
176
+ last_entry = i == len(entries) - 1
177
+ old_poi_positions = [tuple(point) for point in self.browser.poi_centroids]
178
+ await self.browser.enter_text(
179
+ mark_id=entry.mark_id,
180
+ text=entry.content,
181
+ submit=submit and last_entry,
182
+ )
183
+ await self.browser.update_poi()
184
+ new_poi_positions = [tuple(point) for point in self.browser.poi_centroids]
185
+ if not last_entry and old_poi_positions != new_poi_positions:
186
+ logger.error(
187
+ "POI positions changed mid-typing, cancelling future type entries.",
188
+ )
189
+ break
190
+ return ToolExecutionResponse()
191
+
192
+ @attach_param_schema(ScrollParams)
193
+ async def scroll(self, direction: str, mark_id: int) -> ToolExecutionResponse:
194
+ """Scroll the page (or a scrollable element) up, down, left or right."""
195
+ if mark_id == -1:
196
+ mark_id = None
197
+ await self.browser.scroll(direction=direction, mark_id=mark_id)
198
+ return ToolExecutionResponse()
199
+
200
+ @attach_param_schema(BackParams)
201
+ async def back(self) -> ToolExecutionResponse:
202
+ """Go back to the previous page."""
203
+ await self.browser.go_back()
204
+ return ToolExecutionResponse()
205
+
206
+ @attach_param_schema(WaitParams)
207
+ async def wait(self) -> ToolExecutionResponse:
208
+ """Wait three seconds. Useful when the page appears to still be loading, or if there are any unfinished webpage processes.""" # noqa: E501
209
+ await asyncio.sleep(3)
210
+ return ToolExecutionResponse()
211
+
212
+ @attach_param_schema(ReloadParams)
213
+ async def reload(self) -> ToolExecutionResponse:
214
+ """Reload the current page. Useful when the page seems unresponsive, broken, outdated, or if you want to reset the page to its initial state.""" # noqa: E501
215
+ await self.browser.reload()
216
+ return ToolExecutionResponse()
217
+
218
+ @attach_param_schema(DoNothingParams)
219
+ async def do_nothing_tool(self) -> ToolExecutionResponse:
220
+ """Do nothing. Use this if you have no need for the browser at this time."""
221
+ return ToolExecutionResponse()
src/proxy_lite/tools/return_tool.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+ from proxy_lite.tools.tool_base import Tool, attach_param_schema
4
+
5
+
6
+ class ReturnValueParams(BaseModel):
7
+ value: str = Field(description="The value to return to the user.")
8
+
9
+
10
+ class ReturnValueTool(Tool):
11
+ def __init__(self):
12
+ pass
13
+
14
+ @attach_param_schema(ReturnValueParams)
15
+ def return_value(self, value: str):
16
+ """Return a value to the user. Use this tool when you have finished the task in order to provide any information the user has requested."""
17
+ print(value)
src/proxy_lite/tools/tool_base.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from functools import cached_property, wraps
3
+ from typing import Any, Callable, Optional
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class Tool:
9
+ async def __aenter__(self):
10
+ pass
11
+
12
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
13
+ pass
14
+
15
+ @cached_property
16
+ def schema(self) -> list[dict[str, Any]]:
17
+ schema = []
18
+ for name, method in self.__class__.__dict__.items():
19
+ # If function is not callable and isn't decorated using attach_param_schema
20
+ if not isinstance(method, Callable) or not hasattr(method, "param_model"):
21
+ continue
22
+
23
+ docstring = inspect.getdoc(method)
24
+ if not docstring:
25
+ raise ValueError(f"The tool function '{name}' is missing a docstring.")
26
+ # Handle multi-line docstirngs
27
+ description = " ".join(line.strip() for line in docstring.split("\n"))
28
+
29
+ tool_json = {
30
+ "name": name,
31
+ "description": description,
32
+ "parameters": method.param_model.model_json_schema(),
33
+ }
34
+ schema.append(tool_json)
35
+ return schema
36
+
37
+
38
+ def attach_param_schema(param_model: type[BaseModel]):
39
+ def decorator(func: Callable) -> Callable:
40
+ @wraps(func)
41
+ def wrapper(self, **kwargs):
42
+ # Throw an error if there's a mismatch between the function parameters and pydantic model's fields.
43
+ validated_params = param_model(**kwargs)
44
+ return func(self, **validated_params.model_dump())
45
+
46
+ wrapper.param_model = param_model
47
+ return wrapper
48
+
49
+ return decorator
50
+
51
+
52
+ class ToolExecutionResponse(BaseModel):
53
+ content: Optional[str] = None
54
+ id: Optional[str] = None