Spaces:
Running
Running
setup
Browse files- .gitignore +3 -0
- Makefile +7 -0
- README.md +75 -2
- pyproject.toml +53 -0
- src/proxy_lite/__init__.py +3 -0
- src/proxy_lite/agents/__init__.py +21 -0
- src/proxy_lite/agents/agent_base.py +238 -0
- src/proxy_lite/agents/browser_agent.py +133 -0
- src/proxy_lite/agents/proxy_lite_agent.py +54 -0
- src/proxy_lite/browser/__init__.py +0 -0
- src/proxy_lite/browser/add_custom_select.js +123 -0
- src/proxy_lite/browser/bounding_boxes.py +210 -0
- src/proxy_lite/browser/browser.py +398 -0
- src/proxy_lite/browser/find_pois.js +397 -0
- src/proxy_lite/cli.py +93 -0
- src/proxy_lite/client.py +170 -0
- src/proxy_lite/configs/default.yaml +20 -0
- src/proxy_lite/environments/__init__.py +32 -0
- src/proxy_lite/environments/environment_base.py +161 -0
- src/proxy_lite/environments/webbrowser.py +182 -0
- src/proxy_lite/history.py +183 -0
- src/proxy_lite/logger.py +81 -0
- src/proxy_lite/recorder.py +99 -0
- src/proxy_lite/runner.py +245 -0
- src/proxy_lite/serializer.py +39 -0
- src/proxy_lite/solvers/__init__.py +23 -0
- src/proxy_lite/solvers/simple_solver.py +96 -0
- src/proxy_lite/solvers/solver_base.py +123 -0
- src/proxy_lite/solvers/structured_solver.py +178 -0
- src/proxy_lite/tools/__init__.py +5 -0
- src/proxy_lite/tools/browser_tool.py +221 -0
- src/proxy_lite/tools/return_tool.py +17 -0
- src/proxy_lite/tools/tool_base.py +54 -0
.gitignore
CHANGED
@@ -169,3 +169,6 @@ cython_debug/
|
|
169 |
|
170 |
# PyPI configuration file
|
171 |
.pypirc
|
|
|
|
|
|
|
|
169 |
|
170 |
# PyPI configuration file
|
171 |
.pypirc
|
172 |
+
|
173 |
+
logs/
|
174 |
+
local_trajectories/
|
Makefile
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: proxy
|
2 |
+
|
3 |
+
proxy:
|
4 |
+
uv venv --python 3.11 --python-preference managed
|
5 |
+
uv sync
|
6 |
+
uv pip install -e .
|
7 |
+
playwright install
|
README.md
CHANGED
@@ -1,2 +1,75 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
|
3 |
+
A mini, open-weights version of our Proxy assistant.
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
---
|
8 |
+
|
9 |
+
## Getting Started
|
10 |
+
|
11 |
+
### Installation
|
12 |
+
|
13 |
+
Clone the repository:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
git clone https://github.com/convergence-ai/proxy-lite.git
|
17 |
+
```
|
18 |
+
|
19 |
+
Set-up the environment with:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
make proxy
|
23 |
+
```
|
24 |
+
|
25 |
+
Or do it manually:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
uv venv --python 3.11 --python-preference managed
|
29 |
+
uv sync
|
30 |
+
uv pip install -e .
|
31 |
+
playwright install
|
32 |
+
```
|
33 |
+
|
34 |
+
|
35 |
+
### Usage
|
36 |
+
|
37 |
+
```bash
|
38 |
+
proxy --help
|
39 |
+
```
|
40 |
+
|
41 |
+
You can directly run the proxy with:
|
42 |
+
|
43 |
+
```bash
|
44 |
+
proxy "Book a table for 2 at an Italian restaurant in Kings Cross tonight at 7pm."
|
45 |
+
```
|
46 |
+
|
47 |
+
|
48 |
+
### Proxy-Lite Endpoint
|
49 |
+
|
50 |
+
By default, Proxy-Lite will point to an endpoint set up on HuggingFace spaces. This is a demo endpoint and is not suitable for production use; it may be very slow when under heavy load.
|
51 |
+
|
52 |
+
We recommend hosting your own endpoint with vLLM, you can use the following command:
|
53 |
+
|
54 |
+
```bash
|
55 |
+
vllm serve --model convergence-ai/proxy-lite-7b \
|
56 |
+
--trust-remote-code \
|
57 |
+
--enable-auto-tool-choice \
|
58 |
+
--tool-call-parser hermes \
|
59 |
+
--port 8008 \
|
60 |
+
```
|
61 |
+
|
62 |
+
You can set the `api_base` to point to your local endpoint when calling Proxy-Lite:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
proxy --api-base http://localhost:8008/v1 "Book a table...
|
66 |
+
```
|
67 |
+
or by setting the environment variable:
|
68 |
+
|
69 |
+
```bash
|
70 |
+
export PROXY_LITE_API_BASE=http://localhost:8008/v1
|
71 |
+
```
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
pyproject.toml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "proxy-lite"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Proxy Lite - A mini, open-weights, version of the Convergence AI Proxy assistant."
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.11"
|
7 |
+
dependencies = [
|
8 |
+
"omegaconf>=2.3.0",
|
9 |
+
"openai>=1.61.1",
|
10 |
+
"opencv-python>=4.11.0.86",
|
11 |
+
"playwright-stealth>=1.0.6",
|
12 |
+
"playwright>=1.50.0",
|
13 |
+
"pydantic>=2.10.6",
|
14 |
+
"rich>=13.9.4",
|
15 |
+
"setuptools>=75.8.0",
|
16 |
+
"tenacity>=9.0.0",
|
17 |
+
"torch>=2.6.0",
|
18 |
+
"torchvision>=0.21.0",
|
19 |
+
]
|
20 |
+
|
21 |
+
[project.scripts]
|
22 |
+
proxy = "proxy_lite.cli:main"
|
23 |
+
|
24 |
+
[build-system]
|
25 |
+
requires = ["setuptools"]
|
26 |
+
build-backend = "setuptools.build_meta"
|
27 |
+
|
28 |
+
[tool.setuptools]
|
29 |
+
packages = { find = { where = ["src"] } }
|
30 |
+
|
31 |
+
[tool.setuptools.package-data]
|
32 |
+
proxy_lite = ["**/*.json"]
|
33 |
+
|
34 |
+
[tool.ruff]
|
35 |
+
line-length = 120
|
36 |
+
|
37 |
+
[tool.ruff.lint]
|
38 |
+
select = ["E", "F", "B", "I", "SIM"]
|
39 |
+
ignore = [
|
40 |
+
"B028",
|
41 |
+
"E722", # ignore bare except
|
42 |
+
"B904", # ignore raise from requirement
|
43 |
+
"FA102",
|
44 |
+
]
|
45 |
+
[tool.ruff.lint.flake8-bugbear]
|
46 |
+
|
47 |
+
extend-immutable-calls = [
|
48 |
+
"fastapi.Depends",
|
49 |
+
"fastapi.params.Depends",
|
50 |
+
"fastapi.Query",
|
51 |
+
"fastapi.params.Query",
|
52 |
+
]
|
53 |
+
|
src/proxy_lite/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .runner import Runner, RunnerConfig
|
2 |
+
|
3 |
+
__all__ = ["Runner", "RunnerConfig"]
|
src/proxy_lite/agents/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from .agent_base import Agents, BaseAgent, BaseAgentConfig
|
4 |
+
from .browser_agent import BrowserAgent, BrowserAgentConfig
|
5 |
+
from .proxy_lite_agent import ProxyLiteAgent, ProxyLiteAgentConfig
|
6 |
+
|
7 |
+
AgentTypes = Union[*list(Agents._agent_registry.values())]
|
8 |
+
AgentConfigTypes = Union[*list(Agents._agent_config_registry.values())]
|
9 |
+
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
"AgentConfigTypes",
|
13 |
+
"AgentTypes",
|
14 |
+
"Agents",
|
15 |
+
"BaseAgent",
|
16 |
+
"BaseAgentConfig",
|
17 |
+
"BrowserAgent",
|
18 |
+
"BrowserAgentConfig",
|
19 |
+
"ProxyLiteAgent",
|
20 |
+
"ProxyLiteAgentConfig",
|
21 |
+
]
|
src/proxy_lite/agents/agent_base.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from contextlib import AsyncExitStack
|
5 |
+
from functools import cached_property
|
6 |
+
from typing import Any, Optional, Type, cast
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
|
10 |
+
|
11 |
+
from proxy_lite.client import BaseClient, ClientConfigTypes, OpenAIClientConfig
|
12 |
+
from proxy_lite.history import (
|
13 |
+
AssistantMessage,
|
14 |
+
MessageHistory,
|
15 |
+
MessageLabel,
|
16 |
+
SystemMessage,
|
17 |
+
Text,
|
18 |
+
ToolCall,
|
19 |
+
ToolMessage,
|
20 |
+
UserMessage,
|
21 |
+
)
|
22 |
+
from proxy_lite.logger import logger
|
23 |
+
from proxy_lite.tools import Tool
|
24 |
+
|
25 |
+
# if TYPE_CHECKING:
|
26 |
+
# from proxy_lite.tools import Tool
|
27 |
+
|
28 |
+
|
29 |
+
class BaseAgentConfig(BaseModel):
|
30 |
+
client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig)
|
31 |
+
history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict())
|
32 |
+
history_messages_include: Optional[dict[MessageLabel, int]] = Field(
|
33 |
+
default=None,
|
34 |
+
description="If set, overrides history_messages_limit by setting all message types to 0 except those specified",
|
35 |
+
)
|
36 |
+
|
37 |
+
def model_post_init(self, __context: Any) -> None:
|
38 |
+
if self.history_messages_include is not None:
|
39 |
+
self.history_messages_limit = {label: 0 for label in MessageLabel}
|
40 |
+
self.history_messages_limit.update(self.history_messages_include)
|
41 |
+
|
42 |
+
|
43 |
+
class BaseAgent(BaseModel, ABC):
|
44 |
+
config: BaseAgentConfig
|
45 |
+
temperature: float = Field(default=0.7, ge=0, le=2)
|
46 |
+
history: MessageHistory = Field(default_factory=MessageHistory)
|
47 |
+
client: Optional[BaseClient] = None
|
48 |
+
env_tools: list[Tool] = Field(default_factory=list)
|
49 |
+
task: Optional[str] = Field(default=None)
|
50 |
+
seed: Optional[int] = Field(default=None)
|
51 |
+
|
52 |
+
class Config:
|
53 |
+
arbitrary_types_allowed = True
|
54 |
+
|
55 |
+
def __init__(self, **data) -> None:
|
56 |
+
super().__init__(**data)
|
57 |
+
self._exit_stack = AsyncExitStack()
|
58 |
+
self._tools_init_task = None
|
59 |
+
|
60 |
+
def model_post_init(self, __context: Any) -> None:
|
61 |
+
super().model_post_init(__context)
|
62 |
+
self.client = BaseClient.create(self.config.client)
|
63 |
+
|
64 |
+
@property
|
65 |
+
@abstractmethod
|
66 |
+
def system_prompt(self) -> str: ...
|
67 |
+
|
68 |
+
@cached_property
|
69 |
+
@abstractmethod
|
70 |
+
def tools(self) -> list[Tool]: ...
|
71 |
+
|
72 |
+
@cached_property
|
73 |
+
def tool_descriptions(self) -> str:
|
74 |
+
tool_descriptions = []
|
75 |
+
for tool in self.tools:
|
76 |
+
func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema)
|
77 |
+
tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else ""
|
78 |
+
tool_descriptions.append(f"{tool_title}{func_descriptions}")
|
79 |
+
return "\n\n".join(tool_descriptions)
|
80 |
+
|
81 |
+
async def get_history_view(self) -> MessageHistory:
|
82 |
+
return MessageHistory(
|
83 |
+
messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
|
84 |
+
) + self.history.history_view(
|
85 |
+
limits=self.config.history_messages_limit,
|
86 |
+
)
|
87 |
+
|
88 |
+
@retry(
|
89 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
90 |
+
stop=stop_after_attempt(3),
|
91 |
+
reraise=True,
|
92 |
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
93 |
+
)
|
94 |
+
async def generate_output(
|
95 |
+
self,
|
96 |
+
use_tool: bool = False,
|
97 |
+
response_format: Optional[type[BaseModel]] = None,
|
98 |
+
append_assistant_message: bool = True,
|
99 |
+
) -> AssistantMessage:
|
100 |
+
messages: MessageHistory = await self.get_history_view()
|
101 |
+
response_content = (
|
102 |
+
await self.client.create_completion(
|
103 |
+
messages=messages,
|
104 |
+
temperature=self.temperature,
|
105 |
+
seed=self.seed,
|
106 |
+
response_format=response_format,
|
107 |
+
tools=self.tools if use_tool else None,
|
108 |
+
)
|
109 |
+
).model_dump()
|
110 |
+
response_content = response_content["choices"][0]["message"]
|
111 |
+
assistant_message = AssistantMessage(
|
112 |
+
role=response_content["role"],
|
113 |
+
content=[Text(text=response_content["content"])] if response_content["content"] else [],
|
114 |
+
tool_calls=response_content["tool_calls"],
|
115 |
+
)
|
116 |
+
if append_assistant_message:
|
117 |
+
self.history.append(message=assistant_message, label=self.message_label)
|
118 |
+
return assistant_message
|
119 |
+
|
120 |
+
def receive_user_message(
|
121 |
+
self,
|
122 |
+
text: Optional[str] = None,
|
123 |
+
image: list[bytes] = None,
|
124 |
+
label: MessageLabel = None,
|
125 |
+
is_base64: bool = False,
|
126 |
+
) -> None:
|
127 |
+
message = UserMessage.from_media(
|
128 |
+
text=text,
|
129 |
+
image=image,
|
130 |
+
is_base64=is_base64,
|
131 |
+
)
|
132 |
+
self.history.append(message=message, label=label)
|
133 |
+
|
134 |
+
def receive_system_message(
|
135 |
+
self,
|
136 |
+
text: Optional[str] = None,
|
137 |
+
label: MessageLabel = None,
|
138 |
+
) -> None:
|
139 |
+
message = SystemMessage.from_media(text=text)
|
140 |
+
self.history.append(message=message, label=label)
|
141 |
+
|
142 |
+
def receive_assistant_message(
|
143 |
+
self,
|
144 |
+
content: Optional[str] = None,
|
145 |
+
tool_calls: Optional[list[ToolCall]] = None,
|
146 |
+
label: MessageLabel = None,
|
147 |
+
) -> None:
|
148 |
+
message = AssistantMessage(
|
149 |
+
content=[Text(text=content)] if content else [],
|
150 |
+
tool_calls=tool_calls,
|
151 |
+
)
|
152 |
+
self.history.append(message=message, label=label)
|
153 |
+
|
154 |
+
async def use_tool(self, tool_call: ToolCall):
|
155 |
+
function = tool_call.function
|
156 |
+
for tool in self.tools:
|
157 |
+
if hasattr(tool, function["name"]):
|
158 |
+
return await getattr(tool, function["name"])(
|
159 |
+
**json.loads(function["arguments"]),
|
160 |
+
)
|
161 |
+
msg = f'No tool function with name "{function["name"]}"'
|
162 |
+
raise ValueError(msg)
|
163 |
+
|
164 |
+
async def receive_tool_message(
|
165 |
+
self,
|
166 |
+
text: str,
|
167 |
+
tool_id: str,
|
168 |
+
label: MessageLabel = None,
|
169 |
+
) -> None:
|
170 |
+
self.history.append(
|
171 |
+
message=ToolMessage(content=[Text(text=text)], tool_call_id=tool_id),
|
172 |
+
label=label,
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
class Agents:
|
177 |
+
_agent_registry: dict[str, type[BaseAgent]] = {}
|
178 |
+
_agent_config_registry: dict[str, type[BaseAgentConfig]] = {}
|
179 |
+
|
180 |
+
@classmethod
|
181 |
+
def register_agent(cls, name: str):
|
182 |
+
"""
|
183 |
+
Decorator to register an Agent class under a given name.
|
184 |
+
|
185 |
+
Example:
|
186 |
+
@Agents.register_agent("browser")
|
187 |
+
class BrowserAgent(BaseAgent):
|
188 |
+
...
|
189 |
+
"""
|
190 |
+
|
191 |
+
def decorator(agent_cls: type[BaseAgent]) -> type[BaseAgent]:
|
192 |
+
cls._agent_registry[name] = agent_cls
|
193 |
+
return agent_cls
|
194 |
+
|
195 |
+
return decorator
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def register_agent_config(cls, name: str):
|
199 |
+
"""
|
200 |
+
Decorator to register a configuration class under a given name.
|
201 |
+
|
202 |
+
Example:
|
203 |
+
@Agents.register_agent_config("browser")
|
204 |
+
class BrowserAgentConfig(BaseAgentConfig):
|
205 |
+
...
|
206 |
+
"""
|
207 |
+
|
208 |
+
def decorator(config_cls: type[BaseAgentConfig]) -> type[BaseAgentConfig]:
|
209 |
+
cls._agent_config_registry[name] = config_cls
|
210 |
+
return config_cls
|
211 |
+
|
212 |
+
return decorator
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def get(cls, name: str) -> type[BaseAgent]:
|
216 |
+
"""
|
217 |
+
Retrieve a registered Agent class by its name.
|
218 |
+
|
219 |
+
Raises:
|
220 |
+
ValueError: If no such agent is found.
|
221 |
+
"""
|
222 |
+
try:
|
223 |
+
return cast(Type[BaseAgent], cls._agent_registry[name])
|
224 |
+
except KeyError:
|
225 |
+
raise ValueError(f"Agent '{name}' not found.")
|
226 |
+
|
227 |
+
@classmethod
|
228 |
+
def get_config(cls, name: str) -> type[BaseAgentConfig]:
|
229 |
+
"""
|
230 |
+
Retrieve a registered Agent configuration class by its name.
|
231 |
+
|
232 |
+
Raises:
|
233 |
+
ValueError: If no such config is found.
|
234 |
+
"""
|
235 |
+
try:
|
236 |
+
return cast(type[BaseAgentConfig], cls._agent_config_registry[name])
|
237 |
+
except KeyError:
|
238 |
+
raise ValueError(f"Agent config for '{name}' not found.")
|
src/proxy_lite/agents/browser_agent.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from functools import cached_property
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
from pydantic import Field
|
6 |
+
|
7 |
+
from proxy_lite.agents.agent_base import Agents, BaseAgent, BaseAgentConfig
|
8 |
+
from proxy_lite.history import MessageHistory, MessageLabel, SystemMessage, Text
|
9 |
+
from proxy_lite.tools import Tool
|
10 |
+
|
11 |
+
BROWSER_AGENT_SYSTEM_PROMPT = """ **You are Proxy Lite, the Web-Browsing Agent.** You are developed by Convergence.
|
12 |
+
|
13 |
+
**Current date:** {date_time_with_day}.
|
14 |
+
|
15 |
+
You are given:
|
16 |
+
|
17 |
+
1. A user task that you are trying to complete.
|
18 |
+
2. Relevant facts we have at our disposal.
|
19 |
+
3. A high level plan to complete the task.
|
20 |
+
4. A history of previous actions and observations.
|
21 |
+
5. An annotated webpage screenshot and text description of what's visible in the browser before and after the last action.
|
22 |
+
|
23 |
+
## Objective
|
24 |
+
|
25 |
+
You are an expert at controlling the web browser.
|
26 |
+
You will be assisting a user with a task they are trying to complete on the web.
|
27 |
+
|
28 |
+
## Web Screenshots
|
29 |
+
|
30 |
+
Each iteration of your browsing loop, you'll be provided with a screenshot of the browser.
|
31 |
+
|
32 |
+
The screenshot will have red rectangular annotations. These annotations highlight the marked elements you can interact with.
|
33 |
+
|
34 |
+
## Mark IDs
|
35 |
+
|
36 |
+
Each annotated element is labeled with a "mark id" in the top-left corner.
|
37 |
+
|
38 |
+
When using tools like typing or clicking, specify the "mark id" to indicate which element you want to interact with.
|
39 |
+
|
40 |
+
If an element is not annotated, you cannot interact with it. This is a limitation of the software. Focus on marked elements only.
|
41 |
+
|
42 |
+
## Text Snippets
|
43 |
+
|
44 |
+
Along with the screenshot, you will receive text snippets describing each annotated element.
|
45 |
+
|
46 |
+
Here’s an example of different element types:
|
47 |
+
|
48 |
+
- [0] `<a>text</a>` → Mark 0 is a link (`<a>` tag) containing the text "text".
|
49 |
+
- [1] `<button>text</button>` → Mark 1 is a button (`<button>` tag) containing the text "text".
|
50 |
+
- [2] `<input value="text"/>` → Mark 2 is an input field (`<input>` tag) with the value "text".
|
51 |
+
- [3] `<select>text</select>` → Mark 3 is a dropdown menu (`<select>` tag) with the option "text" selected.
|
52 |
+
- [4] `<textarea>text</textarea>` → Mark 4 is a text area (`<textarea>` tag) containing the text "text".
|
53 |
+
- [5] `<li>text</li>` → Mark 5 is a list item (`<li>` tag) containing the text "text".
|
54 |
+
- [6] `<div scrollable>text</div>` → Mark 6 is a division (`<div>` tag) containing the text "text" and is scrollable.
|
55 |
+
- [7] `<td>text</td>` → Mark 7 is a table cell (`<td>` tag) containing the text "text".
|
56 |
+
|
57 |
+
Note that these text snippets may be incomplete.
|
58 |
+
|
59 |
+
## History
|
60 |
+
|
61 |
+
You will see your past actions and observations but not old annotated webpages.
|
62 |
+
|
63 |
+
This means annotated webpages showing useful information will not be visible in future actions.
|
64 |
+
|
65 |
+
To get around this, key details from each webpage are stored in observations.
|
66 |
+
|
67 |
+
## Web Browser Actions
|
68 |
+
|
69 |
+
You can only take the following actions with the web browser:
|
70 |
+
{tool_descriptions}
|
71 |
+
|
72 |
+
## Important Browsing Tips
|
73 |
+
|
74 |
+
If there is a modal overlay that is unresponsive on the page try reloading the webpage.
|
75 |
+
|
76 |
+
If there is a cookie consent form covering part of the page just click accept on the form.
|
77 |
+
|
78 |
+
When typing into a text field be sure to click one of the dropdown options (when present). Not selecting a dropdown option will result in the field being cleared after the next action.
|
79 |
+
|
80 |
+
You do not have access any internet accounts (outside of those provided by the user).
|
81 |
+
|
82 |
+
The browser has a built in CAPTCHA solver, if you are asked to solve one just wait and it will be solved for you.
|
83 |
+
|
84 |
+
## Don't Repeat the Same Actions Continuously
|
85 |
+
|
86 |
+
If you find yourself repeating an action without making progress, try another action.
|
87 |
+
|
88 |
+
## Task
|
89 |
+
|
90 |
+
You will now be connected to the user, who will give you their task.""" # noqa: E501
|
91 |
+
|
92 |
+
MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
|
93 |
+
MessageLabel.SCREENSHOT: 1,
|
94 |
+
# MessageLabel.REASONING_INDUCTION: 1,
|
95 |
+
# MessageLabel.FORMAT_INSTRUCTIONS: 1,
|
96 |
+
# MessageLabel.ACTION: 1,
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
@Agents.register_agent_config("browser")
|
101 |
+
class BrowserAgentConfig(BaseAgentConfig):
|
102 |
+
name: Literal["browser"] = "browser"
|
103 |
+
history_messages_limit: dict[MessageLabel, int] = Field(
|
104 |
+
default_factory=lambda: MAX_MESSAGES_FOR_CONTEXT_WINDOW,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
@Agents.register_agent("browser")
|
109 |
+
class BrowserAgent(BaseAgent):
|
110 |
+
config: BrowserAgentConfig
|
111 |
+
message_label: MessageLabel = MessageLabel.AGENT_MODEL_RESPONSE
|
112 |
+
|
113 |
+
def __init__(self, **data):
|
114 |
+
super().__init__(**data)
|
115 |
+
|
116 |
+
@property
|
117 |
+
def system_prompt(self) -> str:
|
118 |
+
return BROWSER_AGENT_SYSTEM_PROMPT.format(
|
119 |
+
date_time_with_day=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
120 |
+
tool_descriptions=self.tool_descriptions,
|
121 |
+
memories="",
|
122 |
+
)
|
123 |
+
|
124 |
+
@cached_property
|
125 |
+
def tools(self) -> list[Tool]:
|
126 |
+
return self.env_tools
|
127 |
+
|
128 |
+
async def get_history_view(self) -> MessageHistory:
|
129 |
+
return MessageHistory(
|
130 |
+
messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
|
131 |
+
) + self.history.history_view(
|
132 |
+
limits=self.config.history_messages_limit,
|
133 |
+
)
|
src/proxy_lite/agents/proxy_lite_agent.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
from pydantic import Field
|
5 |
+
|
6 |
+
from proxy_lite.history import MessageHistory, MessageLabel, SystemMessage, Text
|
7 |
+
from proxy_lite.tools import Tool
|
8 |
+
|
9 |
+
from .agent_base import Agents, BaseAgent, BaseAgentConfig
|
10 |
+
|
11 |
+
MODEL_SYSTEM_PROMPT = """You are Proxy-Lite, an AI assistant that can perform actions on a computer screen.
|
12 |
+
You were developed by Convergence AI.
|
13 |
+
The user will instuct you to perform a task.
|
14 |
+
You will be shown a screen as well as relevant interactable elements highlighted by mark_ids and you will be given a set of tools to use to perform the task.
|
15 |
+
You should make observations about the screen, putting them in <observation></observation> tags.
|
16 |
+
You should then reason about what needs to be done to complete the task, putting your thoughts in <thinking></thinking> tags.
|
17 |
+
You should then use the tools to perform the task, putting the tool calls in <tool_call></tool_call> tags.
|
18 |
+
""" # noqa: E501
|
19 |
+
|
20 |
+
MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
|
21 |
+
MessageLabel.SCREENSHOT: 1,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
@Agents.register_agent_config("proxy_lite")
|
26 |
+
class ProxyLiteAgentConfig(BaseAgentConfig):
|
27 |
+
name: Literal["proxy_lite"] = "proxy_lite"
|
28 |
+
history_messages_limit: dict[MessageLabel, int] = Field(
|
29 |
+
default_factory=lambda: MAX_MESSAGES_FOR_CONTEXT_WINDOW,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@Agents.register_agent("proxy_lite")
|
34 |
+
class ProxyLiteAgent(BaseAgent):
|
35 |
+
config: ProxyLiteAgentConfig
|
36 |
+
message_label: MessageLabel = MessageLabel.AGENT_MODEL_RESPONSE
|
37 |
+
|
38 |
+
def __init__(self, **data):
|
39 |
+
super().__init__(**data)
|
40 |
+
|
41 |
+
@property
|
42 |
+
def system_prompt(self) -> str:
|
43 |
+
return MODEL_SYSTEM_PROMPT
|
44 |
+
|
45 |
+
@cached_property
|
46 |
+
def tools(self) -> list[Tool]:
|
47 |
+
return self.env_tools
|
48 |
+
|
49 |
+
async def get_history_view(self) -> MessageHistory:
|
50 |
+
return MessageHistory(
|
51 |
+
messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
|
52 |
+
) + self.history.history_view(
|
53 |
+
limits=self.config.history_messages_limit,
|
54 |
+
)
|
src/proxy_lite/browser/__init__.py
ADDED
File without changes
|
src/proxy_lite/browser/add_custom_select.js
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
handledSelectElementsConvergence = new WeakSet();
|
2 |
+
|
3 |
+
overwriteDefaultSelectConvergence = (input = null) => {
|
4 |
+
let activeSelectElement = null;
|
5 |
+
|
6 |
+
// Handle iframe input element
|
7 |
+
let rootElement = input ? input : document.documentElement;
|
8 |
+
|
9 |
+
function createCustomSelectElement() {
|
10 |
+
// Create the custom select container
|
11 |
+
const customSelect = document.createElement('div');
|
12 |
+
customSelect.id = 'convergence-custom-select-element-X2EmudtLRN';
|
13 |
+
customSelect.style.position = 'absolute'
|
14 |
+
customSelect.style.zIndex = 2147483647 - 1;
|
15 |
+
customSelect.style.display = 'none';
|
16 |
+
document.body.appendChild(customSelect);
|
17 |
+
|
18 |
+
// Create the select options list
|
19 |
+
const optionsList = document.createElement('div');
|
20 |
+
optionsList.style.border = '1px solid #ccc';
|
21 |
+
optionsList.style.backgroundColor = '#fff';
|
22 |
+
optionsList.style.color = 'black';
|
23 |
+
customSelect.appendChild(optionsList);
|
24 |
+
|
25 |
+
return customSelect;
|
26 |
+
}
|
27 |
+
|
28 |
+
function showCustomSelect(select) {
|
29 |
+
activeSelectElement = select;
|
30 |
+
|
31 |
+
// Clear previous options
|
32 |
+
const customSelect = rootElement.querySelector('#convergence-custom-select-element-X2EmudtLRN');
|
33 |
+
let optionsList = customSelect.firstChild;
|
34 |
+
optionsList.innerHTML = '';
|
35 |
+
|
36 |
+
// Populate with new options
|
37 |
+
Array.from(select.options).forEach(option => {
|
38 |
+
const customOption = document.createElement('div');
|
39 |
+
customOption.className = 'custom-option';
|
40 |
+
customOption.style.padding = '8px';
|
41 |
+
customOption.style.cursor = 'pointer';
|
42 |
+
customOption.textContent = option.text;
|
43 |
+
customOption.dataset.value = option.value;
|
44 |
+
optionsList.appendChild(customOption);
|
45 |
+
|
46 |
+
customOption.addEventListener('mouseenter', function () {
|
47 |
+
customOption.style.backgroundColor = '#f0f0f0';
|
48 |
+
});
|
49 |
+
|
50 |
+
customOption.addEventListener('mouseleave', function () {
|
51 |
+
customOption.style.backgroundColor = '';
|
52 |
+
});
|
53 |
+
|
54 |
+
customOption.addEventListener('mousedown', (e) => {
|
55 |
+
e.stopPropagation();
|
56 |
+
select.value = customOption.dataset.value;
|
57 |
+
customSelect.style.display = 'none';
|
58 |
+
activeSelectElement = null;
|
59 |
+
// ensure we trigger all potential event listeners
|
60 |
+
select.dispatchEvent(new InputEvent('focus', { bubbles: true, cancelable: true }));
|
61 |
+
select.dispatchEvent(new InputEvent('input', { bubbles: true, cancelable: true }));
|
62 |
+
select.dispatchEvent(new InputEvent('change', { bubbles: true, cancelable: true }));
|
63 |
+
select.dispatchEvent(new InputEvent('blur', { bubbles: true, cancelable: true }));
|
64 |
+
});
|
65 |
+
});
|
66 |
+
|
67 |
+
// Position and show the custom select
|
68 |
+
const selectRect = select.getBoundingClientRect();
|
69 |
+
customSelect.style.top = `${selectRect.bottom + window.scrollY}px`;
|
70 |
+
customSelect.style.left = `${selectRect.left + window.scrollX}px`;
|
71 |
+
customSelect.style.width = `${selectRect.width}px`;
|
72 |
+
customSelect.style.display = 'block';
|
73 |
+
select.focus();
|
74 |
+
select.addEventListener('blur', function (e) {
|
75 |
+
customSelect.style.display = 'none';
|
76 |
+
activeSelectElement = null;
|
77 |
+
});
|
78 |
+
select.addEventListener('change', function (e) {
|
79 |
+
customSelect.style.display = 'none';
|
80 |
+
activeSelectElement = null;
|
81 |
+
});
|
82 |
+
}
|
83 |
+
|
84 |
+
// Ensure we have a custom select element
|
85 |
+
let customSelect = rootElement.querySelector(`#convergence-custom-select-element-X2EmudtLRN`);
|
86 |
+
if (!customSelect) {
|
87 |
+
customSelect = createCustomSelectElement();
|
88 |
+
}
|
89 |
+
|
90 |
+
// Find selects in shadow DOMs
|
91 |
+
function findSelectInShadowRoot(element) {
|
92 |
+
if (element.shadowRoot) {
|
93 |
+
return element.shadowRoot.querySelectorAll('select');
|
94 |
+
}
|
95 |
+
return [];
|
96 |
+
}
|
97 |
+
let shadowSelects = [];
|
98 |
+
rootElement.querySelectorAll('*').forEach(el => {
|
99 |
+
shadowSelects.push(...findSelectInShadowRoot(el));
|
100 |
+
});
|
101 |
+
|
102 |
+
// Find selects in the regular (light) DOM
|
103 |
+
const lightSelects = Array.from(rootElement.querySelectorAll('select'));
|
104 |
+
|
105 |
+
// Add event listeners to all select elements
|
106 |
+
const allSelects = [...lightSelects, ...shadowSelects];
|
107 |
+
allSelects.forEach(select => {
|
108 |
+
if (select.hasAttribute('multiple')) {
|
109 |
+
// skip special multiple elements as our POI code already handles them
|
110 |
+
return;
|
111 |
+
}
|
112 |
+
if (!handledSelectElementsConvergence.has(select)) {
|
113 |
+
select.addEventListener('mousedown', (e) => {
|
114 |
+
// only use custom select when the default behaviour is being used
|
115 |
+
if (!e.defaultPrevented) {
|
116 |
+
showCustomSelect(select);
|
117 |
+
e.preventDefault();
|
118 |
+
}
|
119 |
+
});
|
120 |
+
handledSelectElementsConvergence.add(select);
|
121 |
+
}
|
122 |
+
});
|
123 |
+
}
|
src/proxy_lite/browser/bounding_boxes.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from pydantic import BaseModel, Field, field_validator
|
7 |
+
|
8 |
+
|
9 |
+
class Point(BaseModel):
|
10 |
+
x: int
|
11 |
+
y: int
|
12 |
+
|
13 |
+
def __iter__(self):
|
14 |
+
return iter((self.x, self.y))
|
15 |
+
|
16 |
+
def __getitem__(self, index) -> int:
|
17 |
+
return (self.x, self.y)[index]
|
18 |
+
|
19 |
+
def __tuple__(self) -> tuple[int, int]:
|
20 |
+
return (self.x, self.y)
|
21 |
+
|
22 |
+
def __repr__(self) -> str:
|
23 |
+
return f"Point(x={self.x}, y={self.y})"
|
24 |
+
|
25 |
+
|
26 |
+
class BoundingBox(BaseModel):
|
27 |
+
label: str = Field(..., description="The label that's given for this bounding box")
|
28 |
+
left: int = Field(..., description="Left coordinate of the bounding box")
|
29 |
+
right: int = Field(..., description="Right coordinate of the bounding box")
|
30 |
+
top: int = Field(..., description="Top coordinate of the bounding box")
|
31 |
+
bottom: int = Field(..., description="Bottom coordinate of the bounding box")
|
32 |
+
|
33 |
+
@field_validator("left", "top", mode="before")
|
34 |
+
@classmethod
|
35 |
+
def round_down(cls, v):
|
36 |
+
return math.floor(float(v))
|
37 |
+
|
38 |
+
@field_validator("right", "bottom", mode="before")
|
39 |
+
@classmethod
|
40 |
+
def round_up(cls, v):
|
41 |
+
return math.ceil(float(v))
|
42 |
+
|
43 |
+
|
44 |
+
class POI(BaseModel):
|
45 |
+
info: dict[str, Any]
|
46 |
+
element_centroid: Point
|
47 |
+
bounding_box: BoundingBox
|
48 |
+
|
49 |
+
|
50 |
+
def calculate_dash_points(start, end, dash_length, gap_length):
|
51 |
+
x1, y1 = start
|
52 |
+
x2, y2 = end
|
53 |
+
dx = x2 - x1
|
54 |
+
dy = y2 - y1
|
55 |
+
dist = np.sqrt(dx * dx + dy * dy)
|
56 |
+
|
57 |
+
if dist == 0:
|
58 |
+
return []
|
59 |
+
|
60 |
+
unit_x = dx / dist
|
61 |
+
unit_y = dy / dist
|
62 |
+
|
63 |
+
dash_points = []
|
64 |
+
current_dist = 0
|
65 |
+
while current_dist < dist:
|
66 |
+
dash_end = min(current_dist + dash_length, dist)
|
67 |
+
dash_points.extend(
|
68 |
+
[
|
69 |
+
(int(x1 + unit_x * current_dist), int(y1 + unit_y * current_dist)),
|
70 |
+
(int(x1 + unit_x * dash_end), int(y1 + unit_y * dash_end)),
|
71 |
+
],
|
72 |
+
)
|
73 |
+
current_dist += dash_length + gap_length
|
74 |
+
|
75 |
+
return dash_points
|
76 |
+
|
77 |
+
|
78 |
+
def draw_dashed_rectangle(
|
79 |
+
img,
|
80 |
+
bbox: BoundingBox,
|
81 |
+
color,
|
82 |
+
thickness=1,
|
83 |
+
dash_length=10,
|
84 |
+
gap_length=5,
|
85 |
+
):
|
86 |
+
# Calculate dash points for all sides
|
87 |
+
top_points = calculate_dash_points(
|
88 |
+
(bbox.left + 25, bbox.top + 25),
|
89 |
+
(bbox.right + 25, bbox.top + 25),
|
90 |
+
dash_length,
|
91 |
+
gap_length,
|
92 |
+
)
|
93 |
+
right_points = calculate_dash_points(
|
94 |
+
(bbox.right + 25, bbox.top + 25),
|
95 |
+
(bbox.right + 25, bbox.bottom + 25),
|
96 |
+
dash_length,
|
97 |
+
gap_length,
|
98 |
+
)
|
99 |
+
bottom_points = calculate_dash_points(
|
100 |
+
(bbox.right + 25, bbox.bottom + 25),
|
101 |
+
(bbox.left + 25, bbox.bottom + 25),
|
102 |
+
dash_length,
|
103 |
+
gap_length,
|
104 |
+
)
|
105 |
+
left_points = calculate_dash_points(
|
106 |
+
(bbox.left + 25, bbox.bottom + 25),
|
107 |
+
(bbox.left + 25, bbox.top + 25),
|
108 |
+
dash_length,
|
109 |
+
gap_length,
|
110 |
+
)
|
111 |
+
|
112 |
+
# Combine all points
|
113 |
+
all_points = top_points + right_points + bottom_points + left_points
|
114 |
+
|
115 |
+
# Draw all lines at once
|
116 |
+
if all_points:
|
117 |
+
all_points = np.array(all_points).reshape((-1, 2, 2))
|
118 |
+
cv2.polylines(img, all_points, False, color, thickness)
|
119 |
+
|
120 |
+
|
121 |
+
# @time_it(name='Annotate bounding box')
|
122 |
+
def annotate_bounding_box(image: bytes, bbox: BoundingBox) -> None:
|
123 |
+
# Draw dashed bounding box
|
124 |
+
draw_dashed_rectangle(
|
125 |
+
image,
|
126 |
+
bbox,
|
127 |
+
color=(0, 0, 255),
|
128 |
+
thickness=1,
|
129 |
+
dash_length=10,
|
130 |
+
gap_length=5,
|
131 |
+
)
|
132 |
+
|
133 |
+
# Prepare label
|
134 |
+
font_scale = 0.4 * 4 # Increased by 4x for the larger patch
|
135 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
136 |
+
thickness = 3 # Increased thickness for the larger patch
|
137 |
+
|
138 |
+
# Get text size for the larger patch
|
139 |
+
(label_width, label_height), _ = cv2.getTextSize(
|
140 |
+
bbox.label,
|
141 |
+
font,
|
142 |
+
font_scale,
|
143 |
+
thickness,
|
144 |
+
)
|
145 |
+
|
146 |
+
# Create a larger patch (4x)
|
147 |
+
large_label_patch = np.zeros(
|
148 |
+
(label_height + 20, label_width + 20, 4),
|
149 |
+
dtype=np.uint8,
|
150 |
+
)
|
151 |
+
large_label_patch[:, :, 0:3] = (0, 0, 255) # BGR color format: Red background
|
152 |
+
large_label_patch[:, :, 3] = 128 # Alpha channel: 50% opacity (128/255 = 0.5)
|
153 |
+
|
154 |
+
# Draw text on the larger patch
|
155 |
+
cv2.putText(
|
156 |
+
large_label_patch,
|
157 |
+
bbox.label,
|
158 |
+
(8, label_height + 8), # Adjusted position for the larger patch
|
159 |
+
font,
|
160 |
+
font_scale,
|
161 |
+
(255, 255, 255, 128), # White text, 50% opaque (128/255 = 0.5)
|
162 |
+
thickness,
|
163 |
+
)
|
164 |
+
|
165 |
+
# Scale down the patch to improve anti-aliasing
|
166 |
+
label_patch = cv2.resize(
|
167 |
+
large_label_patch,
|
168 |
+
(label_width // 4 + 5, label_height // 4 + 5),
|
169 |
+
interpolation=cv2.INTER_AREA,
|
170 |
+
)
|
171 |
+
|
172 |
+
# Calculate position for top-left alignment
|
173 |
+
offset = 2 # Small offset to prevent touching the bounding box edge
|
174 |
+
x = min(image.shape[1], max(0, int(bbox.left + 25) - offset))
|
175 |
+
y = min(image.shape[0], max(0, int(bbox.top + 25) - label_patch.shape[0] - offset))
|
176 |
+
|
177 |
+
# Ensure we're not out of bounds
|
178 |
+
x_end = min(image.shape[1], x + label_patch.shape[1])
|
179 |
+
y_end = min(image.shape[0], y + label_patch.shape[0])
|
180 |
+
label_patch = label_patch[: (y_end - y), : (x_end - x)]
|
181 |
+
|
182 |
+
# Create a mask for the label patch
|
183 |
+
alpha_mask = label_patch[:, :, 3] / 255.0
|
184 |
+
alpha_mask = np.repeat(alpha_mask[:, :, np.newaxis], 3, axis=2)
|
185 |
+
|
186 |
+
# Blend the label patch with the image
|
187 |
+
image_section = image[y:y_end, x:x_end]
|
188 |
+
blended = (1 - alpha_mask) * image_section + alpha_mask * label_patch[:, :, 0:3]
|
189 |
+
image[y:y_end, x:x_end] = blended.astype(np.uint8)
|
190 |
+
|
191 |
+
|
192 |
+
def annotate_bounding_boxes(image: bytes, bounding_boxes: list[BoundingBox]) -> bytes:
|
193 |
+
# Read the image
|
194 |
+
nparr = np.frombuffer(image, np.uint8)
|
195 |
+
# Decode the image
|
196 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
197 |
+
padded_img = cv2.copyMakeBorder(
|
198 |
+
img,
|
199 |
+
top=25, # Value chosen based on label size
|
200 |
+
bottom=25, # Value chosen based on label size
|
201 |
+
left=25, # Value chosen based on label size
|
202 |
+
right=25, # Value chosen based on label size
|
203 |
+
borderType=cv2.BORDER_CONSTANT,
|
204 |
+
value=(255, 255, 255),
|
205 |
+
)
|
206 |
+
for bounding_box in bounding_boxes:
|
207 |
+
# Annotate the image in place with the bounding box and the bounding box label
|
208 |
+
annotate_bounding_box(padded_img, bounding_box)
|
209 |
+
_, buffer = cv2.imencode(".jpeg", padded_img)
|
210 |
+
return buffer.tobytes()
|
src/proxy_lite/browser/browser.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from contextlib import AsyncExitStack
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Literal, Optional, Self
|
7 |
+
|
8 |
+
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
9 |
+
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
10 |
+
from playwright_stealth import stealth_async
|
11 |
+
from pydantic import Field
|
12 |
+
from tenacity import before_sleep_log, retry, stop_after_delay, wait_exponential
|
13 |
+
|
14 |
+
from proxy_lite.browser.bounding_boxes import POI, BoundingBox, Point, annotate_bounding_boxes
|
15 |
+
from proxy_lite.logger import logger
|
16 |
+
|
17 |
+
SELF_CONTAINED_TAGS = [
|
18 |
+
# many of these are non-interactive but keeping them anyway
|
19 |
+
"area",
|
20 |
+
"base",
|
21 |
+
"br",
|
22 |
+
"col",
|
23 |
+
"embed",
|
24 |
+
"hr",
|
25 |
+
"img",
|
26 |
+
"input",
|
27 |
+
"link",
|
28 |
+
"meta",
|
29 |
+
"param",
|
30 |
+
"source",
|
31 |
+
"track",
|
32 |
+
"wbr",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
def element_as_text(
|
37 |
+
mark_id: int,
|
38 |
+
tag: Optional[str] = None,
|
39 |
+
text: Optional[str] = None,
|
40 |
+
**raw_attributes,
|
41 |
+
) -> str:
|
42 |
+
"""Return a text representation of all elements on the page."""
|
43 |
+
attributes = []
|
44 |
+
for k, v in raw_attributes.items():
|
45 |
+
if v is None:
|
46 |
+
continue
|
47 |
+
if isinstance(v, bool):
|
48 |
+
if v:
|
49 |
+
attributes.append(k)
|
50 |
+
# we ignore False bool attributes
|
51 |
+
else:
|
52 |
+
v = str(v)
|
53 |
+
if len(v) > 2500:
|
54 |
+
v = v[: 2500 - 1] + "…"
|
55 |
+
attributes.append(f'{k}="{v}"')
|
56 |
+
attributes = " ".join(attributes)
|
57 |
+
attributes = (" " + attributes).rstrip()
|
58 |
+
tag = tag.lower()
|
59 |
+
if text is None:
|
60 |
+
text = ""
|
61 |
+
if len(text) > 2500:
|
62 |
+
text = text[: 2500 - 1] + "…"
|
63 |
+
|
64 |
+
# sub-out line breaks so elements are easier to distinguish
|
65 |
+
attributes = re.sub(r"\r\n|\r|\n", "⏎", attributes)
|
66 |
+
text = re.sub(r"\r\n|\r|\n", "⏎", text)
|
67 |
+
|
68 |
+
if tag in SELF_CONTAINED_TAGS:
|
69 |
+
if text:
|
70 |
+
logger.warning(
|
71 |
+
f"Got self-contained element '{tag}' which contained text '{text}'.",
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
return f"- [{mark_id}] <{tag}{attributes}/>"
|
75 |
+
return f"- [{mark_id}] <{tag}{attributes}>{text}</{tag}>"
|
76 |
+
|
77 |
+
|
78 |
+
class BrowserSession:
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
viewport_width: int = 1280,
|
82 |
+
viewport_height: int = 720,
|
83 |
+
headless: bool = True,
|
84 |
+
):
|
85 |
+
self.viewport_width = viewport_width
|
86 |
+
self.viewport_height = viewport_height
|
87 |
+
self.headless = headless
|
88 |
+
self.playwright: Playwright | None = None
|
89 |
+
self.browser: Browser | None = None
|
90 |
+
self.context: BrowserContext | None = None
|
91 |
+
self._exit_stack: AsyncExitStack | None = None
|
92 |
+
|
93 |
+
self.poi_elements: list = Field(default_factory=list)
|
94 |
+
self.poi_centroids: list[Point] = Field(default_factory=list)
|
95 |
+
self.bounding_boxes: list[BoundingBox] = Field(default_factory=list)
|
96 |
+
self.pois: list[POI] = Field(default_factory=list)
|
97 |
+
|
98 |
+
async def __aenter__(self) -> Self:
|
99 |
+
self._exit_stack = AsyncExitStack()
|
100 |
+
self.playwright = await async_playwright().start()
|
101 |
+
|
102 |
+
self.browser = await self.playwright.chromium.launch(headless=self.headless)
|
103 |
+
self.context = await self.browser.new_context(
|
104 |
+
viewport={"width": self.viewport_width, "height": self.viewport_height},
|
105 |
+
)
|
106 |
+
await self.context.new_page()
|
107 |
+
self.context.set_default_timeout(60_000)
|
108 |
+
self.current_page.set_default_timeout(60_000)
|
109 |
+
await stealth_async(self.current_page)
|
110 |
+
await self.context.add_init_script(
|
111 |
+
path=Path(__file__).with_name("add_custom_select.js"),
|
112 |
+
)
|
113 |
+
await self.context.add_init_script(
|
114 |
+
path=Path(__file__).with_name("find_pois.js"),
|
115 |
+
)
|
116 |
+
|
117 |
+
return self
|
118 |
+
|
119 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
120 |
+
if self.browser:
|
121 |
+
await self.browser.close()
|
122 |
+
if self.playwright:
|
123 |
+
await self.playwright.stop()
|
124 |
+
if self._exit_stack:
|
125 |
+
await self._exit_stack.aclose()
|
126 |
+
|
127 |
+
@property
|
128 |
+
def current_page(self) -> Optional[Page]:
|
129 |
+
if self.context.pages:
|
130 |
+
return self.context.pages[-1]
|
131 |
+
return None
|
132 |
+
|
133 |
+
@property
|
134 |
+
def current_url(self) -> Optional[str]:
|
135 |
+
if self.current_page:
|
136 |
+
return self.current_page.url
|
137 |
+
return None
|
138 |
+
|
139 |
+
# re-run for cases of mid-run redirects
|
140 |
+
@retry(
|
141 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
142 |
+
stop=stop_after_delay(5),
|
143 |
+
reraise=True,
|
144 |
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
145 |
+
)
|
146 |
+
async def process_iframe(self, iframe) -> Optional[tuple[dict, dict]]:
|
147 |
+
try:
|
148 |
+
# Check iframe visibility and size
|
149 |
+
bounding_box = await iframe.bounding_box()
|
150 |
+
if not bounding_box:
|
151 |
+
return None # Skip if iframe is not visible
|
152 |
+
|
153 |
+
width, height = bounding_box["width"], bounding_box["height"]
|
154 |
+
if width < 50 or height < 50:
|
155 |
+
return None
|
156 |
+
|
157 |
+
frame = await iframe.content_frame()
|
158 |
+
if not frame:
|
159 |
+
return None
|
160 |
+
|
161 |
+
poi = await frame.evaluate(
|
162 |
+
"""() => {
|
163 |
+
overwriteDefaultSelectConvergence();
|
164 |
+
return findPOIsConvergence();
|
165 |
+
}""",
|
166 |
+
)
|
167 |
+
if not poi:
|
168 |
+
return None
|
169 |
+
|
170 |
+
iframe_offset = {"x": round(bounding_box["x"]), "y": round(bounding_box["y"])}
|
171 |
+
return poi, iframe_offset
|
172 |
+
except Exception as e:
|
173 |
+
logger.error(f"Error processing iframe: {e}")
|
174 |
+
return None
|
175 |
+
|
176 |
+
# re-run for cases of mid-run redirects
|
177 |
+
@retry(
|
178 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
179 |
+
stop=stop_after_delay(5),
|
180 |
+
reraise=True,
|
181 |
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
182 |
+
)
|
183 |
+
async def update_poi(self) -> None:
|
184 |
+
try:
|
185 |
+
await self.current_page.wait_for_load_state(timeout=60000)
|
186 |
+
except PlaywrightTimeoutError:
|
187 |
+
logger.error(f"Timeout waiting for website load state: {self.current_url}")
|
188 |
+
await self.current_page.wait_for_selector("body", timeout=60000, state="visible")
|
189 |
+
# Run the bounding box javascript code to highlight the points of interest on the page
|
190 |
+
page_info = await self.current_page.evaluate(
|
191 |
+
"""() => {
|
192 |
+
overwriteDefaultSelectConvergence();
|
193 |
+
return findPOIsConvergence();
|
194 |
+
}""",
|
195 |
+
)
|
196 |
+
# Get the points of interest on the page
|
197 |
+
self.poi_elements = page_info["element_descriptions"]
|
198 |
+
element_centroids = page_info["element_centroids"]
|
199 |
+
try:
|
200 |
+
# Select all iframes on the page
|
201 |
+
iframes = await self.current_page.query_selector_all("iframe")
|
202 |
+
|
203 |
+
max_iframes = 10
|
204 |
+
|
205 |
+
# Define an asynchronous function to process and filter each iframe
|
206 |
+
|
207 |
+
tasks = [asyncio.create_task(self.process_iframe(iframe)) for iframe in iframes[:max_iframes]]
|
208 |
+
|
209 |
+
results = await asyncio.gather(*tasks)
|
210 |
+
|
211 |
+
filtered_results = [result for result in results if result is not None]
|
212 |
+
|
213 |
+
iframes_pois = []
|
214 |
+
iframe_offsets = []
|
215 |
+
|
216 |
+
for poi, offset in filtered_results:
|
217 |
+
iframes_pois.append(poi)
|
218 |
+
iframe_offsets.append(offset)
|
219 |
+
|
220 |
+
# Combine the points of interest from the iframes with the main page and adjust the centroids
|
221 |
+
for index, iframe_poi in enumerate(iframes_pois):
|
222 |
+
self.poi_elements.extend(iframe_poi["element_descriptions"])
|
223 |
+
for centroid in iframe_poi["element_centroids"]:
|
224 |
+
centroid["x"] += iframe_offsets[index]["x"]
|
225 |
+
centroid["y"] += iframe_offsets[index]["y"]
|
226 |
+
centroid["left"] += iframe_offsets[index]["x"]
|
227 |
+
centroid["top"] += iframe_offsets[index]["y"]
|
228 |
+
centroid["right"] += iframe_offsets[index]["x"]
|
229 |
+
centroid["bottom"] += iframe_offsets[index]["y"]
|
230 |
+
element_centroids.extend(iframe_poi["element_centroids"])
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
logger.error(f"Error in finding iframes: {e}")
|
234 |
+
|
235 |
+
# Get the centroids of the points of interest
|
236 |
+
self.poi_centroids = [Point(x=xy["x"], y=xy["y"]) for xy in element_centroids]
|
237 |
+
self.bounding_boxes = [BoundingBox(**xy, label=str(i)) for i, xy in enumerate(element_centroids)]
|
238 |
+
self.pois = [
|
239 |
+
POI(info=info, element_centroid=centroid, bounding_box=bbox)
|
240 |
+
for info, centroid, bbox in zip(
|
241 |
+
self.poi_elements,
|
242 |
+
self.poi_centroids,
|
243 |
+
self.bounding_boxes,
|
244 |
+
strict=False,
|
245 |
+
)
|
246 |
+
]
|
247 |
+
|
248 |
+
@property
|
249 |
+
def poi_text(self) -> str:
|
250 |
+
# Get all points of interest on the page as text
|
251 |
+
texts = [element_as_text(mark_id=i, **element) for i, element in enumerate(self.poi_elements)]
|
252 |
+
# Return formatted text of points of interest on page
|
253 |
+
return "\n".join([txt for txt in texts if txt])
|
254 |
+
|
255 |
+
async def screenshot(
|
256 |
+
self,
|
257 |
+
delay: float = 0.0,
|
258 |
+
quality: int = 70,
|
259 |
+
type: str = "jpeg",
|
260 |
+
scale: str = "css",
|
261 |
+
) -> tuple[bytes, bytes]:
|
262 |
+
if delay > 0.0:
|
263 |
+
await asyncio.sleep(delay)
|
264 |
+
await self.update_poi()
|
265 |
+
old_poi_positions = [tuple(point) for point in self.poi_centroids]
|
266 |
+
img = await self.current_page.screenshot(type=type, quality=quality, scale=scale)
|
267 |
+
annotated_img = annotate_bounding_boxes(image=img, bounding_boxes=self.bounding_boxes)
|
268 |
+
# check page has not changed since the screenshot was taken
|
269 |
+
await self.update_poi()
|
270 |
+
new_poi_positions = [tuple(point) for point in self.poi_centroids]
|
271 |
+
if new_poi_positions != old_poi_positions:
|
272 |
+
# if it has changed, take another
|
273 |
+
img = await self.current_page.screenshot(type=type, quality=quality, scale=scale)
|
274 |
+
await self.update_poi()
|
275 |
+
annotated_img = annotate_bounding_boxes(image=img, bounding_boxes=self.bounding_boxes)
|
276 |
+
return img, annotated_img
|
277 |
+
|
278 |
+
async def goto(self, url: str) -> None:
|
279 |
+
await self.current_page.goto(url, wait_until="domcontentloaded")
|
280 |
+
|
281 |
+
async def reload(self) -> None:
|
282 |
+
await self.current_page.reload(wait_until="domcontentloaded")
|
283 |
+
|
284 |
+
async def click_tab(self, mark_id: int) -> None:
|
285 |
+
point: Point = self.poi_centroids[mark_id]
|
286 |
+
await self.hover(point)
|
287 |
+
await self.current_page.mouse.click(*point, button="middle")
|
288 |
+
|
289 |
+
async def click(self, mark_id: int) -> None:
|
290 |
+
point: Point = self.poi_centroids[mark_id]
|
291 |
+
await self.hover(point)
|
292 |
+
await self.current_page.mouse.click(*point)
|
293 |
+
|
294 |
+
async def enter_text(self, mark_id: int, text: str, submit: bool = False) -> None:
|
295 |
+
await self.clear_text_field(mark_id)
|
296 |
+
await self.click(mark_id)
|
297 |
+
await self.current_page.keyboard.type(text)
|
298 |
+
|
299 |
+
if submit:
|
300 |
+
await self.current_page.keyboard.press("Enter")
|
301 |
+
|
302 |
+
async def scroll(
|
303 |
+
self,
|
304 |
+
direction: Literal["up", "down", "left", "right"],
|
305 |
+
mark_id: Optional[int] = None,
|
306 |
+
) -> None:
|
307 |
+
if mark_id is None:
|
308 |
+
point = Point(x=-1, y=-1)
|
309 |
+
max_scroll_x = self.viewport_width
|
310 |
+
max_scroll_y = self.viewport_height
|
311 |
+
else:
|
312 |
+
point: Point = self.poi_centroids[mark_id]
|
313 |
+
bbox: BoundingBox = self.bounding_boxes[mark_id]
|
314 |
+
max_scroll_x = bbox.right - bbox.left
|
315 |
+
max_scroll_y = bbox.bottom - bbox.top
|
316 |
+
|
317 |
+
await self.hover(point=point)
|
318 |
+
scroll_x = int(max_scroll_x * 0.8)
|
319 |
+
scroll_y = int(max_scroll_y * 0.8)
|
320 |
+
is_vertical = direction in ("up", "down")
|
321 |
+
reverse_scroll = direction in ("up", "left")
|
322 |
+
await self.current_page.mouse.wheel(
|
323 |
+
scroll_x * (-1 if reverse_scroll else 1) * (not is_vertical),
|
324 |
+
scroll_y * (-1 if reverse_scroll else 1) * is_vertical,
|
325 |
+
)
|
326 |
+
|
327 |
+
async def go_back(self) -> None:
|
328 |
+
# If there is no tab open then return
|
329 |
+
if not self.current_page:
|
330 |
+
return
|
331 |
+
|
332 |
+
await self.current_page.go_back(wait_until="domcontentloaded")
|
333 |
+
if self.current_page.url == "about:blank":
|
334 |
+
if not len(self.context.pages) > 1:
|
335 |
+
await self.current_page.go_forward(wait_until="domcontentloaded")
|
336 |
+
raise Exception("There is no previous page to go back to.")
|
337 |
+
await self.current_page.close()
|
338 |
+
|
339 |
+
async def hover(self, point: Point) -> None:
|
340 |
+
await self.current_page.mouse.move(*point)
|
341 |
+
|
342 |
+
async def focus(self, point: Point) -> None:
|
343 |
+
# Focus on the element on the page at point (x, y)
|
344 |
+
await self.current_page.evaluate(
|
345 |
+
"""
|
346 |
+
([x, y]) => {
|
347 |
+
const element = document.elementFromPoint(x, y);
|
348 |
+
if (element && element.focus) {
|
349 |
+
element.focus();
|
350 |
+
}
|
351 |
+
}""",
|
352 |
+
tuple(point),
|
353 |
+
)
|
354 |
+
|
355 |
+
async def get_text(self, mark_id: int) -> str:
|
356 |
+
return await self.current_page.evaluate(
|
357 |
+
"""
|
358 |
+
(mark_id) => {
|
359 |
+
const element = marked_elements_convergence[mark_id];
|
360 |
+
if (element && (element.value !== undefined || element.textContent !== undefined)) {
|
361 |
+
return element.value || element.textContent;
|
362 |
+
}
|
363 |
+
return '';
|
364 |
+
}
|
365 |
+
""",
|
366 |
+
(mark_id,),
|
367 |
+
)
|
368 |
+
|
369 |
+
async def clear_text_field(self, mark_id: int) -> None:
|
370 |
+
existing_text = await self.get_text(mark_id)
|
371 |
+
if existing_text.strip():
|
372 |
+
# Clear existing text only if it exists
|
373 |
+
await self.click(mark_id)
|
374 |
+
await self.current_page.keyboard.press("Control+Home")
|
375 |
+
await self.current_page.keyboard.press("Control+Shift+End")
|
376 |
+
await self.current_page.keyboard.press("Backspace")
|
377 |
+
|
378 |
+
|
379 |
+
if __name__ == "__main__":
|
380 |
+
import json
|
381 |
+
|
382 |
+
test = """{"name": "return_value", "arguments": {'value': 'The most downloaded French speech recognition model on Hugging Face is DeepSeek-R1. Here are its evaluation metrics:\n\n- Claude-3.5-1022: MMLU 88.3, MMLU-Redux 88.9\n- GPT-4.0-5013: MMLU 87.2, MMLU-Redux 88.0\n- DeepSeek-01-3013: MMLU 88.5, MMLU-Redux 89.1\n- OpenAI-01-mini: MMLU 91.0, MMLU-Redux 88.7\n\nPlease see the attached screenshot for more details.'}}"""
|
383 |
+
test = json.loads(test)
|
384 |
+
print(test)
|
385 |
+
exit()
|
386 |
+
|
387 |
+
async def dummy_test():
|
388 |
+
async with BrowserSession(headless=False) as s:
|
389 |
+
page = await s.context.new_page()
|
390 |
+
await page.goto("http://google.co.uk")
|
391 |
+
await asyncio.sleep(5)
|
392 |
+
await page.screenshot(path="example.png")
|
393 |
+
await s.update_poi()
|
394 |
+
_, annotated_image = await s.screenshot()
|
395 |
+
with open("output.png", "wb") as f:
|
396 |
+
f.write(annotated_image)
|
397 |
+
|
398 |
+
asyncio.run(dummy_test())
|
src/proxy_lite/browser/find_pois.js
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
marked_elements_convergence = [];
|
2 |
+
|
3 |
+
const interactiveTags = new Set([
|
4 |
+
'a', 'button', 'details', 'embed', 'input', 'label',
|
5 |
+
'menu', 'menuitem', 'object', 'select', 'textarea', 'summary',
|
6 |
+
'video', 'audio', 'option', 'iframe'
|
7 |
+
]);
|
8 |
+
|
9 |
+
const interactiveRoles = new Set([
|
10 |
+
'button', 'menu', 'menuitem', 'link', 'checkbox', 'radio',
|
11 |
+
'slider', 'tab', 'tabpanel', 'textbox', 'combobox', 'grid',
|
12 |
+
'listbox', 'option', 'progressbar', 'scrollbar', 'searchbox',
|
13 |
+
'switch', 'tree', 'treeitem', 'spinbutton', 'tooltip',
|
14 |
+
'a-button-inner', 'a-dropdown-button', 'click',
|
15 |
+
'menuitemcheckbox', 'menuitemradio', 'a-button-text',
|
16 |
+
'button-text', 'button-icon', 'button-icon-only',
|
17 |
+
'button-text-icon-only', 'dropdown', 'combobox'
|
18 |
+
]);
|
19 |
+
|
20 |
+
findPOIsConvergence = (input = null) => {
|
21 |
+
|
22 |
+
let rootElement = input ? input : document.documentElement;
|
23 |
+
|
24 |
+
function isScrollable(element) {
|
25 |
+
if ((input === null) && (element === document.documentElement)) {
|
26 |
+
// we can always scroll the full page
|
27 |
+
return false;
|
28 |
+
}
|
29 |
+
|
30 |
+
const style = window.getComputedStyle(element);
|
31 |
+
|
32 |
+
const hasScrollableYContent = element.scrollHeight > element.clientHeight
|
33 |
+
const overflowYScroll = style.overflowY === 'scroll' || style.overflowY === 'auto';
|
34 |
+
|
35 |
+
const hasScrollableXContent = element.scrollWidth > element.clientWidth;
|
36 |
+
const overflowXScroll = style.overflowX === 'scroll' || style.overflowX === 'auto';
|
37 |
+
|
38 |
+
return (hasScrollableYContent && overflowYScroll) || (hasScrollableXContent && overflowXScroll);
|
39 |
+
}
|
40 |
+
|
41 |
+
function getEventListeners(element) {
|
42 |
+
try {
|
43 |
+
return window.getEventListeners?.(element) || {};
|
44 |
+
} catch (e) {
|
45 |
+
return {};
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
function isInteractive(element) {
|
50 |
+
if (!element) return false;
|
51 |
+
|
52 |
+
return (hasInteractiveTag(element) ||
|
53 |
+
hasInteractiveAttributes(element) ||
|
54 |
+
hasInteractiveEventListeners(element)) ||
|
55 |
+
isScrollable(element);
|
56 |
+
}
|
57 |
+
|
58 |
+
function hasInteractiveTag(element) {
|
59 |
+
return interactiveTags.has(element.tagName.toLowerCase());
|
60 |
+
}
|
61 |
+
|
62 |
+
function hasInteractiveAttributes(element) {
|
63 |
+
const role = element.getAttribute('role');
|
64 |
+
const ariaRole = element.getAttribute('aria-role');
|
65 |
+
const tabIndex = element.getAttribute('tabindex');
|
66 |
+
const onAttribute = element.getAttribute('on');
|
67 |
+
|
68 |
+
if (element.getAttribute('contenteditable') === 'true') return true;
|
69 |
+
if ((role && interactiveRoles.has(role)) ||
|
70 |
+
(ariaRole && interactiveRoles.has(ariaRole))) return true;
|
71 |
+
if (tabIndex !== null && tabIndex !== '-1') return true;
|
72 |
+
|
73 |
+
// Add check for AMP's 'on' attribute that starts with 'tap:'
|
74 |
+
if (onAttribute && onAttribute.startsWith('tap:')) return true;
|
75 |
+
|
76 |
+
const hasAriaProps = element.hasAttribute('aria-expanded') ||
|
77 |
+
element.hasAttribute('aria-pressed') ||
|
78 |
+
element.hasAttribute('aria-selected') ||
|
79 |
+
element.hasAttribute('aria-checked');
|
80 |
+
|
81 |
+
return hasAriaProps;
|
82 |
+
}
|
83 |
+
|
84 |
+
function hasInteractiveEventListeners(element) {
|
85 |
+
const hasClickHandler = element.onclick !== null ||
|
86 |
+
element.getAttribute('onclick') !== null ||
|
87 |
+
element.hasAttribute('ng-click') ||
|
88 |
+
element.hasAttribute('@click') ||
|
89 |
+
element.hasAttribute('v-on:click');
|
90 |
+
if (hasClickHandler) return true;
|
91 |
+
|
92 |
+
const listeners = getEventListeners(element);
|
93 |
+
return listeners && (
|
94 |
+
listeners.click?.length > 0 ||
|
95 |
+
listeners.mousedown?.length > 0 ||
|
96 |
+
listeners.mouseup?.length > 0 ||
|
97 |
+
listeners.touchstart?.length > 0 ||
|
98 |
+
listeners.touchend?.length > 0
|
99 |
+
);
|
100 |
+
}
|
101 |
+
|
102 |
+
function calculateArea(rects) {
|
103 |
+
return rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
|
104 |
+
}
|
105 |
+
|
106 |
+
function getElementRects(element, context) {
|
107 |
+
const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
108 |
+
const vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
109 |
+
|
110 |
+
let rects = [...element.getClientRects()];
|
111 |
+
|
112 |
+
// If rects are empty (likely due to Shadow DOM), try to estimate position
|
113 |
+
if (rects.length === 0 && element.getBoundingClientRect) {
|
114 |
+
rects = [element.getBoundingClientRect()];
|
115 |
+
}
|
116 |
+
|
117 |
+
// Get iframe offset if element is in an iframe
|
118 |
+
let iframeOffset = { x: 0, y: 0 };
|
119 |
+
if (context !== document && context?.defaultView?.frameElement) {
|
120 |
+
const iframe = context.defaultView.frameElement;
|
121 |
+
if (iframe) {
|
122 |
+
const iframeRect = iframe.getBoundingClientRect();
|
123 |
+
iframeOffset = {
|
124 |
+
x: iframeRect.left,
|
125 |
+
y: iframeRect.top
|
126 |
+
};
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
return rects.filter(bb => {
|
131 |
+
const center_x = bb.left + bb.width / 2 + iframeOffset.x;
|
132 |
+
const center_y = bb.top + bb.height / 2 + iframeOffset.y;
|
133 |
+
const elAtCenter = context.elementFromPoint(center_x - iframeOffset.x, center_y - iframeOffset.y);
|
134 |
+
|
135 |
+
return elAtCenter === element || element.contains(elAtCenter);
|
136 |
+
}).map(bb => {
|
137 |
+
const rect = {
|
138 |
+
left: Math.max(0, bb.left + iframeOffset.x),
|
139 |
+
top: Math.max(0, bb.top + iframeOffset.y),
|
140 |
+
right: Math.min(vw, bb.right + iframeOffset.x),
|
141 |
+
bottom: Math.min(vh, bb.bottom + iframeOffset.y)
|
142 |
+
};
|
143 |
+
return {
|
144 |
+
...rect,
|
145 |
+
width: rect.right - rect.left,
|
146 |
+
height: rect.bottom - rect.top
|
147 |
+
};
|
148 |
+
});
|
149 |
+
}
|
150 |
+
|
151 |
+
function isElementVisible(element) {
|
152 |
+
const style = window.getComputedStyle(element);
|
153 |
+
return element.offsetWidth > 0 &&
|
154 |
+
element.offsetHeight > 0 &&
|
155 |
+
style.visibility !== 'hidden' &&
|
156 |
+
style.display !== 'none';
|
157 |
+
}
|
158 |
+
|
159 |
+
function isTopElement(element) {
|
160 |
+
let doc = element.ownerDocument;
|
161 |
+
if (doc !== window.document) {
|
162 |
+
// If in an iframe's document, treat as top
|
163 |
+
return true;
|
164 |
+
}
|
165 |
+
const shadowRoot = element.getRootNode();
|
166 |
+
if (shadowRoot instanceof ShadowRoot) {
|
167 |
+
const rect = element.getBoundingClientRect();
|
168 |
+
const point = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
|
169 |
+
try {
|
170 |
+
const topEl = shadowRoot.elementFromPoint(point.x, point.y);
|
171 |
+
if (!topEl) return false;
|
172 |
+
let current = topEl;
|
173 |
+
while (current && current !== shadowRoot) {
|
174 |
+
if (current === element) return true;
|
175 |
+
current = current.parentElement;
|
176 |
+
}
|
177 |
+
return false;
|
178 |
+
} catch (e) {
|
179 |
+
return true;
|
180 |
+
}
|
181 |
+
}
|
182 |
+
const rect = element.getBoundingClientRect();
|
183 |
+
const point = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
|
184 |
+
try {
|
185 |
+
const topEl = document.elementFromPoint(point.x, point.y);
|
186 |
+
if (!topEl) return false;
|
187 |
+
let current = topEl;
|
188 |
+
while (current && current !== document.documentElement) {
|
189 |
+
if (current === element) return true;
|
190 |
+
current = current.parentElement;
|
191 |
+
}
|
192 |
+
return false;
|
193 |
+
} catch (e) {
|
194 |
+
return true;
|
195 |
+
}
|
196 |
+
}
|
197 |
+
|
198 |
+
function getVisibleText(element, marked_elements_convergence = []) {
|
199 |
+
const blockLikeDisplays = [
|
200 |
+
// Basic block elements
|
201 |
+
'block', 'flow-root', 'inline-block',
|
202 |
+
// Lists
|
203 |
+
'list-item',
|
204 |
+
// Table elements
|
205 |
+
'table', 'inline-table', 'table-row', 'table-cell',
|
206 |
+
'table-caption', 'table-header-group', 'table-footer-group',
|
207 |
+
'table-row-group',
|
208 |
+
// Modern layouts
|
209 |
+
'flex', 'inline-flex', 'grid', 'inline-grid'
|
210 |
+
];
|
211 |
+
|
212 |
+
// Check if element is hidden
|
213 |
+
const style = window.getComputedStyle(element);
|
214 |
+
if (style.display === 'none' || style.visibility === 'hidden') {
|
215 |
+
return '';
|
216 |
+
}
|
217 |
+
|
218 |
+
let collectedText = [];
|
219 |
+
|
220 |
+
function isMarkedInteractive(el) {
|
221 |
+
return marked_elements_convergence.includes(el);
|
222 |
+
}
|
223 |
+
|
224 |
+
function traverse(node) {
|
225 |
+
if (
|
226 |
+
node.nodeType === Node.ELEMENT_NODE &&
|
227 |
+
node !== element &&
|
228 |
+
isMarkedInteractive(node)
|
229 |
+
) {
|
230 |
+
return false;
|
231 |
+
}
|
232 |
+
|
233 |
+
if (node.nodeType === Node.TEXT_NODE) {
|
234 |
+
const trimmed = node.textContent.trim();
|
235 |
+
if (trimmed) {
|
236 |
+
collectedText.push(trimmed);
|
237 |
+
}
|
238 |
+
} else if (node.nodeType === Node.ELEMENT_NODE) {
|
239 |
+
// Skip noscript elements
|
240 |
+
if (node.tagName === 'NOSCRIPT') {
|
241 |
+
return true;
|
242 |
+
}
|
243 |
+
|
244 |
+
const nodeStyle = window.getComputedStyle(node);
|
245 |
+
|
246 |
+
// Skip hidden elements
|
247 |
+
if (nodeStyle.display === 'none' || nodeStyle.visibility === 'hidden') {
|
248 |
+
return true;
|
249 |
+
}
|
250 |
+
|
251 |
+
// Add newline before block elements if we have text
|
252 |
+
if (blockLikeDisplays.includes(nodeStyle.display) && collectedText.length > 0) {
|
253 |
+
collectedText.push('\n');
|
254 |
+
}
|
255 |
+
|
256 |
+
if (node.tagName === 'IMG') {
|
257 |
+
const textParts = [];
|
258 |
+
const alt = node.getAttribute('alt');
|
259 |
+
const title = node.getAttribute('title');
|
260 |
+
const ariaLabel = node.getAttribute('aria-label');
|
261 |
+
// Add more as needed (e.g., 'aria-describedby', 'data-caption', etc.)
|
262 |
+
|
263 |
+
if (alt) textParts.push(`alt="${alt}"`);
|
264 |
+
if (title) textParts.push(`title="${title}"`);
|
265 |
+
if (ariaLabel) textParts.push(`aria-label="${ariaLabel}"`);
|
266 |
+
|
267 |
+
if (textParts.length > 0) {
|
268 |
+
collectedText.push(`[img - ${textParts.join(' ')}]`);
|
269 |
+
}
|
270 |
+
return true;
|
271 |
+
}
|
272 |
+
|
273 |
+
for (const child of node.childNodes) {
|
274 |
+
const shouldContinue = traverse(child);
|
275 |
+
if (shouldContinue === false) {
|
276 |
+
return false;
|
277 |
+
}
|
278 |
+
}
|
279 |
+
|
280 |
+
// Add newline after block elements
|
281 |
+
if (blockLikeDisplays.includes(nodeStyle.display)) {
|
282 |
+
collectedText.push('\n');
|
283 |
+
}
|
284 |
+
}
|
285 |
+
|
286 |
+
return true;
|
287 |
+
}
|
288 |
+
|
289 |
+
traverse(element);
|
290 |
+
|
291 |
+
// Join text and normalize whitespace
|
292 |
+
return collectedText.join(' ').trim().replace(/\s{2,}/g, ' ').trim();
|
293 |
+
}
|
294 |
+
|
295 |
+
function extractInteractiveItems(rootElement) {
|
296 |
+
const items = [];
|
297 |
+
|
298 |
+
function processElement(element, context) {
|
299 |
+
if (!element) return;
|
300 |
+
|
301 |
+
// Recursively process elements
|
302 |
+
if (element.nodeType === Node.ELEMENT_NODE && isInteractive(element) && isElementVisible(element) && isTopElement(element)) {
|
303 |
+
const rects = getElementRects(element, context);
|
304 |
+
const area = calculateArea(rects);
|
305 |
+
items.push({
|
306 |
+
element: element,
|
307 |
+
area,
|
308 |
+
rects,
|
309 |
+
is_scrollable: isScrollable(element),
|
310 |
+
});
|
311 |
+
}
|
312 |
+
|
313 |
+
if (element.shadowRoot) {
|
314 |
+
// if it's shadow DOM, process elements in the shadow DOM
|
315 |
+
Array.from(element.shadowRoot.childNodes || []).forEach(child => {
|
316 |
+
processElement(child, element.shadowRoot);
|
317 |
+
});
|
318 |
+
}
|
319 |
+
|
320 |
+
if (element.tagName === 'SLOT') {
|
321 |
+
// Handle both assigned elements and nodes
|
322 |
+
const assigned = element.assignedNodes ? element.assignedNodes() : element.assignedElements();
|
323 |
+
assigned.forEach(child => {
|
324 |
+
processElement(child, context);
|
325 |
+
});
|
326 |
+
}
|
327 |
+
else if (element.tagName === 'IFRAME') {
|
328 |
+
try {
|
329 |
+
const iframeDoc = element.contentDocument || element.contentWindow?.document;
|
330 |
+
if (iframeDoc && iframeDoc.body) {
|
331 |
+
// Process elements inside iframe
|
332 |
+
processElement(iframeDoc.body, iframeDoc);
|
333 |
+
}
|
334 |
+
} catch (e) {
|
335 |
+
console.warn('Unable to access iframe contents:', e);
|
336 |
+
}
|
337 |
+
} else {
|
338 |
+
// if it's regular child elements, process regular child elements
|
339 |
+
Array.from(element.children || []).forEach(child => {
|
340 |
+
processElement(child, context);
|
341 |
+
});
|
342 |
+
}
|
343 |
+
}
|
344 |
+
|
345 |
+
processElement(rootElement, document);
|
346 |
+
return items;
|
347 |
+
}
|
348 |
+
|
349 |
+
if (marked_elements_convergence) {
|
350 |
+
marked_elements_convergence = [];
|
351 |
+
}
|
352 |
+
let mark_centres = [];
|
353 |
+
let marked_element_descriptions = [];
|
354 |
+
var items = extractInteractiveItems(rootElement);
|
355 |
+
|
356 |
+
// Lets create a floating border on top of these elements that will always be visible
|
357 |
+
let index = 0;
|
358 |
+
items.forEach(function (item) {
|
359 |
+
item.rects.forEach((bbox) => {
|
360 |
+
marked_elements_convergence.push(item.element);
|
361 |
+
mark_centres.push({
|
362 |
+
x: Math.round((bbox.left + bbox.right) / 2),
|
363 |
+
y: Math.round((bbox.top + bbox.bottom) / 2),
|
364 |
+
left: bbox.left,
|
365 |
+
top: bbox.top,
|
366 |
+
right: bbox.right,
|
367 |
+
bottom: bbox.bottom,
|
368 |
+
});
|
369 |
+
marked_element_descriptions.push({
|
370 |
+
tag: item.element.tagName,
|
371 |
+
text: getVisibleText(item.element),
|
372 |
+
// NOTE: all other attributes will be shown to the model when present
|
373 |
+
// TODO: incorperate child attributes, e.g. <img alt="..."> when img is a child of the link element
|
374 |
+
value: item.element.value,
|
375 |
+
placeholder: item.element.getAttribute("placeholder"),
|
376 |
+
element_type: item.element.getAttribute("type"),
|
377 |
+
aria_label: item.element.getAttribute("aria-label"),
|
378 |
+
name: item.element.getAttribute("name"),
|
379 |
+
required: item.element.getAttribute("required"),
|
380 |
+
disabled: item.element.getAttribute("disabled"),
|
381 |
+
pattern: item.element.getAttribute("pattern"),
|
382 |
+
checked: item.element.getAttribute("checked"),
|
383 |
+
minlength: item.element.getAttribute("minlength"),
|
384 |
+
maxlength: item.element.getAttribute("maxlength"),
|
385 |
+
role: item.element.getAttribute("role"),
|
386 |
+
title: item.element.getAttribute("title"),
|
387 |
+
scrollable: item.is_scrollable
|
388 |
+
});
|
389 |
+
index++;
|
390 |
+
});
|
391 |
+
});
|
392 |
+
|
393 |
+
return {
|
394 |
+
element_descriptions: marked_element_descriptions,
|
395 |
+
element_centroids: mark_centres
|
396 |
+
};
|
397 |
+
}
|
src/proxy_lite/cli.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from proxy_lite import Runner, RunnerConfig
|
8 |
+
from proxy_lite.logger import logger
|
9 |
+
|
10 |
+
|
11 |
+
def update_config_from_env(config: RunnerConfig) -> RunnerConfig:
|
12 |
+
if os.getenv("PROXY_LITE_API_BASE"):
|
13 |
+
config.solver.client.api_base = os.getenv("PROXY_LITE_API_BASE")
|
14 |
+
if os.getenv("PROXY_LITE_MODEL"):
|
15 |
+
config.solver.client.model_id = os.getenv("PROXY_LITE_MODEL")
|
16 |
+
return config
|
17 |
+
|
18 |
+
|
19 |
+
def do_command(args):
|
20 |
+
do_text = " ".join(args.task)
|
21 |
+
logger.info("🤖 Let me help you with that...")
|
22 |
+
# Take default config from YAML
|
23 |
+
config = RunnerConfig.from_yaml(args.config)
|
24 |
+
# Update config from environment variables
|
25 |
+
config = update_config_from_env(config)
|
26 |
+
# Update config from command-line arguments
|
27 |
+
if args.api_base:
|
28 |
+
config.solver.client.api_base = args.api_base
|
29 |
+
if args.model:
|
30 |
+
config.solver.client.model_id = args.model
|
31 |
+
if args.homepage:
|
32 |
+
config.homepage = args.homepage
|
33 |
+
if args.viewport_width:
|
34 |
+
config.viewport_width = args.viewport_width
|
35 |
+
if args.viewport_height:
|
36 |
+
config.viewport_height = args.viewport_height
|
37 |
+
o = Runner(config=config)
|
38 |
+
asyncio.run(o.run(do_text))
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
parser = argparse.ArgumentParser(description="Proxy-Lite")
|
43 |
+
parser.add_argument(
|
44 |
+
"task",
|
45 |
+
type=str,
|
46 |
+
help="The task you want to accomplish",
|
47 |
+
nargs="*",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--model",
|
51 |
+
type=Optional[str],
|
52 |
+
default=None,
|
53 |
+
help="The model to use.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--api_base",
|
57 |
+
type=Optional[str],
|
58 |
+
default=None,
|
59 |
+
help="The API base URL to use.",
|
60 |
+
)
|
61 |
+
# New option for setting a homepage URL:
|
62 |
+
parser.add_argument(
|
63 |
+
"--homepage",
|
64 |
+
type=Optional[str],
|
65 |
+
default=None,
|
66 |
+
help="The homepage URL to use.",
|
67 |
+
)
|
68 |
+
# New viewport controls:
|
69 |
+
parser.add_argument(
|
70 |
+
"--viewport-width",
|
71 |
+
type=Optional[int],
|
72 |
+
default=None,
|
73 |
+
help="Viewport width in pixels.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--viewport-height",
|
77 |
+
type=Optional[int],
|
78 |
+
default=None,
|
79 |
+
help="Viewport height in pixels.",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--config",
|
83 |
+
type=Path,
|
84 |
+
default=Path(__file__).parent / "configs/default.yaml",
|
85 |
+
help="Path to config file (default: configs/default.yaml)",
|
86 |
+
)
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
do_command(args)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
src/proxy_lite/client.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from functools import cached_property
|
4 |
+
from typing import ClassVar, Literal, Optional, Union
|
5 |
+
|
6 |
+
import httpx
|
7 |
+
from httpx import Limits, Timeout
|
8 |
+
from openai import AsyncOpenAI
|
9 |
+
from openai.types.chat.chat_completion import (
|
10 |
+
ChatCompletion,
|
11 |
+
)
|
12 |
+
from pydantic import BaseModel
|
13 |
+
|
14 |
+
from proxy_lite.history import MessageHistory
|
15 |
+
from proxy_lite.logger import logger
|
16 |
+
from proxy_lite.serializer import (
|
17 |
+
BaseSerializer,
|
18 |
+
OpenAISerializer,
|
19 |
+
)
|
20 |
+
from proxy_lite.tools import Tool
|
21 |
+
|
22 |
+
|
23 |
+
class BaseClientConfig(BaseModel):
|
24 |
+
http_timeout: float = 50
|
25 |
+
http_concurrent_connections: int = 50
|
26 |
+
|
27 |
+
|
28 |
+
class BaseClient(BaseModel, ABC):
|
29 |
+
config: BaseClientConfig
|
30 |
+
serializer: ClassVar[BaseSerializer]
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
async def create_completion(
|
34 |
+
self,
|
35 |
+
messages: MessageHistory,
|
36 |
+
temperature: float = 0.7,
|
37 |
+
seed: Optional[int] = None,
|
38 |
+
tools: Optional[list[Tool]] = None,
|
39 |
+
response_format: Optional[type[BaseModel]] = None,
|
40 |
+
) -> ChatCompletion: ...
|
41 |
+
|
42 |
+
"""
|
43 |
+
Create completion from model.
|
44 |
+
Expect subclasses to adapt from various endpoints that will handle
|
45 |
+
requests differently, make sure to raise appropriate warnings.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
ChatCompletion: OpenAI ChatCompletion format for consistency
|
49 |
+
"""
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def create(cls, config: BaseClientConfig) -> "BaseClient":
|
53 |
+
supported_clients = {
|
54 |
+
"openai-azure": OpenAIClient,
|
55 |
+
"convergence": ConvergenceClient,
|
56 |
+
}
|
57 |
+
if config.name not in supported_clients:
|
58 |
+
error_message = f"Unsupported model: {config.name}."
|
59 |
+
raise ValueError(error_message)
|
60 |
+
return supported_clients[config.name](config=config)
|
61 |
+
|
62 |
+
@property
|
63 |
+
def http_client(self) -> httpx.AsyncClient:
|
64 |
+
return httpx.AsyncClient(
|
65 |
+
timeout=Timeout(self.config.http_timeout),
|
66 |
+
limits=Limits(
|
67 |
+
max_connections=self.config.http_concurrent_connections,
|
68 |
+
max_keepalive_connections=self.config.http_concurrent_connections,
|
69 |
+
),
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
class OpenAIClientConfig(BaseClientConfig):
|
74 |
+
name: Literal["openai"] = "openai"
|
75 |
+
model_id: str = "gpt-4o"
|
76 |
+
api_key: str = os.environ["OPENAI_API_KEY"]
|
77 |
+
|
78 |
+
|
79 |
+
class OpenAIClient(BaseClient):
|
80 |
+
config: OpenAIClientConfig
|
81 |
+
serializer: ClassVar[OpenAISerializer] = OpenAISerializer()
|
82 |
+
|
83 |
+
@cached_property
|
84 |
+
def external_client(self) -> AsyncOpenAI:
|
85 |
+
return AsyncOpenAI(
|
86 |
+
api_key=self.config.api_key,
|
87 |
+
http_client=self.http_client,
|
88 |
+
)
|
89 |
+
|
90 |
+
async def create_completion(
|
91 |
+
self,
|
92 |
+
messages: MessageHistory,
|
93 |
+
temperature: float = 0.7,
|
94 |
+
seed: Optional[int] = None,
|
95 |
+
tools: Optional[list[Tool]] = None,
|
96 |
+
response_format: Optional[type[BaseModel]] = None,
|
97 |
+
) -> ChatCompletion:
|
98 |
+
base_params = {
|
99 |
+
"model": self.config.model_id,
|
100 |
+
"messages": self.serializer.serialize_messages(messages),
|
101 |
+
"temperature": temperature,
|
102 |
+
}
|
103 |
+
optional_params = {
|
104 |
+
"seed": seed,
|
105 |
+
"tools": self.serializer.serialize_tools(tools) if tools else None,
|
106 |
+
"tool_choice": "required" if tools else None,
|
107 |
+
"response_format": {"type": "json_object"} if response_format else {"type": "text"},
|
108 |
+
}
|
109 |
+
base_params.update({k: v for k, v in optional_params.items() if v is not None})
|
110 |
+
return await self.external_client.chat.completions.create(**base_params)
|
111 |
+
|
112 |
+
|
113 |
+
class ConvergenceClientConfig(BaseClientConfig):
|
114 |
+
name: Literal["convergence"] = "convergence"
|
115 |
+
model_id: str = "convergence-ai/proxy-lite-7b"
|
116 |
+
api_base: str = "http://localhost:8000/v1"
|
117 |
+
api_key: str = "none"
|
118 |
+
|
119 |
+
|
120 |
+
class ConvergenceClient(OpenAIClient):
|
121 |
+
config: ConvergenceClientConfig
|
122 |
+
serializer: ClassVar[OpenAISerializer] = OpenAISerializer()
|
123 |
+
_model_validated: bool = False
|
124 |
+
|
125 |
+
async def _validate_model(self) -> None:
|
126 |
+
try:
|
127 |
+
await self.external_client.beta.chat.completions.parse(
|
128 |
+
model=self.config.model_id,
|
129 |
+
messages=[{"role": "user", "content": "Hello"}],
|
130 |
+
)
|
131 |
+
self._model_validated = True
|
132 |
+
logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error retrieving model: {e}")
|
135 |
+
raise e
|
136 |
+
|
137 |
+
@cached_property
|
138 |
+
def external_client(self) -> AsyncOpenAI:
|
139 |
+
return AsyncOpenAI(
|
140 |
+
base_url=self.config.api_base,
|
141 |
+
http_client=self.http_client,
|
142 |
+
)
|
143 |
+
|
144 |
+
async def create_completion(
|
145 |
+
self,
|
146 |
+
messages: MessageHistory,
|
147 |
+
temperature: float = 0.7,
|
148 |
+
seed: Optional[int] = None,
|
149 |
+
tools: Optional[list[Tool]] = None,
|
150 |
+
response_format: Optional[type[BaseModel]] = None,
|
151 |
+
) -> ChatCompletion:
|
152 |
+
if not self._model_validated:
|
153 |
+
await self._validate_model()
|
154 |
+
base_params = {
|
155 |
+
"model": self.config.model_id,
|
156 |
+
"messages": self.serializer.serialize_messages(messages),
|
157 |
+
"temperature": temperature,
|
158 |
+
}
|
159 |
+
optional_params = {
|
160 |
+
"seed": seed,
|
161 |
+
"tools": self.serializer.serialize_tools(tools) if tools else None,
|
162 |
+
"tool_choice": "auto" if tools else None, # vLLM does not support "required"
|
163 |
+
"response_format": response_format if response_format else {"type": "text"},
|
164 |
+
}
|
165 |
+
base_params.update({k: v for k, v in optional_params.items() if v is not None})
|
166 |
+
return await self.external_client.chat.completions.create(**base_params)
|
167 |
+
|
168 |
+
|
169 |
+
ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig]
|
170 |
+
ClientTypes = Union[OpenAIClient, ConvergenceClient]
|
src/proxy_lite/configs/default.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
environment:
|
2 |
+
name: webbrowser
|
3 |
+
annotate_image: true
|
4 |
+
screenshot_delay: 2.0
|
5 |
+
viewport_width: 1280
|
6 |
+
viewport_height: 1920
|
7 |
+
include_poi_text: true
|
8 |
+
headless: false
|
9 |
+
homepage: https://www.google.co.uk
|
10 |
+
solver:
|
11 |
+
name: simple
|
12 |
+
agent:
|
13 |
+
name: proxy_lite
|
14 |
+
client:
|
15 |
+
name: convergence
|
16 |
+
model_id: convergence-ai/subset-distill-tools-7b-15-02-2025
|
17 |
+
api_base: http://slurm1-a3nodeset-4-1:8002/v1
|
18 |
+
local_view: true
|
19 |
+
task_timeout: 1800
|
20 |
+
verbose: true
|
src/proxy_lite/environments/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from .environment_base import (
|
4 |
+
Action,
|
5 |
+
BaseEnvironment,
|
6 |
+
BaseEnvironmentConfig,
|
7 |
+
Environments,
|
8 |
+
Event,
|
9 |
+
EventType,
|
10 |
+
Observation,
|
11 |
+
)
|
12 |
+
from .webbrowser import (
|
13 |
+
WebBrowserEnvironment,
|
14 |
+
WebBrowserEnvironmentConfig,
|
15 |
+
)
|
16 |
+
|
17 |
+
EnvironmentConfigTypes = Union[*list(Environments._environment_config_registry.values())]
|
18 |
+
EnvironmentTypes = Union[*list(Environments._environment_registry.values())]
|
19 |
+
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"Action",
|
23 |
+
"BaseEnvironment",
|
24 |
+
"BaseEnvironmentConfig",
|
25 |
+
"EnvironmentConfigTypes",
|
26 |
+
"Environments",
|
27 |
+
"Event",
|
28 |
+
"EventType",
|
29 |
+
"Observation",
|
30 |
+
"WebBrowserEnvironment",
|
31 |
+
"WebBrowserEnvironmentConfig",
|
32 |
+
]
|
src/proxy_lite/environments/environment_base.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from enum import Enum
|
5 |
+
from functools import cached_property
|
6 |
+
from typing import Any, Literal, Optional, Self
|
7 |
+
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from proxy_lite.history import ToolCall
|
11 |
+
from proxy_lite.tools import Tool, ToolExecutionResponse
|
12 |
+
|
13 |
+
|
14 |
+
class EventType(str, Enum):
|
15 |
+
OBSERVATION = "observation"
|
16 |
+
ACTION = "action"
|
17 |
+
MESSAGE = "message"
|
18 |
+
|
19 |
+
|
20 |
+
class Event(BaseModel):
|
21 |
+
type: EventType
|
22 |
+
|
23 |
+
|
24 |
+
class State(BaseModel):
|
25 |
+
text: Optional[str] = None
|
26 |
+
image: Optional[str] = None # base64 encoded image
|
27 |
+
html: Optional[str] = None
|
28 |
+
tool_responses: Optional[list[ToolExecutionResponse]] = None
|
29 |
+
|
30 |
+
|
31 |
+
class Observation(Event):
|
32 |
+
type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION
|
33 |
+
state: State
|
34 |
+
terminated: bool
|
35 |
+
reward: Optional[float] = None
|
36 |
+
info: Optional[dict[str, Any]] = None
|
37 |
+
|
38 |
+
|
39 |
+
class Action(Event):
|
40 |
+
type: Literal[EventType.ACTION] = EventType.ACTION
|
41 |
+
text: Optional[str] = None
|
42 |
+
tool_calls: Optional[list[ToolCall]] = None
|
43 |
+
info: Optional[dict[str, Any]] = None
|
44 |
+
|
45 |
+
|
46 |
+
class BaseEnvironmentConfig(BaseModel): ...
|
47 |
+
|
48 |
+
|
49 |
+
class BaseEnvironment(BaseModel, ABC):
|
50 |
+
config: BaseEnvironmentConfig
|
51 |
+
logger: logging.Logger | None = None
|
52 |
+
|
53 |
+
class Config:
|
54 |
+
arbitrary_types_allowed = True
|
55 |
+
|
56 |
+
async def __aenter__(self) -> Self:
|
57 |
+
return self
|
58 |
+
|
59 |
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
60 |
+
pass
|
61 |
+
|
62 |
+
@property
|
63 |
+
@abstractmethod
|
64 |
+
def info_for_user(self) -> str: ...
|
65 |
+
|
66 |
+
@cached_property
|
67 |
+
@abstractmethod
|
68 |
+
def tools(self) -> list[Tool]: ...
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
async def initialise(self) -> Observation: ...
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
async def execute_action(self, action: Action) -> Observation: ...
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
async def observe(self) -> Observation: ...
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ...
|
81 |
+
|
82 |
+
async def execute_tool(self, tool_call: ToolCall) -> None:
|
83 |
+
function = tool_call.function
|
84 |
+
for tool in self.tools:
|
85 |
+
if hasattr(tool, function["name"]):
|
86 |
+
arguments = json.loads(function["arguments"])
|
87 |
+
if type(arguments) == str:
|
88 |
+
arguments = json.loads(arguments)
|
89 |
+
return await getattr(tool, function["name"])(
|
90 |
+
**arguments,
|
91 |
+
)
|
92 |
+
msg = f'No tool function with name "{function["name"]}"'
|
93 |
+
raise ValueError(msg)
|
94 |
+
|
95 |
+
async def get_info(self) -> dict[str, Any]:
|
96 |
+
return {}
|
97 |
+
|
98 |
+
|
99 |
+
class Environments:
|
100 |
+
_environment_registry: dict[str, type[BaseEnvironment]] = {}
|
101 |
+
_environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {}
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def register_environment(cls, name: str):
|
105 |
+
"""
|
106 |
+
Decorator to register an Environment class under a given name.
|
107 |
+
|
108 |
+
Example:
|
109 |
+
@Environments.register_environment("my_environment")
|
110 |
+
class MyEnvironment(BaseEnvironment):
|
111 |
+
...
|
112 |
+
"""
|
113 |
+
|
114 |
+
def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]:
|
115 |
+
cls._environment_registry[name] = env_cls
|
116 |
+
return env_cls
|
117 |
+
|
118 |
+
return decorator
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def register_environment_config(cls, name: str):
|
122 |
+
"""
|
123 |
+
Decorator to register an Environment configuration class under a given name.
|
124 |
+
|
125 |
+
Example:
|
126 |
+
@Environments.register_environment_config("my_environment")
|
127 |
+
class MyEnvironmentConfig(BaseEnvironmentConfig):
|
128 |
+
...
|
129 |
+
"""
|
130 |
+
|
131 |
+
def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]:
|
132 |
+
cls._environment_config_registry[name] = config_cls
|
133 |
+
return config_cls
|
134 |
+
|
135 |
+
return decorator
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def get(cls, name: str) -> type[BaseEnvironment]:
|
139 |
+
"""
|
140 |
+
Retrieve a registered Environment class by its name.
|
141 |
+
|
142 |
+
Raises:
|
143 |
+
ValueError: If no such environment is found.
|
144 |
+
"""
|
145 |
+
try:
|
146 |
+
return cls._environment_registry[name]
|
147 |
+
except KeyError:
|
148 |
+
raise ValueError(f"Environment '{name}' not found.")
|
149 |
+
|
150 |
+
@classmethod
|
151 |
+
def get_config(cls, name: str) -> type[BaseEnvironmentConfig]:
|
152 |
+
"""
|
153 |
+
Retrieve a registered Environment configuration class by its name.
|
154 |
+
|
155 |
+
Raises:
|
156 |
+
ValueError: If no such configuration is found.
|
157 |
+
"""
|
158 |
+
try:
|
159 |
+
return cls._environment_config_registry[name]
|
160 |
+
except KeyError:
|
161 |
+
raise ValueError(f"Environment config for '{name}' not found.")
|
src/proxy_lite/environments/webbrowser.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from functools import cached_property
|
3 |
+
from typing import Any, Literal, Optional, Self
|
4 |
+
|
5 |
+
from proxy_lite.browser.browser import BrowserSession
|
6 |
+
from proxy_lite.environments.environment_base import (
|
7 |
+
Action,
|
8 |
+
BaseEnvironment,
|
9 |
+
BaseEnvironmentConfig,
|
10 |
+
Environments,
|
11 |
+
Observation,
|
12 |
+
State,
|
13 |
+
)
|
14 |
+
from proxy_lite.tools import BrowserTool, Tool, ToolExecutionResponse
|
15 |
+
|
16 |
+
|
17 |
+
@Environments.register_environment_config("webbrowser")
|
18 |
+
class WebBrowserEnvironmentConfig(BaseEnvironmentConfig):
|
19 |
+
name: Literal["webbrowser"] = "webbrowser"
|
20 |
+
homepage: str = "https://google.com"
|
21 |
+
annotate_image: bool = True
|
22 |
+
screenshot_delay: float = 1.0 # seconds
|
23 |
+
include_html: bool = True
|
24 |
+
include_poi_text: bool = True
|
25 |
+
record_pois: bool = True
|
26 |
+
viewport_width: int = 1280
|
27 |
+
viewport_height: int = 720
|
28 |
+
browserbase_timeout: int = 7200
|
29 |
+
headless: bool = True
|
30 |
+
keep_original_image: bool = False
|
31 |
+
|
32 |
+
|
33 |
+
@Environments.register_environment("webbrowser")
|
34 |
+
class WebBrowserEnvironment(BaseEnvironment):
|
35 |
+
config: WebBrowserEnvironmentConfig
|
36 |
+
browser: Optional[BrowserSession] = None
|
37 |
+
cancelled_last_action: bool = False
|
38 |
+
|
39 |
+
class Config:
|
40 |
+
arbitrary_types_allowed = True
|
41 |
+
|
42 |
+
async def __aenter__(self) -> Self:
|
43 |
+
# Initialize the BrowserSession
|
44 |
+
self.browser = self.browser_session(
|
45 |
+
viewport_width=self.config.viewport_width,
|
46 |
+
viewport_height=self.config.viewport_height,
|
47 |
+
headless=self.config.headless,
|
48 |
+
)
|
49 |
+
await self.browser.__aenter__()
|
50 |
+
# Initialize other resources if necessary
|
51 |
+
if self.cookies:
|
52 |
+
await self.browser.context.add_cookies(self.cookies)
|
53 |
+
self.logger.info("🌐 [bold blue]Browser session started.[/]")
|
54 |
+
return self
|
55 |
+
|
56 |
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
57 |
+
# Clean up the BrowserSession
|
58 |
+
await self.browser.__aexit__(exc_type, exc_value, traceback)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def info_for_user(self) -> str:
|
62 |
+
return "This is a web browser environment. You can navigate the web, search the web, and perform actions on the web." # noqa: E501
|
63 |
+
|
64 |
+
@cached_property
|
65 |
+
def tools(self) -> list[Tool]:
|
66 |
+
return [BrowserTool(session=self.browser)]
|
67 |
+
|
68 |
+
@cached_property
|
69 |
+
def browser_session(self) -> type[BrowserSession]:
|
70 |
+
return BrowserSession
|
71 |
+
|
72 |
+
@property
|
73 |
+
def cookies(self) -> list[dict]:
|
74 |
+
return []
|
75 |
+
|
76 |
+
async def initialise(self) -> Observation:
|
77 |
+
await self.browser.goto(self.config.homepage)
|
78 |
+
original_img, annotated_img = await self.browser.screenshot(
|
79 |
+
delay=self.config.screenshot_delay,
|
80 |
+
)
|
81 |
+
|
82 |
+
base64_image = base64.b64encode(annotated_img).decode("utf-8")
|
83 |
+
|
84 |
+
html_content = await self.browser.current_page.content() if self.config.include_html else None
|
85 |
+
|
86 |
+
info = {"url": self.browser.current_url}
|
87 |
+
if self.config.record_pois:
|
88 |
+
info["pois"] = self.browser.pois
|
89 |
+
if self.config.keep_original_image:
|
90 |
+
info["original_image"] = base64.b64encode(original_img).decode("utf-8")
|
91 |
+
|
92 |
+
return Observation(
|
93 |
+
state=State(
|
94 |
+
text=f"URL: {self.browser.current_url}"
|
95 |
+
+ (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
|
96 |
+
image=base64_image,
|
97 |
+
html=html_content,
|
98 |
+
),
|
99 |
+
terminated=False,
|
100 |
+
reward=None,
|
101 |
+
info=info,
|
102 |
+
)
|
103 |
+
|
104 |
+
async def should_perform_action(self) -> bool:
|
105 |
+
# if cancelled last action, run the action without updating POIs
|
106 |
+
if self.cancelled_last_action:
|
107 |
+
self.cancelled_last_action = False
|
108 |
+
return True
|
109 |
+
|
110 |
+
# check for page changes
|
111 |
+
old_points = [tuple(point) for point in self.browser.poi_centroids]
|
112 |
+
await self.browser.update_poi()
|
113 |
+
new_points = [tuple(point) for point in self.browser.poi_centroids]
|
114 |
+
page_changed_mid_action = old_points != new_points
|
115 |
+
|
116 |
+
# record if the last action was cancelled
|
117 |
+
if page_changed_mid_action:
|
118 |
+
self.cancelled_last_action = True
|
119 |
+
return False
|
120 |
+
return True
|
121 |
+
|
122 |
+
async def execute_action(self, action: Action) -> Observation:
|
123 |
+
responses = []
|
124 |
+
cancelled_tools_flag = False
|
125 |
+
if await self.should_perform_action():
|
126 |
+
for tool_call in action.tool_calls:
|
127 |
+
# Perform the chosen action
|
128 |
+
try:
|
129 |
+
tool_response: ToolExecutionResponse = await self.execute_tool(
|
130 |
+
tool_call,
|
131 |
+
)
|
132 |
+
tool_response.id = tool_call.id
|
133 |
+
responses.append(tool_response)
|
134 |
+
except Exception as e: # noqa: PERF203
|
135 |
+
self.logger.warning("🌐 An error occurred taking action: %s", str(e), exc_info=False)
|
136 |
+
tool_response = ToolExecutionResponse(content=str(e), id=tool_call.id)
|
137 |
+
responses.append(tool_response)
|
138 |
+
else:
|
139 |
+
self.logger.warning("🌐 Page changed since last observation, cancelling action.")
|
140 |
+
self.cancelled_last_action = True
|
141 |
+
for tool_call in action.tool_calls:
|
142 |
+
tool_response = ToolExecutionResponse(
|
143 |
+
content="The page changed before the action could be executed, instead of being ran it was cancelled.", # noqa: E501
|
144 |
+
id=tool_call.id,
|
145 |
+
)
|
146 |
+
responses.append(tool_response)
|
147 |
+
cancelled_tools_flag = True
|
148 |
+
original_img, annotated_img = await self.browser.screenshot(
|
149 |
+
delay=self.config.screenshot_delay,
|
150 |
+
)
|
151 |
+
|
152 |
+
base64_image = base64.b64encode(annotated_img).decode("utf-8")
|
153 |
+
|
154 |
+
info = {"url": self.browser.current_url, "cancelled_tools": cancelled_tools_flag}
|
155 |
+
if self.config.record_pois:
|
156 |
+
info["pois"] = self.browser.pois
|
157 |
+
if self.config.keep_original_image:
|
158 |
+
info["original_image"] = base64.b64encode(original_img).decode("utf-8")
|
159 |
+
|
160 |
+
html_content = await self.browser.current_page.content() if self.config.include_html else None
|
161 |
+
return Observation(
|
162 |
+
state=State(
|
163 |
+
text=f"URL: {self.browser.current_url}"
|
164 |
+
+ (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
|
165 |
+
image=base64_image,
|
166 |
+
html=html_content,
|
167 |
+
tool_responses=responses,
|
168 |
+
),
|
169 |
+
terminated=False,
|
170 |
+
reward=None,
|
171 |
+
info=info,
|
172 |
+
)
|
173 |
+
|
174 |
+
async def observe(self) -> Observation:
|
175 |
+
return await self.browser.observe()
|
176 |
+
|
177 |
+
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
|
178 |
+
return {}
|
179 |
+
|
180 |
+
async def get_info(self) -> dict[str, Any]:
|
181 |
+
info = {}
|
182 |
+
return info
|
src/proxy_lite/history.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import base64
|
4 |
+
from collections.abc import Iterator
|
5 |
+
from enum import Enum
|
6 |
+
from typing import Any, Literal, Optional, Set, Union
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
9 |
+
|
10 |
+
|
11 |
+
class MessageLabel(str, Enum):
|
12 |
+
SYSTEM = "system"
|
13 |
+
USER_INPUT = "user_input"
|
14 |
+
SCREENSHOT = "screenshot"
|
15 |
+
AGENT_MODEL_RESPONSE = "agent_model_response"
|
16 |
+
|
17 |
+
|
18 |
+
MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
|
19 |
+
MessageLabel.SCREENSHOT: 1,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class MessageContent(BaseModel):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
class Text(MessageContent):
|
28 |
+
type: Literal["text"] = Field(default="text", init=False)
|
29 |
+
text: str
|
30 |
+
|
31 |
+
|
32 |
+
class ImageUrl(BaseModel):
|
33 |
+
url: str
|
34 |
+
|
35 |
+
|
36 |
+
class Image(MessageContent):
|
37 |
+
type: Literal["image_url"] = Field(default="image_url", init=False)
|
38 |
+
image_url: ImageUrl
|
39 |
+
|
40 |
+
|
41 |
+
class Message(BaseModel):
|
42 |
+
label: Optional[MessageLabel] = None
|
43 |
+
content: list[Union[Text, Image]] = Field(default_factory=list)
|
44 |
+
|
45 |
+
class Config:
|
46 |
+
use_enum_values = True
|
47 |
+
|
48 |
+
@property
|
49 |
+
def images(self) -> list[Image]:
|
50 |
+
return [content for content in self.content if isinstance(content, Image)]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def texts(self) -> list[Text]:
|
54 |
+
return [content for content in self.content if isinstance(content, Text)]
|
55 |
+
|
56 |
+
@property
|
57 |
+
def first_image(self) -> Optional[Image]:
|
58 |
+
return self.images[0] if self.images else None
|
59 |
+
|
60 |
+
@property
|
61 |
+
def first_text(self) -> Optional[Text]:
|
62 |
+
return self.texts[0] if self.texts else None
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return len(self.content)
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def from_media(
|
69 |
+
cls,
|
70 |
+
text: Optional[str] = None,
|
71 |
+
image: Optional[bytes | str] = None,
|
72 |
+
is_base64: bool = False,
|
73 |
+
) -> Message:
|
74 |
+
if text is not None:
|
75 |
+
text = Text(text=text)
|
76 |
+
if image is not None:
|
77 |
+
base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8")
|
78 |
+
data_url = f"data:image/jpeg;base64,{base64_image}"
|
79 |
+
image = Image(image_url=ImageUrl(url=data_url))
|
80 |
+
content = [text, image] if text is not None else [image]
|
81 |
+
else:
|
82 |
+
content = [text]
|
83 |
+
return cls(content=content)
|
84 |
+
|
85 |
+
|
86 |
+
class SystemMessage(Message):
|
87 |
+
role: Literal["system"] = Field(default="system", init=False)
|
88 |
+
|
89 |
+
|
90 |
+
class UserMessage(Message):
|
91 |
+
role: Literal["user"] = Field(default="user", init=False)
|
92 |
+
|
93 |
+
|
94 |
+
class ToolCall(BaseModel):
|
95 |
+
id: str
|
96 |
+
type: str
|
97 |
+
function: dict[str, Any]
|
98 |
+
|
99 |
+
|
100 |
+
class AssistantMessage(Message):
|
101 |
+
role: Literal["assistant"] = Field(default="assistant", init=False)
|
102 |
+
tool_calls: list[ToolCall] = Field(default_factory=list)
|
103 |
+
|
104 |
+
def model_dump(self, **kwargs):
|
105 |
+
data = super().model_dump(**kwargs)
|
106 |
+
if not self.tool_calls:
|
107 |
+
data.pop("tool_calls")
|
108 |
+
return data
|
109 |
+
|
110 |
+
@field_validator("tool_calls", mode="before")
|
111 |
+
@classmethod
|
112 |
+
def ensure_list(cls, v):
|
113 |
+
return [] if v is None else v
|
114 |
+
|
115 |
+
|
116 |
+
class ToolMessage(Message):
|
117 |
+
role: Literal["tool"] = Field(default="tool", init=False)
|
118 |
+
tool_call_id: str
|
119 |
+
|
120 |
+
|
121 |
+
MessageTypes = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
|
122 |
+
MessageAdapter = TypeAdapter(MessageTypes)
|
123 |
+
|
124 |
+
|
125 |
+
class MessageHistory(BaseModel):
|
126 |
+
messages: list[MessageTypes] = Field(default_factory=list)
|
127 |
+
|
128 |
+
def append(self, message: MessageTypes, label: Optional[str] = None):
|
129 |
+
if label is not None:
|
130 |
+
message.label = label
|
131 |
+
self.messages.append(message)
|
132 |
+
|
133 |
+
def pop(self) -> MessageTypes:
|
134 |
+
return self.messages.pop()
|
135 |
+
|
136 |
+
def extend(self, history: MessageHistory):
|
137 |
+
self.messages.extend(history.messages)
|
138 |
+
|
139 |
+
def __reversed__(self):
|
140 |
+
return MessageHistory(messages=self.messages[::-1])
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
return self.messages[index]
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.messages)
|
147 |
+
|
148 |
+
def __iter__(self) -> Iterator[MessageTypes]:
|
149 |
+
return iter(self.messages)
|
150 |
+
|
151 |
+
def to_dict(self, exclude: Set[str] | None = None) -> list[dict]:
|
152 |
+
exclude = exclude or set()
|
153 |
+
return [message.model_dump(exclude=exclude) for message in self.messages]
|
154 |
+
|
155 |
+
def history_view(
|
156 |
+
self,
|
157 |
+
limits: dict = MAX_MESSAGES_FOR_CONTEXT_WINDOW,
|
158 |
+
) -> MessageHistory:
|
159 |
+
"""Context window management.
|
160 |
+
|
161 |
+
Filters messages in reverse order, retaining a limited number of recent screenshots and prompts.
|
162 |
+
"""
|
163 |
+
label_counts = {label: 0 for label in limits}
|
164 |
+
filtered_messages = []
|
165 |
+
for message in reversed(self.messages):
|
166 |
+
if message.label in limits:
|
167 |
+
maximum_count = limits[message.label]
|
168 |
+
if label_counts[message.label] < maximum_count:
|
169 |
+
filtered_messages.append(message)
|
170 |
+
label_counts[message.label] += 1
|
171 |
+
else:
|
172 |
+
filtered_messages.append(message)
|
173 |
+
return MessageHistory(messages=reversed(filtered_messages))
|
174 |
+
|
175 |
+
def __add__(self, other: MessageHistory) -> MessageHistory:
|
176 |
+
new_history = MessageHistory()
|
177 |
+
new_history.extend(self)
|
178 |
+
new_history.extend(other)
|
179 |
+
return new_history
|
180 |
+
|
181 |
+
def __iadd__(self, other: MessageHistory) -> MessageHistory:
|
182 |
+
self.extend(other)
|
183 |
+
return self
|
src/proxy_lite/logger.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
from typing import Literal
|
4 |
+
from uuid import uuid4
|
5 |
+
|
6 |
+
from rich.logging import RichHandler
|
7 |
+
|
8 |
+
|
9 |
+
class StructuredLogger(logging.Logger):
|
10 |
+
def _log(
|
11 |
+
self,
|
12 |
+
level,
|
13 |
+
msg,
|
14 |
+
args,
|
15 |
+
exc_info=None,
|
16 |
+
extra=None,
|
17 |
+
stack_info=False,
|
18 |
+
stacklevel=1,
|
19 |
+
):
|
20 |
+
if extra is None:
|
21 |
+
extra = {}
|
22 |
+
|
23 |
+
json_fields = {
|
24 |
+
"logger_name": self.name,
|
25 |
+
"message": msg % args if args else msg,
|
26 |
+
}
|
27 |
+
|
28 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
29 |
+
if exc_type is not None:
|
30 |
+
json_fields["exception_class"] = exc_type.__name__
|
31 |
+
json_fields["exception_message"] = str(exc_value)
|
32 |
+
|
33 |
+
json_fields.update(extra)
|
34 |
+
super()._log(
|
35 |
+
level,
|
36 |
+
msg,
|
37 |
+
args,
|
38 |
+
exc_info,
|
39 |
+
{"json_fields": json_fields},
|
40 |
+
stack_info,
|
41 |
+
stacklevel + 1,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def create_logger(
|
46 |
+
name: str,
|
47 |
+
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
|
48 |
+
detailed_name: bool = False,
|
49 |
+
) -> logging.Logger:
|
50 |
+
unique_name = f"{name}-{str(uuid4())[:8]}"
|
51 |
+
logger = logging.getLogger(unique_name)
|
52 |
+
logger.setLevel(level)
|
53 |
+
handler = RichHandler(
|
54 |
+
rich_tracebacks=True,
|
55 |
+
markup=True,
|
56 |
+
show_path=False,
|
57 |
+
show_time=False,
|
58 |
+
log_time_format="[%s]",
|
59 |
+
)
|
60 |
+
if detailed_name:
|
61 |
+
handler.setFormatter(logging.Formatter("%(name)s:\n%(message)s\n------"))
|
62 |
+
else:
|
63 |
+
handler.setFormatter(logging.Formatter("%(message)s\n------"))
|
64 |
+
logger.addHandler(handler)
|
65 |
+
logger.propagate = False
|
66 |
+
return logger
|
67 |
+
|
68 |
+
|
69 |
+
# Set StructuredLogger as the default logger class
|
70 |
+
logging.setLoggerClass(StructuredLogger)
|
71 |
+
|
72 |
+
logger = logging.getLogger(__name__)
|
73 |
+
logger.setLevel(logging.INFO)
|
74 |
+
logger.propagate = True
|
75 |
+
handler = RichHandler(
|
76 |
+
rich_tracebacks=True,
|
77 |
+
markup=True,
|
78 |
+
show_path=False,
|
79 |
+
show_time=False,
|
80 |
+
)
|
81 |
+
logger.addHandler(handler)
|
src/proxy_lite/recorder.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import uuid
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Optional, Self
|
10 |
+
|
11 |
+
from pydantic import BaseModel, Field
|
12 |
+
|
13 |
+
from proxy_lite.environments import EnvironmentConfigTypes
|
14 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
15 |
+
from proxy_lite.history import MessageHistory
|
16 |
+
from proxy_lite.solvers import SolverConfigTypes
|
17 |
+
|
18 |
+
|
19 |
+
class Run(BaseModel):
|
20 |
+
run_id: str # uuid.UUID
|
21 |
+
task: str
|
22 |
+
created_at: str # datetime.datetime
|
23 |
+
complete: bool = False
|
24 |
+
terminated_at: str | None = None # datetime.datetime
|
25 |
+
evaluation: dict[str, Any] | None = None
|
26 |
+
history: list[Observation | Action] = Field(default_factory=list)
|
27 |
+
solver_history: MessageHistory | None = None
|
28 |
+
result: str | None = None
|
29 |
+
env_info: dict[str, Any] = Field(default_factory=dict)
|
30 |
+
environment: Optional[EnvironmentConfigTypes] = None
|
31 |
+
solver: Optional[SolverConfigTypes] = None
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def initialise(cls, task: str) -> Self:
|
35 |
+
run_id = str(uuid.uuid4())
|
36 |
+
return cls(
|
37 |
+
run_id=run_id,
|
38 |
+
task=task,
|
39 |
+
created_at=str(datetime.datetime.now(datetime.UTC)),
|
40 |
+
)
|
41 |
+
|
42 |
+
@property
|
43 |
+
def observations(self) -> list[Observation]:
|
44 |
+
return [h for h in self.history if isinstance(h, Observation)]
|
45 |
+
|
46 |
+
@property
|
47 |
+
def actions(self) -> list[Action]:
|
48 |
+
return [h for h in self.history if isinstance(h, Action)]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def last_action(self) -> Action | None:
|
52 |
+
return self.actions[-1] if self.actions else None
|
53 |
+
|
54 |
+
@property
|
55 |
+
def last_observation(self) -> Observation | None:
|
56 |
+
return self.observations[-1] if self.observations else None
|
57 |
+
|
58 |
+
def record(
|
59 |
+
self,
|
60 |
+
observation: Optional[Observation] = None,
|
61 |
+
action: Optional[Action] = None,
|
62 |
+
solver_history: Optional[MessageHistory] = None,
|
63 |
+
) -> None:
|
64 |
+
# expect only one of observation and action to be provided in order to handle ordering
|
65 |
+
if observation and action:
|
66 |
+
raise ValueError("Only one of observation and action can be provided")
|
67 |
+
if observation:
|
68 |
+
self.history.append(observation)
|
69 |
+
if action:
|
70 |
+
self.history.append(action)
|
71 |
+
if solver_history:
|
72 |
+
self.solver_history = solver_history
|
73 |
+
|
74 |
+
def terminate(self) -> None:
|
75 |
+
self.terminated_at = str(datetime.datetime.now(datetime.UTC))
|
76 |
+
|
77 |
+
|
78 |
+
class DataRecorder:
|
79 |
+
def __init__(self, local_folder: str | None = None):
|
80 |
+
self.local_folder = local_folder
|
81 |
+
|
82 |
+
def initialise_run(self, task: str) -> Run:
|
83 |
+
self.local_folder = Path(os.path.abspath(sys.path[0])) / "local_trajectories"
|
84 |
+
os.makedirs(self.local_folder, exist_ok=True)
|
85 |
+
return Run.initialise(task)
|
86 |
+
|
87 |
+
async def terminate(
|
88 |
+
self,
|
89 |
+
run: Run,
|
90 |
+
save: bool = True,
|
91 |
+
) -> None:
|
92 |
+
run.terminate()
|
93 |
+
if save:
|
94 |
+
await self.save(run)
|
95 |
+
|
96 |
+
async def save(self, run: Run) -> None:
|
97 |
+
json_payload = run.model_dump()
|
98 |
+
with open(self.local_folder / f"{run.run_id}.json", "w") as f:
|
99 |
+
json.dump(json_payload, f)
|
src/proxy_lite/runner.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from collections.abc import AsyncIterator
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
from typing import Any, Literal, Self
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from proxy_lite.environments import (
|
11 |
+
Action,
|
12 |
+
BaseEnvironment,
|
13 |
+
EnvironmentConfigTypes,
|
14 |
+
Environments,
|
15 |
+
EventType,
|
16 |
+
Observation,
|
17 |
+
)
|
18 |
+
from proxy_lite.logger import create_logger
|
19 |
+
from proxy_lite.recorder import DataRecorder, Run
|
20 |
+
from proxy_lite.solvers import (
|
21 |
+
BaseSolver,
|
22 |
+
SolverConfigTypes,
|
23 |
+
Solvers,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@asynccontextmanager
|
28 |
+
async def async_timeout(timeout: float, task_name: str = "timeout"):
|
29 |
+
try:
|
30 |
+
async with asyncio.TaskGroup() as tg:
|
31 |
+
|
32 |
+
async def timeout_task():
|
33 |
+
await asyncio.sleep(timeout)
|
34 |
+
raise TimeoutError(
|
35 |
+
f"Operation {task_name} timed out after {timeout} seconds",
|
36 |
+
)
|
37 |
+
|
38 |
+
# Create the timeout task
|
39 |
+
timeout_handle = tg.create_task(timeout_task())
|
40 |
+
|
41 |
+
try:
|
42 |
+
yield
|
43 |
+
finally:
|
44 |
+
timeout_handle.cancel()
|
45 |
+
except* asyncio.TimeoutError as eg:
|
46 |
+
for e in eg.exceptions:
|
47 |
+
raise e
|
48 |
+
except* Exception as eg:
|
49 |
+
for e in eg.exceptions:
|
50 |
+
raise e
|
51 |
+
|
52 |
+
|
53 |
+
class RunnerConfig(BaseModel):
|
54 |
+
environment: EnvironmentConfigTypes
|
55 |
+
solver: SolverConfigTypes
|
56 |
+
|
57 |
+
save_every_step: bool = True
|
58 |
+
max_steps: int = 100
|
59 |
+
action_timeout: float = 60.0
|
60 |
+
environment_timeout: float = 30.0
|
61 |
+
task_timeout: float = 1800.0
|
62 |
+
logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
|
63 |
+
detailed_logger_name: bool = False
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def from_dict(cls, config_dict: dict) -> Self:
|
67 |
+
conf = OmegaConf.create(config_dict)
|
68 |
+
config_dict = OmegaConf.to_container(conf, resolve=True)
|
69 |
+
return cls(**config_dict)
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def from_yaml(cls, yaml_path: str) -> Self:
|
73 |
+
conf = OmegaConf.load(yaml_path)
|
74 |
+
config_dict = OmegaConf.to_container(conf, resolve=True)
|
75 |
+
return cls(**config_dict)
|
76 |
+
|
77 |
+
|
78 |
+
class Runner(BaseModel):
|
79 |
+
config: RunnerConfig
|
80 |
+
recorder: DataRecorder | None = None
|
81 |
+
environment: type[BaseEnvironment] | None = None
|
82 |
+
solver: type[BaseSolver] | None = None
|
83 |
+
logger: logging.Logger | None = None
|
84 |
+
_run: Run | None = None
|
85 |
+
|
86 |
+
class Config:
|
87 |
+
arbitrary_types_allowed = True
|
88 |
+
|
89 |
+
def model_post_init(self, __context: Any) -> None:
|
90 |
+
super().model_post_init(__context)
|
91 |
+
self.environment = Environments.get(self.config.environment.name)
|
92 |
+
self.solver = Solvers.get(self.config.solver.name)
|
93 |
+
self.recorder = DataRecorder()
|
94 |
+
self.logger = create_logger(
|
95 |
+
name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])",
|
96 |
+
level=self.config.logger_level,
|
97 |
+
detailed_name=self.config.detailed_logger_name,
|
98 |
+
)
|
99 |
+
|
100 |
+
async def run_generator(self, task: str) -> AsyncIterator[Run]:
|
101 |
+
async with (
|
102 |
+
async_timeout(self.config.task_timeout, "Task"),
|
103 |
+
):
|
104 |
+
if self.config.logger_level is not None:
|
105 |
+
self.logger.setLevel(self.config.logger_level)
|
106 |
+
run = self.recorder.initialise_run(task)
|
107 |
+
run.environment = self.config.environment
|
108 |
+
run.solver = self.config.solver
|
109 |
+
self.logger.debug(f"Run intialised: {run.run_id}")
|
110 |
+
event_queue = asyncio.Queue()
|
111 |
+
async with (
|
112 |
+
self.environment(
|
113 |
+
config=self.config.environment,
|
114 |
+
logger=self.logger,
|
115 |
+
) as environment,
|
116 |
+
self.solver(config=self.config.solver, logger=self.logger) as solver,
|
117 |
+
):
|
118 |
+
run.env_info = await environment.get_info()
|
119 |
+
await solver.initialise(
|
120 |
+
task,
|
121 |
+
environment.tools,
|
122 |
+
environment.info_for_user,
|
123 |
+
)
|
124 |
+
self.logger.debug("Solver initialised.")
|
125 |
+
run.solver_history = solver.history
|
126 |
+
observation: Observation = await environment.initialise()
|
127 |
+
await event_queue.put(observation)
|
128 |
+
self.logger.debug("Environment initialised.")
|
129 |
+
step_count = 0
|
130 |
+
while step_count < self.config.max_steps:
|
131 |
+
event = await event_queue.get()
|
132 |
+
self.logger.debug(f"🤖 [bold purple]Processing event:[/] {event.type}")
|
133 |
+
match event.type:
|
134 |
+
case EventType.OBSERVATION:
|
135 |
+
observation: Observation = event
|
136 |
+
run.record(
|
137 |
+
observation=observation,
|
138 |
+
solver_history=solver.history,
|
139 |
+
)
|
140 |
+
async with async_timeout(
|
141 |
+
self.config.action_timeout,
|
142 |
+
"Action decision",
|
143 |
+
):
|
144 |
+
action: Action = await solver.act(observation)
|
145 |
+
await event_queue.put(action)
|
146 |
+
case EventType.ACTION:
|
147 |
+
action: Action = event
|
148 |
+
self.logger.debug(f"Tool calls: {action.tool_calls}")
|
149 |
+
run.record(action=action, solver_history=solver.history)
|
150 |
+
run.complete = await solver.is_complete(observation)
|
151 |
+
if self.config.save_every_step:
|
152 |
+
await self.recorder.save(run)
|
153 |
+
if run.complete:
|
154 |
+
run.result = action.text
|
155 |
+
self.logger.info(f"🤖 [bold purple]Task complete.[/] ✨ \n{run.result}")
|
156 |
+
break
|
157 |
+
async with async_timeout(
|
158 |
+
self.config.environment_timeout,
|
159 |
+
"Environment response",
|
160 |
+
):
|
161 |
+
observation: Observation = await environment.execute_action(action)
|
162 |
+
step_count += 1
|
163 |
+
await event_queue.put(observation)
|
164 |
+
yield run
|
165 |
+
if not run.complete:
|
166 |
+
self.logger.warning("🤖 [bold purple]Ran out of steps!")
|
167 |
+
await self.recorder.terminate(run, save=True)
|
168 |
+
yield run
|
169 |
+
|
170 |
+
async def run(self, task: str) -> Run:
|
171 |
+
async for run in self.run_generator(task): # noqa: B007
|
172 |
+
self._run = run
|
173 |
+
return run
|
174 |
+
|
175 |
+
def run_concurrent(self, tasks: list[str]) -> list[Run]:
|
176 |
+
async def gather_runs():
|
177 |
+
return await asyncio.gather(
|
178 |
+
*[self.run(task) for task in tasks],
|
179 |
+
return_exceptions=True,
|
180 |
+
)
|
181 |
+
|
182 |
+
return asyncio.run(gather_runs())
|
183 |
+
|
184 |
+
@property
|
185 |
+
def complete(self) -> bool:
|
186 |
+
if self._run is None:
|
187 |
+
raise RuntimeError("Run not initialised")
|
188 |
+
return self._run.complete
|
189 |
+
|
190 |
+
@property
|
191 |
+
def run_id(self) -> str:
|
192 |
+
if self._run is None:
|
193 |
+
raise RuntimeError("Run not initialised")
|
194 |
+
return self._run.run_id
|
195 |
+
|
196 |
+
@property
|
197 |
+
def run_result(self) -> str:
|
198 |
+
if self._run is None:
|
199 |
+
raise RuntimeError("Run not initialised")
|
200 |
+
return self._run.result
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
from proxy_lite.logger import logger
|
205 |
+
|
206 |
+
config = RunnerConfig.from_dict(
|
207 |
+
{
|
208 |
+
"environment": {
|
209 |
+
"name": "webbrowser",
|
210 |
+
"homepage": "https://www.google.com",
|
211 |
+
"viewport_width": 1920,
|
212 |
+
"viewport_height": 1080,
|
213 |
+
"screenshot_delay": 1,
|
214 |
+
"headless": False,
|
215 |
+
},
|
216 |
+
"solver": {
|
217 |
+
"name": "simple",
|
218 |
+
"agent": {
|
219 |
+
"name": "proxy_lite",
|
220 |
+
"client": {
|
221 |
+
"name": "convergence",
|
222 |
+
"model_id": "convergence-ai/all-distill-tools-7b-16-02-2025",
|
223 |
+
"api_base": "http://slurm1-a3nodeset-4-1:8009/v1",
|
224 |
+
# # "model_id": "Qwen/Qwen2.5-VL-3B-Instruct",
|
225 |
+
# # "api_base": "http://0.0.0.0:8000/v1",
|
226 |
+
},
|
227 |
+
},
|
228 |
+
},
|
229 |
+
"max_steps": 150,
|
230 |
+
"action_timeout": 1800,
|
231 |
+
"environment_timeout": 1800,
|
232 |
+
"task_timeout": 18000,
|
233 |
+
"logger_level": "DEBUG",
|
234 |
+
},
|
235 |
+
)
|
236 |
+
logger.info(f"🤖 [bold purple]Config:[/] {config}")
|
237 |
+
|
238 |
+
runner = Runner(config=config)
|
239 |
+
result = asyncio.run(
|
240 |
+
runner.run(
|
241 |
+
"Tell me the tesla stock price" # noqa: E501
|
242 |
+
)
|
243 |
+
)
|
244 |
+
print(runner.run_result)
|
245 |
+
print(runner.complete)
|
src/proxy_lite/serializer.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from proxy_lite.history import MessageAdapter, MessageHistory
|
7 |
+
from proxy_lite.tools import Tool
|
8 |
+
|
9 |
+
|
10 |
+
class BaseSerializer(BaseModel, ABC):
|
11 |
+
"""Base class for serializers.
|
12 |
+
|
13 |
+
Serializers are responsible for converting between the internal MessageHistory/Tool
|
14 |
+
objects and the external API format. Deserialise is not always possible, so raise
|
15 |
+
appropriate warnings.
|
16 |
+
"""
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def serialize_messages(self, message_history: MessageHistory) -> list[dict]: ...
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def deserialize_messages(self, data: list[dict]) -> MessageHistory: ...
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def serialize_tools(self, tools: list[Tool]) -> list[dict]: ...
|
26 |
+
|
27 |
+
|
28 |
+
class OpenAISerializer(BaseSerializer):
|
29 |
+
def serialize_messages(self, message_history: MessageHistory) -> list[dict]:
|
30 |
+
return message_history.to_dict(exclude={"label"})
|
31 |
+
|
32 |
+
def deserialize_messages(self, data: list[dict]) -> MessageHistory:
|
33 |
+
return MessageHistory(
|
34 |
+
messages=[MessageAdapter.validate_python(message) for message in data],
|
35 |
+
)
|
36 |
+
|
37 |
+
def serialize_tools(self, tools: list[Tool]) -> list[dict]:
|
38 |
+
tool_schemas = [[{"type": "function", "function": schema} for schema in tool.schema] for tool in tools]
|
39 |
+
return list(itertools.chain.from_iterable(tool_schemas))
|
src/proxy_lite/solvers/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from .simple_solver import SimpleSolver, SimpleSolverConfig
|
6 |
+
from .solver_base import BaseSolver, BaseSolverConfig, Solvers
|
7 |
+
from .structured_solver import StructuredSolver, StructuredSolverConfig
|
8 |
+
|
9 |
+
SolverConfigTypes = Union[*Solvers._solver_config_registry.values()]
|
10 |
+
SolverTypes = Union[*Solvers._solver_registry.values()]
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"BaseSolver",
|
15 |
+
"BaseSolverConfig",
|
16 |
+
"SimpleSolver",
|
17 |
+
"SimpleSolverConfig",
|
18 |
+
"StructuredSolver",
|
19 |
+
"StructuredSolverConfig",
|
20 |
+
"SolverConfigTypes",
|
21 |
+
"SolverTypes",
|
22 |
+
"Solvers",
|
23 |
+
]
|
src/proxy_lite/solvers/simple_solver.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: E501
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from functools import cached_property
|
5 |
+
from typing import Literal, Optional
|
6 |
+
|
7 |
+
from proxy_lite.agents import AgentConfigTypes, Agents, BaseAgent
|
8 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
9 |
+
from proxy_lite.history import (
|
10 |
+
MessageHistory,
|
11 |
+
MessageLabel,
|
12 |
+
SystemMessage,
|
13 |
+
)
|
14 |
+
from proxy_lite.solvers.solver_base import BaseSolver, BaseSolverConfig, Solvers
|
15 |
+
from proxy_lite.tools import ReturnValueTool, Tool
|
16 |
+
|
17 |
+
WEB_TOOL_TURN = """The action has been attempted in the computer."""
|
18 |
+
|
19 |
+
|
20 |
+
@Solvers.register_solver_config("simple")
|
21 |
+
class SimpleSolverConfig(BaseSolverConfig):
|
22 |
+
name: Literal["simple"] = "simple"
|
23 |
+
agent: AgentConfigTypes
|
24 |
+
|
25 |
+
|
26 |
+
@Solvers.register_solver("simple")
|
27 |
+
class SimpleSolver(BaseSolver):
|
28 |
+
task: Optional[str] = None
|
29 |
+
complete: bool = False
|
30 |
+
|
31 |
+
@cached_property
|
32 |
+
def tools(self) -> list[Tool]:
|
33 |
+
return [ReturnValueTool()] + self.env_tools
|
34 |
+
|
35 |
+
@cached_property
|
36 |
+
def agent(self) -> BaseAgent:
|
37 |
+
self.logger.debug(f"Tools: {self.tools}")
|
38 |
+
return Agents.get(self.config.agent.name)(
|
39 |
+
config=self.config.agent,
|
40 |
+
env_tools=self.tools,
|
41 |
+
)
|
42 |
+
|
43 |
+
@property
|
44 |
+
def history(self) -> MessageHistory:
|
45 |
+
return MessageHistory(
|
46 |
+
messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
|
47 |
+
)
|
48 |
+
|
49 |
+
async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
|
50 |
+
self.env_tools = env_tools
|
51 |
+
self.task = task
|
52 |
+
self.agent.receive_user_message(
|
53 |
+
text=f"Task: {task}",
|
54 |
+
label=MessageLabel.USER_INPUT,
|
55 |
+
)
|
56 |
+
self.logger.debug(f"Initialised with task: {task}")
|
57 |
+
|
58 |
+
async def act(self, observation: Observation) -> Action:
|
59 |
+
self.agent.receive_user_message(
|
60 |
+
image=observation.state.image,
|
61 |
+
text=observation.state.text,
|
62 |
+
label=MessageLabel.SCREENSHOT,
|
63 |
+
is_base64=True,
|
64 |
+
)
|
65 |
+
|
66 |
+
message = await self.agent.generate_output(use_tool=True)
|
67 |
+
|
68 |
+
self.logger.debug(f"Assistant message generated: {message}")
|
69 |
+
|
70 |
+
# check tool calls for return_value
|
71 |
+
if any(tool_call.function["name"] == "return_value" for tool_call in message.tool_calls):
|
72 |
+
self.complete = True
|
73 |
+
arguments = json.loads(message.tool_calls[0].function["arguments"])
|
74 |
+
if isinstance(arguments, str):
|
75 |
+
arguments = json.loads(arguments)
|
76 |
+
return_value = arguments["value"]
|
77 |
+
return Action(tool_calls=[], text=return_value)
|
78 |
+
|
79 |
+
text_content = message.content[0].text
|
80 |
+
|
81 |
+
observation_match = re.search(r"<observation>(.*?)</observation>", text_content, re.DOTALL)
|
82 |
+
observation_content = observation_match.group(1).strip() if observation_match else ""
|
83 |
+
|
84 |
+
self.logger.info(f"🌐 [bold blue]Observation:[/] {observation_content}")
|
85 |
+
|
86 |
+
# Extract text between thinking tags if present
|
87 |
+
thinking_match = re.search(r"<thinking>(.*?)</thinking>", text_content, re.DOTALL)
|
88 |
+
thinking_content = thinking_match.group(1).strip() if thinking_match else text_content
|
89 |
+
|
90 |
+
self.logger.info(f"🤖 [bold purple]Action:[/] {thinking_content}")
|
91 |
+
|
92 |
+
return Action(tool_calls=message.tool_calls, text=text_content)
|
93 |
+
|
94 |
+
async def is_complete(self, observation: Observation) -> bool:
|
95 |
+
env_terminated = observation.terminated
|
96 |
+
return self.complete or env_terminated
|
src/proxy_lite/solvers/solver_base.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from functools import cached_property
|
4 |
+
from typing import Optional, Self, Type, cast
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
9 |
+
from proxy_lite.tools import Tool
|
10 |
+
|
11 |
+
|
12 |
+
class BaseSolverConfig(BaseModel):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class BaseSolver(BaseModel, ABC):
|
17 |
+
task: Optional[str] = None
|
18 |
+
env_tools: list[Tool] = Field(default_factory=list)
|
19 |
+
config: BaseSolverConfig
|
20 |
+
logger: logging.Logger | None = None
|
21 |
+
|
22 |
+
class Config:
|
23 |
+
arbitrary_types_allowed = True
|
24 |
+
|
25 |
+
async def __aenter__(self) -> Self:
|
26 |
+
return self
|
27 |
+
|
28 |
+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
29 |
+
pass
|
30 |
+
|
31 |
+
@cached_property
|
32 |
+
@abstractmethod
|
33 |
+
def tools(self) -> list[Tool]: ...
|
34 |
+
|
35 |
+
@abstractmethod
|
36 |
+
async def initialise(
|
37 |
+
self,
|
38 |
+
task: str,
|
39 |
+
env_tools: list[Tool],
|
40 |
+
env_info: str,
|
41 |
+
) -> None:
|
42 |
+
"""
|
43 |
+
Initialise the solution with the given task.
|
44 |
+
"""
|
45 |
+
...
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
async def act(self, observation: Observation) -> Action:
|
49 |
+
"""
|
50 |
+
Return an action for interacting with the environment.
|
51 |
+
"""
|
52 |
+
...
|
53 |
+
|
54 |
+
async def is_complete(self, observation: Observation) -> bool:
|
55 |
+
"""
|
56 |
+
Return a boolean indicating if the task is complete.
|
57 |
+
"""
|
58 |
+
return observation.terminated
|
59 |
+
|
60 |
+
|
61 |
+
class Solvers:
|
62 |
+
_solver_registry: dict[str, type[BaseSolver]] = {}
|
63 |
+
_solver_config_registry: dict[str, type[BaseSolverConfig]] = {}
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def register_solver(cls, name: str):
|
67 |
+
"""
|
68 |
+
Decorator to register a Solver class under a given name.
|
69 |
+
|
70 |
+
Example:
|
71 |
+
@Solvers.register_solver("my_solver")
|
72 |
+
class MySolver(BaseSolver):
|
73 |
+
...
|
74 |
+
"""
|
75 |
+
|
76 |
+
def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]:
|
77 |
+
cls._solver_registry[name] = solver_cls
|
78 |
+
return solver_cls
|
79 |
+
|
80 |
+
return decorator
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def register_solver_config(cls, name: str):
|
84 |
+
"""
|
85 |
+
Decorator to register a Solver configuration class under a given name.
|
86 |
+
|
87 |
+
Example:
|
88 |
+
@Solvers.register_solver_config("my_solver")
|
89 |
+
class MySolverConfig(BaseSolverConfig):
|
90 |
+
...
|
91 |
+
"""
|
92 |
+
|
93 |
+
def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]:
|
94 |
+
cls._solver_config_registry[name] = config_cls
|
95 |
+
return config_cls
|
96 |
+
|
97 |
+
return decorator
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def get(cls, name: str) -> type[BaseSolver]:
|
101 |
+
"""
|
102 |
+
Retrieve a registered Solver class by its name.
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
ValueError: If no such solver is found.
|
106 |
+
"""
|
107 |
+
try:
|
108 |
+
return cast(Type[BaseSolver], cls._solver_registry[name])
|
109 |
+
except KeyError:
|
110 |
+
raise ValueError(f"Solver '{name}' not found.")
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def get_config(cls, name: str) -> type[BaseSolverConfig]:
|
114 |
+
"""
|
115 |
+
Retrieve a registered Solver configuration class by its name.
|
116 |
+
|
117 |
+
Raises:
|
118 |
+
ValueError: If no such config is found.
|
119 |
+
"""
|
120 |
+
try:
|
121 |
+
return cast(Type[BaseSolverConfig], cls._solver_config_registry[name])
|
122 |
+
except KeyError:
|
123 |
+
raise ValueError(f"Solver config for '{name}' not found.")
|
src/proxy_lite/solvers/structured_solver.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: E501
|
2 |
+
|
3 |
+
from functools import cached_property
|
4 |
+
from typing import Literal, Optional
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from proxy_lite.agents import AgentConfigTypes, Agents, BaseAgent
|
9 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
10 |
+
from proxy_lite.history import (
|
11 |
+
MessageHistory,
|
12 |
+
MessageLabel,
|
13 |
+
SystemMessage,
|
14 |
+
)
|
15 |
+
from proxy_lite.tools import Tool
|
16 |
+
|
17 |
+
from .solver_base import BaseSolver, BaseSolverConfig, Solvers
|
18 |
+
|
19 |
+
WEB_TOOL_TURN = """The browser action has been attempted. Please double check if the action was successful."""
|
20 |
+
PLAN_USER_PROMPT = "First create a high-level plan to help solve the task on the web."
|
21 |
+
ACTION_PROMPT = """Now take the most-promising next action in the browser.
|
22 |
+
|
23 |
+
Only refer to the latest web elements from the latest screenshot.
|
24 |
+
|
25 |
+
Using mark ids from older turns will lead to errors as they are no longer valid.
|
26 |
+
|
27 |
+
Only interact with elements visible on the current webpage. Do not make up numbers or elements."""
|
28 |
+
REASONING_PROMPT = """You will now follow these steps.
|
29 |
+
|
30 |
+
1. **Make observations about the state of the webpage**:
|
31 |
+
- Consider the previous screenshot, your attempted previous action, and the current screenshot.
|
32 |
+
- Describe any changes you observe, and try to determine if the previous action succeeded.
|
33 |
+
- For example, if a form is being filled out, check whether the correct information is now displayed.
|
34 |
+
|
35 |
+
2. **Write down any helpful facts you have gathered**:
|
36 |
+
- Describe any useful information on the webpage that might be helpful for completing the task.
|
37 |
+
- For example, if you are viewing a document, you may wish to note down any information you want to refer back to later.
|
38 |
+
|
39 |
+
3. **Reason about the system's status**:
|
40 |
+
- Have you fully completed the task?
|
41 |
+
|
42 |
+
4. **Select one of the following statuses**:
|
43 |
+
- "complete": if the task has been completed.
|
44 |
+
- "continue": if you are ready to continue without information or help.
|
45 |
+
|
46 |
+
5. **Reason through next steps**:
|
47 |
+
- If the status is "continue", write down your reasoning for the next action you will take. You can only take one action at a time.
|
48 |
+
- If the status is not "continue", return an empty string.
|
49 |
+
|
50 |
+
6. **Write a message to the user**:
|
51 |
+
- If the status is "complete", write a message to the user. If they asked a question in the task, make sure the answer is here. Otherwise, just provide other useful information about how the task went or if there was a problem in completing it.
|
52 |
+
- If the status is not "complete", set this to an empty string.
|
53 |
+
|
54 |
+
Tips:
|
55 |
+
- If you have already provided a response, don't provide it again.
|
56 |
+
- If you notice you are repeating previous actions, you're likely stuck. Try something different."""
|
57 |
+
|
58 |
+
|
59 |
+
class Reflection(BaseModel):
|
60 |
+
observation: str = Field(
|
61 |
+
...,
|
62 |
+
description="Observation of the current browser state, including an assessment on the success of the last action (previous actions and observations are often wrong).",
|
63 |
+
)
|
64 |
+
fact_updates: list[str] = Field(
|
65 |
+
"",
|
66 |
+
description="List of new information relevant to the task that was found on the page, ignore input fields holding content you wrote.",
|
67 |
+
)
|
68 |
+
status_reasoning: str = Field(
|
69 |
+
...,
|
70 |
+
description="Reasoning about the current state of the task.",
|
71 |
+
)
|
72 |
+
status: Literal["complete", "continue"] = Field(
|
73 |
+
...,
|
74 |
+
description="Choose a system status based on your status reasoning.",
|
75 |
+
)
|
76 |
+
next_step_reasoning: str = Field(
|
77 |
+
...,
|
78 |
+
description='If status is "continue", reason through the next action you will be taking (do not repeat actions over and over). Otherwise set to "".',
|
79 |
+
)
|
80 |
+
ending_message: str = Field(
|
81 |
+
...,
|
82 |
+
description="If status is 'complete', write a message to the user. If they asked a question in the task, make sure the answer is here. Otherwise, just provide other useful information about how the task went or if there was a problem in completing it. If status is 'continue', set to ''.",
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
@Solvers.register_solver_config("structured")
|
87 |
+
class StructuredSolverConfig(BaseSolverConfig):
|
88 |
+
name: Literal["structured"] = "structured"
|
89 |
+
agent: AgentConfigTypes
|
90 |
+
start_with_plan: bool = True
|
91 |
+
|
92 |
+
|
93 |
+
@Solvers.register_solver("structured")
|
94 |
+
class StructuredSolver(BaseSolver):
|
95 |
+
task: Optional[str] = None
|
96 |
+
complete: bool = False
|
97 |
+
|
98 |
+
@cached_property
|
99 |
+
def tools(self) -> list[Tool]:
|
100 |
+
return self.env_tools
|
101 |
+
|
102 |
+
@cached_property
|
103 |
+
def local_tools(self) -> list[Tool]:
|
104 |
+
if self.sandbox:
|
105 |
+
return self.sandbox.tools
|
106 |
+
return []
|
107 |
+
|
108 |
+
@cached_property
|
109 |
+
def agent(self) -> BaseAgent:
|
110 |
+
self.logger.debug(f"Tools: {self.tools}")
|
111 |
+
return Agents.get(self.config.agent.name)(
|
112 |
+
config=self.config.agent,
|
113 |
+
env_tools=self.tools,
|
114 |
+
)
|
115 |
+
|
116 |
+
@property
|
117 |
+
def history(self) -> MessageHistory:
|
118 |
+
return MessageHistory(
|
119 |
+
messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
|
120 |
+
)
|
121 |
+
|
122 |
+
async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
|
123 |
+
self.env_tools = env_tools
|
124 |
+
self.agent.receive_user_message(
|
125 |
+
text=env_info,
|
126 |
+
label=MessageLabel.USER_INPUT,
|
127 |
+
)
|
128 |
+
self.task = task
|
129 |
+
self.agent.receive_user_message(
|
130 |
+
text=f"Task: {task}",
|
131 |
+
label=MessageLabel.USER_INPUT,
|
132 |
+
)
|
133 |
+
if self.config.start_with_plan:
|
134 |
+
self.agent.receive_user_message(text=PLAN_USER_PROMPT, label=MessageLabel.PLAN)
|
135 |
+
await self.agent.generate_output(use_tool=False)
|
136 |
+
|
137 |
+
async def act(self, observation: Observation) -> Action:
|
138 |
+
if observation.state.tool_responses:
|
139 |
+
for tool_response in observation.state.tool_responses:
|
140 |
+
await self.agent.receive_tool_message(
|
141 |
+
text=f"{WEB_TOOL_TURN}\n{tool_response.content}",
|
142 |
+
tool_id=tool_response.id,
|
143 |
+
label=MessageLabel.TOOL_RESULT_INDUCTION,
|
144 |
+
)
|
145 |
+
|
146 |
+
self.agent.receive_user_message(
|
147 |
+
image=observation.state.image,
|
148 |
+
text=observation.state.text,
|
149 |
+
label=MessageLabel.SCREENSHOT,
|
150 |
+
is_base64=True,
|
151 |
+
)
|
152 |
+
|
153 |
+
self.agent.receive_user_message(
|
154 |
+
text=REASONING_PROMPT,
|
155 |
+
label=MessageLabel.REASONING_INDUCTION,
|
156 |
+
)
|
157 |
+
|
158 |
+
message = await self.agent.generate_structured_output(model=Reflection)
|
159 |
+
self.logger.info(f"🌐 [bold blue]Observation:[/] {message.observation}")
|
160 |
+
|
161 |
+
if message.status == "complete":
|
162 |
+
self.complete = True
|
163 |
+
return Action(tool_calls=[], text=message.ending_message)
|
164 |
+
|
165 |
+
next_step = message.next_step_reasoning
|
166 |
+
|
167 |
+
self.agent.receive_user_message(
|
168 |
+
text=ACTION_PROMPT,
|
169 |
+
label=MessageLabel.ACTION,
|
170 |
+
is_base64=True,
|
171 |
+
)
|
172 |
+
message = await self.agent.generate_output(use_tool=True)
|
173 |
+
|
174 |
+
return Action(tool_calls=message.tool_calls, text=next_step)
|
175 |
+
|
176 |
+
async def is_complete(self, observation: Observation) -> bool:
|
177 |
+
env_terminated = observation.terminated
|
178 |
+
return self.complete or env_terminated
|
src/proxy_lite/tools/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .browser_tool import BrowserTool
|
2 |
+
from .return_tool import ReturnValueTool
|
3 |
+
from .tool_base import Tool, ToolExecutionResponse, attach_param_schema
|
4 |
+
|
5 |
+
__all__ = ["BrowserTool", "ReturnValueTool", "Tool", "ToolExecutionResponse", "attach_param_schema"]
|
src/proxy_lite/tools/browser_tool.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from contextlib import AsyncExitStack
|
3 |
+
from typing import List, Literal, Optional
|
4 |
+
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
|
7 |
+
from proxy_lite.browser.browser import BrowserSession
|
8 |
+
from proxy_lite.logger import logger
|
9 |
+
|
10 |
+
from .tool_base import Tool, ToolExecutionResponse, attach_param_schema
|
11 |
+
|
12 |
+
SELF_CONTAINED_TAGS = [
|
13 |
+
# many of these are non-interactive but keeping them anyway
|
14 |
+
"area",
|
15 |
+
"base",
|
16 |
+
"br",
|
17 |
+
"col",
|
18 |
+
"embed",
|
19 |
+
"hr",
|
20 |
+
"img",
|
21 |
+
"input",
|
22 |
+
"link",
|
23 |
+
"meta",
|
24 |
+
"param",
|
25 |
+
"source",
|
26 |
+
"track",
|
27 |
+
"wbr",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
def element_as_text(
|
32 |
+
mark_id: int,
|
33 |
+
tag: Optional[str] = None,
|
34 |
+
text: Optional[str] = None,
|
35 |
+
**raw_attributes,
|
36 |
+
) -> str:
|
37 |
+
"""Return a text representation of all elements on the page"""
|
38 |
+
attributes = []
|
39 |
+
for k, v in raw_attributes.items():
|
40 |
+
if v is None:
|
41 |
+
continue
|
42 |
+
if isinstance(v, bool):
|
43 |
+
if v:
|
44 |
+
attributes.append(k)
|
45 |
+
# we ignore False bool attributes
|
46 |
+
else:
|
47 |
+
v = str(v)
|
48 |
+
if len(v) > 2500:
|
49 |
+
v = v[: 2500 - 1] + "…"
|
50 |
+
attributes.append(f'{k}="{v}"')
|
51 |
+
attributes = " ".join(attributes)
|
52 |
+
attributes = (" " + attributes).rstrip()
|
53 |
+
tag = tag.lower()
|
54 |
+
if text is None:
|
55 |
+
text = ""
|
56 |
+
if len(text) > 2500:
|
57 |
+
text = text[: 2500 - 1] + "…"
|
58 |
+
if tag in SELF_CONTAINED_TAGS:
|
59 |
+
if text:
|
60 |
+
logger.warning(
|
61 |
+
f"Got self-contained element '{tag}' which contained text '{text}'.",
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
return f"<{tag} id={mark_id}{attributes}/>"
|
65 |
+
return f"<{tag} id={mark_id}{attributes}>{text}</{tag}>"
|
66 |
+
|
67 |
+
|
68 |
+
class GotoParams(BaseModel):
|
69 |
+
url: str = Field(..., description="The web address to visit. Must be a valid URL.")
|
70 |
+
|
71 |
+
|
72 |
+
class GoogleSearchParams(BaseModel):
|
73 |
+
query_plan: str = Field(
|
74 |
+
...,
|
75 |
+
description="Plan out the query you will make. Re-write queries in a way that will yield the best results.",
|
76 |
+
)
|
77 |
+
query: str = Field(..., description="The Google search to perform.")
|
78 |
+
|
79 |
+
|
80 |
+
class ClickParams(BaseModel):
|
81 |
+
mark_id: int = Field(..., description="Element Mark ID.")
|
82 |
+
|
83 |
+
|
84 |
+
class TypeEntry(BaseModel):
|
85 |
+
mark_id: int = Field(..., description="Element Mark ID.")
|
86 |
+
content: str = Field(..., description="The text to type into the element.")
|
87 |
+
|
88 |
+
|
89 |
+
class TypeParams(BaseModel):
|
90 |
+
entries: List[TypeEntry] = Field(
|
91 |
+
...,
|
92 |
+
description="A list of elements and contents to type.",
|
93 |
+
)
|
94 |
+
submit: bool = Field(
|
95 |
+
...,
|
96 |
+
description='Whether to press the "Enter" key after typing in the last entry.',
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
class ScrollParams(BaseModel):
|
101 |
+
direction: Literal["up", "down", "left", "right"] = Field(
|
102 |
+
...,
|
103 |
+
description='Direction to scroll. Must be one of "up", "down", "left" or "right".',
|
104 |
+
)
|
105 |
+
mark_id: int = Field(
|
106 |
+
...,
|
107 |
+
description="What to scroll. Use -1 to scroll the whole page otherwise give the mark ID of an element that is `scrollable`.", # noqa: E501
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
class BackParams(BaseModel):
|
112 |
+
pass
|
113 |
+
|
114 |
+
|
115 |
+
class WaitParams(BaseModel):
|
116 |
+
pass
|
117 |
+
|
118 |
+
|
119 |
+
class ReloadParams(BaseModel):
|
120 |
+
pass
|
121 |
+
|
122 |
+
|
123 |
+
class DoNothingParams(BaseModel):
|
124 |
+
pass
|
125 |
+
|
126 |
+
|
127 |
+
class BrowserTool(Tool):
|
128 |
+
def __init__(self, session: BrowserSession) -> None:
|
129 |
+
super().__init__()
|
130 |
+
self.browser = session
|
131 |
+
|
132 |
+
async def __aenter__(self):
|
133 |
+
self._exit_stack = AsyncExitStack()
|
134 |
+
await self._exit_stack.enter_async_context(self.browser)
|
135 |
+
return self
|
136 |
+
|
137 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
138 |
+
await self._exit_stack.aclose()
|
139 |
+
|
140 |
+
@property
|
141 |
+
def poi_text(self) -> str:
|
142 |
+
# Get all points of interest on the page as text
|
143 |
+
texts = [element_as_text(mark_id=i, **element) for i, element in enumerate(self.browser.poi_elements)]
|
144 |
+
# Return formatted text of points of interest on page
|
145 |
+
return "\n".join([txt for txt in texts if txt])
|
146 |
+
|
147 |
+
@attach_param_schema(GotoParams)
|
148 |
+
async def goto(self, url: str) -> ToolExecutionResponse:
|
149 |
+
"""Go directly to a specific web url. Specify the exact URL."""
|
150 |
+
await self.browser.goto(url)
|
151 |
+
return ToolExecutionResponse()
|
152 |
+
|
153 |
+
@attach_param_schema(GoogleSearchParams)
|
154 |
+
async def google_search(self, query_plan: str, query: str) -> ToolExecutionResponse:
|
155 |
+
"""Perform a generic web search using Google.
|
156 |
+
Results may not be relevant. If you see poor results, you can try another query.
|
157 |
+
"""
|
158 |
+
url = f"https://www.google.com/search?q={query}"
|
159 |
+
await self.browser.goto(url)
|
160 |
+
return ToolExecutionResponse()
|
161 |
+
|
162 |
+
@attach_param_schema(ClickParams)
|
163 |
+
async def click(self, mark_id: int) -> ToolExecutionResponse:
|
164 |
+
"""Click on an element of the page."""
|
165 |
+
await self.browser.click(mark_id=mark_id)
|
166 |
+
return ToolExecutionResponse()
|
167 |
+
|
168 |
+
@attach_param_schema(TypeParams)
|
169 |
+
async def type(self, entries: List[dict], submit: bool) -> ToolExecutionResponse:
|
170 |
+
"""Type text.
|
171 |
+
You can type into one or more elements.
|
172 |
+
Note that the text inside an element is cleared before typing.
|
173 |
+
"""
|
174 |
+
for i, entry_dict in enumerate(entries):
|
175 |
+
entry = TypeEntry(**entry_dict)
|
176 |
+
last_entry = i == len(entries) - 1
|
177 |
+
old_poi_positions = [tuple(point) for point in self.browser.poi_centroids]
|
178 |
+
await self.browser.enter_text(
|
179 |
+
mark_id=entry.mark_id,
|
180 |
+
text=entry.content,
|
181 |
+
submit=submit and last_entry,
|
182 |
+
)
|
183 |
+
await self.browser.update_poi()
|
184 |
+
new_poi_positions = [tuple(point) for point in self.browser.poi_centroids]
|
185 |
+
if not last_entry and old_poi_positions != new_poi_positions:
|
186 |
+
logger.error(
|
187 |
+
"POI positions changed mid-typing, cancelling future type entries.",
|
188 |
+
)
|
189 |
+
break
|
190 |
+
return ToolExecutionResponse()
|
191 |
+
|
192 |
+
@attach_param_schema(ScrollParams)
|
193 |
+
async def scroll(self, direction: str, mark_id: int) -> ToolExecutionResponse:
|
194 |
+
"""Scroll the page (or a scrollable element) up, down, left or right."""
|
195 |
+
if mark_id == -1:
|
196 |
+
mark_id = None
|
197 |
+
await self.browser.scroll(direction=direction, mark_id=mark_id)
|
198 |
+
return ToolExecutionResponse()
|
199 |
+
|
200 |
+
@attach_param_schema(BackParams)
|
201 |
+
async def back(self) -> ToolExecutionResponse:
|
202 |
+
"""Go back to the previous page."""
|
203 |
+
await self.browser.go_back()
|
204 |
+
return ToolExecutionResponse()
|
205 |
+
|
206 |
+
@attach_param_schema(WaitParams)
|
207 |
+
async def wait(self) -> ToolExecutionResponse:
|
208 |
+
"""Wait three seconds. Useful when the page appears to still be loading, or if there are any unfinished webpage processes.""" # noqa: E501
|
209 |
+
await asyncio.sleep(3)
|
210 |
+
return ToolExecutionResponse()
|
211 |
+
|
212 |
+
@attach_param_schema(ReloadParams)
|
213 |
+
async def reload(self) -> ToolExecutionResponse:
|
214 |
+
"""Reload the current page. Useful when the page seems unresponsive, broken, outdated, or if you want to reset the page to its initial state.""" # noqa: E501
|
215 |
+
await self.browser.reload()
|
216 |
+
return ToolExecutionResponse()
|
217 |
+
|
218 |
+
@attach_param_schema(DoNothingParams)
|
219 |
+
async def do_nothing_tool(self) -> ToolExecutionResponse:
|
220 |
+
"""Do nothing. Use this if you have no need for the browser at this time."""
|
221 |
+
return ToolExecutionResponse()
|
src/proxy_lite/tools/return_tool.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
|
3 |
+
from proxy_lite.tools.tool_base import Tool, attach_param_schema
|
4 |
+
|
5 |
+
|
6 |
+
class ReturnValueParams(BaseModel):
|
7 |
+
value: str = Field(description="The value to return to the user.")
|
8 |
+
|
9 |
+
|
10 |
+
class ReturnValueTool(Tool):
|
11 |
+
def __init__(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
@attach_param_schema(ReturnValueParams)
|
15 |
+
def return_value(self, value: str):
|
16 |
+
"""Return a value to the user. Use this tool when you have finished the task in order to provide any information the user has requested."""
|
17 |
+
print(value)
|
src/proxy_lite/tools/tool_base.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from functools import cached_property, wraps
|
3 |
+
from typing import Any, Callable, Optional
|
4 |
+
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
|
7 |
+
|
8 |
+
class Tool:
|
9 |
+
async def __aenter__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
13 |
+
pass
|
14 |
+
|
15 |
+
@cached_property
|
16 |
+
def schema(self) -> list[dict[str, Any]]:
|
17 |
+
schema = []
|
18 |
+
for name, method in self.__class__.__dict__.items():
|
19 |
+
# If function is not callable and isn't decorated using attach_param_schema
|
20 |
+
if not isinstance(method, Callable) or not hasattr(method, "param_model"):
|
21 |
+
continue
|
22 |
+
|
23 |
+
docstring = inspect.getdoc(method)
|
24 |
+
if not docstring:
|
25 |
+
raise ValueError(f"The tool function '{name}' is missing a docstring.")
|
26 |
+
# Handle multi-line docstirngs
|
27 |
+
description = " ".join(line.strip() for line in docstring.split("\n"))
|
28 |
+
|
29 |
+
tool_json = {
|
30 |
+
"name": name,
|
31 |
+
"description": description,
|
32 |
+
"parameters": method.param_model.model_json_schema(),
|
33 |
+
}
|
34 |
+
schema.append(tool_json)
|
35 |
+
return schema
|
36 |
+
|
37 |
+
|
38 |
+
def attach_param_schema(param_model: type[BaseModel]):
|
39 |
+
def decorator(func: Callable) -> Callable:
|
40 |
+
@wraps(func)
|
41 |
+
def wrapper(self, **kwargs):
|
42 |
+
# Throw an error if there's a mismatch between the function parameters and pydantic model's fields.
|
43 |
+
validated_params = param_model(**kwargs)
|
44 |
+
return func(self, **validated_params.model_dump())
|
45 |
+
|
46 |
+
wrapper.param_model = param_model
|
47 |
+
return wrapper
|
48 |
+
|
49 |
+
return decorator
|
50 |
+
|
51 |
+
|
52 |
+
class ToolExecutionResponse(BaseModel):
|
53 |
+
content: Optional[str] = None
|
54 |
+
id: Optional[str] = None
|