Spaces:
Running
Running
Trisha Tomy
commited on
Commit
Β·
c9803a3
0
Parent(s):
Stretch goal experimentation
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .DS_Store +0 -0
- proxy-lite-demo-v2/.gitattributes +35 -0
- proxy-lite-demo-v2/.gitignore +177 -0
- proxy-lite-demo-v2/.idea/.gitignore +14 -0
- proxy-lite-demo-v2/.idea/libraries/my_test_package.xml +9 -0
- proxy-lite-demo-v2/.idea/misc.xml +6 -0
- proxy-lite-demo-v2/.idea/modules.xml +8 -0
- proxy-lite-demo-v2/.idea/proxy-lite-demo-v2.iml +9 -0
- proxy-lite-demo-v2/.idea/vcs.xml +6 -0
- proxy-lite-demo-v2/CODEOWNERS +1 -0
- proxy-lite-demo-v2/Dockerfile +59 -0
- proxy-lite-demo-v2/LICENSE +3 -0
- proxy-lite-demo-v2/Makefile +11 -0
- proxy-lite-demo-v2/Procfile +1 -0
- proxy-lite-demo-v2/README.md +10 -0
- proxy-lite-demo-v2/app.py +350 -0
- proxy-lite-demo-v2/pyproject.toml +65 -0
- proxy-lite-demo-v2/requirements.txt +6 -0
- proxy-lite-demo-v2/src/proxy_lite/__init__.py +3 -0
- proxy-lite-demo-v2/src/proxy_lite/agents/__init__.py +18 -0
- proxy-lite-demo-v2/src/proxy_lite/agents/agent_base.py +238 -0
- proxy-lite-demo-v2/src/proxy_lite/agents/proxy_lite_agent.py +61 -0
- proxy-lite-demo-v2/src/proxy_lite/app.py +239 -0
- proxy-lite-demo-v2/src/proxy_lite/browser/__init__.py +0 -0
- proxy-lite-demo-v2/src/proxy_lite/browser/add_custom_select.js +123 -0
- proxy-lite-demo-v2/src/proxy_lite/browser/bounding_boxes.py +210 -0
- proxy-lite-demo-v2/src/proxy_lite/browser/browser.py +508 -0
- proxy-lite-demo-v2/src/proxy_lite/browser/find_pois.js +397 -0
- proxy-lite-demo-v2/src/proxy_lite/cli.py +112 -0
- proxy-lite-demo-v2/src/proxy_lite/client.py +405 -0
- proxy-lite-demo-v2/src/proxy_lite/configs/default.yaml +23 -0
- proxy-lite-demo-v2/src/proxy_lite/environments/__init__.py +32 -0
- proxy-lite-demo-v2/src/proxy_lite/environments/environment_base.py +161 -0
- proxy-lite-demo-v2/src/proxy_lite/environments/webbrowser.py +205 -0
- proxy-lite-demo-v2/src/proxy_lite/gif_maker.py +122 -0
- proxy-lite-demo-v2/src/proxy_lite/history.py +183 -0
- proxy-lite-demo-v2/src/proxy_lite/logger.py +92 -0
- proxy-lite-demo-v2/src/proxy_lite/recorder.py +103 -0
- proxy-lite-demo-v2/src/proxy_lite/runner.py +240 -0
- proxy-lite-demo-v2/src/proxy_lite/serializer.py +39 -0
- proxy-lite-demo-v2/src/proxy_lite/solvers/__init__.py +20 -0
- proxy-lite-demo-v2/src/proxy_lite/solvers/simple_solver.py +117 -0
- proxy-lite-demo-v2/src/proxy_lite/solvers/solver_base.py +123 -0
- proxy-lite-demo-v2/src/proxy_lite/tools/__init__.py +5 -0
- proxy-lite-demo-v2/src/proxy_lite/tools/browser_tool.py +374 -0
- proxy-lite-demo-v2/src/proxy_lite/tools/return_tool.py +17 -0
- proxy-lite-demo-v2/src/proxy_lite/tools/tool_base.py +54 -0
- proxy-lite-demo-v2/test_tool_calling.py +65 -0
- proxy-lite-demo-v2/uv.lock +0 -0
- proxy-lite-work/.forceignore +12 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
proxy-lite-demo-v2/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
proxy-lite-demo-v2/.gitignore
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# PyPI configuration file
|
171 |
+
.pypirc
|
172 |
+
|
173 |
+
logs/
|
174 |
+
local_trajectories/
|
175 |
+
screenshots/
|
176 |
+
gifs/
|
177 |
+
.DS_Store
|
proxy-lite-demo-v2/.idea/.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Environment-dependent path to Maven home directory
|
7 |
+
/mavenHomeManager.xml
|
8 |
+
# Datasource local storage ignored files
|
9 |
+
/dataSources/
|
10 |
+
/dataSources.local.xml
|
11 |
+
# Core Dev Booster ignored files
|
12 |
+
/compile.flag
|
13 |
+
/coreModuleDependants.csv
|
14 |
+
/.mavenCleaned
|
proxy-lite-demo-v2/.idea/libraries/my_test_package.xml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="libraryTable">
|
2 |
+
<library name="my-test-package">
|
3 |
+
<CLASSES>
|
4 |
+
<root url="jar://$PROJECT_DIR$/venv/lib/python3.13/site-packages/pkg_resources/tests/data/my-test-package-zip/my-test-package.zip!/" />
|
5 |
+
</CLASSES>
|
6 |
+
<JAVADOC />
|
7 |
+
<SOURCES />
|
8 |
+
</library>
|
9 |
+
</component>
|
proxy-lite-demo-v2/.idea/misc.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" languageLevel="JDK_24" default="true" project-jdk-name="24" project-jdk-type="JavaSDK">
|
4 |
+
<output url="file://$PROJECT_DIR$/out" />
|
5 |
+
</component>
|
6 |
+
</project>
|
proxy-lite-demo-v2/.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/proxy-lite-demo-v2.iml" filepath="$PROJECT_DIR$/.idea/proxy-lite-demo-v2.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
proxy-lite-demo-v2/.idea/proxy-lite-demo-v2.iml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="JAVA_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
4 |
+
<exclude-output />
|
5 |
+
<content url="file://$MODULE_DIR$" />
|
6 |
+
<orderEntry type="inheritedJdk" />
|
7 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
8 |
+
</component>
|
9 |
+
</module>
|
proxy-lite-demo-v2/.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
proxy-lite-demo-v2/CODEOWNERS
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
* @aptoul @Fraser-Greenlee @XanderJC
|
proxy-lite-demo-v2/Dockerfile
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Playwright Docker image for Python, matching your Playwright version and Debian base
|
2 |
+
FROM mcr.microsoft.com/playwright/python:v1.53.0-noble
|
3 |
+
|
4 |
+
# Set the working directory inside the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# The official Playwright image comes with most necessary system dependencies,
|
8 |
+
# so we only need to add git for proxy-lite and potentially any very specific missing libs.
|
9 |
+
# Removing the extensive list as it's largely redundant with the Playwright base image.
|
10 |
+
RUN apt-get update && apt-get install -y \
|
11 |
+
git \
|
12 |
+
xvfb \
|
13 |
+
# Clean up apt caches to reduce image size
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Copy common Python dependencies first (needed for pip installs)
|
17 |
+
COPY requirements.txt .
|
18 |
+
|
19 |
+
# Copy your Flask application code (app.py) and other project files.
|
20 |
+
COPY . .
|
21 |
+
|
22 |
+
# --- START: Directory permission workaround ---
|
23 |
+
# Create the directory proxy-lite's recorder insists on writing to
|
24 |
+
# and grant full permissions. This addresses the PermissionError.
|
25 |
+
# This line creates the directory *directly* under /app, which is now the correct path
|
26 |
+
RUN mkdir -p /app/local_trajectories \
|
27 |
+
&& chmod -R 777 /app/local_trajectories
|
28 |
+
# --- END: Directory permission workaround ---
|
29 |
+
|
30 |
+
# Upgrade pip, setuptools, and wheel for a robust Python build environment.
|
31 |
+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel
|
32 |
+
|
33 |
+
# Install your local proxy-lite package in editable mode.
|
34 |
+
RUN pip install --no-cache-dir --no-input -e .
|
35 |
+
|
36 |
+
# Install the rest of the Python dependencies from requirements.txt
|
37 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
38 |
+
|
39 |
+
|
40 |
+
# Set environment variables required for Playwright at runtime
|
41 |
+
ENV DISPLAY=:99
|
42 |
+
ENV XDG_RUNTIME_DIR=/tmp
|
43 |
+
# Removed PLAYWRIGHT_BROWSERS_PATH and PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD
|
44 |
+
# as the official Playwright image manages these internally, defaulting to /ms-playwright.
|
45 |
+
|
46 |
+
# --- Debugging: Check Playwright version and browser installation (moved AFTER install in the original setup) ---
|
47 |
+
# Now checking the default Playwright browser installation path /ms-playwright
|
48 |
+
RUN echo "--- Checking Playwright Version (from base image) ---"
|
49 |
+
RUN python -m playwright --version
|
50 |
+
RUN echo "--- Listing Playwright Browser Cache (Recursive, from base image) ---"
|
51 |
+
RUN ls -alR /ms-playwright/
|
52 |
+
RUN echo "-----------------------------------"
|
53 |
+
# --- End Debugging ---
|
54 |
+
|
55 |
+
# Expose the port your Flask app will listen on. Hugging Face Spaces requires 7860.
|
56 |
+
EXPOSE 7860
|
57 |
+
|
58 |
+
# Define the command to run your Flask application using Gunicorn for production.
|
59 |
+
CMD exec gunicorn --bind 0.0.0.0:7860 --workers 2 --worker-class gevent app:app --timeout 300
|
proxy-lite-demo-v2/LICENSE
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Creative Commons Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
proxy-lite-demo-v2/Makefile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: proxy
|
2 |
+
|
3 |
+
proxy:
|
4 |
+
pip install uv
|
5 |
+
uv venv --python 3.11 --python-preference managed
|
6 |
+
uv sync
|
7 |
+
uv pip install -e .
|
8 |
+
playwright install
|
9 |
+
|
10 |
+
app:
|
11 |
+
streamlit run src/proxy_lite/app.py
|
proxy-lite-demo-v2/Procfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
web: gunicorn --bind 0.0.0.0:7860 --workers 2 --worker-class gevent app:app --timeout 300
|
proxy-lite-demo-v2/README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Proxy Lite Demo For Setup
|
3 |
+
emoji: π»
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: gray
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
proxy-lite-demo-v2/app.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gevent.monkey
|
2 |
+
gevent.monkey.patch_all(asyncio=True) # Keep this at the very top
|
3 |
+
|
4 |
+
import asyncio
|
5 |
+
from flask import Flask, request, jsonify
|
6 |
+
from proxy_lite import Runner, RunnerConfig
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
from datetime import datetime
|
10 |
+
from playwright.async_api import async_playwright, TimeoutError as PlaywrightTimeoutError
|
11 |
+
|
12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
app = Flask(__name__)
|
16 |
+
|
17 |
+
_runner = None
|
18 |
+
|
19 |
+
async def perform_hardcoded_salesforce_login_and_get_cookies(username, password, login_url, target_url):
|
20 |
+
logger.info("Attempting hardcoded Salesforce login with Playwright to obtain cookies...")
|
21 |
+
async with async_playwright() as p:
|
22 |
+
browser = await p.chromium.launch(headless=False, args=["--no-sandbox", "--disable-setuid-sandbox"])
|
23 |
+
context = await browser.new_context()
|
24 |
+
page = await context.new_page()
|
25 |
+
|
26 |
+
try:
|
27 |
+
await page.goto(login_url, wait_until="domcontentloaded", timeout=60000)
|
28 |
+
logger.info(f"Playwright: Navigated to Salesforce login page: {page.url}")
|
29 |
+
|
30 |
+
await page.fill("#username", username)
|
31 |
+
await page.fill("#password", password)
|
32 |
+
await page.click("#Login")
|
33 |
+
logger.info("Playwright: Filled credentials and clicked Login. Waiting for post-login state...")
|
34 |
+
|
35 |
+
try:
|
36 |
+
await page.wait_for_url(lambda url: "login.salesforce.com" not in url and "unauthorized" not in url.lower(), timeout=60000)
|
37 |
+
logger.info(f"Playwright: Successfully redirected from login page. Current URL: {page.url}")
|
38 |
+
await page.wait_for_selector('button[title="App Launcher"]', timeout=30000)
|
39 |
+
logger.info("Playwright: Main Salesforce Lightning UI (e.g., App Launcher) detected after login.")
|
40 |
+
|
41 |
+
except PlaywrightTimeoutError:
|
42 |
+
logger.error(f"Playwright: Did not detect main UI or expected URL change within timeout after login. Current URL: {page.url}. Login might have failed or stuck on a redirect loop.")
|
43 |
+
raise Exception("Salesforce login redirection failed or main UI not detected.")
|
44 |
+
|
45 |
+
logger.info(f"Playwright: Navigating to target URL: {target_url} to ensure all relevant cookies are captured.")
|
46 |
+
await page.goto(target_url, wait_until="domcontentloaded", timeout=60000)
|
47 |
+
|
48 |
+
try:
|
49 |
+
# Wait for generic Salesforce setup page elements to load
|
50 |
+
await page.wait_for_selector('.setupPage, .slds-page-header, .slds-card, [data-aura-class*="setup"], .forcePageBlockSectionView', timeout=30000)
|
51 |
+
logger.info("Playwright: Detected Salesforce setup page elements loaded successfully.")
|
52 |
+
except PlaywrightTimeoutError:
|
53 |
+
logger.warning("Playwright: Specific setup page elements not found. Trying generic page load check...")
|
54 |
+
try:
|
55 |
+
# Fallback: wait for page to reach network idle state
|
56 |
+
await page.wait_for_load_state("networkidle", timeout=10000)
|
57 |
+
logger.info("Playwright: Page reached network idle state - proceeding with task.")
|
58 |
+
except PlaywrightTimeoutError:
|
59 |
+
logger.info("Playwright: Page load validation timed out, but continuing as page may still be functional.")
|
60 |
+
|
61 |
+
await asyncio.sleep(2)
|
62 |
+
logger.info(f"Playwright: Successfully navigated to and confirmed content on {page.url}")
|
63 |
+
|
64 |
+
cookies = await context.cookies()
|
65 |
+
logger.info(f"Playwright: Extracted {len(cookies)} cookies after successful login and navigation.")
|
66 |
+
return cookies
|
67 |
+
|
68 |
+
except PlaywrightTimeoutError as e:
|
69 |
+
logger.error(f"Playwright login/navigation failed (Timeout): {e}. Current URL: {page.url}")
|
70 |
+
raise
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Playwright login/navigation failed (General Error): {e}. Current URL: {page.url}")
|
73 |
+
raise
|
74 |
+
finally:
|
75 |
+
if browser:
|
76 |
+
await browser.close()
|
77 |
+
|
78 |
+
|
79 |
+
async def initialize_runner_with_cookies(cookies: list, target_url: str):
|
80 |
+
global _runner
|
81 |
+
logger.info("Initializing Proxy-lite Runner with provided cookies...")
|
82 |
+
|
83 |
+
gemini_api_key = os.environ.get("GEMINI_API_KEY")
|
84 |
+
if not gemini_api_key:
|
85 |
+
logger.error("GEMINI_API_KEY environment variable not set. Cannot initialize Runner.")
|
86 |
+
raise ValueError("GEMINI_API_KEY environment variable not set. Please set it as a Space secret.")
|
87 |
+
|
88 |
+
config_dict = {
|
89 |
+
"environment": {
|
90 |
+
"name": "webbrowser",
|
91 |
+
"homepage": "about:blank", # Safe startup, we'll open new tab programmatically
|
92 |
+
"headless": False,
|
93 |
+
"launch_args": ["--no-sandbox", "--disable-setuid-sandbox"],
|
94 |
+
"screenshot_delay": 0.5,
|
95 |
+
"include_html": True,
|
96 |
+
"include_poi_text": True,
|
97 |
+
"record_pois": True,
|
98 |
+
"viewport_width": 1280,
|
99 |
+
"viewport_height": 720,
|
100 |
+
"browserbase_timeout": 7200,
|
101 |
+
"keep_original_image": False,
|
102 |
+
"no_pois_in_image": False,
|
103 |
+
"initial_cookies": cookies
|
104 |
+
},
|
105 |
+
"solver": {
|
106 |
+
"name": "simple",
|
107 |
+
"agent": {
|
108 |
+
"name": "proxy_lite",
|
109 |
+
"client": {
|
110 |
+
"name": "gemini",
|
111 |
+
"model_id": "gemini-2.0-flash-001",
|
112 |
+
"api_key": gemini_api_key,
|
113 |
+
"http_timeout": 50.0,
|
114 |
+
"http_concurrent_connections": 50,
|
115 |
+
},
|
116 |
+
"history_messages_limit": {
|
117 |
+
"screenshot": 1
|
118 |
+
},
|
119 |
+
"history_messages_include": None,
|
120 |
+
}
|
121 |
+
},
|
122 |
+
"environment_timeout": 1800.0,
|
123 |
+
"action_timeout": 1800.0,
|
124 |
+
"task_timeout": 18000.0,
|
125 |
+
"max_steps": 150,
|
126 |
+
"logger_level": "DEBUG",
|
127 |
+
"save_every_step": True,
|
128 |
+
"detailed_logger_name": False
|
129 |
+
}
|
130 |
+
config = RunnerConfig.from_dict(config_dict)
|
131 |
+
|
132 |
+
logger.info(f"DEBUG: app.py - Initializing Proxy-lite Runner with Gemini Flash 2.0 configuration.")
|
133 |
+
_runner = Runner(config=config)
|
134 |
+
logger.info("Proxy-lite Runner initialized successfully with Gemini Flash 2.0 and injected cookies.")
|
135 |
+
return _runner
|
136 |
+
|
137 |
+
|
138 |
+
@app.route('/run_proxy_task', methods=['POST'])
|
139 |
+
async def run_proxy_task_endpoint():
|
140 |
+
data = request.json
|
141 |
+
request_task_instruction = data.get('task')
|
142 |
+
target_url = data.get('url')
|
143 |
+
|
144 |
+
if not request_task_instruction:
|
145 |
+
logger.warning("Received request without 'task' field. Returning 400.")
|
146 |
+
return jsonify({"error": "No 'task' provided in request body"}), 400
|
147 |
+
|
148 |
+
if not target_url:
|
149 |
+
logger.warning("Received request without 'url' field. Returning 400.")
|
150 |
+
return jsonify({"error": "No 'url' provided in request body"}), 400
|
151 |
+
|
152 |
+
logger.info(f"Received user request task: '{request_task_instruction}'")
|
153 |
+
logger.info(f"Target URL: '{target_url}'")
|
154 |
+
|
155 |
+
# Check if this is a Salesforce URL
|
156 |
+
is_salesforce_url = "salesforce.com" in target_url or "force.com" in target_url
|
157 |
+
|
158 |
+
try:
|
159 |
+
if is_salesforce_url:
|
160 |
+
# Salesforce automation - requires login
|
161 |
+
salesforce_username = os.environ.get("SALESFORCE_USERNAME")
|
162 |
+
salesforce_password = os.environ.get("SALESFORCE_PASSWORD")
|
163 |
+
|
164 |
+
if not salesforce_username or not salesforce_password:
|
165 |
+
logger.error("Salesforce credentials (SALESFORCE_USERNAME, SALESFORCE_PASSWORD) environment variables not set.")
|
166 |
+
return jsonify({"error": "Salesforce credentials not configured. Please set SALESFORCE_USERNAME and SALESFORCE_PASSWORD as Space secrets."}), 500
|
167 |
+
|
168 |
+
salesforce_login_url = "https://login.salesforce.com/"
|
169 |
+
logger.info("Executing hardcoded login via Playwright to get session cookies...")
|
170 |
+
session_cookies = await perform_hardcoded_salesforce_login_and_get_cookies(
|
171 |
+
salesforce_username, salesforce_password, salesforce_login_url, target_url
|
172 |
+
)
|
173 |
+
logger.info(f"Successfully obtained {len(session_cookies)} cookies. These will be injected into the agent's browser.")
|
174 |
+
else:
|
175 |
+
# General web browsing - no login required
|
176 |
+
logger.info("Non-Salesforce URL detected. Skipping Salesforce login.")
|
177 |
+
session_cookies = []
|
178 |
+
|
179 |
+
runner = await initialize_runner_with_cookies(session_cookies, target_url)
|
180 |
+
logger.info("Proxy-lite Runner initialized with cookies." if session_cookies else "Proxy-lite Runner initialized for general web browsing.")
|
181 |
+
|
182 |
+
logger.info("Agent will use mandatory new tab tool to bypass loading issues.")
|
183 |
+
|
184 |
+
# MANDATORY new tab navigation task - this is critical to avoid loading issues
|
185 |
+
agent_task = f"""
|
186 |
+
CRITICAL FIRST STEP - MANDATORY:
|
187 |
+
Your VERY FIRST action must be to use the open_new_tab_and_go_to tool to navigate to {target_url}
|
188 |
+
|
189 |
+
DO NOT skip this step. DO NOT use goto. You MUST use: open_new_tab_and_go_to(url='{target_url}')
|
190 |
+
|
191 |
+
This is necessary because direct navigation to this URL gets stuck loading. The new tab approach bypasses this issue.
|
192 |
+
|
193 |
+
STEP 1: Use open_new_tab_and_go_to(url='{target_url}')
|
194 |
+
STEP 2: Wait for the page to be fully loaded (no loading spinners visible)
|
195 |
+
STEP 3: {request_task_instruction}
|
196 |
+
|
197 |
+
CRITICAL WORKFLOW - FOLLOW THESE EXACT STEPS IN SEQUENCE:
|
198 |
+
|
199 |
+
STEP A: Select Permission Set
|
200 |
+
- Use select_option_by_text tool to find and select the target permission set from Available list
|
201 |
+
- Wait for "[ACTION COMPLETED]" response before proceeding
|
202 |
+
|
203 |
+
STEP B: Click Add Button
|
204 |
+
- After successful selection, immediately click the "Add" button to move permission set to Enabled list
|
205 |
+
- Do NOT repeat the selection - proceed directly to Add button
|
206 |
+
|
207 |
+
STEP C: Click Save Button
|
208 |
+
- After clicking Add, immediately click "Save" to persist the changes
|
209 |
+
- After Save, Salesforce redirects to User page indicating SUCCESS
|
210 |
+
|
211 |
+
CRITICAL: Do NOT repeat actions. Each step should happen exactly once in sequence.
|
212 |
+
|
213 |
+
GENERAL INSTRUCTIONS:
|
214 |
+
- You must EXECUTE all actions immediately - do NOT just describe what you plan to do
|
215 |
+
- Do NOT wait for user input or ask "what should I do next?"
|
216 |
+
- Complete the entire task autonomously using the available tools
|
217 |
+
- After completing all steps, use the return_value tool to provide your final response
|
218 |
+
- If you make a plan, IMMEDIATELY execute it step by step using the appropriate tools
|
219 |
+
"""
|
220 |
+
|
221 |
+
logger.info("Executing agent task with mandatory new tab navigation...")
|
222 |
+
result = await runner.run(task=agent_task)
|
223 |
+
|
224 |
+
# Extract the actual result value from the Run object
|
225 |
+
if hasattr(result, 'value') and result.value:
|
226 |
+
task_result = str(result.value)
|
227 |
+
elif hasattr(result, 'result') and result.result:
|
228 |
+
task_result = str(result.result)
|
229 |
+
else:
|
230 |
+
task_result = str(result)
|
231 |
+
|
232 |
+
logger.info(f"Proxy-lite task completed. Output (truncated for log): {task_result[:500]}...")
|
233 |
+
|
234 |
+
# Structure response for LWC integration
|
235 |
+
response = {
|
236 |
+
"status": "success",
|
237 |
+
"message": "Task completed successfully",
|
238 |
+
"data": {
|
239 |
+
"task_result": task_result,
|
240 |
+
"steps_completed": [
|
241 |
+
"Hardcoded Salesforce login completed",
|
242 |
+
"Browser session initialized with cookies",
|
243 |
+
"New tab navigation executed",
|
244 |
+
"Target Salesforce setup page accessed",
|
245 |
+
"Task execution completed successfully"
|
246 |
+
],
|
247 |
+
"environment": {
|
248 |
+
"target_url": target_url,
|
249 |
+
"cookies_count": len(session_cookies),
|
250 |
+
"navigation_method": "new_tab_bypass"
|
251 |
+
}
|
252 |
+
},
|
253 |
+
"timestamp": datetime.now().isoformat(),
|
254 |
+
"task_request": request_task_instruction
|
255 |
+
}
|
256 |
+
|
257 |
+
return jsonify(response)
|
258 |
+
|
259 |
+
except PlaywrightTimeoutError as e:
|
260 |
+
logger.exception(f"Playwright timeout during login/navigation: {e}")
|
261 |
+
error_response = {
|
262 |
+
"status": "error",
|
263 |
+
"error_type": "navigation_timeout",
|
264 |
+
"message": "Page loading timed out during login or navigation",
|
265 |
+
"data": {
|
266 |
+
"error_details": str(e),
|
267 |
+
"suggested_action": "Retry the request - network issues may be temporary",
|
268 |
+
"steps_completed": ["Login attempted", "Navigation failed due to timeout"]
|
269 |
+
},
|
270 |
+
"timestamp": datetime.now().isoformat(),
|
271 |
+
"task_request": request_task_instruction
|
272 |
+
}
|
273 |
+
return jsonify(error_response), 500
|
274 |
+
|
275 |
+
except ValueError as e:
|
276 |
+
logger.exception(f"Configuration error: {e}")
|
277 |
+
error_response = {
|
278 |
+
"status": "error",
|
279 |
+
"error_type": "configuration_error",
|
280 |
+
"message": "System configuration issue",
|
281 |
+
"data": {
|
282 |
+
"error_details": str(e),
|
283 |
+
"suggested_action": "Check environment variables and system configuration",
|
284 |
+
"steps_completed": ["Configuration validation failed"]
|
285 |
+
},
|
286 |
+
"timestamp": datetime.now().isoformat(),
|
287 |
+
"task_request": request_task_instruction
|
288 |
+
}
|
289 |
+
return jsonify(error_response), 500
|
290 |
+
|
291 |
+
except Exception as e:
|
292 |
+
logger.exception(f"Unexpected error processing Salesforce task: {e}")
|
293 |
+
error_response = {
|
294 |
+
"status": "error",
|
295 |
+
"error_type": "unexpected_error",
|
296 |
+
"message": "An unexpected error occurred during task execution",
|
297 |
+
"data": {
|
298 |
+
"error_details": str(e),
|
299 |
+
"error_class": type(e).__name__,
|
300 |
+
"suggested_action": "Check logs for detailed error information and retry",
|
301 |
+
"steps_completed": ["Login attempted", "Error occurred during execution"]
|
302 |
+
},
|
303 |
+
"timestamp": datetime.now().isoformat(),
|
304 |
+
"task_request": request_task_instruction
|
305 |
+
}
|
306 |
+
return jsonify(error_response), 500
|
307 |
+
|
308 |
+
@app.route('/')
|
309 |
+
def root():
|
310 |
+
logger.info("Root endpoint accessed.")
|
311 |
+
return "Proxy-lite API is running. Send POST requests to /run_proxy_task with a 'task' in JSON body."
|
312 |
+
|
313 |
+
@app.route('/health', methods=['GET'])
|
314 |
+
def health_check():
|
315 |
+
"""Health check endpoint for monitoring and debugging"""
|
316 |
+
logger.info("Health check endpoint accessed.")
|
317 |
+
|
318 |
+
# Check environment variables
|
319 |
+
env_status = {
|
320 |
+
"GEMINI_API_KEY": "β" if os.environ.get("GEMINI_API_KEY") else "β",
|
321 |
+
"SALESFORCE_USERNAME": "β" if os.environ.get("SALESFORCE_USERNAME") else "β",
|
322 |
+
"SALESFORCE_PASSWORD": "β" if os.environ.get("SALESFORCE_PASSWORD") else "β"
|
323 |
+
}
|
324 |
+
|
325 |
+
health_response = {
|
326 |
+
"status": "healthy",
|
327 |
+
"message": "Proxy-lite API is running",
|
328 |
+
"environment_variables": env_status,
|
329 |
+
"endpoints": {
|
330 |
+
"POST /run_proxy_task": "Execute Salesforce automation tasks (requires 'task' and 'url' parameters)",
|
331 |
+
"GET /health": "Health check and status",
|
332 |
+
"GET /": "API information"
|
333 |
+
},
|
334 |
+
"supported_pages": [
|
335 |
+
"Warranty Lifecycle Management",
|
336 |
+
"Account Forecasting Settings",
|
337 |
+
"Sales Agreements",
|
338 |
+
"Account Manager Targets",
|
339 |
+
"Any Salesforce Setup page"
|
340 |
+
],
|
341 |
+
"timestamp": datetime.now().isoformat()
|
342 |
+
}
|
343 |
+
|
344 |
+
return jsonify(health_response)
|
345 |
+
|
346 |
+
if __name__ == '__main__':
|
347 |
+
if not os.environ.get("GEMINI_API_KEY"):
|
348 |
+
logger.error("GEMINI_API_KEY environment variable is not set. Please set it for local testing.")
|
349 |
+
logger.info("Starting Flask development server on 0.0.0.0:6101...")
|
350 |
+
app.run(host='0.0.0.0', port=6101, debug=True)
|
proxy-lite-demo-v2/pyproject.toml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "proxy-lite"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Proxy Lite - A mini, open-weights, version of the Convergence AI Proxy assistant."
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.11"
|
7 |
+
dependencies = [
|
8 |
+
"omegaconf>=2.3.0",
|
9 |
+
"openai>=1.61.1",
|
10 |
+
"opencv-python>=4.11.0.86",
|
11 |
+
"opencv-python-headless>=4.11.0.86",
|
12 |
+
"playwright-stealth>=1.0.6",
|
13 |
+
"playwright>=1.50.0",
|
14 |
+
"pydantic>=2.10.6",
|
15 |
+
"rich>=13.9.4",
|
16 |
+
"setuptools>=75.8.0",
|
17 |
+
"tenacity>=9.0.0",
|
18 |
+
"torch>=2.5.1",
|
19 |
+
"torchvision>=0.20.1",
|
20 |
+
"streamlit>=1.40.2",
|
21 |
+
"pre-commit>=4.1.0",
|
22 |
+
]
|
23 |
+
|
24 |
+
[project.scripts]
|
25 |
+
proxy = "proxy_lite.cli:main"
|
26 |
+
|
27 |
+
[project.optional-dependencies]
|
28 |
+
serving = [
|
29 |
+
"transformers",
|
30 |
+
"vllm==0.7.2",
|
31 |
+
]
|
32 |
+
|
33 |
+
[build-system]
|
34 |
+
requires = ["setuptools"]
|
35 |
+
build-backend = "setuptools.build_meta"
|
36 |
+
|
37 |
+
[tool.setuptools]
|
38 |
+
packages = { find = { where = ["src"] } }
|
39 |
+
|
40 |
+
[tool.setuptools.package-data]
|
41 |
+
proxy_lite = ["**/*.json"]
|
42 |
+
|
43 |
+
[tool.ruff]
|
44 |
+
line-length = 120
|
45 |
+
|
46 |
+
[tool.ruff.lint]
|
47 |
+
select = ["E", "F", "B", "I", "SIM"]
|
48 |
+
ignore = [
|
49 |
+
"B028",
|
50 |
+
"E722", # ignore bare except
|
51 |
+
"B904", # ignore raise from requirement
|
52 |
+
"FA102",
|
53 |
+
]
|
54 |
+
[tool.ruff.lint.flake8-bugbear]
|
55 |
+
|
56 |
+
extend-immutable-calls = [
|
57 |
+
"fastapi.Depends",
|
58 |
+
"fastapi.params.Depends",
|
59 |
+
"fastapi.Query",
|
60 |
+
"fastapi.params.Query",
|
61 |
+
]
|
62 |
+
|
63 |
+
[tool.uv.sources]
|
64 |
+
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "336dc69d63d56f232a183a3e7f52790429b871ef" }
|
65 |
+
|
proxy-lite-demo-v2/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Flask[async]
|
2 |
+
-e .
|
3 |
+
playwright
|
4 |
+
playwright-stealth==1.0.6
|
5 |
+
gunicorn
|
6 |
+
gevent
|
proxy-lite-demo-v2/src/proxy_lite/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .runner import Runner, RunnerConfig
|
2 |
+
|
3 |
+
__all__ = ["Runner", "RunnerConfig"]
|
proxy-lite-demo-v2/src/proxy_lite/agents/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from .agent_base import Agents, BaseAgent, BaseAgentConfig
|
4 |
+
from .proxy_lite_agent import ProxyLiteAgent, ProxyLiteAgentConfig
|
5 |
+
|
6 |
+
AgentTypes = Union[*list(Agents._agent_registry.values())]
|
7 |
+
AgentConfigTypes = Union[*list(Agents._agent_config_registry.values())]
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"AgentConfigTypes",
|
12 |
+
"AgentTypes",
|
13 |
+
"Agents",
|
14 |
+
"BaseAgent",
|
15 |
+
"BaseAgentConfig",
|
16 |
+
"ProxyLiteAgent",
|
17 |
+
"ProxyLiteAgentConfig",
|
18 |
+
]
|
proxy-lite-demo-v2/src/proxy_lite/agents/agent_base.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from contextlib import AsyncExitStack
|
5 |
+
from functools import cached_property
|
6 |
+
from typing import Any, Optional, Type, cast
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
|
10 |
+
|
11 |
+
from proxy_lite.client import BaseClient, ClientConfigTypes, OpenAIClientConfig
|
12 |
+
from proxy_lite.history import (
|
13 |
+
AssistantMessage,
|
14 |
+
MessageHistory,
|
15 |
+
MessageLabel,
|
16 |
+
SystemMessage,
|
17 |
+
Text,
|
18 |
+
ToolCall,
|
19 |
+
ToolMessage,
|
20 |
+
UserMessage,
|
21 |
+
)
|
22 |
+
from proxy_lite.logger import logger
|
23 |
+
from proxy_lite.tools import Tool
|
24 |
+
|
25 |
+
# if TYPE_CHECKING:
|
26 |
+
# from proxy_lite.tools import Tool
|
27 |
+
|
28 |
+
|
29 |
+
class BaseAgentConfig(BaseModel):
|
30 |
+
client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig)
|
31 |
+
history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict())
|
32 |
+
history_messages_include: Optional[dict[MessageLabel, int]] = Field(
|
33 |
+
default=None,
|
34 |
+
description="If set, overrides history_messages_limit by setting all message types to 0 except those specified",
|
35 |
+
)
|
36 |
+
|
37 |
+
def model_post_init(self, __context: Any) -> None:
|
38 |
+
if self.history_messages_include is not None:
|
39 |
+
self.history_messages_limit = {label: 0 for label in MessageLabel}
|
40 |
+
self.history_messages_limit.update(self.history_messages_include)
|
41 |
+
|
42 |
+
|
43 |
+
class BaseAgent(BaseModel, ABC):
|
44 |
+
config: BaseAgentConfig
|
45 |
+
temperature: float = Field(default=0.7, ge=0, le=2)
|
46 |
+
history: MessageHistory = Field(default_factory=MessageHistory)
|
47 |
+
client: Optional[BaseClient] = None
|
48 |
+
env_tools: list[Tool] = Field(default_factory=list)
|
49 |
+
task: Optional[str] = Field(default=None)
|
50 |
+
seed: Optional[int] = Field(default=None)
|
51 |
+
|
52 |
+
class Config:
|
53 |
+
arbitrary_types_allowed = True
|
54 |
+
|
55 |
+
def __init__(self, **data) -> None:
|
56 |
+
super().__init__(**data)
|
57 |
+
self._exit_stack = AsyncExitStack()
|
58 |
+
self._tools_init_task = None
|
59 |
+
|
60 |
+
def model_post_init(self, __context: Any) -> None:
|
61 |
+
super().model_post_init(__context)
|
62 |
+
self.client = BaseClient.create(self.config.client)
|
63 |
+
|
64 |
+
@property
|
65 |
+
@abstractmethod
|
66 |
+
def system_prompt(self) -> str: ...
|
67 |
+
|
68 |
+
@cached_property
|
69 |
+
@abstractmethod
|
70 |
+
def tools(self) -> list[Tool]: ...
|
71 |
+
|
72 |
+
@cached_property
|
73 |
+
def tool_descriptions(self) -> str:
|
74 |
+
tool_descriptions = []
|
75 |
+
for tool in self.tools:
|
76 |
+
func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema)
|
77 |
+
tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else ""
|
78 |
+
tool_descriptions.append(f"{tool_title}{func_descriptions}")
|
79 |
+
return "\n\n".join(tool_descriptions)
|
80 |
+
|
81 |
+
async def get_history_view(self) -> MessageHistory:
|
82 |
+
return MessageHistory(
|
83 |
+
messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
|
84 |
+
) + self.history.history_view(
|
85 |
+
limits=self.config.history_messages_limit,
|
86 |
+
)
|
87 |
+
|
88 |
+
@retry(
|
89 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
90 |
+
stop=stop_after_attempt(3),
|
91 |
+
reraise=True,
|
92 |
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
93 |
+
)
|
94 |
+
async def generate_output(
|
95 |
+
self,
|
96 |
+
use_tool: bool = False,
|
97 |
+
response_format: Optional[type[BaseModel]] = None,
|
98 |
+
append_assistant_message: bool = True,
|
99 |
+
) -> AssistantMessage:
|
100 |
+
messages: MessageHistory = await self.get_history_view()
|
101 |
+
response_content = (
|
102 |
+
await self.client.create_completion(
|
103 |
+
messages=messages,
|
104 |
+
temperature=self.temperature,
|
105 |
+
seed=self.seed,
|
106 |
+
response_format=response_format,
|
107 |
+
tools=self.tools if use_tool else None,
|
108 |
+
)
|
109 |
+
).model_dump()
|
110 |
+
response_content = response_content["choices"][0]["message"]
|
111 |
+
assistant_message = AssistantMessage(
|
112 |
+
role=response_content["role"],
|
113 |
+
content=[Text(text=response_content["content"])] if response_content["content"] else [],
|
114 |
+
tool_calls=response_content["tool_calls"],
|
115 |
+
)
|
116 |
+
if append_assistant_message:
|
117 |
+
self.history.append(message=assistant_message, label=self.message_label)
|
118 |
+
return assistant_message
|
119 |
+
|
120 |
+
def receive_user_message(
|
121 |
+
self,
|
122 |
+
text: Optional[str] = None,
|
123 |
+
image: list[bytes] = None,
|
124 |
+
label: MessageLabel = None,
|
125 |
+
is_base64: bool = False,
|
126 |
+
) -> None:
|
127 |
+
message = UserMessage.from_media(
|
128 |
+
text=text,
|
129 |
+
image=image,
|
130 |
+
is_base64=is_base64,
|
131 |
+
)
|
132 |
+
self.history.append(message=message, label=label)
|
133 |
+
|
134 |
+
def receive_system_message(
|
135 |
+
self,
|
136 |
+
text: Optional[str] = None,
|
137 |
+
label: MessageLabel = None,
|
138 |
+
) -> None:
|
139 |
+
message = SystemMessage.from_media(text=text)
|
140 |
+
self.history.append(message=message, label=label)
|
141 |
+
|
142 |
+
def receive_assistant_message(
|
143 |
+
self,
|
144 |
+
content: Optional[str] = None,
|
145 |
+
tool_calls: Optional[list[ToolCall]] = None,
|
146 |
+
label: MessageLabel = None,
|
147 |
+
) -> None:
|
148 |
+
message = AssistantMessage(
|
149 |
+
content=[Text(text=content)] if content else [],
|
150 |
+
tool_calls=tool_calls,
|
151 |
+
)
|
152 |
+
self.history.append(message=message, label=label)
|
153 |
+
|
154 |
+
async def use_tool(self, tool_call: ToolCall):
|
155 |
+
function = tool_call.function
|
156 |
+
for tool in self.tools:
|
157 |
+
if hasattr(tool, function["name"]):
|
158 |
+
return await getattr(tool, function["name"])(
|
159 |
+
**json.loads(function["arguments"]),
|
160 |
+
)
|
161 |
+
msg = f'No tool function with name "{function["name"]}"'
|
162 |
+
raise ValueError(msg)
|
163 |
+
|
164 |
+
async def receive_tool_message(
|
165 |
+
self,
|
166 |
+
text: str,
|
167 |
+
tool_id: str,
|
168 |
+
label: MessageLabel = None,
|
169 |
+
) -> None:
|
170 |
+
self.history.append(
|
171 |
+
message=ToolMessage(content=[Text(text=text)], tool_call_id=tool_id),
|
172 |
+
label=label,
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
class Agents:
|
177 |
+
_agent_registry: dict[str, type[BaseAgent]] = {}
|
178 |
+
_agent_config_registry: dict[str, type[BaseAgentConfig]] = {}
|
179 |
+
|
180 |
+
@classmethod
|
181 |
+
def register_agent(cls, name: str):
|
182 |
+
"""
|
183 |
+
Decorator to register an Agent class under a given name.
|
184 |
+
|
185 |
+
Example:
|
186 |
+
@Agents.register_agent("browser")
|
187 |
+
class BrowserAgent(BaseAgent):
|
188 |
+
...
|
189 |
+
"""
|
190 |
+
|
191 |
+
def decorator(agent_cls: type[BaseAgent]) -> type[BaseAgent]:
|
192 |
+
cls._agent_registry[name] = agent_cls
|
193 |
+
return agent_cls
|
194 |
+
|
195 |
+
return decorator
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def register_agent_config(cls, name: str):
|
199 |
+
"""
|
200 |
+
Decorator to register a configuration class under a given name.
|
201 |
+
|
202 |
+
Example:
|
203 |
+
@Agents.register_agent_config("browser")
|
204 |
+
class BrowserAgentConfig(BaseAgentConfig):
|
205 |
+
...
|
206 |
+
"""
|
207 |
+
|
208 |
+
def decorator(config_cls: type[BaseAgentConfig]) -> type[BaseAgentConfig]:
|
209 |
+
cls._agent_config_registry[name] = config_cls
|
210 |
+
return config_cls
|
211 |
+
|
212 |
+
return decorator
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def get(cls, name: str) -> type[BaseAgent]:
|
216 |
+
"""
|
217 |
+
Retrieve a registered Agent class by its name.
|
218 |
+
|
219 |
+
Raises:
|
220 |
+
ValueError: If no such agent is found.
|
221 |
+
"""
|
222 |
+
try:
|
223 |
+
return cast(Type[BaseAgent], cls._agent_registry[name])
|
224 |
+
except KeyError:
|
225 |
+
raise ValueError(f"Agent '{name}' not found.")
|
226 |
+
|
227 |
+
@classmethod
|
228 |
+
def get_config(cls, name: str) -> type[BaseAgentConfig]:
|
229 |
+
"""
|
230 |
+
Retrieve a registered Agent configuration class by its name.
|
231 |
+
|
232 |
+
Raises:
|
233 |
+
ValueError: If no such config is found.
|
234 |
+
"""
|
235 |
+
try:
|
236 |
+
return cast(type[BaseAgentConfig], cls._agent_config_registry[name])
|
237 |
+
except KeyError:
|
238 |
+
raise ValueError(f"Agent config for '{name}' not found.")
|
proxy-lite-demo-v2/src/proxy_lite/agents/proxy_lite_agent.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import cached_property
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
from pydantic import Field
|
5 |
+
|
6 |
+
from proxy_lite.history import MessageHistory, MessageLabel, SystemMessage, Text
|
7 |
+
from proxy_lite.tools import Tool
|
8 |
+
|
9 |
+
from .agent_base import Agents, BaseAgent, BaseAgentConfig
|
10 |
+
|
11 |
+
MODEL_SYSTEM_PROMPT = """You are Proxy-Lite, an AI assistant that can perform actions on a computer screen.
|
12 |
+
You were developed by Convergence AI.
|
13 |
+
The user will instruct you to perform a task.
|
14 |
+
You will be shown a screen as well as relevant interactable elements highlighted by mark_ids and you will be given a set of tools to use to perform the task.
|
15 |
+
|
16 |
+
CRITICAL WORKFLOW INSTRUCTIONS:
|
17 |
+
1. Make observations about the screen, putting them in <observation></observation> tags.
|
18 |
+
2. Reason about what needs to be done to complete the task, putting your thoughts in <thinking></thinking> tags.
|
19 |
+
3. Use the tools to perform actions - DO NOT just describe what you plan to do, EXECUTE the actions immediately.
|
20 |
+
4. When you receive "[ACTION COMPLETED]" feedback, analyze the new screen state to determine your next action.
|
21 |
+
5. Continue executing actions step by step until the entire task is complete.
|
22 |
+
6. Use the return_value tool only when the ENTIRE task is finished.
|
23 |
+
|
24 |
+
IMPORTANT: Do NOT stop after one action. Multi-step tasks require multiple tool calls. When you receive action completion feedback, immediately analyze the screen and continue with the next required action.
|
25 |
+
""" # noqa: E501
|
26 |
+
|
27 |
+
MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
|
28 |
+
MessageLabel.SCREENSHOT: 1,
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
@Agents.register_agent_config("proxy_lite")
|
33 |
+
class ProxyLiteAgentConfig(BaseAgentConfig):
|
34 |
+
name: Literal["proxy_lite"] = "proxy_lite"
|
35 |
+
history_messages_limit: dict[MessageLabel, int] = Field(
|
36 |
+
default_factory=lambda: MAX_MESSAGES_FOR_CONTEXT_WINDOW,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
@Agents.register_agent("proxy_lite")
|
41 |
+
class ProxyLiteAgent(BaseAgent):
|
42 |
+
config: ProxyLiteAgentConfig
|
43 |
+
message_label: MessageLabel = MessageLabel.AGENT_MODEL_RESPONSE
|
44 |
+
|
45 |
+
def __init__(self, **data):
|
46 |
+
super().__init__(**data)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def system_prompt(self) -> str:
|
50 |
+
return MODEL_SYSTEM_PROMPT
|
51 |
+
|
52 |
+
@cached_property
|
53 |
+
def tools(self) -> list[Tool]:
|
54 |
+
return self.env_tools
|
55 |
+
|
56 |
+
async def get_history_view(self) -> MessageHistory:
|
57 |
+
return MessageHistory(
|
58 |
+
messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
|
59 |
+
) + self.history.history_view(
|
60 |
+
limits=self.config.history_messages_limit,
|
61 |
+
)
|
proxy-lite-demo-v2/src/proxy_lite/app.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import base64
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from proxy_lite import Runner, RunnerConfig
|
9 |
+
|
10 |
+
|
11 |
+
def get_user_config(config_expander):
|
12 |
+
config = {
|
13 |
+
"environment": {
|
14 |
+
"name": "webbrowser",
|
15 |
+
"annotate_image": True,
|
16 |
+
"screenshot_delay": 2.0,
|
17 |
+
"include_html": False,
|
18 |
+
"viewport_width": 1280,
|
19 |
+
"viewport_height": 1920,
|
20 |
+
"include_poi_text": True,
|
21 |
+
"homepage": "https://dwd000006jia1mae.lightning.force.com/lightning/setup/AccountForecastSettings/home",
|
22 |
+
"keep_original_image": False,
|
23 |
+
"headless": False, # without proxies headless mode often results in getting bot blocked
|
24 |
+
},
|
25 |
+
"solver": {
|
26 |
+
"name": "simple",
|
27 |
+
"agent": {
|
28 |
+
"name": "proxy_lite",
|
29 |
+
"client": {
|
30 |
+
"name": "convergence",
|
31 |
+
"model_id": "convergence-ai/proxy-lite-3b",
|
32 |
+
"api_base": "https://convergence-ai-demo-api.hf.space/v1",
|
33 |
+
},
|
34 |
+
},
|
35 |
+
},
|
36 |
+
"local_view": False,
|
37 |
+
"verbose": True,
|
38 |
+
"task_timeout": 1800, # 30 minutes
|
39 |
+
"action_timeout": 300,
|
40 |
+
"environment_timeout": 120,
|
41 |
+
}
|
42 |
+
|
43 |
+
with config_expander:
|
44 |
+
st.subheader("Environment Settings")
|
45 |
+
col1, col2 = st.columns(2)
|
46 |
+
|
47 |
+
with col1:
|
48 |
+
config["environment"]["include_html"] = st.checkbox(
|
49 |
+
"Include HTML",
|
50 |
+
value=config["environment"]["include_html"],
|
51 |
+
help="Include HTML in observations",
|
52 |
+
)
|
53 |
+
config["environment"]["include_poi_text"] = st.checkbox(
|
54 |
+
"Include POI Text",
|
55 |
+
value=config["environment"]["include_poi_text"],
|
56 |
+
help="Include points of interest text in observations",
|
57 |
+
)
|
58 |
+
config["environment"]["homepage"] = st.text_input(
|
59 |
+
"Homepage",
|
60 |
+
value=config["environment"]["homepage"],
|
61 |
+
help="Homepage to start from",
|
62 |
+
)
|
63 |
+
|
64 |
+
with col2:
|
65 |
+
config["solver"]["agent"]["client"]["api_base"] = st.text_input(
|
66 |
+
"VLLM Server URL",
|
67 |
+
value=config["solver"]["agent"]["client"]["api_base"],
|
68 |
+
help="URL of a vllm server running proxy-lite",
|
69 |
+
)
|
70 |
+
config["environment"]["screenshot_delay"] = st.slider(
|
71 |
+
"Screenshot Delay (seconds)",
|
72 |
+
min_value=0.5,
|
73 |
+
max_value=10.0,
|
74 |
+
value=config["environment"]["screenshot_delay"],
|
75 |
+
step=0.5,
|
76 |
+
help="Delay before taking screenshots",
|
77 |
+
)
|
78 |
+
|
79 |
+
st.subheader("Advanced Settings")
|
80 |
+
config["task_timeout"] = st.number_input(
|
81 |
+
"Task Timeout (seconds)",
|
82 |
+
min_value=60,
|
83 |
+
max_value=3600,
|
84 |
+
step=60,
|
85 |
+
value=config["task_timeout"],
|
86 |
+
help="Maximum time allowed for task completion",
|
87 |
+
)
|
88 |
+
config["action_timeout"] = st.number_input(
|
89 |
+
"Action Timeout (seconds)",
|
90 |
+
min_value=10,
|
91 |
+
max_value=300,
|
92 |
+
step=10,
|
93 |
+
value=config["action_timeout"],
|
94 |
+
help="Maximum time allowed for an action to complete",
|
95 |
+
)
|
96 |
+
config["environment_timeout"] = st.number_input(
|
97 |
+
"Environment Timeout (seconds)",
|
98 |
+
min_value=10,
|
99 |
+
max_value=300,
|
100 |
+
step=10,
|
101 |
+
value=config["environment_timeout"],
|
102 |
+
help="Maximum time allowed for environment to respond",
|
103 |
+
)
|
104 |
+
|
105 |
+
return config
|
106 |
+
|
107 |
+
|
108 |
+
async def run_task_async(
|
109 |
+
task: str,
|
110 |
+
status_placeholder,
|
111 |
+
action_placeholder,
|
112 |
+
environment_placeholder,
|
113 |
+
image_placeholder,
|
114 |
+
history_placeholder,
|
115 |
+
config: dict,
|
116 |
+
):
|
117 |
+
try:
|
118 |
+
config = RunnerConfig.from_dict(config)
|
119 |
+
except Exception as e:
|
120 |
+
st.error(f"Error loading RunnerConfig: {e!s}")
|
121 |
+
return
|
122 |
+
print(config)
|
123 |
+
runner = Runner(config=config)
|
124 |
+
|
125 |
+
# Add the spinning animation using HTML
|
126 |
+
status_placeholder.markdown(
|
127 |
+
"""
|
128 |
+
<style>
|
129 |
+
@keyframes spin {
|
130 |
+
0% { content: "β‘"; }
|
131 |
+
25% { content: "β‘."; }
|
132 |
+
50% { content: "β‘.."; }
|
133 |
+
75% { content: "β‘..."; }
|
134 |
+
}
|
135 |
+
.spinner::before {
|
136 |
+
content: "β‘";
|
137 |
+
animation: spin 2s linear infinite;
|
138 |
+
display: inline-block;
|
139 |
+
}
|
140 |
+
</style>
|
141 |
+
<div><b>Resolving your task </b><span class="spinner"></span></div>
|
142 |
+
""",
|
143 |
+
unsafe_allow_html=True,
|
144 |
+
)
|
145 |
+
|
146 |
+
all_steps = []
|
147 |
+
all_screenshots = []
|
148 |
+
all_soms = []
|
149 |
+
|
150 |
+
async for run in runner.run_generator(task):
|
151 |
+
# Update status with latest step
|
152 |
+
if run.actions:
|
153 |
+
latest_step = run.actions[-1].text
|
154 |
+
latest_step += "".join(
|
155 |
+
[
|
156 |
+
f'<tool_call>{{"name": {tool_call.function["name"]}, "arguments": {tool_call.function["arguments"]}}}</tool_call>' # noqa: E501
|
157 |
+
for tool_call in run.actions[-1].tool_calls
|
158 |
+
]
|
159 |
+
)
|
160 |
+
action_placeholder.write(f"β‘ **Latest Step:** {latest_step}")
|
161 |
+
all_steps.append(latest_step)
|
162 |
+
|
163 |
+
# Update image if available
|
164 |
+
if run.observations and run.observations[-1].state.image:
|
165 |
+
environment_placeholder.write("π **Environment:**")
|
166 |
+
image_bytes = base64.b64decode(run.observations[-1].state.image)
|
167 |
+
image = Image.open(BytesIO(image_bytes))
|
168 |
+
image_placeholder.image(image, use_container_width=True)
|
169 |
+
all_screenshots.append(image)
|
170 |
+
som = run.observations[-1].state.text
|
171 |
+
all_soms.append(som)
|
172 |
+
|
173 |
+
# Update history
|
174 |
+
with history_placeholder, st.expander("π **History**"):
|
175 |
+
for idx, (action, img, som) in enumerate(zip(all_steps, all_screenshots, all_soms, strict=False)):
|
176 |
+
st.write(f"**Step {idx + 1}**")
|
177 |
+
st.image(img, use_container_width=True)
|
178 |
+
st.markdown(som)
|
179 |
+
st.write(action)
|
180 |
+
action_placeholder.write(" ")
|
181 |
+
status_placeholder.write(f"β¨ **Result:** {latest_step}")
|
182 |
+
|
183 |
+
|
184 |
+
def main():
|
185 |
+
st.title("β‘ Proxy-Lite")
|
186 |
+
|
187 |
+
def img_to_base64(image_path):
|
188 |
+
with open(image_path, "rb") as img_file:
|
189 |
+
return base64.b64encode(img_file.read()).decode("utf-8")
|
190 |
+
|
191 |
+
st.markdown("Powered by **Proxy-Lite**", unsafe_allow_html=True)
|
192 |
+
|
193 |
+
if "config_expanded" not in st.session_state:
|
194 |
+
st.session_state.config_expanded = False
|
195 |
+
if "settings_expanded" not in st.session_state:
|
196 |
+
st.session_state.settings_expanded = False
|
197 |
+
|
198 |
+
config_expander = st.expander("βοΈ Proxy-Lite Configuration", expanded=st.session_state.config_expanded)
|
199 |
+
config = get_user_config(config_expander)
|
200 |
+
|
201 |
+
with st.form(key="run_task_form"):
|
202 |
+
task = st.text_input(
|
203 |
+
"Submit a task",
|
204 |
+
key="task_input",
|
205 |
+
help="Enter a task to be completed",
|
206 |
+
)
|
207 |
+
submit_button = st.form_submit_button("Submit a task", type="primary", use_container_width=True)
|
208 |
+
|
209 |
+
if submit_button:
|
210 |
+
st.session_state.config_expanded = False
|
211 |
+
if task:
|
212 |
+
# Create placeholders for dynamic updates
|
213 |
+
status_placeholder = st.empty()
|
214 |
+
st.write(" ")
|
215 |
+
action_placeholder = st.empty()
|
216 |
+
environment_placeholder = st.empty()
|
217 |
+
image_placeholder = st.empty()
|
218 |
+
history_placeholder = st.empty()
|
219 |
+
|
220 |
+
# Run the async task
|
221 |
+
asyncio.run(
|
222 |
+
run_task_async(
|
223 |
+
task,
|
224 |
+
status_placeholder,
|
225 |
+
action_placeholder,
|
226 |
+
environment_placeholder,
|
227 |
+
image_placeholder,
|
228 |
+
history_placeholder,
|
229 |
+
config,
|
230 |
+
),
|
231 |
+
)
|
232 |
+
|
233 |
+
st.success("Task completed!", icon="β¨")
|
234 |
+
else:
|
235 |
+
st.error("Please give a task first!")
|
236 |
+
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
main()
|
proxy-lite-demo-v2/src/proxy_lite/browser/__init__.py
ADDED
File without changes
|
proxy-lite-demo-v2/src/proxy_lite/browser/add_custom_select.js
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
handledSelectElementsConvergence = new WeakSet();
|
2 |
+
|
3 |
+
overwriteDefaultSelectConvergence = (input = null) => {
|
4 |
+
let activeSelectElement = null;
|
5 |
+
|
6 |
+
// Handle iframe input element
|
7 |
+
let rootElement = input ? input : document.documentElement;
|
8 |
+
|
9 |
+
function createCustomSelectElement() {
|
10 |
+
// Create the custom select container
|
11 |
+
const customSelect = document.createElement('div');
|
12 |
+
customSelect.id = 'convergence-custom-select-element-X2EmudtLRN';
|
13 |
+
customSelect.style.position = 'absolute'
|
14 |
+
customSelect.style.zIndex = 2147483647 - 1;
|
15 |
+
customSelect.style.display = 'none';
|
16 |
+
document.body.appendChild(customSelect);
|
17 |
+
|
18 |
+
// Create the select options list
|
19 |
+
const optionsList = document.createElement('div');
|
20 |
+
optionsList.style.border = '1px solid #ccc';
|
21 |
+
optionsList.style.backgroundColor = '#fff';
|
22 |
+
optionsList.style.color = 'black';
|
23 |
+
customSelect.appendChild(optionsList);
|
24 |
+
|
25 |
+
return customSelect;
|
26 |
+
}
|
27 |
+
|
28 |
+
function showCustomSelect(select) {
|
29 |
+
activeSelectElement = select;
|
30 |
+
|
31 |
+
// Clear previous options
|
32 |
+
const customSelect = rootElement.querySelector('#convergence-custom-select-element-X2EmudtLRN');
|
33 |
+
let optionsList = customSelect.firstChild;
|
34 |
+
optionsList.innerHTML = '';
|
35 |
+
|
36 |
+
// Populate with new options
|
37 |
+
Array.from(select.options).forEach(option => {
|
38 |
+
const customOption = document.createElement('div');
|
39 |
+
customOption.className = 'custom-option';
|
40 |
+
customOption.style.padding = '8px';
|
41 |
+
customOption.style.cursor = 'pointer';
|
42 |
+
customOption.textContent = option.text;
|
43 |
+
customOption.dataset.value = option.value;
|
44 |
+
optionsList.appendChild(customOption);
|
45 |
+
|
46 |
+
customOption.addEventListener('mouseenter', function () {
|
47 |
+
customOption.style.backgroundColor = '#f0f0f0';
|
48 |
+
});
|
49 |
+
|
50 |
+
customOption.addEventListener('mouseleave', function () {
|
51 |
+
customOption.style.backgroundColor = '';
|
52 |
+
});
|
53 |
+
|
54 |
+
customOption.addEventListener('mousedown', (e) => {
|
55 |
+
e.stopPropagation();
|
56 |
+
select.value = customOption.dataset.value;
|
57 |
+
customSelect.style.display = 'none';
|
58 |
+
activeSelectElement = null;
|
59 |
+
// ensure we trigger all potential event listeners
|
60 |
+
select.dispatchEvent(new InputEvent('focus', { bubbles: true, cancelable: true }));
|
61 |
+
select.dispatchEvent(new InputEvent('input', { bubbles: true, cancelable: true }));
|
62 |
+
select.dispatchEvent(new InputEvent('change', { bubbles: true, cancelable: true }));
|
63 |
+
select.dispatchEvent(new InputEvent('blur', { bubbles: true, cancelable: true }));
|
64 |
+
});
|
65 |
+
});
|
66 |
+
|
67 |
+
// Position and show the custom select
|
68 |
+
const selectRect = select.getBoundingClientRect();
|
69 |
+
customSelect.style.top = `${selectRect.bottom + window.scrollY}px`;
|
70 |
+
customSelect.style.left = `${selectRect.left + window.scrollX}px`;
|
71 |
+
customSelect.style.width = `${selectRect.width}px`;
|
72 |
+
customSelect.style.display = 'block';
|
73 |
+
select.focus();
|
74 |
+
select.addEventListener('blur', function (e) {
|
75 |
+
customSelect.style.display = 'none';
|
76 |
+
activeSelectElement = null;
|
77 |
+
});
|
78 |
+
select.addEventListener('change', function (e) {
|
79 |
+
customSelect.style.display = 'none';
|
80 |
+
activeSelectElement = null;
|
81 |
+
});
|
82 |
+
}
|
83 |
+
|
84 |
+
// Ensure we have a custom select element
|
85 |
+
let customSelect = rootElement.querySelector(`#convergence-custom-select-element-X2EmudtLRN`);
|
86 |
+
if (!customSelect) {
|
87 |
+
customSelect = createCustomSelectElement();
|
88 |
+
}
|
89 |
+
|
90 |
+
// Find selects in shadow DOMs
|
91 |
+
function findSelectInShadowRoot(element) {
|
92 |
+
if (element.shadowRoot) {
|
93 |
+
return element.shadowRoot.querySelectorAll('select');
|
94 |
+
}
|
95 |
+
return [];
|
96 |
+
}
|
97 |
+
let shadowSelects = [];
|
98 |
+
rootElement.querySelectorAll('*').forEach(el => {
|
99 |
+
shadowSelects.push(...findSelectInShadowRoot(el));
|
100 |
+
});
|
101 |
+
|
102 |
+
// Find selects in the regular (light) DOM
|
103 |
+
const lightSelects = Array.from(rootElement.querySelectorAll('select'));
|
104 |
+
|
105 |
+
// Add event listeners to all select elements
|
106 |
+
const allSelects = [...lightSelects, ...shadowSelects];
|
107 |
+
allSelects.forEach(select => {
|
108 |
+
if (select.hasAttribute('multiple')) {
|
109 |
+
// skip special multiple elements as our POI code already handles them
|
110 |
+
return;
|
111 |
+
}
|
112 |
+
if (!handledSelectElementsConvergence.has(select)) {
|
113 |
+
select.addEventListener('mousedown', (e) => {
|
114 |
+
// only use custom select when the default behaviour is being used
|
115 |
+
if (!e.defaultPrevented) {
|
116 |
+
showCustomSelect(select);
|
117 |
+
e.preventDefault();
|
118 |
+
}
|
119 |
+
});
|
120 |
+
handledSelectElementsConvergence.add(select);
|
121 |
+
}
|
122 |
+
});
|
123 |
+
}
|
proxy-lite-demo-v2/src/proxy_lite/browser/bounding_boxes.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from pydantic import BaseModel, Field, field_validator
|
7 |
+
|
8 |
+
|
9 |
+
class Point(BaseModel):
|
10 |
+
x: int
|
11 |
+
y: int
|
12 |
+
|
13 |
+
def __iter__(self):
|
14 |
+
return iter((self.x, self.y))
|
15 |
+
|
16 |
+
def __getitem__(self, index) -> int:
|
17 |
+
return (self.x, self.y)[index]
|
18 |
+
|
19 |
+
def __tuple__(self) -> tuple[int, int]:
|
20 |
+
return (self.x, self.y)
|
21 |
+
|
22 |
+
def __repr__(self) -> str:
|
23 |
+
return f"Point(x={self.x}, y={self.y})"
|
24 |
+
|
25 |
+
|
26 |
+
class BoundingBox(BaseModel):
|
27 |
+
label: str = Field(..., description="The label that's given for this bounding box")
|
28 |
+
left: int = Field(..., description="Left coordinate of the bounding box")
|
29 |
+
right: int = Field(..., description="Right coordinate of the bounding box")
|
30 |
+
top: int = Field(..., description="Top coordinate of the bounding box")
|
31 |
+
bottom: int = Field(..., description="Bottom coordinate of the bounding box")
|
32 |
+
|
33 |
+
@field_validator("left", "top", mode="before")
|
34 |
+
@classmethod
|
35 |
+
def round_down(cls, v):
|
36 |
+
return math.floor(float(v))
|
37 |
+
|
38 |
+
@field_validator("right", "bottom", mode="before")
|
39 |
+
@classmethod
|
40 |
+
def round_up(cls, v):
|
41 |
+
return math.ceil(float(v))
|
42 |
+
|
43 |
+
|
44 |
+
class POI(BaseModel):
|
45 |
+
info: dict[str, Any]
|
46 |
+
element_centroid: Point
|
47 |
+
bounding_box: BoundingBox
|
48 |
+
|
49 |
+
|
50 |
+
def calculate_dash_points(start, end, dash_length, gap_length):
|
51 |
+
x1, y1 = start
|
52 |
+
x2, y2 = end
|
53 |
+
dx = x2 - x1
|
54 |
+
dy = y2 - y1
|
55 |
+
dist = np.sqrt(dx * dx + dy * dy)
|
56 |
+
|
57 |
+
if dist == 0:
|
58 |
+
return []
|
59 |
+
|
60 |
+
unit_x = dx / dist
|
61 |
+
unit_y = dy / dist
|
62 |
+
|
63 |
+
dash_points = []
|
64 |
+
current_dist = 0
|
65 |
+
while current_dist < dist:
|
66 |
+
dash_end = min(current_dist + dash_length, dist)
|
67 |
+
dash_points.extend(
|
68 |
+
[
|
69 |
+
(int(x1 + unit_x * current_dist), int(y1 + unit_y * current_dist)),
|
70 |
+
(int(x1 + unit_x * dash_end), int(y1 + unit_y * dash_end)),
|
71 |
+
],
|
72 |
+
)
|
73 |
+
current_dist += dash_length + gap_length
|
74 |
+
|
75 |
+
return dash_points
|
76 |
+
|
77 |
+
|
78 |
+
def draw_dashed_rectangle(
|
79 |
+
img,
|
80 |
+
bbox: BoundingBox,
|
81 |
+
color,
|
82 |
+
thickness=1,
|
83 |
+
dash_length=10,
|
84 |
+
gap_length=5,
|
85 |
+
):
|
86 |
+
# Calculate dash points for all sides
|
87 |
+
top_points = calculate_dash_points(
|
88 |
+
(bbox.left + 25, bbox.top + 25),
|
89 |
+
(bbox.right + 25, bbox.top + 25),
|
90 |
+
dash_length,
|
91 |
+
gap_length,
|
92 |
+
)
|
93 |
+
right_points = calculate_dash_points(
|
94 |
+
(bbox.right + 25, bbox.top + 25),
|
95 |
+
(bbox.right + 25, bbox.bottom + 25),
|
96 |
+
dash_length,
|
97 |
+
gap_length,
|
98 |
+
)
|
99 |
+
bottom_points = calculate_dash_points(
|
100 |
+
(bbox.right + 25, bbox.bottom + 25),
|
101 |
+
(bbox.left + 25, bbox.bottom + 25),
|
102 |
+
dash_length,
|
103 |
+
gap_length,
|
104 |
+
)
|
105 |
+
left_points = calculate_dash_points(
|
106 |
+
(bbox.left + 25, bbox.bottom + 25),
|
107 |
+
(bbox.left + 25, bbox.top + 25),
|
108 |
+
dash_length,
|
109 |
+
gap_length,
|
110 |
+
)
|
111 |
+
|
112 |
+
# Combine all points
|
113 |
+
all_points = top_points + right_points + bottom_points + left_points
|
114 |
+
|
115 |
+
# Draw all lines at once
|
116 |
+
if all_points:
|
117 |
+
all_points = np.array(all_points).reshape((-1, 2, 2))
|
118 |
+
cv2.polylines(img, all_points, False, color, thickness)
|
119 |
+
|
120 |
+
|
121 |
+
# @time_it(name='Annotate bounding box')
|
122 |
+
def annotate_bounding_box(image: bytes, bbox: BoundingBox) -> None:
|
123 |
+
# Draw dashed bounding box
|
124 |
+
draw_dashed_rectangle(
|
125 |
+
image,
|
126 |
+
bbox,
|
127 |
+
color=(0, 0, 255),
|
128 |
+
thickness=1,
|
129 |
+
dash_length=10,
|
130 |
+
gap_length=5,
|
131 |
+
)
|
132 |
+
|
133 |
+
# Prepare label
|
134 |
+
font_scale = 0.4 * 4 # Increased by 4x for the larger patch
|
135 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
136 |
+
thickness = 3 # Increased thickness for the larger patch
|
137 |
+
|
138 |
+
# Get text size for the larger patch
|
139 |
+
(label_width, label_height), _ = cv2.getTextSize(
|
140 |
+
bbox.label,
|
141 |
+
font,
|
142 |
+
font_scale,
|
143 |
+
thickness,
|
144 |
+
)
|
145 |
+
|
146 |
+
# Create a larger patch (4x)
|
147 |
+
large_label_patch = np.zeros(
|
148 |
+
(label_height + 20, label_width + 20, 4),
|
149 |
+
dtype=np.uint8,
|
150 |
+
)
|
151 |
+
large_label_patch[:, :, 0:3] = (0, 0, 255) # BGR color format: Red background
|
152 |
+
large_label_patch[:, :, 3] = 128 # Alpha channel: 50% opacity (128/255 = 0.5)
|
153 |
+
|
154 |
+
# Draw text on the larger patch
|
155 |
+
cv2.putText(
|
156 |
+
large_label_patch,
|
157 |
+
bbox.label,
|
158 |
+
(8, label_height + 8), # Adjusted position for the larger patch
|
159 |
+
font,
|
160 |
+
font_scale,
|
161 |
+
(255, 255, 255, 128), # White text, 50% opaque (128/255 = 0.5)
|
162 |
+
thickness,
|
163 |
+
)
|
164 |
+
|
165 |
+
# Scale down the patch to improve anti-aliasing
|
166 |
+
label_patch = cv2.resize(
|
167 |
+
large_label_patch,
|
168 |
+
(label_width // 4 + 5, label_height // 4 + 5),
|
169 |
+
interpolation=cv2.INTER_AREA,
|
170 |
+
)
|
171 |
+
|
172 |
+
# Calculate position for top-left alignment
|
173 |
+
offset = 2 # Small offset to prevent touching the bounding box edge
|
174 |
+
x = min(image.shape[1], max(0, int(bbox.left + 25) - offset))
|
175 |
+
y = min(image.shape[0], max(0, int(bbox.top + 25) - label_patch.shape[0] - offset))
|
176 |
+
|
177 |
+
# Ensure we're not out of bounds
|
178 |
+
x_end = min(image.shape[1], x + label_patch.shape[1])
|
179 |
+
y_end = min(image.shape[0], y + label_patch.shape[0])
|
180 |
+
label_patch = label_patch[: (y_end - y), : (x_end - x)]
|
181 |
+
|
182 |
+
# Create a mask for the label patch
|
183 |
+
alpha_mask = label_patch[:, :, 3] / 255.0
|
184 |
+
alpha_mask = np.repeat(alpha_mask[:, :, np.newaxis], 3, axis=2)
|
185 |
+
|
186 |
+
# Blend the label patch with the image
|
187 |
+
image_section = image[y:y_end, x:x_end]
|
188 |
+
blended = (1 - alpha_mask) * image_section + alpha_mask * label_patch[:, :, 0:3]
|
189 |
+
image[y:y_end, x:x_end] = blended.astype(np.uint8)
|
190 |
+
|
191 |
+
|
192 |
+
def annotate_bounding_boxes(image: bytes, bounding_boxes: list[BoundingBox]) -> bytes:
|
193 |
+
# Read the image
|
194 |
+
nparr = np.frombuffer(image, np.uint8)
|
195 |
+
# Decode the image
|
196 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
197 |
+
padded_img = cv2.copyMakeBorder(
|
198 |
+
img,
|
199 |
+
top=25, # Value chosen based on label size
|
200 |
+
bottom=25, # Value chosen based on label size
|
201 |
+
left=25, # Value chosen based on label size
|
202 |
+
right=25, # Value chosen based on label size
|
203 |
+
borderType=cv2.BORDER_CONSTANT,
|
204 |
+
value=(255, 255, 255),
|
205 |
+
)
|
206 |
+
for bounding_box in bounding_boxes:
|
207 |
+
# Annotate the image in place with the bounding box and the bounding box label
|
208 |
+
annotate_bounding_box(padded_img, bounding_box)
|
209 |
+
_, buffer = cv2.imencode(".jpeg", padded_img)
|
210 |
+
return buffer.tobytes()
|
proxy-lite-demo-v2/src/proxy_lite/browser/browser.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import platform
|
4 |
+
import re
|
5 |
+
from contextlib import AsyncExitStack
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Literal, Optional, Self
|
8 |
+
|
9 |
+
from playwright.async_api import Browser, BrowserContext, Page, Playwright, async_playwright
|
10 |
+
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
11 |
+
from playwright_stealth import StealthConfig, stealth_async
|
12 |
+
from pydantic import Field
|
13 |
+
from tenacity import before_sleep_log, retry, stop_after_delay, wait_exponential
|
14 |
+
|
15 |
+
from proxy_lite.browser.bounding_boxes import POI, BoundingBox, Point, annotate_bounding_boxes
|
16 |
+
from proxy_lite.logger import logger
|
17 |
+
|
18 |
+
import base64
|
19 |
+
|
20 |
+
SELF_CONTAINED_TAGS = [
|
21 |
+
# many of these are non-interactive but keeping them anyway
|
22 |
+
"area",
|
23 |
+
"base",
|
24 |
+
"br",
|
25 |
+
"col",
|
26 |
+
"embed",
|
27 |
+
"hr",
|
28 |
+
"img",
|
29 |
+
"input",
|
30 |
+
"link",
|
31 |
+
"meta",
|
32 |
+
"param",
|
33 |
+
"source",
|
34 |
+
"track",
|
35 |
+
"wbr",
|
36 |
+
]
|
37 |
+
|
38 |
+
|
39 |
+
def element_as_text(
|
40 |
+
mark_id: int,
|
41 |
+
tag: Optional[str] = None,
|
42 |
+
text: Optional[str] = None,
|
43 |
+
**raw_attributes,
|
44 |
+
) -> str:
|
45 |
+
"""Return a text representation of all elements on the page."""
|
46 |
+
attributes = []
|
47 |
+
for k, v in raw_attributes.items():
|
48 |
+
if v is None:
|
49 |
+
continue
|
50 |
+
if isinstance(v, bool):
|
51 |
+
if v:
|
52 |
+
attributes.append(k)
|
53 |
+
# we ignore False bool attributes
|
54 |
+
else:
|
55 |
+
v = str(v)
|
56 |
+
if len(v) > 2500:
|
57 |
+
v = v[: 2500 - 1] + "β¦"
|
58 |
+
attributes.append(f'{k}="{v}"')
|
59 |
+
attributes = " ".join(attributes)
|
60 |
+
attributes = (" " + attributes).rstrip()
|
61 |
+
tag = tag.lower()
|
62 |
+
if text is None:
|
63 |
+
text = ""
|
64 |
+
if len(text) > 2500:
|
65 |
+
text = text[: 2500 - 1] + "β¦"
|
66 |
+
|
67 |
+
# sub-out line breaks so elements are easier to distinguish
|
68 |
+
attributes = re.sub(r"\r\n|\r|\n", "β", attributes)
|
69 |
+
text = re.sub(r"\r\n|\r|\n", "β", text)
|
70 |
+
|
71 |
+
if tag in SELF_CONTAINED_TAGS:
|
72 |
+
if text:
|
73 |
+
logger.warning(
|
74 |
+
f"Got self-contained element '{tag}' which contained text '{text}'.",
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
return f"- [{mark_id}] <{tag}{attributes}/>"
|
78 |
+
return f"- [{mark_id}] <{tag}{attributes}>{text}</{tag}>"
|
79 |
+
|
80 |
+
|
81 |
+
class BrowserSession:
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
viewport_width: int = 1280,
|
85 |
+
viewport_height: int = 720,
|
86 |
+
headless: bool = True,
|
87 |
+
):
|
88 |
+
self.viewport_width = viewport_width
|
89 |
+
self.viewport_height = viewport_height
|
90 |
+
self.headless = headless
|
91 |
+
self.playwright: Playwright | None = None
|
92 |
+
self.browser: Browser | None = None
|
93 |
+
self.context: BrowserContext | None = None
|
94 |
+
self._exit_stack: AsyncExitStack | None = None
|
95 |
+
|
96 |
+
self.poi_elements: list = Field(default_factory=list)
|
97 |
+
self.poi_centroids: list[Point] = Field(default_factory=list)
|
98 |
+
self.bounding_boxes: list[BoundingBox] = Field(default_factory=list)
|
99 |
+
self.pois: list[POI] = Field(default_factory=list)
|
100 |
+
|
101 |
+
async def __aenter__(self) -> Self:
|
102 |
+
self._exit_stack = AsyncExitStack()
|
103 |
+
self.playwright = await async_playwright().start()
|
104 |
+
|
105 |
+
self.browser = await self.playwright.chromium.launch(headless=self.headless)
|
106 |
+
self.context = await self.browser.new_context(
|
107 |
+
viewport={"width": self.viewport_width, "height": self.viewport_height},
|
108 |
+
)
|
109 |
+
# Ensure there's at least one page open
|
110 |
+
if not self.context.pages:
|
111 |
+
await self.context.new_page()
|
112 |
+
|
113 |
+
self.context.set_default_timeout(60_000)
|
114 |
+
self.current_page.set_default_timeout(60_000)
|
115 |
+
await stealth_async(self.current_page, StealthConfig(navigator_user_agent=False))
|
116 |
+
await self.context.add_init_script(
|
117 |
+
path=Path(__file__).with_name("add_custom_select.js"),
|
118 |
+
)
|
119 |
+
await self.context.add_init_script(
|
120 |
+
path=Path(__file__).with_name("find_pois.js"),
|
121 |
+
)
|
122 |
+
|
123 |
+
return self
|
124 |
+
|
125 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
126 |
+
if self.browser:
|
127 |
+
await self.browser.close()
|
128 |
+
if self.playwright:
|
129 |
+
await self.playwright.stop()
|
130 |
+
if self._exit_stack:
|
131 |
+
await self._exit_stack.aclose()
|
132 |
+
|
133 |
+
@property
|
134 |
+
def current_page(self) -> Optional[Page]:
|
135 |
+
if self.context and self.context.pages:
|
136 |
+
return self.context.pages[-1] # Return the most recently opened page
|
137 |
+
return None
|
138 |
+
|
139 |
+
@property
|
140 |
+
def current_url(self) -> Optional[str]:
|
141 |
+
if self.current_page:
|
142 |
+
return self.current_page.url
|
143 |
+
return None
|
144 |
+
|
145 |
+
# re-run for cases of mid-run redirects
|
146 |
+
@retry(
|
147 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
148 |
+
stop=stop_after_delay(5),
|
149 |
+
reraise=True,
|
150 |
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
151 |
+
)
|
152 |
+
async def process_iframe(self, iframe) -> Optional[tuple[dict, dict]]:
|
153 |
+
try:
|
154 |
+
# Check iframe visibility and size
|
155 |
+
bounding_box = await iframe.bounding_box()
|
156 |
+
if not bounding_box:
|
157 |
+
return None # Skip if iframe is not visible
|
158 |
+
|
159 |
+
width, height = bounding_box["width"], bounding_box["height"]
|
160 |
+
if width < 50 or height < 50:
|
161 |
+
return None
|
162 |
+
|
163 |
+
frame = await iframe.content_frame()
|
164 |
+
if not frame:
|
165 |
+
return None
|
166 |
+
|
167 |
+
poi = await frame.evaluate(
|
168 |
+
"""() => {
|
169 |
+
overwriteDefaultSelectConvergence();
|
170 |
+
return findPOIsConvergence();
|
171 |
+
}""",
|
172 |
+
)
|
173 |
+
if not poi:
|
174 |
+
return None
|
175 |
+
|
176 |
+
iframe_offset = {"x": round(bounding_box["x"]), "y": round(bounding_box["y"])}
|
177 |
+
return poi, iframe_offset
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"Error processing iframe: {e}")
|
180 |
+
return None
|
181 |
+
|
182 |
+
@retry(
|
183 |
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
184 |
+
stop=stop_after_delay(5),
|
185 |
+
reraise=True,
|
186 |
+
before_sleep=before_sleep_log(logger, logging.ERROR),
|
187 |
+
)
|
188 |
+
async def update_poi(self) -> None:
|
189 |
+
try:
|
190 |
+
# Wait for basic page load states to ensure the DOM is ready.
|
191 |
+
# This is a fundamental wait that should always apply.
|
192 |
+
await self.current_page.wait_for_load_state("domcontentloaded", timeout=60000)
|
193 |
+
logger.debug(f"DEBUG: wait_for_load_state('domcontentloaded') completed for {self.current_page.url}.")
|
194 |
+
|
195 |
+
current_url = self.current_page.url
|
196 |
+
|
197 |
+
# Define common Salesforce URL patterns for different states
|
198 |
+
login_url_patterns = [
|
199 |
+
"login.salesforce.com",
|
200 |
+
"identity.force.com",
|
201 |
+
"auth.lightning.force.com",
|
202 |
+
"setup.salesforce.com", # Sometimes a setup login redirects here temporarily
|
203 |
+
"my.salesforce.com" # Your specific custom domain login redirects here
|
204 |
+
]
|
205 |
+
|
206 |
+
# This is the main Salesforce Lightning application base URL, typically seen after login.
|
207 |
+
# We treat this as an intermediate loading state before the specific target page.
|
208 |
+
intermediate_app_url_pattern = "/one/one.app"
|
209 |
+
|
210 |
+
# Check the current state of the page based on its URL
|
211 |
+
is_on_login_page = any(pattern in current_url for pattern in login_url_patterns)
|
212 |
+
is_on_intermediate_app_page = intermediate_app_url_pattern in current_url
|
213 |
+
# Note: is_on_target_forecast_page checks if the specific target path is in the URL
|
214 |
+
is_on_target_forecast_page = "/AccountForecastSettings/home" in current_url
|
215 |
+
|
216 |
+
# --- CONDITIONAL WAITING LOGIC BASED ON URL ---
|
217 |
+
if is_on_target_forecast_page:
|
218 |
+
logger.info(f"INFO: Detected target Account Forecast Settings page: {current_url}. Waiting for content.")
|
219 |
+
# When on the specific target page, wait for its content and spinners
|
220 |
+
spinner_selectors = [
|
221 |
+
"div.slds-spinner_container",
|
222 |
+
"div.auraLoadingBox",
|
223 |
+
"div.dxp_axb_container", # Main overlay from your inspect screenshot
|
224 |
+
"div.slds-sprite-astro-x-large" # Specific animated element itself
|
225 |
+
]
|
226 |
+
for selector in spinner_selectors:
|
227 |
+
try:
|
228 |
+
await self.current_page.wait_for_selector(selector, state="hidden", timeout=5000) # Reduced timeout
|
229 |
+
logger.debug(f"DEBUG: Spinner element '{selector}' became hidden for {self.current_page.url}.")
|
230 |
+
except PlaywrightTimeoutError:
|
231 |
+
logger.warning(f"DEBUGGING: Spinner element '{selector}' not detected or did not disappear on {self.current_page.url} within 5s.")
|
232 |
+
|
233 |
+
# Wait for a known element on the Account Forecast Settings page to ensure content is there.
|
234 |
+
try:
|
235 |
+
# Added 'h2' for section headers, and a more generic 'div[data-aura-rendered-by]' for Lightning components
|
236 |
+
await self.current_page.wait_for_selector("h1.slds-page-header__title, h2, .account-forecast-settings-component, div[data-aura-rendered-by]", state="visible", timeout=15000) # Increased timeout slightly for robust content load
|
237 |
+
logger.debug(f"DEBUG: Confirmed main page element visible for {self.current_page.url}.")
|
238 |
+
except PlaywrightTimeoutError:
|
239 |
+
logger.warning(f"DEBUGGING: Main page element not visible on {self.current_page.url} within 15s. This might indicate incomplete page load despite no spinner.")
|
240 |
+
|
241 |
+
elif is_on_login_page:
|
242 |
+
logger.info(f"INFO: Detected Salesforce login page: {current_url}. Waiting for login elements.")
|
243 |
+
# When on a login page, just wait for the login form elements to be visible
|
244 |
+
try:
|
245 |
+
await self.current_page.wait_for_selector("input[type='email'], input[type='password'], input[type='submit'], #username, #password, #Login", state="visible", timeout=10000)
|
246 |
+
logger.debug(f"DEBUG: Login page elements visible on {self.current_page.url}.")
|
247 |
+
except PlaywrightTimeoutError:
|
248 |
+
logger.warning(f"DEBUGGING: Login page elements not visible on {self.current_page.url} within 10s. This may happen if elements are in an iframe or if page is extremely slow.")
|
249 |
+
|
250 |
+
elif is_on_intermediate_app_page:
|
251 |
+
logger.info(f"INFO: Detected intermediate Salesforce Lightning app loading page: {current_url}. Waiting for network idle and app spinner.")
|
252 |
+
# This is the /one/one.app page or similar. Don't wait for specific content, just general load.
|
253 |
+
try:
|
254 |
+
await self.current_page.wait_for_load_state("networkidle", timeout=30000) # Give it more time for network to settle
|
255 |
+
logger.debug(f"DEBUG: Network idle detected on intermediate app page: {current_url}.")
|
256 |
+
except PlaywrightTimeoutError:
|
257 |
+
logger.warning(f"DEBUGGING: Network idle timeout on intermediate app page: {current_url}. Proceeding anyway.")
|
258 |
+
|
259 |
+
# Also try to wait for a common full-app spinner to disappear, if present
|
260 |
+
try:
|
261 |
+
await self.current_page.wait_for_selector('div.app-spinner, div.auraLoadingBox', state='hidden', timeout=15000) # Added auraLoadingBox as it might reappear
|
262 |
+
logger.debug(f"DEBUG: App spinner on intermediate page became hidden.")
|
263 |
+
except PlaywrightTimeoutError:
|
264 |
+
logger.warning(f"DEBUGGING: App spinner on intermediate page not found or did not disappear.")
|
265 |
+
|
266 |
+
else:
|
267 |
+
logger.info(f"INFO: Detected unhandled URL type: {current_url}. Performing generic body wait.")
|
268 |
+
# Fallback for any other page, just wait for body to be visible
|
269 |
+
try:
|
270 |
+
await self.current_page.wait_for_selector("body", timeout=5000, state="visible")
|
271 |
+
logger.debug(f"DEBUG: wait_for_selector('body', state='visible') completed for {self.current_page.url}.")
|
272 |
+
except PlaywrightTimeoutError:
|
273 |
+
logger.warning(f"DEBUGGING: Playwright Timeout (5s) on body selector for {self.current_page.url}. Continuing anyway.")
|
274 |
+
pass
|
275 |
+
|
276 |
+
except PlaywrightTimeoutError as e:
|
277 |
+
logger.error(f"ERROR: Timeout waiting for page readiness for {self.current_page.url}: {e}")
|
278 |
+
raise # Re-raise if essential waits fail (e.g., initial domcontentloaded)
|
279 |
+
except Exception as e:
|
280 |
+
logger.error(f"ERROR: An unexpected error occurred during page readiness check for {self.current_page.url}: {e}")
|
281 |
+
raise
|
282 |
+
|
283 |
+
# Rest of update_poi: Run the bounding box javascript code to highlight the points of interest on the page
|
284 |
+
page_info = await self.current_page.evaluate(
|
285 |
+
"""() => {
|
286 |
+
overwriteDefaultSelectConvergence();
|
287 |
+
return findPOIsConvergence();
|
288 |
+
}""",
|
289 |
+
)
|
290 |
+
# Get the points of interest on the page
|
291 |
+
self.poi_elements = page_info["element_descriptions"]
|
292 |
+
element_centroids = page_info["element_centroids"]
|
293 |
+
try:
|
294 |
+
# Select all iframes on the page
|
295 |
+
iframes = await self.current_page.query_selector_all("iframe")
|
296 |
+
|
297 |
+
max_iframes = 10
|
298 |
+
|
299 |
+
# Define an asynchronous function to process and filter each iframe
|
300 |
+
tasks = [asyncio.create_task(self.process_iframe(iframe)) for iframe in iframes[:max_iframes]]
|
301 |
+
|
302 |
+
results = await asyncio.gather(*tasks)
|
303 |
+
|
304 |
+
filtered_results = [result for result in results if result is not None]
|
305 |
+
|
306 |
+
iframes_pois = []
|
307 |
+
iframe_offsets = []
|
308 |
+
|
309 |
+
for poi, offset in filtered_results:
|
310 |
+
iframes_pois.append(poi)
|
311 |
+
iframe_offsets.append(offset)
|
312 |
+
|
313 |
+
# Combine the points of interest from the iframes with the main page and adjust the centroids
|
314 |
+
for index, iframe_poi in enumerate(iframes_pois):
|
315 |
+
self.poi_elements.extend(iframe_poi["element_descriptions"])
|
316 |
+
for centroid in iframe_poi["element_centroids"]:
|
317 |
+
centroid["x"] += iframe_offsets[index]["x"]
|
318 |
+
centroid["y"] += iframe_offsets[index]["y"]
|
319 |
+
centroid["left"] += iframe_offsets[index]["x"]
|
320 |
+
centroid["top"] += iframe_offsets[index]["y"]
|
321 |
+
centroid["right"] += iframe_offsets[index]["x"]
|
322 |
+
# Fix: Removed duplicate 'centroid["y"] += iframe_offsets[index]["y"]'
|
323 |
+
centroid["bottom"] += iframe_offsets[index]["y"]
|
324 |
+
element_centroids.extend(iframe_poi["element_centroids"])
|
325 |
+
|
326 |
+
except Exception as e:
|
327 |
+
logger.error(f"Error in finding iframes: {e}")
|
328 |
+
|
329 |
+
# Get the centroids of the points of interest
|
330 |
+
self.poi_centroids = [Point(x=xy["x"], y=xy["y"]) for xy in element_centroids]
|
331 |
+
self.bounding_boxes = [BoundingBox(**xy, label=str(i)) for i, xy in enumerate(element_centroids)]
|
332 |
+
self.pois = [
|
333 |
+
POI(info=info, element_centroid=centroid, bounding_box=bbox)
|
334 |
+
for info, centroid, bbox in zip(
|
335 |
+
self.poi_elements,
|
336 |
+
self.poi_centroids,
|
337 |
+
self.bounding_boxes,
|
338 |
+
strict=False,
|
339 |
+
)
|
340 |
+
]
|
341 |
+
|
342 |
+
@property
|
343 |
+
def poi_text(self) -> str:
|
344 |
+
# Get all points of interest on the page as text
|
345 |
+
texts = [element_as_text(mark_id=i, **element) for i, element in enumerate(self.poi_elements)]
|
346 |
+
# Return formatted text of points of interest on page
|
347 |
+
return "\n".join([txt for txt in texts if txt])
|
348 |
+
|
349 |
+
async def screenshot(
|
350 |
+
self,
|
351 |
+
delay: float = 0.0,
|
352 |
+
quality: int = 70,
|
353 |
+
type: str = "jpeg",
|
354 |
+
scale: str = "css",
|
355 |
+
) -> tuple[bytes, bytes]:
|
356 |
+
if delay > 0.0:
|
357 |
+
await asyncio.sleep(delay)
|
358 |
+
await self.update_poi()
|
359 |
+
# Keep original logic if page is highly dynamic, but for static shots, simpler is faster
|
360 |
+
# old_poi_positions = [tuple(point) for point in self.poi_centroids]
|
361 |
+
img = await self.current_page.screenshot(type=type, quality=quality, scale=scale)
|
362 |
+
annotated_img = annotate_bounding_boxes(image=img, bounding_boxes=self.bounding_boxes)
|
363 |
+
# Re-evaluating this block for performance. Removed redundant update_poi and conditional screenshot.
|
364 |
+
# If precise screenshot timing is needed, the caller should manage delays and updates.
|
365 |
+
return img, annotated_img
|
366 |
+
|
367 |
+
async def goto(self, url: str) -> None:
|
368 |
+
await self.current_page.goto(url, wait_until="domcontentloaded")
|
369 |
+
|
370 |
+
async def reload(self) -> None:
|
371 |
+
await self.current_page.reload(wait_until="domcontentloaded")
|
372 |
+
|
373 |
+
async def click_tab(self, mark_id: int) -> None:
|
374 |
+
point: Point = self.poi_centroids[mark_id]
|
375 |
+
await self.hover(point)
|
376 |
+
await self.current_page.mouse.click(*point, button="middle")
|
377 |
+
|
378 |
+
async def click(self, mark_id: int) -> None:
|
379 |
+
point: Point = self.poi_centroids[mark_id]
|
380 |
+
await self.hover(point)
|
381 |
+
await self.current_page.mouse.click(*point)
|
382 |
+
|
383 |
+
async def enter_text(self, mark_id: int, text: str, submit: bool = False) -> None:
|
384 |
+
await self.clear_text_field(mark_id)
|
385 |
+
await self.click(mark_id)
|
386 |
+
await self.current_page.keyboard.type(text)
|
387 |
+
|
388 |
+
if submit:
|
389 |
+
await self.current_page.keyboard.press("Enter")
|
390 |
+
|
391 |
+
async def scroll(
|
392 |
+
self,
|
393 |
+
direction: Literal["up", "down", "left", "right"],
|
394 |
+
mark_id: Optional[int] = None,
|
395 |
+
) -> None:
|
396 |
+
if mark_id is None:
|
397 |
+
point = Point(x=-1, y=-1)
|
398 |
+
max_scroll_x = self.viewport_width
|
399 |
+
max_scroll_y = self.viewport_height
|
400 |
+
else:
|
401 |
+
point: Point = self.poi_centroids[mark_id]
|
402 |
+
bbox: BoundingBox = self.bounding_boxes[mark_id]
|
403 |
+
max_scroll_x = bbox.right - bbox.left
|
404 |
+
max_scroll_y = bbox.bottom - bbox.top
|
405 |
+
|
406 |
+
await self.hover(point=point)
|
407 |
+
scroll_x = int(max_scroll_x * 0.8)
|
408 |
+
scroll_y = int(max_scroll_y * 0.8)
|
409 |
+
is_vertical = direction in ("up", "down")
|
410 |
+
reverse_scroll = direction in ("up", "left")
|
411 |
+
await self.current_page.mouse.wheel(
|
412 |
+
scroll_x * (-1 if reverse_scroll else 1) * (not is_vertical),
|
413 |
+
scroll_y * (-1 if reverse_scroll else 1) * is_vertical,
|
414 |
+
)
|
415 |
+
|
416 |
+
async def go_back(self) -> None:
|
417 |
+
# If there is no tab open then return
|
418 |
+
if not self.current_page:
|
419 |
+
return
|
420 |
+
|
421 |
+
await self.current_page.go_back(wait_until="domcontentloaded")
|
422 |
+
if self.current_page.url == "about:blank":
|
423 |
+
if not len(self.context.pages) > 1:
|
424 |
+
await self.current_page.go_forward(wait_until="domcontentloaded")
|
425 |
+
raise Exception("There is no previous page to go back to.")
|
426 |
+
await self.current_page.close()
|
427 |
+
|
428 |
+
async def hover(self, point: Point) -> None:
|
429 |
+
await self.current_page.mouse.move(*point)
|
430 |
+
|
431 |
+
async def focus(self, point: Point) -> None:
|
432 |
+
# Focus on the element on the page at point (x, y)
|
433 |
+
await self.current_page.evaluate(
|
434 |
+
"""
|
435 |
+
([x, y]) => {
|
436 |
+
const element = document.elementFromPoint(x, y);
|
437 |
+
if (element && element.focus) {
|
438 |
+
element.focus();
|
439 |
+
}
|
440 |
+
}""",
|
441 |
+
tuple(point),
|
442 |
+
)
|
443 |
+
|
444 |
+
async def get_text(self, mark_id: int) -> str:
|
445 |
+
return await self.current_page.evaluate(
|
446 |
+
"""
|
447 |
+
(mark_id) => {
|
448 |
+
const element = marked_elements_convergence[mark_id];
|
449 |
+
if (element && (element.value !== undefined || element.textContent !== undefined)) {
|
450 |
+
return element.value || element.textContent;
|
451 |
+
}
|
452 |
+
return '';
|
453 |
+
}
|
454 |
+
""",
|
455 |
+
(mark_id,),
|
456 |
+
)
|
457 |
+
|
458 |
+
async def clear_text_field(self, mark_id: int) -> None:
|
459 |
+
existing_text = await self.get_text(mark_id)
|
460 |
+
if existing_text.strip():
|
461 |
+
# Clear existing text only if it exists
|
462 |
+
await self.click(mark_id)
|
463 |
+
if platform.system() == "Darwin": # selecting all text is OS-specific
|
464 |
+
await self.click(mark_id)
|
465 |
+
await self.current_page.keyboard.press("Meta+a")
|
466 |
+
await self.current_page.keyboard.press("Backspace")
|
467 |
+
else:
|
468 |
+
await self.current_page.keyboard.press("Control+Home")
|
469 |
+
await self.current_page.keyboard.press("Control+Shift+End")
|
470 |
+
await self.current_page.keyboard.press("Backspace")
|
471 |
+
|
472 |
+
async def open_new_tab_and_go_to(self, url: str) -> None:
|
473 |
+
"""
|
474 |
+
Opens a new browser tab/page and navigates to the specified URL.
|
475 |
+
Closes the old page if it's not the last one remaining.
|
476 |
+
"""
|
477 |
+
logger.info(f"Attempting to open a new tab and navigate to: {url}")
|
478 |
+
new_page = await self.context.new_page()
|
479 |
+
|
480 |
+
# Close the previous page if it's not the only one left in the context
|
481 |
+
if len(self.context.pages) > 1 and self.current_page and self.current_page != new_page:
|
482 |
+
try:
|
483 |
+
await self.current_page.close()
|
484 |
+
logger.debug("Closed previous page.")
|
485 |
+
except Exception as e:
|
486 |
+
logger.warning(f"Could not close previous page (might already be closed or detached): {e}")
|
487 |
+
|
488 |
+
# After navigation, trigger POI update to reflect the new page's state
|
489 |
+
await new_page.goto(url, wait_until="domcontentloaded")
|
490 |
+
logger.info(f"Successfully navigated to {url} in a new tab.")
|
491 |
+
# Crucial: update_poi uses self.current_page, which is now new_page implicitly
|
492 |
+
await self.update_poi()
|
493 |
+
|
494 |
+
|
495 |
+
if __name__ == "__main__":
|
496 |
+
|
497 |
+
async def dummy_test():
|
498 |
+
async with BrowserSession(headless=False) as s:
|
499 |
+
page = await s.context.new_page()
|
500 |
+
await page.goto("http://google.co.uk")
|
501 |
+
await asyncio.sleep(5)
|
502 |
+
await page.screenshot(path="example.png")
|
503 |
+
await s.update_poi()
|
504 |
+
_, annotated_image = await s.screenshot()
|
505 |
+
with open("output.png", "wb") as f:
|
506 |
+
f.write(annotated_image)
|
507 |
+
|
508 |
+
asyncio.run(dummy_test())
|
proxy-lite-demo-v2/src/proxy_lite/browser/find_pois.js
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
marked_elements_convergence = [];
|
2 |
+
|
3 |
+
const interactiveTags = new Set([
|
4 |
+
'a', 'button', 'details', 'embed', 'input', 'label',
|
5 |
+
'menu', 'menuitem', 'object', 'select', 'textarea', 'summary',
|
6 |
+
'video', 'audio', 'option', 'iframe'
|
7 |
+
]);
|
8 |
+
|
9 |
+
const interactiveRoles = new Set([
|
10 |
+
'button', 'menu', 'menuitem', 'link', 'checkbox', 'radio',
|
11 |
+
'slider', 'tab', 'tabpanel', 'textbox', 'combobox', 'grid',
|
12 |
+
'listbox', 'option', 'progressbar', 'scrollbar', 'searchbox',
|
13 |
+
'switch', 'tree', 'treeitem', 'spinbutton', 'tooltip',
|
14 |
+
'a-button-inner', 'a-dropdown-button', 'click',
|
15 |
+
'menuitemcheckbox', 'menuitemradio', 'a-button-text',
|
16 |
+
'button-text', 'button-icon', 'button-icon-only',
|
17 |
+
'button-text-icon-only', 'dropdown', 'combobox'
|
18 |
+
]);
|
19 |
+
|
20 |
+
findPOIsConvergence = (input = null) => {
|
21 |
+
|
22 |
+
let rootElement = input ? input : document.documentElement;
|
23 |
+
|
24 |
+
function isScrollable(element) {
|
25 |
+
if ((input === null) && (element === document.documentElement)) {
|
26 |
+
// we can always scroll the full page
|
27 |
+
return false;
|
28 |
+
}
|
29 |
+
|
30 |
+
const style = window.getComputedStyle(element);
|
31 |
+
|
32 |
+
const hasScrollableYContent = element.scrollHeight > element.clientHeight
|
33 |
+
const overflowYScroll = style.overflowY === 'scroll' || style.overflowY === 'auto';
|
34 |
+
|
35 |
+
const hasScrollableXContent = element.scrollWidth > element.clientWidth;
|
36 |
+
const overflowXScroll = style.overflowX === 'scroll' || style.overflowX === 'auto';
|
37 |
+
|
38 |
+
return (hasScrollableYContent && overflowYScroll) || (hasScrollableXContent && overflowXScroll);
|
39 |
+
}
|
40 |
+
|
41 |
+
function getEventListeners(element) {
|
42 |
+
try {
|
43 |
+
return window.getEventListeners?.(element) || {};
|
44 |
+
} catch (e) {
|
45 |
+
return {};
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
function isInteractive(element) {
|
50 |
+
if (!element) return false;
|
51 |
+
|
52 |
+
return (hasInteractiveTag(element) ||
|
53 |
+
hasInteractiveAttributes(element) ||
|
54 |
+
hasInteractiveEventListeners(element)) ||
|
55 |
+
isScrollable(element);
|
56 |
+
}
|
57 |
+
|
58 |
+
function hasInteractiveTag(element) {
|
59 |
+
return interactiveTags.has(element.tagName.toLowerCase());
|
60 |
+
}
|
61 |
+
|
62 |
+
function hasInteractiveAttributes(element) {
|
63 |
+
const role = element.getAttribute('role');
|
64 |
+
const ariaRole = element.getAttribute('aria-role');
|
65 |
+
const tabIndex = element.getAttribute('tabindex');
|
66 |
+
const onAttribute = element.getAttribute('on');
|
67 |
+
|
68 |
+
if (element.getAttribute('contenteditable') === 'true') return true;
|
69 |
+
if ((role && interactiveRoles.has(role)) ||
|
70 |
+
(ariaRole && interactiveRoles.has(ariaRole))) return true;
|
71 |
+
if (tabIndex !== null && tabIndex !== '-1') return true;
|
72 |
+
|
73 |
+
// Add check for AMP's 'on' attribute that starts with 'tap:'
|
74 |
+
if (onAttribute && onAttribute.startsWith('tap:')) return true;
|
75 |
+
|
76 |
+
const hasAriaProps = element.hasAttribute('aria-expanded') ||
|
77 |
+
element.hasAttribute('aria-pressed') ||
|
78 |
+
element.hasAttribute('aria-selected') ||
|
79 |
+
element.hasAttribute('aria-checked');
|
80 |
+
|
81 |
+
return hasAriaProps;
|
82 |
+
}
|
83 |
+
|
84 |
+
function hasInteractiveEventListeners(element) {
|
85 |
+
const hasClickHandler = element.onclick !== null ||
|
86 |
+
element.getAttribute('onclick') !== null ||
|
87 |
+
element.hasAttribute('ng-click') ||
|
88 |
+
element.hasAttribute('@click') ||
|
89 |
+
element.hasAttribute('v-on:click');
|
90 |
+
if (hasClickHandler) return true;
|
91 |
+
|
92 |
+
const listeners = getEventListeners(element);
|
93 |
+
return listeners && (
|
94 |
+
listeners.click?.length > 0 ||
|
95 |
+
listeners.mousedown?.length > 0 ||
|
96 |
+
listeners.mouseup?.length > 0 ||
|
97 |
+
listeners.touchstart?.length > 0 ||
|
98 |
+
listeners.touchend?.length > 0
|
99 |
+
);
|
100 |
+
}
|
101 |
+
|
102 |
+
function calculateArea(rects) {
|
103 |
+
return rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
|
104 |
+
}
|
105 |
+
|
106 |
+
function getElementRects(element, context) {
|
107 |
+
const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
108 |
+
const vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
109 |
+
|
110 |
+
let rects = [...element.getClientRects()];
|
111 |
+
|
112 |
+
// If rects are empty (likely due to Shadow DOM), try to estimate position
|
113 |
+
if (rects.length === 0 && element.getBoundingClientRect) {
|
114 |
+
rects = [element.getBoundingClientRect()];
|
115 |
+
}
|
116 |
+
|
117 |
+
// Get iframe offset if element is in an iframe
|
118 |
+
let iframeOffset = { x: 0, y: 0 };
|
119 |
+
if (context !== document && context?.defaultView?.frameElement) {
|
120 |
+
const iframe = context.defaultView.frameElement;
|
121 |
+
if (iframe) {
|
122 |
+
const iframeRect = iframe.getBoundingClientRect();
|
123 |
+
iframeOffset = {
|
124 |
+
x: iframeRect.left,
|
125 |
+
y: iframeRect.top
|
126 |
+
};
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
return rects.filter(bb => {
|
131 |
+
const center_x = bb.left + bb.width / 2 + iframeOffset.x;
|
132 |
+
const center_y = bb.top + bb.height / 2 + iframeOffset.y;
|
133 |
+
const elAtCenter = context.elementFromPoint(center_x - iframeOffset.x, center_y - iframeOffset.y);
|
134 |
+
|
135 |
+
return elAtCenter === element || element.contains(elAtCenter);
|
136 |
+
}).map(bb => {
|
137 |
+
const rect = {
|
138 |
+
left: Math.max(0, bb.left + iframeOffset.x),
|
139 |
+
top: Math.max(0, bb.top + iframeOffset.y),
|
140 |
+
right: Math.min(vw, bb.right + iframeOffset.x),
|
141 |
+
bottom: Math.min(vh, bb.bottom + iframeOffset.y)
|
142 |
+
};
|
143 |
+
return {
|
144 |
+
...rect,
|
145 |
+
width: rect.right - rect.left,
|
146 |
+
height: rect.bottom - rect.top
|
147 |
+
};
|
148 |
+
});
|
149 |
+
}
|
150 |
+
|
151 |
+
function isElementVisible(element) {
|
152 |
+
const style = window.getComputedStyle(element);
|
153 |
+
return element.offsetWidth > 0 &&
|
154 |
+
element.offsetHeight > 0 &&
|
155 |
+
style.visibility !== 'hidden' &&
|
156 |
+
style.display !== 'none';
|
157 |
+
}
|
158 |
+
|
159 |
+
function isTopElement(element) {
|
160 |
+
let doc = element.ownerDocument;
|
161 |
+
if (doc !== window.document) {
|
162 |
+
// If in an iframe's document, treat as top
|
163 |
+
return true;
|
164 |
+
}
|
165 |
+
const shadowRoot = element.getRootNode();
|
166 |
+
if (shadowRoot instanceof ShadowRoot) {
|
167 |
+
const rect = element.getBoundingClientRect();
|
168 |
+
const point = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
|
169 |
+
try {
|
170 |
+
const topEl = shadowRoot.elementFromPoint(point.x, point.y);
|
171 |
+
if (!topEl) return false;
|
172 |
+
let current = topEl;
|
173 |
+
while (current && current !== shadowRoot) {
|
174 |
+
if (current === element) return true;
|
175 |
+
current = current.parentElement;
|
176 |
+
}
|
177 |
+
return false;
|
178 |
+
} catch (e) {
|
179 |
+
return true;
|
180 |
+
}
|
181 |
+
}
|
182 |
+
const rect = element.getBoundingClientRect();
|
183 |
+
const point = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
|
184 |
+
try {
|
185 |
+
const topEl = document.elementFromPoint(point.x, point.y);
|
186 |
+
if (!topEl) return false;
|
187 |
+
let current = topEl;
|
188 |
+
while (current && current !== document.documentElement) {
|
189 |
+
if (current === element) return true;
|
190 |
+
current = current.parentElement;
|
191 |
+
}
|
192 |
+
return false;
|
193 |
+
} catch (e) {
|
194 |
+
return true;
|
195 |
+
}
|
196 |
+
}
|
197 |
+
|
198 |
+
function getVisibleText(element, marked_elements_convergence = []) {
|
199 |
+
const blockLikeDisplays = [
|
200 |
+
// Basic block elements
|
201 |
+
'block', 'flow-root', 'inline-block',
|
202 |
+
// Lists
|
203 |
+
'list-item',
|
204 |
+
// Table elements
|
205 |
+
'table', 'inline-table', 'table-row', 'table-cell',
|
206 |
+
'table-caption', 'table-header-group', 'table-footer-group',
|
207 |
+
'table-row-group',
|
208 |
+
// Modern layouts
|
209 |
+
'flex', 'inline-flex', 'grid', 'inline-grid'
|
210 |
+
];
|
211 |
+
|
212 |
+
// Check if element is hidden
|
213 |
+
const style = window.getComputedStyle(element);
|
214 |
+
if (style.display === 'none' || style.visibility === 'hidden') {
|
215 |
+
return '';
|
216 |
+
}
|
217 |
+
|
218 |
+
let collectedText = [];
|
219 |
+
|
220 |
+
function isMarkedInteractive(el) {
|
221 |
+
return marked_elements_convergence.includes(el);
|
222 |
+
}
|
223 |
+
|
224 |
+
function traverse(node) {
|
225 |
+
if (
|
226 |
+
node.nodeType === Node.ELEMENT_NODE &&
|
227 |
+
node !== element &&
|
228 |
+
isMarkedInteractive(node)
|
229 |
+
) {
|
230 |
+
return false;
|
231 |
+
}
|
232 |
+
|
233 |
+
if (node.nodeType === Node.TEXT_NODE) {
|
234 |
+
const trimmed = node.textContent.trim();
|
235 |
+
if (trimmed) {
|
236 |
+
collectedText.push(trimmed);
|
237 |
+
}
|
238 |
+
} else if (node.nodeType === Node.ELEMENT_NODE) {
|
239 |
+
// Skip noscript elements
|
240 |
+
if (node.tagName === 'NOSCRIPT') {
|
241 |
+
return true;
|
242 |
+
}
|
243 |
+
|
244 |
+
const nodeStyle = window.getComputedStyle(node);
|
245 |
+
|
246 |
+
// Skip hidden elements
|
247 |
+
if (nodeStyle.display === 'none' || nodeStyle.visibility === 'hidden') {
|
248 |
+
return true;
|
249 |
+
}
|
250 |
+
|
251 |
+
// Add newline before block elements if we have text
|
252 |
+
if (blockLikeDisplays.includes(nodeStyle.display) && collectedText.length > 0) {
|
253 |
+
collectedText.push('\n');
|
254 |
+
}
|
255 |
+
|
256 |
+
if (node.tagName === 'IMG') {
|
257 |
+
const textParts = [];
|
258 |
+
const alt = node.getAttribute('alt');
|
259 |
+
const title = node.getAttribute('title');
|
260 |
+
const ariaLabel = node.getAttribute('aria-label');
|
261 |
+
// Add more as needed (e.g., 'aria-describedby', 'data-caption', etc.)
|
262 |
+
|
263 |
+
if (alt) textParts.push(`alt="${alt}"`);
|
264 |
+
if (title) textParts.push(`title="${title}"`);
|
265 |
+
if (ariaLabel) textParts.push(`aria-label="${ariaLabel}"`);
|
266 |
+
|
267 |
+
if (textParts.length > 0) {
|
268 |
+
collectedText.push(`[img - ${textParts.join(' ')}]`);
|
269 |
+
}
|
270 |
+
return true;
|
271 |
+
}
|
272 |
+
|
273 |
+
for (const child of node.childNodes) {
|
274 |
+
const shouldContinue = traverse(child);
|
275 |
+
if (shouldContinue === false) {
|
276 |
+
return false;
|
277 |
+
}
|
278 |
+
}
|
279 |
+
|
280 |
+
// Add newline after block elements
|
281 |
+
if (blockLikeDisplays.includes(nodeStyle.display)) {
|
282 |
+
collectedText.push('\n');
|
283 |
+
}
|
284 |
+
}
|
285 |
+
|
286 |
+
return true;
|
287 |
+
}
|
288 |
+
|
289 |
+
traverse(element);
|
290 |
+
|
291 |
+
// Join text and normalize whitespace
|
292 |
+
return collectedText.join(' ').trim().replace(/\s{2,}/g, ' ').trim();
|
293 |
+
}
|
294 |
+
|
295 |
+
function extractInteractiveItems(rootElement) {
|
296 |
+
const items = [];
|
297 |
+
|
298 |
+
function processElement(element, context) {
|
299 |
+
if (!element) return;
|
300 |
+
|
301 |
+
// Recursively process elements
|
302 |
+
if (element.nodeType === Node.ELEMENT_NODE && isInteractive(element) && isElementVisible(element) && isTopElement(element)) {
|
303 |
+
const rects = getElementRects(element, context);
|
304 |
+
const area = calculateArea(rects);
|
305 |
+
items.push({
|
306 |
+
element: element,
|
307 |
+
area,
|
308 |
+
rects,
|
309 |
+
is_scrollable: isScrollable(element),
|
310 |
+
});
|
311 |
+
}
|
312 |
+
|
313 |
+
if (element.shadowRoot) {
|
314 |
+
// if it's shadow DOM, process elements in the shadow DOM
|
315 |
+
Array.from(element.shadowRoot.childNodes || []).forEach(child => {
|
316 |
+
processElement(child, element.shadowRoot);
|
317 |
+
});
|
318 |
+
}
|
319 |
+
|
320 |
+
if (element.tagName === 'SLOT') {
|
321 |
+
// Handle both assigned elements and nodes
|
322 |
+
const assigned = element.assignedNodes ? element.assignedNodes() : element.assignedElements();
|
323 |
+
assigned.forEach(child => {
|
324 |
+
processElement(child, context);
|
325 |
+
});
|
326 |
+
}
|
327 |
+
else if (element.tagName === 'IFRAME') {
|
328 |
+
try {
|
329 |
+
const iframeDoc = element.contentDocument || element.contentWindow?.document;
|
330 |
+
if (iframeDoc && iframeDoc.body) {
|
331 |
+
// Process elements inside iframe
|
332 |
+
processElement(iframeDoc.body, iframeDoc);
|
333 |
+
}
|
334 |
+
} catch (e) {
|
335 |
+
console.warn('Unable to access iframe contents:', e);
|
336 |
+
}
|
337 |
+
} else {
|
338 |
+
// if it's regular child elements, process regular child elements
|
339 |
+
Array.from(element.children || []).forEach(child => {
|
340 |
+
processElement(child, context);
|
341 |
+
});
|
342 |
+
}
|
343 |
+
}
|
344 |
+
|
345 |
+
processElement(rootElement, document);
|
346 |
+
return items;
|
347 |
+
}
|
348 |
+
|
349 |
+
if (marked_elements_convergence) {
|
350 |
+
marked_elements_convergence = [];
|
351 |
+
}
|
352 |
+
let mark_centres = [];
|
353 |
+
let marked_element_descriptions = [];
|
354 |
+
var items = extractInteractiveItems(rootElement);
|
355 |
+
|
356 |
+
// Lets create a floating border on top of these elements that will always be visible
|
357 |
+
let index = 0;
|
358 |
+
items.forEach(function (item) {
|
359 |
+
item.rects.forEach((bbox) => {
|
360 |
+
marked_elements_convergence.push(item.element);
|
361 |
+
mark_centres.push({
|
362 |
+
x: Math.round((bbox.left + bbox.right) / 2),
|
363 |
+
y: Math.round((bbox.top + bbox.bottom) / 2),
|
364 |
+
left: bbox.left,
|
365 |
+
top: bbox.top,
|
366 |
+
right: bbox.right,
|
367 |
+
bottom: bbox.bottom,
|
368 |
+
});
|
369 |
+
marked_element_descriptions.push({
|
370 |
+
tag: item.element.tagName,
|
371 |
+
text: getVisibleText(item.element),
|
372 |
+
// NOTE: all other attributes will be shown to the model when present
|
373 |
+
// TODO: incorperate child attributes, e.g. <img alt="..."> when img is a child of the link element
|
374 |
+
value: item.element.value,
|
375 |
+
placeholder: item.element.getAttribute("placeholder"),
|
376 |
+
element_type: item.element.getAttribute("type"),
|
377 |
+
aria_label: item.element.getAttribute("aria-label"),
|
378 |
+
name: item.element.getAttribute("name"),
|
379 |
+
required: item.element.getAttribute("required"),
|
380 |
+
disabled: item.element.getAttribute("disabled"),
|
381 |
+
pattern: item.element.getAttribute("pattern"),
|
382 |
+
checked: item.element.getAttribute("checked"),
|
383 |
+
minlength: item.element.getAttribute("minlength"),
|
384 |
+
maxlength: item.element.getAttribute("maxlength"),
|
385 |
+
role: item.element.getAttribute("role"),
|
386 |
+
title: item.element.getAttribute("title"),
|
387 |
+
scrollable: item.is_scrollable
|
388 |
+
});
|
389 |
+
index++;
|
390 |
+
});
|
391 |
+
});
|
392 |
+
|
393 |
+
return {
|
394 |
+
element_descriptions: marked_element_descriptions,
|
395 |
+
element_centroids: mark_centres
|
396 |
+
};
|
397 |
+
}
|
proxy-lite-demo-v2/src/proxy_lite/cli.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import base64
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from proxy_lite import Runner, RunnerConfig
|
8 |
+
from proxy_lite.gif_maker import create_run_gif
|
9 |
+
from proxy_lite.logger import logger
|
10 |
+
|
11 |
+
|
12 |
+
def update_config_from_env(config: RunnerConfig) -> RunnerConfig:
|
13 |
+
if os.getenv("PROXY_LITE_API_BASE"):
|
14 |
+
config.solver.agent.client.api_base = os.getenv("PROXY_LITE_API_BASE")
|
15 |
+
if os.getenv("PROXY_LITE_MODEL"):
|
16 |
+
config.solver.agent.client.model_id = os.getenv("PROXY_LITE_MODEL")
|
17 |
+
if os.getenv("PROXY_LITE_VIEWPORT_WIDTH"):
|
18 |
+
config.environment.viewport_width = int(os.getenv("PROXY_LITE_VIEWPORT_WIDTH"))
|
19 |
+
if os.getenv("PROXY_LITE_VIEWPORT_HEIGHT"):
|
20 |
+
config.environment.viewport_height = int(os.getenv("PROXY_LITE_VIEWPORT_HEIGHT"))
|
21 |
+
return config
|
22 |
+
|
23 |
+
|
24 |
+
def do_command(args):
|
25 |
+
do_text = " ".join(args.task)
|
26 |
+
logger.info("π€ Let me help you with that...")
|
27 |
+
# Take default config from YAML
|
28 |
+
config = RunnerConfig.from_yaml(args.config)
|
29 |
+
# Update config from environment variables
|
30 |
+
config = update_config_from_env(config)
|
31 |
+
# Update config from command-line arguments
|
32 |
+
if args.api_base:
|
33 |
+
config.solver.agent.client.api_base = args.api_base
|
34 |
+
if args.model:
|
35 |
+
config.solver.agent.client.model_id = args.model
|
36 |
+
if args.homepage:
|
37 |
+
config.environment.homepage = args.homepage
|
38 |
+
if args.viewport_width:
|
39 |
+
config.environment.viewport_width = args.viewport_width
|
40 |
+
if args.viewport_height:
|
41 |
+
config.environment.viewport_height = args.viewport_height
|
42 |
+
o = Runner(config=config)
|
43 |
+
result = asyncio.run(o.run(do_text))
|
44 |
+
|
45 |
+
final_screenshot = result.observations[-1].info["original_image"]
|
46 |
+
folder_path = Path(__file__).parent.parent.parent / "screenshots"
|
47 |
+
folder_path.mkdir(parents=True, exist_ok=True)
|
48 |
+
path = folder_path / f"{result.run_id}.png"
|
49 |
+
with open(path, "wb") as f:
|
50 |
+
f.write(base64.b64decode(final_screenshot))
|
51 |
+
logger.info(f"π€ Final screenshot saved to {path}")
|
52 |
+
|
53 |
+
gif_folder_path = Path(__file__).parent.parent.parent / "gifs"
|
54 |
+
gif_folder_path.mkdir(parents=True, exist_ok=True)
|
55 |
+
gif_path = gif_folder_path / f"{result.run_id}.gif"
|
56 |
+
create_run_gif(result, gif_path, duration=1500)
|
57 |
+
logger.info(f"π€ GIF saved to {gif_path}")
|
58 |
+
|
59 |
+
|
60 |
+
def main():
|
61 |
+
parser = argparse.ArgumentParser(description="Proxy-Lite")
|
62 |
+
parser.add_argument(
|
63 |
+
"task",
|
64 |
+
type=str,
|
65 |
+
help="The task you want to accomplish",
|
66 |
+
nargs="*",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--model",
|
70 |
+
type=str,
|
71 |
+
default=None,
|
72 |
+
help="The model to use.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--api_base",
|
76 |
+
type=str,
|
77 |
+
default=None,
|
78 |
+
help="The API base URL to use.",
|
79 |
+
)
|
80 |
+
# New option for setting a homepage URL:
|
81 |
+
parser.add_argument(
|
82 |
+
"--homepage",
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
help="The homepage URL to use.",
|
86 |
+
)
|
87 |
+
# New viewport controls:
|
88 |
+
parser.add_argument(
|
89 |
+
"--viewport-width",
|
90 |
+
type=int,
|
91 |
+
default=None,
|
92 |
+
help="Viewport width in pixels.",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--viewport-height",
|
96 |
+
type=int,
|
97 |
+
default=None,
|
98 |
+
help="Viewport height in pixels.",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--config",
|
102 |
+
type=Path,
|
103 |
+
default=Path(__file__).parent / "configs/default.yaml",
|
104 |
+
help="Path to config file (default: configs/default.yaml)",
|
105 |
+
)
|
106 |
+
|
107 |
+
args = parser.parse_args()
|
108 |
+
do_command(args)
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|
proxy-lite-demo-v2/src/proxy_lite/client.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from functools import cached_property
|
4 |
+
from typing import ClassVar, Literal, Optional, Union
|
5 |
+
|
6 |
+
import httpx
|
7 |
+
from httpx import Limits, Timeout
|
8 |
+
from openai import AsyncOpenAI
|
9 |
+
from openai.types.chat.chat_completion import (
|
10 |
+
ChatCompletion,
|
11 |
+
)
|
12 |
+
from pydantic import BaseModel
|
13 |
+
|
14 |
+
from proxy_lite.history import MessageHistory
|
15 |
+
from proxy_lite.logger import logger
|
16 |
+
from proxy_lite.serializer import (
|
17 |
+
BaseSerializer,
|
18 |
+
OpenAICompatibleSerializer,
|
19 |
+
)
|
20 |
+
from proxy_lite.tools import Tool
|
21 |
+
|
22 |
+
|
23 |
+
class BaseClientConfig(BaseModel):
|
24 |
+
http_timeout: float = 50
|
25 |
+
http_concurrent_connections: int = 50
|
26 |
+
|
27 |
+
|
28 |
+
class BaseClient(BaseModel, ABC):
|
29 |
+
config: BaseClientConfig
|
30 |
+
serializer: ClassVar[BaseSerializer]
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
async def create_completion(
|
34 |
+
self,
|
35 |
+
messages: MessageHistory,
|
36 |
+
temperature: float = 0.7,
|
37 |
+
seed: Optional[int] = None,
|
38 |
+
tools: Optional[list[Tool]] = None,
|
39 |
+
response_format: Optional[type[BaseModel]] = None,
|
40 |
+
) -> ChatCompletion: ...
|
41 |
+
|
42 |
+
"""
|
43 |
+
Create completion from model.
|
44 |
+
Expect subclasses to adapt from various endpoints that will handle
|
45 |
+
requests differently, make sure to raise appropriate warnings.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
ChatCompletion: OpenAI ChatCompletion format for consistency
|
49 |
+
"""
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def create(cls, config: BaseClientConfig) -> "BaseClient":
|
53 |
+
supported_clients = {
|
54 |
+
"openai": OpenAIClient,
|
55 |
+
"openai-azure": OpenAIClient,
|
56 |
+
"convergence": ConvergenceClient,
|
57 |
+
"gemini": GeminiClient,
|
58 |
+
}
|
59 |
+
if config.name not in supported_clients:
|
60 |
+
error_message = f"Unsupported model: {config.name}."
|
61 |
+
raise ValueError(error_message)
|
62 |
+
return supported_clients[config.name](config=config)
|
63 |
+
|
64 |
+
@property
|
65 |
+
def http_client(self) -> httpx.AsyncClient:
|
66 |
+
return httpx.AsyncClient(
|
67 |
+
timeout=Timeout(self.config.http_timeout),
|
68 |
+
limits=Limits(
|
69 |
+
max_connections=self.config.http_concurrent_connections,
|
70 |
+
max_keepalive_connections=self.config.http_concurrent_connections,
|
71 |
+
),
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
class OpenAIClientConfig(BaseClientConfig):
|
76 |
+
name: Literal["openai"] = "openai"
|
77 |
+
model_id: str = "gpt-4o"
|
78 |
+
api_key: str = os.environ.get("OPENAI_API_KEY")
|
79 |
+
api_base: Optional[str] = None
|
80 |
+
|
81 |
+
|
82 |
+
class OpenAIClient(BaseClient):
|
83 |
+
config: OpenAIClientConfig
|
84 |
+
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
|
85 |
+
|
86 |
+
@cached_property
|
87 |
+
def external_client(self) -> AsyncOpenAI:
|
88 |
+
client_params = {
|
89 |
+
"api_key": self.config.api_key,
|
90 |
+
"http_client": self.http_client,
|
91 |
+
}
|
92 |
+
if self.config.api_base:
|
93 |
+
client_params["base_url"] = self.config.api_base
|
94 |
+
return AsyncOpenAI(**client_params)
|
95 |
+
|
96 |
+
async def create_completion(
|
97 |
+
self,
|
98 |
+
messages: MessageHistory,
|
99 |
+
temperature: float = 0.7,
|
100 |
+
seed: Optional[int] = None,
|
101 |
+
tools: Optional[list[Tool]] = None,
|
102 |
+
response_format: Optional[type[BaseModel]] = None,
|
103 |
+
) -> ChatCompletion:
|
104 |
+
base_params = {
|
105 |
+
"model": self.config.model_id,
|
106 |
+
"messages": self.serializer.serialize_messages(messages),
|
107 |
+
"temperature": temperature,
|
108 |
+
}
|
109 |
+
optional_params = {
|
110 |
+
"seed": seed,
|
111 |
+
"tools": self.serializer.serialize_tools(tools) if tools else None,
|
112 |
+
"tool_choice": "required" if tools else None,
|
113 |
+
"response_format": {"type": "json_object"} if response_format else {"type": "text"},
|
114 |
+
}
|
115 |
+
base_params.update({k: v for k, v in optional_params.items() if v is not None})
|
116 |
+
return await self.external_client.chat.completions.create(**base_params)
|
117 |
+
|
118 |
+
|
119 |
+
class ConvergenceClientConfig(BaseClientConfig):
|
120 |
+
name: Literal["convergence"] = "convergence"
|
121 |
+
model_id: str = "convergence-ai/proxy-lite-7b"
|
122 |
+
api_base: str = "http://localhost:8000/v1"
|
123 |
+
api_key: str = "none"
|
124 |
+
|
125 |
+
|
126 |
+
class ConvergenceClient(OpenAIClient):
|
127 |
+
config: ConvergenceClientConfig
|
128 |
+
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
|
129 |
+
_model_validated: bool = False
|
130 |
+
|
131 |
+
async def _validate_model(self) -> None:
|
132 |
+
try:
|
133 |
+
response = await self.external_client.models.list()
|
134 |
+
assert self.config.model_id in [model.id for model in response.data], (
|
135 |
+
f"Model {self.config.model_id} not found in {response.data}"
|
136 |
+
)
|
137 |
+
self._model_validated = True
|
138 |
+
logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
|
139 |
+
except Exception as e:
|
140 |
+
logger.error(f"Error retrieving model: {e}")
|
141 |
+
raise e
|
142 |
+
|
143 |
+
@cached_property
|
144 |
+
def external_client(self) -> AsyncOpenAI:
|
145 |
+
return AsyncOpenAI(
|
146 |
+
api_key=self.config.api_key,
|
147 |
+
base_url=self.config.api_base,
|
148 |
+
http_client=self.http_client,
|
149 |
+
)
|
150 |
+
|
151 |
+
async def create_completion(
|
152 |
+
self,
|
153 |
+
messages: MessageHistory,
|
154 |
+
temperature: float = 0.7,
|
155 |
+
seed: Optional[int] = None,
|
156 |
+
tools: Optional[list[Tool]] = None,
|
157 |
+
response_format: Optional[type[BaseModel]] = None,
|
158 |
+
) -> ChatCompletion:
|
159 |
+
if not self._model_validated:
|
160 |
+
await self._validate_model()
|
161 |
+
base_params = {
|
162 |
+
"model": self.config.model_id,
|
163 |
+
"messages": self.serializer.serialize_messages(messages),
|
164 |
+
"temperature": temperature,
|
165 |
+
}
|
166 |
+
optional_params = {
|
167 |
+
"seed": seed,
|
168 |
+
"tools": self.serializer.serialize_tools(tools) if tools else None,
|
169 |
+
"tool_choice": "auto" if tools else None, # vLLM does not support "required"
|
170 |
+
"response_format": response_format if response_format else {"type": "text"},
|
171 |
+
}
|
172 |
+
base_params.update({k: v for k, v in optional_params.items() if v is not None})
|
173 |
+
return await self.external_client.chat.completions.create(**base_params)
|
174 |
+
|
175 |
+
|
176 |
+
class GeminiClientConfig(BaseClientConfig):
|
177 |
+
name: Literal["gemini"] = "gemini"
|
178 |
+
model_id: str = "gemini-2.0-flash-001"
|
179 |
+
api_key: str = ""
|
180 |
+
|
181 |
+
|
182 |
+
class GeminiClient(BaseClient):
|
183 |
+
config: GeminiClientConfig
|
184 |
+
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
|
185 |
+
|
186 |
+
def _convert_messages_to_gemini_format(self, messages):
|
187 |
+
"""Convert OpenAI format messages to Gemini format"""
|
188 |
+
gemini_parts = []
|
189 |
+
for msg in messages:
|
190 |
+
if msg["role"] == "user":
|
191 |
+
gemini_parts.append({"text": msg["content"]})
|
192 |
+
elif msg["role"] == "assistant":
|
193 |
+
gemini_parts.append({"text": msg["content"]})
|
194 |
+
# Skip system messages or add them to the first user message
|
195 |
+
return gemini_parts
|
196 |
+
|
197 |
+
def _clean_schema_for_gemini(self, schema):
|
198 |
+
"""Clean up JSON schema for Gemini function calling - remove $defs and $ref"""
|
199 |
+
if not isinstance(schema, dict):
|
200 |
+
return schema
|
201 |
+
|
202 |
+
cleaned = {}
|
203 |
+
for key, value in schema.items():
|
204 |
+
if key == "$defs":
|
205 |
+
# Skip $defs - we'll inline the definitions
|
206 |
+
continue
|
207 |
+
elif key == "$ref":
|
208 |
+
# Skip $ref - we'll inline the referenced schema
|
209 |
+
continue
|
210 |
+
elif isinstance(value, dict):
|
211 |
+
cleaned[key] = self._clean_schema_for_gemini(value)
|
212 |
+
elif isinstance(value, list):
|
213 |
+
cleaned[key] = [self._clean_schema_for_gemini(item) for item in value]
|
214 |
+
else:
|
215 |
+
cleaned[key] = value
|
216 |
+
|
217 |
+
# If we have $defs, we need to inline them
|
218 |
+
if "$defs" in schema:
|
219 |
+
cleaned = self._inline_definitions(cleaned, schema["$defs"])
|
220 |
+
|
221 |
+
return cleaned
|
222 |
+
|
223 |
+
def _inline_definitions(self, schema, definitions):
|
224 |
+
"""Inline $ref definitions into the schema"""
|
225 |
+
if not isinstance(schema, dict):
|
226 |
+
return schema
|
227 |
+
|
228 |
+
if "$ref" in schema:
|
229 |
+
# Extract the reference name (e.g., "#/$defs/TypeEntry" -> "TypeEntry")
|
230 |
+
ref_name = schema["$ref"].split("/")[-1]
|
231 |
+
if ref_name in definitions:
|
232 |
+
# Replace the $ref with the actual definition
|
233 |
+
return self._inline_definitions(definitions[ref_name], definitions)
|
234 |
+
else:
|
235 |
+
# If we can't find the definition, remove the $ref
|
236 |
+
return {k: v for k, v in schema.items() if k != "$ref"}
|
237 |
+
|
238 |
+
# Recursively process nested objects
|
239 |
+
inlined = {}
|
240 |
+
for key, value in schema.items():
|
241 |
+
if isinstance(value, dict):
|
242 |
+
inlined[key] = self._inline_definitions(value, definitions)
|
243 |
+
elif isinstance(value, list):
|
244 |
+
inlined[key] = [self._inline_definitions(item, definitions) for item in value]
|
245 |
+
else:
|
246 |
+
inlined[key] = value
|
247 |
+
|
248 |
+
return inlined
|
249 |
+
|
250 |
+
async def create_completion(
|
251 |
+
self,
|
252 |
+
messages: MessageHistory,
|
253 |
+
temperature: float = 0.7,
|
254 |
+
seed: Optional[int] = None,
|
255 |
+
tools: Optional[list[Tool]] = None,
|
256 |
+
response_format: Optional[type[BaseModel]] = None,
|
257 |
+
) -> ChatCompletion:
|
258 |
+
import json
|
259 |
+
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
260 |
+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
261 |
+
from openai.types.completion_usage import CompletionUsage
|
262 |
+
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
263 |
+
|
264 |
+
# Convert messages to format expected by Gemini
|
265 |
+
serialized_messages = self.serializer.serialize_messages(messages)
|
266 |
+
|
267 |
+
# For Gemini API, we need to format contents correctly with proper roles
|
268 |
+
contents = []
|
269 |
+
current_user_text = ""
|
270 |
+
|
271 |
+
for msg in serialized_messages:
|
272 |
+
# Extract the actual text content from the serialized message
|
273 |
+
content_text = ""
|
274 |
+
if isinstance(msg["content"], list):
|
275 |
+
# Handle complex content format
|
276 |
+
for item in msg["content"]:
|
277 |
+
if isinstance(item, dict) and "text" in item:
|
278 |
+
content_text += item["text"]
|
279 |
+
elif isinstance(item, str):
|
280 |
+
content_text += item
|
281 |
+
elif isinstance(msg["content"], str):
|
282 |
+
content_text = msg["content"]
|
283 |
+
|
284 |
+
if msg["role"] == "user":
|
285 |
+
# Accumulate user messages
|
286 |
+
current_user_text += content_text + "\n"
|
287 |
+
elif msg["role"] == "assistant":
|
288 |
+
# If we have accumulated user text, add it first
|
289 |
+
if current_user_text.strip():
|
290 |
+
contents.append({
|
291 |
+
"role": "user",
|
292 |
+
"parts": [{"text": current_user_text.strip()}]
|
293 |
+
})
|
294 |
+
current_user_text = ""
|
295 |
+
|
296 |
+
# Add assistant message with role "model"
|
297 |
+
contents.append({
|
298 |
+
"role": "model",
|
299 |
+
"parts": [{"text": content_text}]
|
300 |
+
})
|
301 |
+
elif msg["role"] == "tool":
|
302 |
+
# Add tool messages as user messages so they're included in context
|
303 |
+
# Format tool message more clearly for the agent to understand
|
304 |
+
current_user_text += f"[ACTION COMPLETED] {content_text}\n"
|
305 |
+
|
306 |
+
# Add any remaining user text
|
307 |
+
if current_user_text.strip():
|
308 |
+
contents.append({
|
309 |
+
"role": "user",
|
310 |
+
"parts": [{"text": current_user_text.strip()}]
|
311 |
+
})
|
312 |
+
|
313 |
+
payload = {
|
314 |
+
"contents": contents,
|
315 |
+
"generationConfig": {
|
316 |
+
"temperature": temperature,
|
317 |
+
}
|
318 |
+
}
|
319 |
+
|
320 |
+
# Add function calling support if tools are provided
|
321 |
+
if tools:
|
322 |
+
# Convert tools to Gemini function declaration format
|
323 |
+
function_declarations = []
|
324 |
+
for tool in tools:
|
325 |
+
for tool_schema in tool.schema:
|
326 |
+
# Clean up the schema for Gemini - remove $defs and $ref
|
327 |
+
cleaned_parameters = self._clean_schema_for_gemini(tool_schema["parameters"])
|
328 |
+
function_declarations.append({
|
329 |
+
"name": tool_schema["name"],
|
330 |
+
"description": tool_schema["description"],
|
331 |
+
"parameters": cleaned_parameters
|
332 |
+
})
|
333 |
+
|
334 |
+
payload["tools"] = [{
|
335 |
+
"function_declarations": function_declarations
|
336 |
+
}]
|
337 |
+
|
338 |
+
# Make direct HTTP request to native Gemini API
|
339 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.config.model_id}:generateContent?key={self.config.api_key}"
|
340 |
+
|
341 |
+
response = await self.http_client.post(
|
342 |
+
url,
|
343 |
+
json=payload,
|
344 |
+
headers={"Content-Type": "application/json"}
|
345 |
+
)
|
346 |
+
|
347 |
+
response.raise_for_status()
|
348 |
+
response_data = response.json()
|
349 |
+
|
350 |
+
# Convert Gemini response to OpenAI ChatCompletion format
|
351 |
+
if "candidates" in response_data and len(response_data["candidates"]) > 0:
|
352 |
+
candidate = response_data["candidates"][0]
|
353 |
+
|
354 |
+
# Extract text from response
|
355 |
+
content = ""
|
356 |
+
tool_calls = []
|
357 |
+
|
358 |
+
if "content" in candidate and "parts" in candidate["content"]:
|
359 |
+
for part in candidate["content"]["parts"]:
|
360 |
+
if "text" in part:
|
361 |
+
content += part["text"]
|
362 |
+
elif "functionCall" in part:
|
363 |
+
# Handle function call
|
364 |
+
func_call = part["functionCall"]
|
365 |
+
tool_call = ChatCompletionMessageToolCall(
|
366 |
+
id=f"call_{hash(str(func_call))}"[:16],
|
367 |
+
type="function",
|
368 |
+
function={
|
369 |
+
"name": func_call["name"],
|
370 |
+
"arguments": json.dumps(func_call.get("args", {}))
|
371 |
+
}
|
372 |
+
)
|
373 |
+
tool_calls.append(tool_call)
|
374 |
+
|
375 |
+
choice = Choice(
|
376 |
+
index=0,
|
377 |
+
message=ChatCompletionMessage(
|
378 |
+
role="assistant",
|
379 |
+
content=content if content else None,
|
380 |
+
tool_calls=tool_calls if tool_calls else None
|
381 |
+
),
|
382 |
+
finish_reason="stop"
|
383 |
+
)
|
384 |
+
|
385 |
+
# Create a mock ChatCompletion response
|
386 |
+
completion = ChatCompletion(
|
387 |
+
id="gemini-" + str(hash(content))[:8],
|
388 |
+
choices=[choice],
|
389 |
+
created=int(__import__('time').time()),
|
390 |
+
model=self.config.model_id,
|
391 |
+
object="chat.completion",
|
392 |
+
usage=CompletionUsage(
|
393 |
+
completion_tokens=len(content.split()),
|
394 |
+
prompt_tokens=sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages),
|
395 |
+
total_tokens=len(content.split()) + sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages)
|
396 |
+
)
|
397 |
+
)
|
398 |
+
|
399 |
+
return completion
|
400 |
+
else:
|
401 |
+
raise Exception(f"No valid response from Gemini API: {response_data}")
|
402 |
+
|
403 |
+
|
404 |
+
ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig, GeminiClientConfig]
|
405 |
+
ClientTypes = Union[OpenAIClient, ConvergenceClient, GeminiClient]
|
proxy-lite-demo-v2/src/proxy_lite/configs/default.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
environment:
|
2 |
+
name: webbrowser
|
3 |
+
annotate_image: true
|
4 |
+
screenshot_delay: 2.0
|
5 |
+
viewport_width: 1280
|
6 |
+
viewport_height: 1920
|
7 |
+
include_poi_text: true
|
8 |
+
headless: false
|
9 |
+
homepage: https://www.google.co.uk
|
10 |
+
keep_original_image: true
|
11 |
+
solver:
|
12 |
+
name: simple
|
13 |
+
agent:
|
14 |
+
name: proxy_lite
|
15 |
+
client:
|
16 |
+
name: convergence
|
17 |
+
model_id: convergence-ai/proxy-lite-3b
|
18 |
+
api_base: https://convergence-ai-demo-api.hf.space/v1
|
19 |
+
local_view: true
|
20 |
+
task_timeout: 1800
|
21 |
+
environment_timeout: 1800
|
22 |
+
action_timeout: 1800
|
23 |
+
verbose: true
|
proxy-lite-demo-v2/src/proxy_lite/environments/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from .environment_base import (
|
4 |
+
Action,
|
5 |
+
BaseEnvironment,
|
6 |
+
BaseEnvironmentConfig,
|
7 |
+
Environments,
|
8 |
+
Event,
|
9 |
+
EventType,
|
10 |
+
Observation,
|
11 |
+
)
|
12 |
+
from .webbrowser import (
|
13 |
+
WebBrowserEnvironment,
|
14 |
+
WebBrowserEnvironmentConfig,
|
15 |
+
)
|
16 |
+
|
17 |
+
EnvironmentConfigTypes = Union[*list(Environments._environment_config_registry.values())]
|
18 |
+
EnvironmentTypes = Union[*list(Environments._environment_registry.values())]
|
19 |
+
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"Action",
|
23 |
+
"BaseEnvironment",
|
24 |
+
"BaseEnvironmentConfig",
|
25 |
+
"EnvironmentConfigTypes",
|
26 |
+
"Environments",
|
27 |
+
"Event",
|
28 |
+
"EventType",
|
29 |
+
"Observation",
|
30 |
+
"WebBrowserEnvironment",
|
31 |
+
"WebBrowserEnvironmentConfig",
|
32 |
+
]
|
proxy-lite-demo-v2/src/proxy_lite/environments/environment_base.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from enum import Enum
|
5 |
+
from functools import cached_property
|
6 |
+
from typing import Any, Literal, Optional, Self
|
7 |
+
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from proxy_lite.history import ToolCall
|
11 |
+
from proxy_lite.tools import Tool, ToolExecutionResponse
|
12 |
+
|
13 |
+
|
14 |
+
class EventType(str, Enum):
|
15 |
+
OBSERVATION = "observation"
|
16 |
+
ACTION = "action"
|
17 |
+
MESSAGE = "message"
|
18 |
+
|
19 |
+
|
20 |
+
class Event(BaseModel):
|
21 |
+
type: EventType
|
22 |
+
|
23 |
+
|
24 |
+
class State(BaseModel):
|
25 |
+
text: Optional[str] = None
|
26 |
+
image: Optional[str] = None # base64 encoded image
|
27 |
+
html: Optional[str] = None
|
28 |
+
tool_responses: Optional[list[ToolExecutionResponse]] = None
|
29 |
+
|
30 |
+
|
31 |
+
class Observation(Event):
|
32 |
+
type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION
|
33 |
+
state: State
|
34 |
+
terminated: bool
|
35 |
+
reward: Optional[float] = None
|
36 |
+
info: Optional[dict[str, Any]] = None
|
37 |
+
|
38 |
+
|
39 |
+
class Action(Event):
|
40 |
+
type: Literal[EventType.ACTION] = EventType.ACTION
|
41 |
+
text: Optional[str] = None
|
42 |
+
tool_calls: Optional[list[ToolCall]] = None
|
43 |
+
info: Optional[dict[str, Any]] = None
|
44 |
+
|
45 |
+
|
46 |
+
class BaseEnvironmentConfig(BaseModel): ...
|
47 |
+
|
48 |
+
|
49 |
+
class BaseEnvironment(BaseModel, ABC):
|
50 |
+
config: BaseEnvironmentConfig
|
51 |
+
logger: logging.Logger | None = None
|
52 |
+
|
53 |
+
class Config:
|
54 |
+
arbitrary_types_allowed = True
|
55 |
+
|
56 |
+
async def __aenter__(self) -> Self:
|
57 |
+
return self
|
58 |
+
|
59 |
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
60 |
+
pass
|
61 |
+
|
62 |
+
@property
|
63 |
+
@abstractmethod
|
64 |
+
def info_for_user(self) -> str: ...
|
65 |
+
|
66 |
+
@cached_property
|
67 |
+
@abstractmethod
|
68 |
+
def tools(self) -> list[Tool]: ...
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
async def initialise(self) -> Observation: ...
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
async def execute_action(self, action: Action) -> Observation: ...
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
async def observe(self) -> Observation: ...
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ...
|
81 |
+
|
82 |
+
async def execute_tool(self, tool_call: ToolCall) -> None:
|
83 |
+
function = tool_call.function
|
84 |
+
for tool in self.tools:
|
85 |
+
if hasattr(tool, function["name"]):
|
86 |
+
arguments = json.loads(function["arguments"])
|
87 |
+
if isinstance(arguments, str):
|
88 |
+
arguments = json.loads(arguments)
|
89 |
+
return await getattr(tool, function["name"])(
|
90 |
+
**arguments,
|
91 |
+
)
|
92 |
+
msg = f'No tool function with name "{function["name"]}"'
|
93 |
+
raise ValueError(msg)
|
94 |
+
|
95 |
+
async def get_info(self) -> dict[str, Any]:
|
96 |
+
return {}
|
97 |
+
|
98 |
+
|
99 |
+
class Environments:
|
100 |
+
_environment_registry: dict[str, type[BaseEnvironment]] = {}
|
101 |
+
_environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {}
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def register_environment(cls, name: str):
|
105 |
+
"""
|
106 |
+
Decorator to register an Environment class under a given name.
|
107 |
+
|
108 |
+
Example:
|
109 |
+
@Environments.register_environment("my_environment")
|
110 |
+
class MyEnvironment(BaseEnvironment):
|
111 |
+
...
|
112 |
+
"""
|
113 |
+
|
114 |
+
def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]:
|
115 |
+
cls._environment_registry[name] = env_cls
|
116 |
+
return env_cls
|
117 |
+
|
118 |
+
return decorator
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def register_environment_config(cls, name: str):
|
122 |
+
"""
|
123 |
+
Decorator to register an Environment configuration class under a given name.
|
124 |
+
|
125 |
+
Example:
|
126 |
+
@Environments.register_environment_config("my_environment")
|
127 |
+
class MyEnvironmentConfig(BaseEnvironmentConfig):
|
128 |
+
...
|
129 |
+
"""
|
130 |
+
|
131 |
+
def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]:
|
132 |
+
cls._environment_config_registry[name] = config_cls
|
133 |
+
return config_cls
|
134 |
+
|
135 |
+
return decorator
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def get(cls, name: str) -> type[BaseEnvironment]:
|
139 |
+
"""
|
140 |
+
Retrieve a registered Environment class by its name.
|
141 |
+
|
142 |
+
Raises:
|
143 |
+
ValueError: If no such environment is found.
|
144 |
+
"""
|
145 |
+
try:
|
146 |
+
return cls._environment_registry[name]
|
147 |
+
except KeyError:
|
148 |
+
raise ValueError(f"Environment '{name}' not found.")
|
149 |
+
|
150 |
+
@classmethod
|
151 |
+
def get_config(cls, name: str) -> type[BaseEnvironmentConfig]:
|
152 |
+
"""
|
153 |
+
Retrieve a registered Environment configuration class by its name.
|
154 |
+
|
155 |
+
Raises:
|
156 |
+
ValueError: If no such configuration is found.
|
157 |
+
"""
|
158 |
+
try:
|
159 |
+
return cls._environment_config_registry[name]
|
160 |
+
except KeyError:
|
161 |
+
raise ValueError(f"Environment config for '{name}' not found.")
|
proxy-lite-demo-v2/src/proxy_lite/environments/webbrowser.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from functools import cached_property
|
3 |
+
from typing import Any, Literal, Optional, Self, List # Added List import
|
4 |
+
|
5 |
+
from proxy_lite.browser.browser import BrowserSession
|
6 |
+
from proxy_lite.environments.environment_base import (
|
7 |
+
Action,
|
8 |
+
BaseEnvironment,
|
9 |
+
BaseEnvironmentConfig,
|
10 |
+
Environments,
|
11 |
+
Observation,
|
12 |
+
State,
|
13 |
+
)
|
14 |
+
from proxy_lite.tools import BrowserTool, Tool, ToolExecutionResponse
|
15 |
+
from proxy_lite.logger import logger
|
16 |
+
|
17 |
+
@Environments.register_environment_config("webbrowser")
|
18 |
+
class WebBrowserEnvironmentConfig(BaseEnvironmentConfig):
|
19 |
+
name: Literal["webbrowser"] = "webbrowser"
|
20 |
+
homepage: str = "https://google.com"
|
21 |
+
annotate_image: bool = True
|
22 |
+
screenshot_delay: float = 1.0 # seconds
|
23 |
+
include_html: bool = True
|
24 |
+
include_poi_text: bool = True
|
25 |
+
record_pois: bool = True
|
26 |
+
viewport_width: int = 1280
|
27 |
+
viewport_height: int = 720
|
28 |
+
browserbase_timeout: int = 7200
|
29 |
+
headless: bool = True
|
30 |
+
keep_original_image: bool = False
|
31 |
+
no_pois_in_image: bool = False
|
32 |
+
# --- MODIFICATION START ---
|
33 |
+
# Added to accept initial cookies from the RunnerConfig
|
34 |
+
initial_cookies: Optional[List[dict]] = None
|
35 |
+
# --- MODIFICATION END ---
|
36 |
+
|
37 |
+
|
38 |
+
@Environments.register_environment("webbrowser")
|
39 |
+
class WebBrowserEnvironment(BaseEnvironment):
|
40 |
+
config: WebBrowserEnvironmentConfig
|
41 |
+
browser: Optional[BrowserSession] = None
|
42 |
+
cancelled_last_action: bool = False
|
43 |
+
|
44 |
+
class Config:
|
45 |
+
arbitrary_types_allowed = True
|
46 |
+
|
47 |
+
async def __aenter__(self) -> Self:
|
48 |
+
# Initialize the BrowserSession
|
49 |
+
self.browser = self.browser_session(
|
50 |
+
viewport_width=self.config.viewport_width,
|
51 |
+
viewport_height=self.config.viewport_height,
|
52 |
+
headless=self.config.headless,
|
53 |
+
)
|
54 |
+
await self.browser.__aenter__()
|
55 |
+
# Initialize other resources if necessary
|
56 |
+
# --- MODIFICATION START ---
|
57 |
+
# Changed to use self.config.initial_cookies
|
58 |
+
if self.config.initial_cookies:
|
59 |
+
self.logger.info(f"π [bold blue]Adding {len(self.config.initial_cookies)} initial cookies to browser context.[/]")
|
60 |
+
await self.browser.context.add_cookies(self.config.initial_cookies)
|
61 |
+
# --- MODIFICATION END ---
|
62 |
+
self.logger.info("π [bold blue]Browser session started.[/]")
|
63 |
+
return self
|
64 |
+
|
65 |
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
66 |
+
# Clean up the BrowserSession
|
67 |
+
await self.browser.__aexit__(exc_type, exc_value, traceback)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def info_for_user(self) -> str:
|
71 |
+
return "This is a web browser environment. You can navigate the web, search the web, and perform actions on the web." # noqa: E501
|
72 |
+
|
73 |
+
@cached_property
|
74 |
+
def tools(self) -> list[Tool]:
|
75 |
+
return [BrowserTool(session=self.browser)]
|
76 |
+
|
77 |
+
@cached_property
|
78 |
+
def browser_session(self) -> type[BrowserSession]:
|
79 |
+
return BrowserSession
|
80 |
+
|
81 |
+
# --- MODIFICATION START ---
|
82 |
+
# Modified this property to return cookies from the config.
|
83 |
+
# It was previously hardcoded to return an empty list.
|
84 |
+
@property
|
85 |
+
def cookies(self) -> list[dict]:
|
86 |
+
return self.config.initial_cookies if self.config.initial_cookies is not None else []
|
87 |
+
# --- MODIFICATION END ---
|
88 |
+
|
89 |
+
async def initialise(self) -> Observation:
|
90 |
+
self.logger.debug(f"DEBUG: Initialising WebBrowserEnvironment. Homepage: {self.config.homepage}")
|
91 |
+
try:
|
92 |
+
await self.browser.goto(self.config.homepage)
|
93 |
+
self.logger.debug(f"DEBUG: Browser navigated to homepage. Current URL: {self.browser.current_url}")
|
94 |
+
except Exception as e:
|
95 |
+
self.logger.error(f"ERROR: Failed to navigate to homepage {self.config.homepage}: {e}")
|
96 |
+
raise # Re-raise to propagate the error
|
97 |
+
|
98 |
+
original_img, annotated_img = await self.browser.screenshot(
|
99 |
+
delay=self.config.screenshot_delay,
|
100 |
+
)
|
101 |
+
if self.config.no_pois_in_image:
|
102 |
+
base64_image = base64.b64encode(original_img).decode("utf-8")
|
103 |
+
else:
|
104 |
+
base64_image = base64.b64encode(annotated_img).decode("utf-8")
|
105 |
+
|
106 |
+
html_content = await self.browser.current_page.content() if self.config.include_html else None
|
107 |
+
|
108 |
+
info = {"url": self.browser.current_url}
|
109 |
+
if self.config.record_pois:
|
110 |
+
info["pois"] = self.browser.pois
|
111 |
+
if self.config.keep_original_image:
|
112 |
+
info["original_image"] = base64.b64encode(original_img).decode("utf-8")
|
113 |
+
|
114 |
+
self.logger.debug(f"DEBUG: Initial observation captured. URL: {self.browser.current_url}")
|
115 |
+
return Observation(
|
116 |
+
state=State(
|
117 |
+
text=f"URL: {self.browser.current_url}"
|
118 |
+
+ (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
|
119 |
+
image=base64_image,
|
120 |
+
html=html_content,
|
121 |
+
),
|
122 |
+
terminated=False,
|
123 |
+
reward=None,
|
124 |
+
info=info,
|
125 |
+
)
|
126 |
+
|
127 |
+
async def should_perform_action(self) -> bool:
|
128 |
+
# if cancelled last action, run the action without updating POIs
|
129 |
+
if self.cancelled_last_action:
|
130 |
+
self.cancelled_last_action = False
|
131 |
+
return True
|
132 |
+
|
133 |
+
# check for page changes
|
134 |
+
old_points = [tuple(point) for point in self.browser.poi_centroids]
|
135 |
+
await self.browser.update_poi()
|
136 |
+
new_points = [tuple(point) for point in self.browser.poi_centroids]
|
137 |
+
page_changed_mid_action = old_points != new_points
|
138 |
+
|
139 |
+
# record if the last action was cancelled
|
140 |
+
if page_changed_mid_action:
|
141 |
+
self.cancelled_last_action = True
|
142 |
+
return False
|
143 |
+
return True
|
144 |
+
|
145 |
+
async def execute_action(self, action: Action) -> Observation:
|
146 |
+
responses = []
|
147 |
+
cancelled_tools_flag = False
|
148 |
+
if await self.should_perform_action():
|
149 |
+
for tool_call in action.tool_calls:
|
150 |
+
# Perform the chosen action
|
151 |
+
try:
|
152 |
+
tool_response: ToolExecutionResponse = await self.execute_tool(
|
153 |
+
tool_call,
|
154 |
+
)
|
155 |
+
tool_response.id = tool_call.id
|
156 |
+
responses.append(tool_response)
|
157 |
+
except Exception as e: # noqa: PERF203
|
158 |
+
self.logger.warning("π An error occurred taking action: %s", str(e), exc_info=False)
|
159 |
+
tool_response = ToolExecutionResponse(content=str(e), id=tool_call.id)
|
160 |
+
responses.append(tool_response)
|
161 |
+
else:
|
162 |
+
self.logger.warning("π Page changed since last observation, cancelling action.")
|
163 |
+
self.cancelled_last_action = True
|
164 |
+
for tool_call in action.tool_calls:
|
165 |
+
tool_response = ToolExecutionResponse(
|
166 |
+
content="The page changed before the action could be executed, instead of being ran it was cancelled.", # noqa: E501
|
167 |
+
id=tool_call.id,
|
168 |
+
)
|
169 |
+
responses.append(tool_response)
|
170 |
+
cancelled_tools_flag = True
|
171 |
+
original_img, annotated_img = await self.browser.screenshot(
|
172 |
+
delay=self.config.screenshot_delay,
|
173 |
+
)
|
174 |
+
|
175 |
+
base64_image = base64.b64encode(annotated_img).decode("utf-8")
|
176 |
+
|
177 |
+
info = {"url": self.browser.current_url, "cancelled_tools": cancelled_tools_flag}
|
178 |
+
if self.config.record_pois:
|
179 |
+
info["pois"] = self.browser.pois
|
180 |
+
if self.config.keep_original_image:
|
181 |
+
info["original_image"] = base64.b64encode(original_img).decode("utf-8")
|
182 |
+
|
183 |
+
html_content = await self.browser.current_page.content() if self.config.include_html else None
|
184 |
+
return Observation(
|
185 |
+
state=State(
|
186 |
+
text=f"URL: {self.browser.current_url}"
|
187 |
+
+ (f"\n{self.browser.poi_text}" if self.config.include_poi_text else ""),
|
188 |
+
image=base64_image,
|
189 |
+
html=html_content,
|
190 |
+
tool_responses=responses,
|
191 |
+
),
|
192 |
+
terminated=False,
|
193 |
+
reward=None,
|
194 |
+
info=info,
|
195 |
+
)
|
196 |
+
|
197 |
+
async def observe(self) -> Observation:
|
198 |
+
return await self.browser.observe()
|
199 |
+
|
200 |
+
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
|
201 |
+
return {}
|
202 |
+
|
203 |
+
async def get_info(self) -> dict[str, Any]:
|
204 |
+
info = {}
|
205 |
+
return info
|
proxy-lite-demo-v2/src/proxy_lite/gif_maker.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
import textwrap
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
+
|
8 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
9 |
+
from proxy_lite.recorder import Run
|
10 |
+
|
11 |
+
|
12 |
+
def create_run_gif(
|
13 |
+
run: Run, output_path: str, white_panel_width: int = 300, duration: int = 1500, resize_factor: int = 4
|
14 |
+
) -> None:
|
15 |
+
"""
|
16 |
+
Generate a gif from the Run object's history.
|
17 |
+
|
18 |
+
For each Observation record, the observation image is decoded from its base64
|
19 |
+
encoded string. If the next record is an Action, its text is drawn onto a
|
20 |
+
white panel. The observation image and the white panel are then concatenated
|
21 |
+
horizontally to produce a frame.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
run (Run): A Run object with its history containing Observation and Action records.
|
25 |
+
output_path (str): The path where the GIF will be saved.
|
26 |
+
white_panel_width (int): The width of the white panel for displaying text.
|
27 |
+
Default increased to 400 for larger images.
|
28 |
+
duration (int): Duration between frames in milliseconds.
|
29 |
+
Increased here to slow the FPS (default is 1000ms).
|
30 |
+
resize_factor (int): The factor to resize the image down by.
|
31 |
+
"""
|
32 |
+
frames = []
|
33 |
+
history = run.history
|
34 |
+
i = 0
|
35 |
+
while i < len(history):
|
36 |
+
if isinstance(history[i], Observation):
|
37 |
+
observation = history[i]
|
38 |
+
image_data = observation.state.image
|
39 |
+
if not image_data:
|
40 |
+
i += 1
|
41 |
+
continue
|
42 |
+
# Decode the base64 image
|
43 |
+
image_bytes = base64.b64decode(image_data)
|
44 |
+
obs_img = Image.open(BytesIO(image_bytes)).convert("RGB")
|
45 |
+
|
46 |
+
# scale the image down
|
47 |
+
obs_img = obs_img.resize((obs_img.width // resize_factor, obs_img.height // resize_factor))
|
48 |
+
|
49 |
+
# Check if the next record is an Action and extract its text if available
|
50 |
+
action_text = ""
|
51 |
+
if i + 1 < len(history) and isinstance(history[i + 1], Action):
|
52 |
+
action = history[i + 1]
|
53 |
+
if action.text:
|
54 |
+
action_text = action.text
|
55 |
+
|
56 |
+
# extract observation and thinking from tags in the action text
|
57 |
+
observation_match = re.search(r"<observation>(.*?)</observation>", action_text, re.DOTALL)
|
58 |
+
observation_content = observation_match.group(1).strip() if observation_match else None
|
59 |
+
|
60 |
+
# Extract text between thinking tags if present
|
61 |
+
thinking_match = re.search(r"<thinking>(.*?)</thinking>", action_text, re.DOTALL)
|
62 |
+
thinking_content = thinking_match.group(1).strip() if thinking_match else None
|
63 |
+
|
64 |
+
if observation_content and thinking_content:
|
65 |
+
action_text = f"**OBSERVATION**\n{observation_content}\n\n**THINKING**\n{thinking_content}"
|
66 |
+
|
67 |
+
# Create a white panel (same height as the observation image)
|
68 |
+
panel = Image.new("RGB", (white_panel_width, obs_img.height), "white")
|
69 |
+
draw = ImageDraw.Draw(panel)
|
70 |
+
font = ImageFont.load_default()
|
71 |
+
|
72 |
+
# Wrap the action text if it is too long
|
73 |
+
max_chars_per_line = 40 # Adjusted for larger font size
|
74 |
+
wrapped_text = textwrap.fill(action_text, width=max_chars_per_line)
|
75 |
+
|
76 |
+
# Calculate text block size and center it on the panel
|
77 |
+
try:
|
78 |
+
# Use multiline_textbbox if available (returns bounding box tuple)
|
79 |
+
bbox = draw.multiline_textbbox((0, 0), wrapped_text, font=font)
|
80 |
+
text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
81 |
+
except AttributeError:
|
82 |
+
# Fallback for older Pillow versions: compute size for each line
|
83 |
+
lines = wrapped_text.splitlines() or [wrapped_text]
|
84 |
+
line_sizes = [draw.textsize(line, font=font) for line in lines]
|
85 |
+
text_width = max(width for width, _ in line_sizes)
|
86 |
+
text_height = sum(height for _, height in line_sizes)
|
87 |
+
text_x = (white_panel_width - text_width) // 2
|
88 |
+
text_y = (obs_img.height - text_height) // 2
|
89 |
+
draw.multiline_text((text_x, text_y), wrapped_text, fill="black", font=font, align="center")
|
90 |
+
|
91 |
+
# Create the combined frame by concatenating the observation image and the panel
|
92 |
+
total_width = obs_img.width + white_panel_width
|
93 |
+
combined_frame = Image.new("RGB", (total_width, obs_img.height))
|
94 |
+
combined_frame.paste(obs_img, (0, 0))
|
95 |
+
combined_frame.paste(panel, (obs_img.width, 0))
|
96 |
+
frames.append(combined_frame)
|
97 |
+
|
98 |
+
# Skip the Action record since it has been processed with this Observation
|
99 |
+
if i + 1 < len(history) and isinstance(history[i + 1], Action):
|
100 |
+
i += 2
|
101 |
+
else:
|
102 |
+
i += 1
|
103 |
+
else:
|
104 |
+
i += 1
|
105 |
+
|
106 |
+
if frames:
|
107 |
+
frames[0].save(output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0)
|
108 |
+
else:
|
109 |
+
raise ValueError("No frames were generated from the Run object's history.")
|
110 |
+
|
111 |
+
|
112 |
+
# Example usage:
|
113 |
+
if __name__ == "__main__":
|
114 |
+
from proxy_lite.recorder import Run
|
115 |
+
|
116 |
+
dummy_run = Run.load("0abdb4cb-f289-48b0-ba13-35ed1210f7c1")
|
117 |
+
|
118 |
+
num_steps = int(len(dummy_run.history) / 2)
|
119 |
+
print(f"Number of steps: {num_steps}")
|
120 |
+
output_gif_path = "trajectory.gif"
|
121 |
+
create_run_gif(dummy_run, output_gif_path, duration=1000)
|
122 |
+
print(f"Trajectory GIF saved to {output_gif_path}")
|
proxy-lite-demo-v2/src/proxy_lite/history.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import base64
|
4 |
+
from collections.abc import Iterator
|
5 |
+
from enum import Enum
|
6 |
+
from typing import Any, Literal, Optional, Set, Union
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
9 |
+
|
10 |
+
|
11 |
+
class MessageLabel(str, Enum):
|
12 |
+
SYSTEM = "system"
|
13 |
+
USER_INPUT = "user_input"
|
14 |
+
SCREENSHOT = "screenshot"
|
15 |
+
AGENT_MODEL_RESPONSE = "agent_model_response"
|
16 |
+
|
17 |
+
|
18 |
+
MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
|
19 |
+
MessageLabel.SCREENSHOT: 1,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class MessageContent(BaseModel):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
class Text(MessageContent):
|
28 |
+
type: Literal["text"] = Field(default="text", init=False)
|
29 |
+
text: str
|
30 |
+
|
31 |
+
|
32 |
+
class ImageUrl(BaseModel):
|
33 |
+
url: str
|
34 |
+
|
35 |
+
|
36 |
+
class Image(MessageContent):
|
37 |
+
type: Literal["image_url"] = Field(default="image_url", init=False)
|
38 |
+
image_url: ImageUrl
|
39 |
+
|
40 |
+
|
41 |
+
class Message(BaseModel):
|
42 |
+
label: Optional[MessageLabel] = None
|
43 |
+
content: list[Union[Text, Image]] = Field(default_factory=list)
|
44 |
+
|
45 |
+
class Config:
|
46 |
+
use_enum_values = True
|
47 |
+
|
48 |
+
@property
|
49 |
+
def images(self) -> list[Image]:
|
50 |
+
return [content for content in self.content if isinstance(content, Image)]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def texts(self) -> list[Text]:
|
54 |
+
return [content for content in self.content if isinstance(content, Text)]
|
55 |
+
|
56 |
+
@property
|
57 |
+
def first_image(self) -> Optional[Image]:
|
58 |
+
return self.images[0] if self.images else None
|
59 |
+
|
60 |
+
@property
|
61 |
+
def first_text(self) -> Optional[Text]:
|
62 |
+
return self.texts[0] if self.texts else None
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return len(self.content)
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def from_media(
|
69 |
+
cls,
|
70 |
+
text: Optional[str] = None,
|
71 |
+
image: Optional[bytes | str] = None,
|
72 |
+
is_base64: bool = False,
|
73 |
+
) -> Message:
|
74 |
+
if text is not None:
|
75 |
+
text = Text(text=text)
|
76 |
+
if image is not None:
|
77 |
+
base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8")
|
78 |
+
data_url = f"data:image/jpeg;base64,{base64_image}"
|
79 |
+
image = Image(image_url=ImageUrl(url=data_url))
|
80 |
+
content = [text, image] if text is not None else [image]
|
81 |
+
else:
|
82 |
+
content = [text]
|
83 |
+
return cls(content=content)
|
84 |
+
|
85 |
+
|
86 |
+
class SystemMessage(Message):
|
87 |
+
role: Literal["system"] = Field(default="system", init=False)
|
88 |
+
|
89 |
+
|
90 |
+
class UserMessage(Message):
|
91 |
+
role: Literal["user"] = Field(default="user", init=False)
|
92 |
+
|
93 |
+
|
94 |
+
class ToolCall(BaseModel):
|
95 |
+
id: str
|
96 |
+
type: str
|
97 |
+
function: dict[str, Any]
|
98 |
+
|
99 |
+
|
100 |
+
class AssistantMessage(Message):
|
101 |
+
role: Literal["assistant"] = Field(default="assistant", init=False)
|
102 |
+
tool_calls: list[ToolCall] = Field(default_factory=list)
|
103 |
+
|
104 |
+
def model_dump(self, **kwargs):
|
105 |
+
data = super().model_dump(**kwargs)
|
106 |
+
if not self.tool_calls:
|
107 |
+
data.pop("tool_calls")
|
108 |
+
return data
|
109 |
+
|
110 |
+
@field_validator("tool_calls", mode="before")
|
111 |
+
@classmethod
|
112 |
+
def ensure_list(cls, v):
|
113 |
+
return [] if v is None else v
|
114 |
+
|
115 |
+
|
116 |
+
class ToolMessage(Message):
|
117 |
+
role: Literal["tool"] = Field(default="tool", init=False)
|
118 |
+
tool_call_id: str
|
119 |
+
|
120 |
+
|
121 |
+
MessageTypes = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
|
122 |
+
MessageAdapter = TypeAdapter(MessageTypes)
|
123 |
+
|
124 |
+
|
125 |
+
class MessageHistory(BaseModel):
|
126 |
+
messages: list[MessageTypes] = Field(default_factory=list)
|
127 |
+
|
128 |
+
def append(self, message: MessageTypes, label: Optional[str] = None):
|
129 |
+
if label is not None:
|
130 |
+
message.label = label
|
131 |
+
self.messages.append(message)
|
132 |
+
|
133 |
+
def pop(self) -> MessageTypes:
|
134 |
+
return self.messages.pop()
|
135 |
+
|
136 |
+
def extend(self, history: MessageHistory):
|
137 |
+
self.messages.extend(history.messages)
|
138 |
+
|
139 |
+
def __reversed__(self):
|
140 |
+
return MessageHistory(messages=self.messages[::-1])
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
return self.messages[index]
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.messages)
|
147 |
+
|
148 |
+
def __iter__(self) -> Iterator[MessageTypes]:
|
149 |
+
return iter(self.messages)
|
150 |
+
|
151 |
+
def to_dict(self, exclude: Set[str] | None = None) -> list[dict]:
|
152 |
+
exclude = exclude or set()
|
153 |
+
return [message.model_dump(exclude=exclude) for message in self.messages]
|
154 |
+
|
155 |
+
def history_view(
|
156 |
+
self,
|
157 |
+
limits: dict = MAX_MESSAGES_FOR_CONTEXT_WINDOW,
|
158 |
+
) -> MessageHistory:
|
159 |
+
"""Context window management.
|
160 |
+
|
161 |
+
Filters messages in reverse order, retaining a limited number of recent screenshots and prompts.
|
162 |
+
"""
|
163 |
+
label_counts = {label: 0 for label in limits}
|
164 |
+
filtered_messages = []
|
165 |
+
for message in reversed(self.messages):
|
166 |
+
if message.label in limits:
|
167 |
+
maximum_count = limits[message.label]
|
168 |
+
if label_counts[message.label] < maximum_count:
|
169 |
+
filtered_messages.append(message)
|
170 |
+
label_counts[message.label] += 1
|
171 |
+
else:
|
172 |
+
filtered_messages.append(message)
|
173 |
+
return MessageHistory(messages=reversed(filtered_messages))
|
174 |
+
|
175 |
+
def __add__(self, other: MessageHistory) -> MessageHistory:
|
176 |
+
new_history = MessageHistory()
|
177 |
+
new_history.extend(self)
|
178 |
+
new_history.extend(other)
|
179 |
+
return new_history
|
180 |
+
|
181 |
+
def __iadd__(self, other: MessageHistory) -> MessageHistory:
|
182 |
+
self.extend(other)
|
183 |
+
return self
|
proxy-lite-demo-v2/src/proxy_lite/logger.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
from typing import Literal
|
5 |
+
from uuid import uuid4
|
6 |
+
|
7 |
+
from rich.logging import RichHandler
|
8 |
+
|
9 |
+
|
10 |
+
class StructuredLogger(logging.Logger):
|
11 |
+
async def stream_message(self, message: str) -> None:
|
12 |
+
"""Streams the message character by character asynchronously."""
|
13 |
+
try:
|
14 |
+
sys.stdout.write("\r") # Overwrite current line
|
15 |
+
for char in message:
|
16 |
+
sys.stdout.write(char)
|
17 |
+
sys.stdout.flush()
|
18 |
+
await asyncio.sleep(0.002)
|
19 |
+
sys.stdout.write("\n")
|
20 |
+
except Exception:
|
21 |
+
pass
|
22 |
+
|
23 |
+
def _log(
|
24 |
+
self,
|
25 |
+
level,
|
26 |
+
msg,
|
27 |
+
args,
|
28 |
+
exc_info=None,
|
29 |
+
extra=None,
|
30 |
+
stack_info=False,
|
31 |
+
stacklevel=1,
|
32 |
+
):
|
33 |
+
if extra is None:
|
34 |
+
extra = {}
|
35 |
+
|
36 |
+
json_fields = {
|
37 |
+
"logger_name": self.name,
|
38 |
+
"message": msg % args if args else msg,
|
39 |
+
}
|
40 |
+
|
41 |
+
exc_type, exc_value, exc_traceback = sys.exc_info()
|
42 |
+
if exc_type is not None:
|
43 |
+
json_fields["exception_class"] = exc_type.__name__
|
44 |
+
json_fields["exception_message"] = str(exc_value)
|
45 |
+
|
46 |
+
json_fields.update(extra)
|
47 |
+
|
48 |
+
super()._log(
|
49 |
+
level,
|
50 |
+
msg,
|
51 |
+
args,
|
52 |
+
exc_info,
|
53 |
+
{"json_fields": json_fields},
|
54 |
+
stack_info,
|
55 |
+
stacklevel + 1,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def create_logger(
|
60 |
+
name: str,
|
61 |
+
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
|
62 |
+
detailed_name: bool = False,
|
63 |
+
) -> logging.Logger:
|
64 |
+
unique_name = f"{name}-{str(uuid4())[:8]}"
|
65 |
+
logger = logging.getLogger(unique_name)
|
66 |
+
logger.setLevel(level)
|
67 |
+
|
68 |
+
# Standard RichHandler for structured logs
|
69 |
+
rich_handler = RichHandler(
|
70 |
+
rich_tracebacks=True,
|
71 |
+
markup=True,
|
72 |
+
show_path=False,
|
73 |
+
show_time=False,
|
74 |
+
log_time_format="[%s]",
|
75 |
+
)
|
76 |
+
|
77 |
+
if detailed_name:
|
78 |
+
rich_handler.setFormatter(logging.Formatter("%(name)s:\n%(message)s"))
|
79 |
+
else:
|
80 |
+
rich_handler.setFormatter(logging.Formatter("-----\n%(message)s"))
|
81 |
+
|
82 |
+
logger.addHandler(rich_handler)
|
83 |
+
logger.propagate = False
|
84 |
+
|
85 |
+
return logger
|
86 |
+
|
87 |
+
|
88 |
+
# Set StructuredLogger as the default logger class
|
89 |
+
logging.setLoggerClass(StructuredLogger)
|
90 |
+
|
91 |
+
# Initialize logger
|
92 |
+
logger = create_logger(__name__, level="INFO")
|
proxy-lite-demo-v2/src/proxy_lite/recorder.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Optional, Self
|
9 |
+
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
+
|
12 |
+
from proxy_lite.environments import EnvironmentConfigTypes
|
13 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
14 |
+
from proxy_lite.history import MessageHistory
|
15 |
+
from proxy_lite.solvers import SolverConfigTypes
|
16 |
+
|
17 |
+
|
18 |
+
class Run(BaseModel):
|
19 |
+
run_id: str # uuid.UUID
|
20 |
+
task: str
|
21 |
+
created_at: str # datetime.datetime
|
22 |
+
complete: bool = False
|
23 |
+
terminated_at: str | None = None # datetime.datetime
|
24 |
+
evaluation: dict[str, Any] | None = None
|
25 |
+
history: list[Observation | Action] = Field(default_factory=list)
|
26 |
+
solver_history: MessageHistory | None = None
|
27 |
+
result: str | None = None
|
28 |
+
env_info: dict[str, Any] = Field(default_factory=dict)
|
29 |
+
environment: Optional[EnvironmentConfigTypes] = None
|
30 |
+
solver: Optional[SolverConfigTypes] = None
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def initialise(cls, task: str) -> Self:
|
34 |
+
run_id = str(uuid.uuid4())
|
35 |
+
return cls(
|
36 |
+
run_id=run_id,
|
37 |
+
task=task,
|
38 |
+
created_at=str(datetime.datetime.now(datetime.UTC)),
|
39 |
+
)
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def load(cls, run_id: str) -> Self:
|
43 |
+
with open(Path(__file__).parent.parent.parent / "local_trajectories" / f"{run_id}.json", "r") as f:
|
44 |
+
return cls(**json.load(f))
|
45 |
+
|
46 |
+
@property
|
47 |
+
def observations(self) -> list[Observation]:
|
48 |
+
return [h for h in self.history if isinstance(h, Observation)]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def actions(self) -> list[Action]:
|
52 |
+
return [h for h in self.history if isinstance(h, Action)]
|
53 |
+
|
54 |
+
@property
|
55 |
+
def last_action(self) -> Action | None:
|
56 |
+
return self.actions[-1] if self.actions else None
|
57 |
+
|
58 |
+
@property
|
59 |
+
def last_observation(self) -> Observation | None:
|
60 |
+
return self.observations[-1] if self.observations else None
|
61 |
+
|
62 |
+
def record(
|
63 |
+
self,
|
64 |
+
observation: Optional[Observation] = None,
|
65 |
+
action: Optional[Action] = None,
|
66 |
+
solver_history: Optional[MessageHistory] = None,
|
67 |
+
) -> None:
|
68 |
+
# expect only one of observation and action to be provided in order to handle ordering
|
69 |
+
if observation and action:
|
70 |
+
raise ValueError("Only one of observation and action can be provided")
|
71 |
+
if observation:
|
72 |
+
self.history.append(observation)
|
73 |
+
if action:
|
74 |
+
self.history.append(action)
|
75 |
+
if solver_history:
|
76 |
+
self.solver_history = solver_history
|
77 |
+
|
78 |
+
def terminate(self) -> None:
|
79 |
+
self.terminated_at = str(datetime.datetime.now(datetime.UTC))
|
80 |
+
|
81 |
+
|
82 |
+
class DataRecorder:
|
83 |
+
def __init__(self, local_folder: str | None = None):
|
84 |
+
self.local_folder = local_folder
|
85 |
+
|
86 |
+
def initialise_run(self, task: str) -> Run:
|
87 |
+
self.local_folder = Path(__file__).parent.parent.parent / "local_trajectories"
|
88 |
+
os.makedirs(self.local_folder, exist_ok=True)
|
89 |
+
return Run.initialise(task)
|
90 |
+
|
91 |
+
async def terminate(
|
92 |
+
self,
|
93 |
+
run: Run,
|
94 |
+
save: bool = True,
|
95 |
+
) -> None:
|
96 |
+
run.terminate()
|
97 |
+
if save:
|
98 |
+
await self.save(run)
|
99 |
+
|
100 |
+
async def save(self, run: Run) -> None:
|
101 |
+
json_payload = run.model_dump()
|
102 |
+
with open(self.local_folder / f"{run.run_id}.json", "w") as f:
|
103 |
+
json.dump(json_payload, f)
|
proxy-lite-demo-v2/src/proxy_lite/runner.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from collections.abc import AsyncIterator
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
from typing import Any, Literal, Self
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from proxy_lite.environments import (
|
11 |
+
Action,
|
12 |
+
BaseEnvironment,
|
13 |
+
EnvironmentConfigTypes,
|
14 |
+
Environments,
|
15 |
+
EventType,
|
16 |
+
Observation,
|
17 |
+
)
|
18 |
+
from proxy_lite.logger import create_logger
|
19 |
+
from proxy_lite.recorder import DataRecorder, Run
|
20 |
+
from proxy_lite.solvers import (
|
21 |
+
BaseSolver,
|
22 |
+
SolverConfigTypes,
|
23 |
+
Solvers,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@asynccontextmanager
|
28 |
+
async def async_timeout(timeout: float, task_name: str = "timeout"):
|
29 |
+
try:
|
30 |
+
async with asyncio.TaskGroup() as tg:
|
31 |
+
|
32 |
+
async def timeout_task():
|
33 |
+
await asyncio.sleep(timeout)
|
34 |
+
raise TimeoutError(
|
35 |
+
f"Operation {task_name} timed out after {timeout} seconds",
|
36 |
+
)
|
37 |
+
|
38 |
+
# Create the timeout task
|
39 |
+
timeout_handle = tg.create_task(timeout_task())
|
40 |
+
|
41 |
+
try:
|
42 |
+
yield
|
43 |
+
finally:
|
44 |
+
timeout_handle.cancel()
|
45 |
+
except* asyncio.TimeoutError as eg:
|
46 |
+
for e in eg.exceptions:
|
47 |
+
raise e
|
48 |
+
except* Exception as eg:
|
49 |
+
for e in eg.exceptions:
|
50 |
+
raise e
|
51 |
+
|
52 |
+
|
53 |
+
class RunnerConfig(BaseModel):
|
54 |
+
environment: EnvironmentConfigTypes
|
55 |
+
solver: SolverConfigTypes
|
56 |
+
|
57 |
+
save_every_step: bool = True
|
58 |
+
max_steps: int = 50
|
59 |
+
action_timeout: float = 600.0
|
60 |
+
environment_timeout: float = 300.0
|
61 |
+
task_timeout: float = 1800.0
|
62 |
+
logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
|
63 |
+
detailed_logger_name: bool = False
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def from_dict(cls, config_dict: dict) -> Self:
|
67 |
+
conf = OmegaConf.create(config_dict)
|
68 |
+
config_dict = OmegaConf.to_container(conf, resolve=True)
|
69 |
+
return cls(**config_dict)
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def from_yaml(cls, yaml_path: str) -> Self:
|
73 |
+
conf = OmegaConf.load(yaml_path)
|
74 |
+
config_dict = OmegaConf.to_container(conf, resolve=True)
|
75 |
+
return cls(**config_dict)
|
76 |
+
|
77 |
+
|
78 |
+
class Runner(BaseModel):
|
79 |
+
config: RunnerConfig
|
80 |
+
recorder: DataRecorder | None = None
|
81 |
+
environment: type[BaseEnvironment] | None = None
|
82 |
+
solver: type[BaseSolver] | None = None
|
83 |
+
logger: logging.Logger | None = None
|
84 |
+
_run: Run | None = None
|
85 |
+
|
86 |
+
class Config:
|
87 |
+
arbitrary_types_allowed = True
|
88 |
+
|
89 |
+
def model_post_init(self, __context: Any) -> None:
|
90 |
+
super().model_post_init(__context)
|
91 |
+
self.environment = Environments.get(self.config.environment.name)
|
92 |
+
self.solver = Solvers.get(self.config.solver.name)
|
93 |
+
self.recorder = DataRecorder()
|
94 |
+
self.logger = create_logger(
|
95 |
+
name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])",
|
96 |
+
level=self.config.logger_level,
|
97 |
+
detailed_name=self.config.detailed_logger_name,
|
98 |
+
)
|
99 |
+
|
100 |
+
async def run_generator(self, task: str) -> AsyncIterator[Run]:
|
101 |
+
async with (
|
102 |
+
async_timeout(self.config.task_timeout, "Task"),
|
103 |
+
):
|
104 |
+
if self.config.logger_level is not None:
|
105 |
+
self.logger.setLevel(self.config.logger_level)
|
106 |
+
run = self.recorder.initialise_run(task)
|
107 |
+
run.environment = self.config.environment
|
108 |
+
run.solver = self.config.solver
|
109 |
+
self.logger.debug(f"Run intialised: {run.run_id}")
|
110 |
+
event_queue = asyncio.Queue()
|
111 |
+
async with (
|
112 |
+
self.environment(
|
113 |
+
config=self.config.environment,
|
114 |
+
logger=self.logger,
|
115 |
+
) as environment,
|
116 |
+
self.solver(config=self.config.solver, logger=self.logger) as solver,
|
117 |
+
):
|
118 |
+
run.env_info = await environment.get_info()
|
119 |
+
await solver.initialise(
|
120 |
+
task,
|
121 |
+
environment.tools,
|
122 |
+
environment.info_for_user,
|
123 |
+
)
|
124 |
+
self.logger.debug("Solver initialised.")
|
125 |
+
run.solver_history = solver.history
|
126 |
+
observation: Observation = await environment.initialise()
|
127 |
+
await event_queue.put(observation)
|
128 |
+
self.logger.debug("Environment initialised.")
|
129 |
+
step_count = 0
|
130 |
+
while step_count < self.config.max_steps:
|
131 |
+
event = await event_queue.get()
|
132 |
+
self.logger.debug(f"π€ [bold purple]Processing event:[/] {event.type}")
|
133 |
+
match event.type:
|
134 |
+
case EventType.OBSERVATION:
|
135 |
+
observation: Observation = event
|
136 |
+
run.record(
|
137 |
+
observation=observation,
|
138 |
+
solver_history=solver.history,
|
139 |
+
)
|
140 |
+
async with async_timeout(
|
141 |
+
self.config.action_timeout,
|
142 |
+
"Action decision",
|
143 |
+
):
|
144 |
+
action: Action = await solver.act(observation)
|
145 |
+
await event_queue.put(action)
|
146 |
+
case EventType.ACTION:
|
147 |
+
action: Action = event
|
148 |
+
self.logger.debug(f"Tool calls: {action.tool_calls}")
|
149 |
+
run.record(action=action, solver_history=solver.history)
|
150 |
+
run.complete = await solver.is_complete(observation)
|
151 |
+
if self.config.save_every_step:
|
152 |
+
await self.recorder.save(run)
|
153 |
+
if run.complete:
|
154 |
+
run.result = action.text
|
155 |
+
self.logger.info(f"π€ [bold purple]Task complete.[/] β¨ \n{run.result}")
|
156 |
+
break
|
157 |
+
self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds")
|
158 |
+
async with async_timeout(
|
159 |
+
self.config.environment_timeout,
|
160 |
+
"Environment response",
|
161 |
+
):
|
162 |
+
observation: Observation = await environment.execute_action(action)
|
163 |
+
step_count += 1
|
164 |
+
await event_queue.put(observation)
|
165 |
+
yield run
|
166 |
+
if not run.complete:
|
167 |
+
self.logger.warning("π€ [bold purple]Ran out of steps!")
|
168 |
+
await self.recorder.terminate(run, save=True)
|
169 |
+
yield run
|
170 |
+
|
171 |
+
async def run(self, task: str) -> Run:
|
172 |
+
async for run in self.run_generator(task):
|
173 |
+
self._run = run
|
174 |
+
return run
|
175 |
+
|
176 |
+
def run_concurrent(self, tasks: list[str]) -> list[Run]:
|
177 |
+
async def gather_runs():
|
178 |
+
return await asyncio.gather(
|
179 |
+
*[self.run(task) for task in tasks],
|
180 |
+
return_exceptions=True,
|
181 |
+
)
|
182 |
+
|
183 |
+
return asyncio.run(gather_runs())
|
184 |
+
|
185 |
+
@property
|
186 |
+
def complete(self) -> bool:
|
187 |
+
if self._run is None:
|
188 |
+
raise RuntimeError("Run not initialised")
|
189 |
+
return self._run.complete
|
190 |
+
|
191 |
+
@property
|
192 |
+
def run_id(self) -> str:
|
193 |
+
if self._run is None:
|
194 |
+
raise RuntimeError("Run not initialised")
|
195 |
+
return self._run.run_id
|
196 |
+
|
197 |
+
@property
|
198 |
+
def run_result(self) -> str:
|
199 |
+
if self._run is None:
|
200 |
+
raise RuntimeError("Run not initialised")
|
201 |
+
return self._run.result
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
from proxy_lite.logger import logger
|
206 |
+
|
207 |
+
config = RunnerConfig.from_dict(
|
208 |
+
{
|
209 |
+
"environment": {
|
210 |
+
"name": "webbrowser",
|
211 |
+
"homepage": "https://www.google.com",
|
212 |
+
"viewport_width": 1280,
|
213 |
+
"viewport_height": 1920,
|
214 |
+
"screenshot_delay": 1,
|
215 |
+
"headless": False,
|
216 |
+
},
|
217 |
+
"solver": {
|
218 |
+
"name": "simple",
|
219 |
+
"agent": {
|
220 |
+
"name": "proxy_lite",
|
221 |
+
"client": {
|
222 |
+
"name": "convergence",
|
223 |
+
"model_id": "convergence-ai/proxy-lite",
|
224 |
+
"api_base": "https://convergence-ai-demo-api.hf.space/v1",
|
225 |
+
},
|
226 |
+
},
|
227 |
+
},
|
228 |
+
"max_steps": 150,
|
229 |
+
"action_timeout": 1800,
|
230 |
+
"environment_timeout": 1800,
|
231 |
+
"task_timeout": 18000,
|
232 |
+
"logger_level": "DEBUG",
|
233 |
+
},
|
234 |
+
)
|
235 |
+
logger.info(f"π€ [bold purple]Config:[/] {config}")
|
236 |
+
|
237 |
+
runner = Runner(config=config)
|
238 |
+
result = asyncio.run(runner.run("Tell me the tesla stock price."))
|
239 |
+
print(runner.run_result)
|
240 |
+
print(runner.complete)
|
proxy-lite-demo-v2/src/proxy_lite/serializer.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from proxy_lite.history import MessageAdapter, MessageHistory
|
7 |
+
from proxy_lite.tools import Tool
|
8 |
+
|
9 |
+
|
10 |
+
class BaseSerializer(BaseModel, ABC):
|
11 |
+
"""Base class for serializers.
|
12 |
+
|
13 |
+
Serializers are responsible for converting between the internal MessageHistory/Tool
|
14 |
+
objects and the external API format. Deserialise is not always possible, so raise
|
15 |
+
appropriate warnings.
|
16 |
+
"""
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def serialize_messages(self, message_history: MessageHistory) -> list[dict]: ...
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def deserialize_messages(self, data: list[dict]) -> MessageHistory: ...
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def serialize_tools(self, tools: list[Tool]) -> list[dict]: ...
|
26 |
+
|
27 |
+
|
28 |
+
class OpenAICompatibleSerializer(BaseSerializer):
|
29 |
+
def serialize_messages(self, message_history: MessageHistory) -> list[dict]:
|
30 |
+
return message_history.to_dict(exclude={"label"})
|
31 |
+
|
32 |
+
def deserialize_messages(self, data: list[dict]) -> MessageHistory:
|
33 |
+
return MessageHistory(
|
34 |
+
messages=[MessageAdapter.validate_python(message) for message in data],
|
35 |
+
)
|
36 |
+
|
37 |
+
def serialize_tools(self, tools: list[Tool]) -> list[dict]:
|
38 |
+
tool_schemas = [[{"type": "function", "function": schema} for schema in tool.schema] for tool in tools]
|
39 |
+
return list(itertools.chain.from_iterable(tool_schemas))
|
proxy-lite-demo-v2/src/proxy_lite/solvers/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from .simple_solver import SimpleSolver, SimpleSolverConfig
|
6 |
+
from .solver_base import BaseSolver, BaseSolverConfig, Solvers
|
7 |
+
|
8 |
+
SolverConfigTypes = Union[*Solvers._solver_config_registry.values()]
|
9 |
+
SolverTypes = Union[*Solvers._solver_registry.values()]
|
10 |
+
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"BaseSolver",
|
14 |
+
"BaseSolverConfig",
|
15 |
+
"SimpleSolver",
|
16 |
+
"SimpleSolverConfig",
|
17 |
+
"SolverConfigTypes",
|
18 |
+
"SolverTypes",
|
19 |
+
"Solvers",
|
20 |
+
]
|
proxy-lite-demo-v2/src/proxy_lite/solvers/simple_solver.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: E501
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from functools import cached_property
|
5 |
+
from typing import Literal, Optional
|
6 |
+
|
7 |
+
from proxy_lite.agents import AgentConfigTypes, Agents, BaseAgent
|
8 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
9 |
+
from proxy_lite.history import (
|
10 |
+
MessageHistory,
|
11 |
+
MessageLabel,
|
12 |
+
SystemMessage,
|
13 |
+
)
|
14 |
+
from proxy_lite.solvers.solver_base import BaseSolver, BaseSolverConfig, Solvers
|
15 |
+
from proxy_lite.tools import ReturnValueTool, Tool
|
16 |
+
|
17 |
+
WEB_TOOL_TURN = """The action has been attempted in the computer."""
|
18 |
+
|
19 |
+
|
20 |
+
@Solvers.register_solver_config("simple")
|
21 |
+
class SimpleSolverConfig(BaseSolverConfig):
|
22 |
+
name: Literal["simple"] = "simple"
|
23 |
+
agent: AgentConfigTypes
|
24 |
+
|
25 |
+
|
26 |
+
@Solvers.register_solver("simple")
|
27 |
+
class SimpleSolver(BaseSolver):
|
28 |
+
task: Optional[str] = None
|
29 |
+
complete: bool = False
|
30 |
+
|
31 |
+
@cached_property
|
32 |
+
def tools(self) -> list[Tool]:
|
33 |
+
return [ReturnValueTool()] + self.env_tools
|
34 |
+
|
35 |
+
@cached_property
|
36 |
+
def agent(self) -> BaseAgent:
|
37 |
+
if self.logger:
|
38 |
+
self.logger.debug(f"Tools: {self.tools}")
|
39 |
+
return Agents.get(self.config.agent.name)(
|
40 |
+
config=self.config.agent,
|
41 |
+
env_tools=self.tools,
|
42 |
+
)
|
43 |
+
|
44 |
+
@property
|
45 |
+
def history(self) -> MessageHistory:
|
46 |
+
return MessageHistory(
|
47 |
+
messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
|
48 |
+
)
|
49 |
+
|
50 |
+
async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
|
51 |
+
self.env_tools = env_tools
|
52 |
+
self.task = task
|
53 |
+
self.agent.receive_user_message(
|
54 |
+
text=f"Task: {task}",
|
55 |
+
label=MessageLabel.USER_INPUT,
|
56 |
+
)
|
57 |
+
self.logger.debug(f"Initialised with task: {task}")
|
58 |
+
|
59 |
+
async def act(self, observation: Observation) -> Action:
|
60 |
+
# Send tool responses to agent as tool messages if they exist
|
61 |
+
if observation.state.tool_responses:
|
62 |
+
for tool_response in observation.state.tool_responses:
|
63 |
+
if tool_response.content and tool_response.id:
|
64 |
+
await self.agent.receive_tool_message(
|
65 |
+
text=tool_response.content,
|
66 |
+
tool_id=tool_response.id,
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
print(f"π§ DEBUG: Skipping tool response - content exists: {bool(tool_response.content)}, id exists: {bool(tool_response.id)}")
|
70 |
+
else:
|
71 |
+
print("π§ DEBUG: No tool responses to process")
|
72 |
+
|
73 |
+
self.agent.receive_user_message(
|
74 |
+
image=observation.state.image,
|
75 |
+
text=observation.state.text,
|
76 |
+
label=MessageLabel.SCREENSHOT,
|
77 |
+
is_base64=True,
|
78 |
+
)
|
79 |
+
|
80 |
+
message = await self.agent.generate_output(use_tool=True)
|
81 |
+
|
82 |
+
self.logger.debug(f"Assistant message generated: {message}")
|
83 |
+
|
84 |
+
# check tool calls for return_value
|
85 |
+
if any(tool_call.function["name"] == "return_value" for tool_call in message.tool_calls):
|
86 |
+
self.complete = True
|
87 |
+
arguments = json.loads(message.tool_calls[0].function["arguments"])
|
88 |
+
if isinstance(arguments, str):
|
89 |
+
arguments = json.loads(arguments)
|
90 |
+
return_value = arguments["value"]
|
91 |
+
return Action(tool_calls=[], text=return_value)
|
92 |
+
|
93 |
+
# Handle empty content array from API response
|
94 |
+
if not message.content or len(message.content) == 0:
|
95 |
+
self.logger.warning("Message content is empty, using empty string as fallback")
|
96 |
+
text_content = ""
|
97 |
+
else:
|
98 |
+
text_content = message.content[0].text
|
99 |
+
|
100 |
+
observation_match = re.search(r"<observation>(.*?)</observation>", text_content, re.DOTALL)
|
101 |
+
observation_content = observation_match.group(1).strip() if observation_match else ""
|
102 |
+
|
103 |
+
self.logger.info("π [bold blue]Observation:[/]")
|
104 |
+
await self.logger.stream_message(observation_content)
|
105 |
+
|
106 |
+
# Extract text between thinking tags if present
|
107 |
+
thinking_match = re.search(r"<thinking>(.*?)</thinking>", text_content, re.DOTALL)
|
108 |
+
thinking_content = thinking_match.group(1).strip() if thinking_match else text_content
|
109 |
+
|
110 |
+
self.logger.info("π§ [bold purple]Thinking:[/]")
|
111 |
+
await self.logger.stream_message(thinking_content)
|
112 |
+
|
113 |
+
return Action(tool_calls=message.tool_calls, text=text_content)
|
114 |
+
|
115 |
+
async def is_complete(self, observation: Observation) -> bool:
|
116 |
+
env_terminated = observation.terminated
|
117 |
+
return self.complete or env_terminated
|
proxy-lite-demo-v2/src/proxy_lite/solvers/solver_base.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from functools import cached_property
|
4 |
+
from typing import Optional, Self, Type, cast
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from proxy_lite.environments.environment_base import Action, Observation
|
9 |
+
from proxy_lite.tools import Tool
|
10 |
+
|
11 |
+
|
12 |
+
class BaseSolverConfig(BaseModel):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class BaseSolver(BaseModel, ABC):
|
17 |
+
task: Optional[str] = None
|
18 |
+
env_tools: list[Tool] = Field(default_factory=list)
|
19 |
+
config: BaseSolverConfig
|
20 |
+
logger: logging.Logger | None = None
|
21 |
+
|
22 |
+
class Config:
|
23 |
+
arbitrary_types_allowed = True
|
24 |
+
|
25 |
+
async def __aenter__(self) -> Self:
|
26 |
+
return self
|
27 |
+
|
28 |
+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
29 |
+
pass
|
30 |
+
|
31 |
+
@cached_property
|
32 |
+
@abstractmethod
|
33 |
+
def tools(self) -> list[Tool]: ...
|
34 |
+
|
35 |
+
@abstractmethod
|
36 |
+
async def initialise(
|
37 |
+
self,
|
38 |
+
task: str,
|
39 |
+
env_tools: list[Tool],
|
40 |
+
env_info: str,
|
41 |
+
) -> None:
|
42 |
+
"""
|
43 |
+
Initialise the solution with the given task.
|
44 |
+
"""
|
45 |
+
...
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
async def act(self, observation: Observation) -> Action:
|
49 |
+
"""
|
50 |
+
Return an action for interacting with the environment.
|
51 |
+
"""
|
52 |
+
...
|
53 |
+
|
54 |
+
async def is_complete(self, observation: Observation) -> bool:
|
55 |
+
"""
|
56 |
+
Return a boolean indicating if the task is complete.
|
57 |
+
"""
|
58 |
+
return observation.terminated
|
59 |
+
|
60 |
+
|
61 |
+
class Solvers:
|
62 |
+
_solver_registry: dict[str, type[BaseSolver]] = {}
|
63 |
+
_solver_config_registry: dict[str, type[BaseSolverConfig]] = {}
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def register_solver(cls, name: str):
|
67 |
+
"""
|
68 |
+
Decorator to register a Solver class under a given name.
|
69 |
+
|
70 |
+
Example:
|
71 |
+
@Solvers.register_solver("my_solver")
|
72 |
+
class MySolver(BaseSolver):
|
73 |
+
...
|
74 |
+
"""
|
75 |
+
|
76 |
+
def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]:
|
77 |
+
cls._solver_registry[name] = solver_cls
|
78 |
+
return solver_cls
|
79 |
+
|
80 |
+
return decorator
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def register_solver_config(cls, name: str):
|
84 |
+
"""
|
85 |
+
Decorator to register a Solver configuration class under a given name.
|
86 |
+
|
87 |
+
Example:
|
88 |
+
@Solvers.register_solver_config("my_solver")
|
89 |
+
class MySolverConfig(BaseSolverConfig):
|
90 |
+
...
|
91 |
+
"""
|
92 |
+
|
93 |
+
def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]:
|
94 |
+
cls._solver_config_registry[name] = config_cls
|
95 |
+
return config_cls
|
96 |
+
|
97 |
+
return decorator
|
98 |
+
|
99 |
+
@classmethod
|
100 |
+
def get(cls, name: str) -> type[BaseSolver]:
|
101 |
+
"""
|
102 |
+
Retrieve a registered Solver class by its name.
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
ValueError: If no such solver is found.
|
106 |
+
"""
|
107 |
+
try:
|
108 |
+
return cast(Type[BaseSolver], cls._solver_registry[name])
|
109 |
+
except KeyError:
|
110 |
+
raise ValueError(f"Solver '{name}' not found.")
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def get_config(cls, name: str) -> type[BaseSolverConfig]:
|
114 |
+
"""
|
115 |
+
Retrieve a registered Solver configuration class by its name.
|
116 |
+
|
117 |
+
Raises:
|
118 |
+
ValueError: If no such config is found.
|
119 |
+
"""
|
120 |
+
try:
|
121 |
+
return cast(Type[BaseSolverConfig], cls._solver_config_registry[name])
|
122 |
+
except KeyError:
|
123 |
+
raise ValueError(f"Solver config for '{name}' not found.")
|
proxy-lite-demo-v2/src/proxy_lite/tools/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .browser_tool import BrowserTool
|
2 |
+
from .return_tool import ReturnValueTool
|
3 |
+
from .tool_base import Tool, ToolExecutionResponse, attach_param_schema
|
4 |
+
|
5 |
+
__all__ = ["BrowserTool", "ReturnValueTool", "Tool", "ToolExecutionResponse", "attach_param_schema"]
|
proxy-lite-demo-v2/src/proxy_lite/tools/browser_tool.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from contextlib import AsyncExitStack
|
3 |
+
from typing import List, Literal, Optional, Any
|
4 |
+
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
|
7 |
+
from proxy_lite.browser.browser import BrowserSession
|
8 |
+
from proxy_lite.logger import logger
|
9 |
+
|
10 |
+
from .tool_base import Tool, ToolExecutionResponse, attach_param_schema
|
11 |
+
|
12 |
+
SELF_CONTAINED_TAGS = [
|
13 |
+
# many of these are non-interactive but keeping them anyway
|
14 |
+
"area",
|
15 |
+
"base",
|
16 |
+
"br",
|
17 |
+
"col",
|
18 |
+
"embed",
|
19 |
+
"hr",
|
20 |
+
"img",
|
21 |
+
"input",
|
22 |
+
"link",
|
23 |
+
"meta",
|
24 |
+
"param",
|
25 |
+
"source",
|
26 |
+
"track",
|
27 |
+
"wbr",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
def element_as_text(
|
32 |
+
mark_id: int,
|
33 |
+
tag: Optional[str] = None,
|
34 |
+
text: Optional[str] = None,
|
35 |
+
**raw_attributes,
|
36 |
+
) -> str:
|
37 |
+
"""Return a text representation of all elements on the page"""
|
38 |
+
attributes = []
|
39 |
+
for k, v in raw_attributes.items():
|
40 |
+
if v is None:
|
41 |
+
continue
|
42 |
+
if isinstance(v, bool):
|
43 |
+
if v:
|
44 |
+
attributes.append(k)
|
45 |
+
# we ignore False bool attributes
|
46 |
+
else:
|
47 |
+
v = str(v)
|
48 |
+
if len(v) > 2500:
|
49 |
+
v = v[: 2500 - 1] + "β¦"
|
50 |
+
attributes.append(f'{k}="{v}"')
|
51 |
+
attributes = " ".join(attributes)
|
52 |
+
attributes = (" " + attributes).rstrip()
|
53 |
+
tag = tag.lower()
|
54 |
+
if text is None:
|
55 |
+
text = ""
|
56 |
+
if len(text) > 2500:
|
57 |
+
text = text[: 2500 - 1] + "β¦"
|
58 |
+
if tag in SELF_CONTAINED_TAGS:
|
59 |
+
if text:
|
60 |
+
logger.warning(
|
61 |
+
f"Got self-contained element '{tag}' which contained text '{text}'.",
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
return f"<{tag} id={mark_id}{attributes}/>"
|
65 |
+
return f"<{tag} id={mark_id}{attributes}>{text}</{tag}>"
|
66 |
+
|
67 |
+
|
68 |
+
class GotoParams(BaseModel):
|
69 |
+
url: str = Field(..., description="The web address to visit. Must be a valid URL.")
|
70 |
+
|
71 |
+
|
72 |
+
class GoogleSearchParams(BaseModel):
|
73 |
+
query_plan: str = Field(
|
74 |
+
...,
|
75 |
+
description="Plan out the query you will make. Re-write queries in a way that will yield the best results.",
|
76 |
+
)
|
77 |
+
query: str = Field(..., description="The Google search to perform.")
|
78 |
+
|
79 |
+
|
80 |
+
class ClickParams(BaseModel):
|
81 |
+
mark_id: int = Field(..., description="Element Mark ID.")
|
82 |
+
|
83 |
+
|
84 |
+
class TypeEntry(BaseModel):
|
85 |
+
mark_id: int = Field(..., description="Element Mark ID.")
|
86 |
+
content: str = Field(..., description="The text to type into the element.")
|
87 |
+
|
88 |
+
|
89 |
+
class TypeParams(BaseModel):
|
90 |
+
entries: List[TypeEntry] = Field(
|
91 |
+
...,
|
92 |
+
description="A list of elements and contents to type.",
|
93 |
+
)
|
94 |
+
submit: bool = Field(
|
95 |
+
...,
|
96 |
+
description='Whether to press the "Enter" key after typing in the last entry.',
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
class ScrollParams(BaseModel):
|
101 |
+
direction: Literal["up", "down", "left", "right"] = Field(
|
102 |
+
...,
|
103 |
+
description='Direction to scroll. Must be one of "up", "down", "left" or "right".',
|
104 |
+
)
|
105 |
+
mark_id: int = Field(
|
106 |
+
...,
|
107 |
+
description="What to scroll. Use -1 to scroll the whole page otherwise give the mark ID of an element that is `scrollable`.", # noqa: E501
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
class BackParams(BaseModel):
|
112 |
+
pass
|
113 |
+
|
114 |
+
|
115 |
+
class WaitParams(BaseModel):
|
116 |
+
pass
|
117 |
+
|
118 |
+
|
119 |
+
class ReloadParams(BaseModel):
|
120 |
+
pass
|
121 |
+
|
122 |
+
|
123 |
+
class DoNothingParams(BaseModel):
|
124 |
+
pass
|
125 |
+
|
126 |
+
# --- NEW: Parameters for open_new_tab_and_go_to tool ---
|
127 |
+
class OpenNewTabAndGoToParams(BaseModel):
|
128 |
+
url: str = Field(..., description="The URL to navigate to in the new tab.")
|
129 |
+
|
130 |
+
# --- NEW: Parameters for select_option_by_text tool ---
|
131 |
+
class SelectOptionByTextParams(BaseModel):
|
132 |
+
mark_id: int = Field(..., description="The mark ID of the select element.")
|
133 |
+
option_text: str = Field(..., description="The text content of the option to select.")
|
134 |
+
|
135 |
+
|
136 |
+
class BrowserTool(Tool):
|
137 |
+
def __init__(self, session: BrowserSession) -> None:
|
138 |
+
super().__init__()
|
139 |
+
self.browser = session
|
140 |
+
|
141 |
+
async def __aenter__(self):
|
142 |
+
self._exit_stack = AsyncExitStack()
|
143 |
+
await self._exit_stack.enter_async_context(self.browser)
|
144 |
+
return self
|
145 |
+
|
146 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
147 |
+
await self._exit_stack.aclose()
|
148 |
+
|
149 |
+
@property
|
150 |
+
def poi_text(self) -> str:
|
151 |
+
# Get all points of interest on the page as text
|
152 |
+
texts = [element_as_text(mark_id=i, **element) for i, element in enumerate(self.browser.poi_elements)]
|
153 |
+
# Return formatted text of points of interest on page
|
154 |
+
return "\n".join([txt for txt in texts if txt])
|
155 |
+
|
156 |
+
@attach_param_schema(GotoParams)
|
157 |
+
async def goto(self, url: str) -> ToolExecutionResponse:
|
158 |
+
"""Go directly to a specific web url. Specify the exact URL."""
|
159 |
+
await self.browser.goto(url)
|
160 |
+
return ToolExecutionResponse(content=f"Successfully navigated to URL: {url}")
|
161 |
+
|
162 |
+
@attach_param_schema(GoogleSearchParams)
|
163 |
+
async def google_search(self, query_plan: str, query: str) -> ToolExecutionResponse:
|
164 |
+
"""Perform a generic web search using Google.
|
165 |
+
Results may not be relevant. If you see poor results, you can try another query.
|
166 |
+
"""
|
167 |
+
url = f"https://www.google.com/search?q={query}"
|
168 |
+
await self.browser.goto(url)
|
169 |
+
return ToolExecutionResponse(content=f"Performed Google search for: {query}")
|
170 |
+
|
171 |
+
@attach_param_schema(ClickParams)
|
172 |
+
async def click(self, mark_id: int) -> ToolExecutionResponse:
|
173 |
+
"""Click on an element of the page."""
|
174 |
+
try:
|
175 |
+
await self.browser.click(mark_id=mark_id)
|
176 |
+
return ToolExecutionResponse(content=f"Clicked element with mark ID: {mark_id}")
|
177 |
+
except IndexError as e:
|
178 |
+
# This happens if mark_id is out of bounds for browser.poi_centroids
|
179 |
+
logger.error(f"Click failed: Mark ID {mark_id} not found or POI list empty. Error: {e}")
|
180 |
+
return ToolExecutionResponse(content=f"Failed to click element with mark ID {mark_id}. Element not found or POI list invalid.")
|
181 |
+
except Exception as e:
|
182 |
+
logger.error(f"Click failed with unexpected error for mark ID {mark_id}: {e}")
|
183 |
+
return ToolExecutionResponse(content=f"An unexpected error occurred while trying to click element {mark_id}: {e}")
|
184 |
+
|
185 |
+
|
186 |
+
@attach_param_schema(TypeParams)
|
187 |
+
async def type(self, entries: List[dict], submit: bool) -> ToolExecutionResponse:
|
188 |
+
"""Type text.
|
189 |
+
You can type into one or more elements.
|
190 |
+
Note that the text inside an element is cleared before typing.
|
191 |
+
"""
|
192 |
+
typed_ids = []
|
193 |
+
for i, entry_dict in enumerate(entries):
|
194 |
+
try:
|
195 |
+
entry = TypeEntry(**entry_dict)
|
196 |
+
last_entry = i == len(entries) - 1
|
197 |
+
old_poi_positions = [tuple(point) for point in self.browser.poi_centroids]
|
198 |
+
await self.browser.enter_text(
|
199 |
+
mark_id=entry.mark_id,
|
200 |
+
text=entry.content,
|
201 |
+
submit=submit and last_entry,
|
202 |
+
)
|
203 |
+
typed_ids.append(entry.mark_id)
|
204 |
+
await self.browser.update_poi()
|
205 |
+
new_poi_positions = [tuple(point) for point in self.browser.poi_centroids]
|
206 |
+
if not last_entry and old_poi_positions != new_poi_positions:
|
207 |
+
logger.error(
|
208 |
+
"POI positions changed mid-typing, cancelling future type entries.",
|
209 |
+
)
|
210 |
+
break
|
211 |
+
except IndexError as e:
|
212 |
+
logger.error(f"Type failed: Mark ID {entry.mark_id} not found or POI list empty. Error: {e}")
|
213 |
+
return ToolExecutionResponse(content=f"Failed to type into element with mark ID {entry.mark_id}. Element not found or POI list invalid. Typed into: {typed_ids if typed_ids else 'none'}.")
|
214 |
+
except Exception as e:
|
215 |
+
logger.error(f"Type failed with unexpected error for mark ID {entry.mark_id}: {e}")
|
216 |
+
return ToolExecutionResponse(content=f"An unexpected error occurred while trying to type into element {entry.mark_id}: {e}. Typed into: {typed_ids if typed_ids else 'none'}.")
|
217 |
+
|
218 |
+
return ToolExecutionResponse(
|
219 |
+
content=f"Typed text into elements with mark IDs: {typed_ids}",
|
220 |
+
)
|
221 |
+
|
222 |
+
@attach_param_schema(ScrollParams)
|
223 |
+
async def scroll(self, direction: str, mark_id: int) -> ToolExecutionResponse:
|
224 |
+
"""Scroll the page (or a scrollable element) up, down, left or right."""
|
225 |
+
try:
|
226 |
+
if mark_id == -1:
|
227 |
+
mark_id_for_browser = None # Pass None to browser.scroll for page scroll
|
228 |
+
else:
|
229 |
+
mark_id_for_browser = mark_id
|
230 |
+
|
231 |
+
await self.browser.scroll(direction=direction, mark_id=mark_id_for_browser)
|
232 |
+
return ToolExecutionResponse(content=f"Scrolled {direction} on element with mark ID: {mark_id if mark_id != -1 else 'page'}")
|
233 |
+
except IndexError as e:
|
234 |
+
logger.error(f"Scroll failed: Mark ID {mark_id} not found or POI list empty. Error: {e}")
|
235 |
+
return ToolExecutionResponse(content=f"Failed to scroll element with mark ID {mark_id}. Element not found or POI list invalid.")
|
236 |
+
except Exception as e:
|
237 |
+
logger.error(f"Scroll failed with unexpected error for mark ID {mark_id}: {e}")
|
238 |
+
return ToolExecutionResponse(content=f"An unexpected error occurred while trying to scroll element {mark_id}: {e}")
|
239 |
+
|
240 |
+
@attach_param_schema(BackParams)
|
241 |
+
async def back(self) -> ToolExecutionResponse:
|
242 |
+
"""Go back to the previous page."""
|
243 |
+
try:
|
244 |
+
await self.browser.go_back()
|
245 |
+
return ToolExecutionResponse(content="Went back to the previous page.")
|
246 |
+
except Exception as e:
|
247 |
+
logger.error(f"Go back failed: {e}")
|
248 |
+
return ToolExecutionResponse(content=f"Failed to go back: {e}")
|
249 |
+
|
250 |
+
|
251 |
+
@attach_param_schema(WaitParams)
|
252 |
+
async def wait(self) -> ToolExecutionResponse:
|
253 |
+
"""Wait three seconds. Useful when the page appears to still be loading, or if there are any unfinished webpage processes.""" # noqa: E501
|
254 |
+
await asyncio.sleep(3)
|
255 |
+
return ToolExecutionResponse(content="Waited for a few seconds.")
|
256 |
+
|
257 |
+
@attach_param_schema(ReloadParams)
|
258 |
+
async def reload(self) -> ToolExecutionResponse:
|
259 |
+
"""Reload the current page. Useful when the page seems unresponsive, broken, outdated, or if you want to reset the page to its initial state.""" # noqa: E501
|
260 |
+
try:
|
261 |
+
await self.browser.reload()
|
262 |
+
return ToolExecutionResponse(content="Reloaded the current page.")
|
263 |
+
except Exception as e:
|
264 |
+
logger.error(f"Reload failed: {e}")
|
265 |
+
return ToolExecutionResponse(content=f"Failed to reload the page: {e}")
|
266 |
+
|
267 |
+
|
268 |
+
@attach_param_schema(DoNothingParams)
|
269 |
+
async def do_nothing_tool(self) -> ToolExecutionResponse:
|
270 |
+
"""Do nothing. Use this if you have no need for the browser at this time."""
|
271 |
+
return ToolExecutionResponse(content="Did nothing in the browser.")
|
272 |
+
|
273 |
+
# --- NEW: Expose the open_new_tab_and_go_to method as a tool ---
|
274 |
+
@attach_param_schema(OpenNewTabAndGoToParams)
|
275 |
+
async def open_new_tab_and_go_to(self, url: str) -> ToolExecutionResponse:
|
276 |
+
"""
|
277 |
+
Opens a new browser tab/page and navigates to the specified URL.
|
278 |
+
Closes the old page if it's not the last one remaining.
|
279 |
+
Use this to bypass loading issues by forcing a new navigation.
|
280 |
+
"""
|
281 |
+
try:
|
282 |
+
await self.browser.open_new_tab_and_go_to(url)
|
283 |
+
return ToolExecutionResponse(
|
284 |
+
content=f"Successfully opened new tab and navigated to: {url}",
|
285 |
+
)
|
286 |
+
except Exception as e:
|
287 |
+
logger.error(f"Error opening new tab and navigating to {url}: {e}")
|
288 |
+
return ToolExecutionResponse(content=f"Failed to open new tab and navigate to {url}: {e}")
|
289 |
+
|
290 |
+
# --- NEW: Select option by text from select element ---
|
291 |
+
@attach_param_schema(SelectOptionByTextParams)
|
292 |
+
async def select_option_by_text(self, mark_id: int, option_text: str) -> ToolExecutionResponse:
|
293 |
+
"""
|
294 |
+
Selects an option from a select element (including dual select picklists) by finding the option with matching text.
|
295 |
+
This is especially useful for Salesforce dual select picklists where you need to find and select a specific option.
|
296 |
+
Uses Playwright's native iframe handling to bypass CORS restrictions.
|
297 |
+
"""
|
298 |
+
try:
|
299 |
+
logger.info(f"Attempting to select option '{option_text}' from element {mark_id}")
|
300 |
+
|
301 |
+
# First, try to click the select element to ensure it's focused
|
302 |
+
await self.browser.click(mark_id=mark_id)
|
303 |
+
await asyncio.sleep(0.5) # Wait for click to register
|
304 |
+
|
305 |
+
# Use Playwright's native frame handling instead of JavaScript evaluation
|
306 |
+
# This bypasses CORS restrictions that prevent JavaScript access
|
307 |
+
|
308 |
+
# Find all frames on the page
|
309 |
+
main_frame = self.browser.current_page.main_frame
|
310 |
+
all_frames = [main_frame] + main_frame.child_frames
|
311 |
+
|
312 |
+
logger.info(f"Searching for element {mark_id} across {len(all_frames)} frames")
|
313 |
+
|
314 |
+
for frame_idx, frame in enumerate(all_frames):
|
315 |
+
try:
|
316 |
+
# Look for select elements in this frame
|
317 |
+
select_elements = await frame.query_selector_all('select')
|
318 |
+
logger.info(f"Frame {frame_idx}: Found {len(select_elements)} select elements")
|
319 |
+
|
320 |
+
for select_elem in select_elements:
|
321 |
+
# Get all options for this select
|
322 |
+
options = await select_elem.query_selector_all('option')
|
323 |
+
|
324 |
+
# Check if any option contains our target text
|
325 |
+
for opt_idx, option in enumerate(options):
|
326 |
+
option_text_content = await option.text_content()
|
327 |
+
option_value = await option.get_attribute('value')
|
328 |
+
|
329 |
+
logger.info(f"Frame {frame_idx}, Select {select_elem}, Option {opt_idx}: text='{option_text_content}', value='{option_value}'")
|
330 |
+
|
331 |
+
if option_text_content and option_text.lower().strip() == option_text_content.lower().strip():
|
332 |
+
# Found the option! Click it directly instead of using select_option
|
333 |
+
try:
|
334 |
+
# Direct click with force=True to bypass visibility checks and short timeout
|
335 |
+
await option.click(force=True, timeout=5000)
|
336 |
+
logger.info(f"Successfully clicked option '{option_text_content.strip()}' in frame {frame_idx}")
|
337 |
+
|
338 |
+
return ToolExecutionResponse(
|
339 |
+
content=f"[ACTION COMPLETED] Successfully selected '{option_text_content.strip()}' from dual select picklist"
|
340 |
+
)
|
341 |
+
|
342 |
+
except Exception as select_error:
|
343 |
+
logger.info(f"Click timed out in frame {frame_idx}, but option may have been selected: {select_error}")
|
344 |
+
# Continue to next frame/option instead of failing completely
|
345 |
+
continue
|
346 |
+
|
347 |
+
except Exception as frame_error:
|
348 |
+
logger.info(f"Could not access frame {frame_idx}: {frame_error}")
|
349 |
+
continue
|
350 |
+
|
351 |
+
# If we get here, the option wasn't found in any frame
|
352 |
+
# Try to get available options for debugging
|
353 |
+
all_options = []
|
354 |
+
for frame in all_frames:
|
355 |
+
try:
|
356 |
+
select_elements = await frame.query_selector_all('select')
|
357 |
+
for select_elem in select_elements:
|
358 |
+
options = await select_elem.query_selector_all('option')
|
359 |
+
for option in options[:5]: # Limit to first 5 options per select
|
360 |
+
text = await option.text_content()
|
361 |
+
if text:
|
362 |
+
all_options.append(text.strip())
|
363 |
+
except:
|
364 |
+
continue
|
365 |
+
|
366 |
+
available_options_str = ', '.join(all_options[:10]) if all_options else 'None found'
|
367 |
+
return ToolExecutionResponse(
|
368 |
+
content=f"Failed to find option '{option_text}' in any select element. Available options (first 10): {available_options_str}"
|
369 |
+
)
|
370 |
+
|
371 |
+
except Exception as e:
|
372 |
+
logger.error(f"Error selecting option '{option_text}' from element {mark_id}: {e}")
|
373 |
+
return ToolExecutionResponse(content=f"An unexpected error occurred while selecting option '{option_text}': {e}")
|
374 |
+
|
proxy-lite-demo-v2/src/proxy_lite/tools/return_tool.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
|
3 |
+
from proxy_lite.tools.tool_base import Tool, attach_param_schema
|
4 |
+
|
5 |
+
|
6 |
+
class ReturnValueParams(BaseModel):
|
7 |
+
value: str = Field(description="The value to return to the user.")
|
8 |
+
|
9 |
+
|
10 |
+
class ReturnValueTool(Tool):
|
11 |
+
def __init__(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
@attach_param_schema(ReturnValueParams)
|
15 |
+
def return_value(self, value: str):
|
16 |
+
"""Return a value to the user. Use this tool when you have finished the task in order to provide any information the user has requested.""" # noqa: E501
|
17 |
+
print(value)
|
proxy-lite-demo-v2/src/proxy_lite/tools/tool_base.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from functools import cached_property, wraps
|
3 |
+
from typing import Any, Callable, Optional
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
|
8 |
+
class Tool:
|
9 |
+
async def __aenter__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
13 |
+
pass
|
14 |
+
|
15 |
+
@cached_property
|
16 |
+
def schema(self) -> list[dict[str, Any]]:
|
17 |
+
schema = []
|
18 |
+
for name, method in self.__class__.__dict__.items():
|
19 |
+
# If function is not callable and isn't decorated using attach_param_schema
|
20 |
+
if not isinstance(method, Callable) or not hasattr(method, "param_model"):
|
21 |
+
continue
|
22 |
+
|
23 |
+
docstring = inspect.getdoc(method)
|
24 |
+
if not docstring:
|
25 |
+
raise ValueError(f"The tool function '{name}' is missing a docstring.")
|
26 |
+
# Handle multi-line docstirngs
|
27 |
+
description = " ".join(line.strip() for line in docstring.split("\n"))
|
28 |
+
|
29 |
+
tool_json = {
|
30 |
+
"name": name,
|
31 |
+
"description": description,
|
32 |
+
"parameters": method.param_model.model_json_schema(),
|
33 |
+
}
|
34 |
+
schema.append(tool_json)
|
35 |
+
return schema
|
36 |
+
|
37 |
+
|
38 |
+
def attach_param_schema(param_model: type[BaseModel]):
|
39 |
+
def decorator(func: Callable) -> Callable:
|
40 |
+
@wraps(func)
|
41 |
+
def wrapper(self, **kwargs):
|
42 |
+
# Throw an error if there's a mismatch between the function parameters and pydantic model's fields.
|
43 |
+
validated_params = param_model(**kwargs)
|
44 |
+
return func(self, **validated_params.model_dump())
|
45 |
+
|
46 |
+
wrapper.param_model = param_model
|
47 |
+
return wrapper
|
48 |
+
|
49 |
+
return decorator
|
50 |
+
|
51 |
+
|
52 |
+
class ToolExecutionResponse(BaseModel):
|
53 |
+
content: Optional[str] = None
|
54 |
+
id: Optional[str] = None
|
proxy-lite-demo-v2/test_tool_calling.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import asyncio
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
sys.path.insert(0, 'src')
|
6 |
+
|
7 |
+
from proxy_lite.client import GeminiClient, GeminiClientConfig
|
8 |
+
from proxy_lite.history import MessageHistory, UserMessage, Text
|
9 |
+
from proxy_lite.tools.browser_tool import BrowserTool
|
10 |
+
from proxy_lite.browser.browser import BrowserSession
|
11 |
+
|
12 |
+
async def test_tool_calling():
|
13 |
+
# Setup client
|
14 |
+
api_key = os.environ.get("GEMINI_API_KEY")
|
15 |
+
if not api_key:
|
16 |
+
print("β GEMINI_API_KEY not set")
|
17 |
+
return
|
18 |
+
|
19 |
+
config = GeminiClientConfig(api_key=api_key)
|
20 |
+
client = GeminiClient(config=config)
|
21 |
+
|
22 |
+
# Create a dummy browser tool
|
23 |
+
class DummyBrowserSession:
|
24 |
+
async def __aenter__(self):
|
25 |
+
return self
|
26 |
+
async def __aexit__(self, *args):
|
27 |
+
pass
|
28 |
+
async def open_new_tab_and_go_to(self, url):
|
29 |
+
print(f"β
Would open new tab and go to: {url}")
|
30 |
+
return True
|
31 |
+
|
32 |
+
browser_tool = BrowserTool(DummyBrowserSession())
|
33 |
+
|
34 |
+
# Create message history
|
35 |
+
messages = MessageHistory()
|
36 |
+
messages.append(UserMessage(content=[Text(text="Please use the open_new_tab_and_go_to tool to navigate to https://google.com")]))
|
37 |
+
|
38 |
+
print("π Testing Gemini tool calling...")
|
39 |
+
|
40 |
+
try:
|
41 |
+
# Test tool calling
|
42 |
+
response = await client.create_completion(
|
43 |
+
messages=messages,
|
44 |
+
tools=[browser_tool],
|
45 |
+
temperature=0.7
|
46 |
+
)
|
47 |
+
|
48 |
+
print(f"β
Response received: {response}")
|
49 |
+
|
50 |
+
if response.choices[0].message.tool_calls:
|
51 |
+
print(f"β
Tool calls found: {len(response.choices[0].message.tool_calls)}")
|
52 |
+
for tool_call in response.choices[0].message.tool_calls:
|
53 |
+
print(f" - Tool: {tool_call.function.name}")
|
54 |
+
print(f" - Args: {tool_call.function.arguments}")
|
55 |
+
else:
|
56 |
+
print("β No tool calls found")
|
57 |
+
print(f"Content: {response.choices[0].message.content}")
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
print(f"β Error: {e}")
|
61 |
+
import traceback
|
62 |
+
traceback.print_exc()
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
asyncio.run(test_tool_calling())
|
proxy-lite-demo-v2/uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
proxy-lite-work/.forceignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# List files or directories below to ignore them when running force:source:push, force:source:pull, and force:source:status
|
2 |
+
# More information: https://developer.salesforce.com/docs/atlas.en-us.sfdx_dev.meta/sfdx_dev/sfdx_dev_exclude_source.htm
|
3 |
+
#
|
4 |
+
|
5 |
+
package.xml
|
6 |
+
|
7 |
+
# LWC configuration files
|
8 |
+
**/jsconfig.json
|
9 |
+
**/.eslintrc.json
|
10 |
+
|
11 |
+
# LWC Jest
|
12 |
+
**/__tests__/**
|