Alex J. Chan commited on
Commit
11250db
·
2 Parent(s): 0eff0a1 fab2b20

Merge pull request #3 from convergence-ai/alex/descriptions

Browse files
.gitignore CHANGED
@@ -171,4 +171,5 @@ cython_debug/
171
  .pypirc
172
 
173
  logs/
174
- local_trajectories/
 
 
171
  .pypirc
172
 
173
  logs/
174
+ local_trajectories/
175
+ screenshots/
src/proxy_lite/cli.py CHANGED
@@ -1,10 +1,12 @@
1
  import argparse
2
  import asyncio
 
3
  import os
4
  from pathlib import Path
5
  from typing import Optional
6
 
7
  from proxy_lite import Runner, RunnerConfig
 
8
  from proxy_lite.logger import logger
9
 
10
 
@@ -35,7 +37,21 @@ def do_command(args):
35
  if args.viewport_height:
36
  config.viewport_height = args.viewport_height
37
  o = Runner(config=config)
38
- asyncio.run(o.run(do_text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def main():
 
1
  import argparse
2
  import asyncio
3
+ import base64
4
  import os
5
  from pathlib import Path
6
  from typing import Optional
7
 
8
  from proxy_lite import Runner, RunnerConfig
9
+ from proxy_lite.gif_maker import create_run_gif
10
  from proxy_lite.logger import logger
11
 
12
 
 
37
  if args.viewport_height:
38
  config.viewport_height = args.viewport_height
39
  o = Runner(config=config)
40
+ result = asyncio.run(o.run(do_text))
41
+
42
+ final_screenshot = result.observations[-1].info["original_image"]
43
+ folder_path = Path(__file__).parent.parent.parent / "screenshots"
44
+ folder_path.mkdir(parents=True, exist_ok=True)
45
+ path = folder_path / f"{result.run_id}.png"
46
+ with open(path, "wb") as f:
47
+ f.write(base64.b64decode(final_screenshot))
48
+ logger.info(f"🤖 Final screenshot saved to {path}")
49
+
50
+ gif_folder_path = Path(__file__).parent.parent.parent / "gifs"
51
+ gif_folder_path.mkdir(parents=True, exist_ok=True)
52
+ gif_path = gif_folder_path / f"{result.run_id}.gif"
53
+ create_run_gif(result, gif_path, duration=1500)
54
+ logger.info(f"🤖 GIF saved to {gif_path}")
55
 
56
 
57
  def main():
src/proxy_lite/client.py CHANGED
@@ -15,7 +15,7 @@ from proxy_lite.history import MessageHistory
15
  from proxy_lite.logger import logger
16
  from proxy_lite.serializer import (
17
  BaseSerializer,
18
- OpenAISerializer,
19
  )
20
  from proxy_lite.tools import Tool
21
 
@@ -78,7 +78,7 @@ class OpenAIClientConfig(BaseClientConfig):
78
 
79
  class OpenAIClient(BaseClient):
80
  config: OpenAIClientConfig
81
- serializer: ClassVar[OpenAISerializer] = OpenAISerializer()
82
 
83
  @cached_property
84
  def external_client(self) -> AsyncOpenAI:
@@ -119,14 +119,14 @@ class ConvergenceClientConfig(BaseClientConfig):
119
 
120
  class ConvergenceClient(OpenAIClient):
121
  config: ConvergenceClientConfig
122
- serializer: ClassVar[OpenAISerializer] = OpenAISerializer()
123
  _model_validated: bool = False
124
 
125
  async def _validate_model(self) -> None:
126
  try:
127
- await self.external_client.beta.chat.completions.parse(
128
- model=self.config.model_id,
129
- messages=[{"role": "user", "content": "Hello"}],
130
  )
131
  self._model_validated = True
132
  logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
 
15
  from proxy_lite.logger import logger
16
  from proxy_lite.serializer import (
17
  BaseSerializer,
18
+ OpenAICompatibleSerializer,
19
  )
20
  from proxy_lite.tools import Tool
21
 
 
78
 
79
  class OpenAIClient(BaseClient):
80
  config: OpenAIClientConfig
81
+ serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
82
 
83
  @cached_property
84
  def external_client(self) -> AsyncOpenAI:
 
119
 
120
  class ConvergenceClient(OpenAIClient):
121
  config: ConvergenceClientConfig
122
+ serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
123
  _model_validated: bool = False
124
 
125
  async def _validate_model(self) -> None:
126
  try:
127
+ response = await self.external_client.models.list()
128
+ assert self.config.model_id in [model.id for model in response.data], (
129
+ f"Model {self.config.model_id} not found in {response.data}"
130
  )
131
  self._model_validated = True
132
  logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
src/proxy_lite/configs/default.yaml CHANGED
@@ -7,6 +7,7 @@ environment:
7
  include_poi_text: true
8
  headless: false
9
  homepage: https://www.google.co.uk
 
10
  solver:
11
  name: simple
12
  agent:
@@ -17,4 +18,6 @@ solver:
17
  api_base: https://convergence-ai-demo-api.hf.space/v1
18
  local_view: true
19
  task_timeout: 1800
 
 
20
  verbose: true
 
7
  include_poi_text: true
8
  headless: false
9
  homepage: https://www.google.co.uk
10
+ keep_original_image: true
11
  solver:
12
  name: simple
13
  agent:
 
18
  api_base: https://convergence-ai-demo-api.hf.space/v1
19
  local_view: true
20
  task_timeout: 1800
21
+ environment_timeout: 1800
22
+ action_timeout: 1800
23
  verbose: true
src/proxy_lite/gif_maker.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ import textwrap
4
+ from io import BytesIO
5
+
6
+ from PIL import Image, ImageDraw, ImageFont
7
+
8
+ from proxy_lite.environments.environment_base import Action, Observation
9
+ from proxy_lite.recorder import Run
10
+
11
+
12
+ def create_run_gif(
13
+ run: Run, output_path: str, white_panel_width: int = 300, duration: int = 1500, resize_factor: int = 4
14
+ ) -> None:
15
+ """
16
+ Generate a gif from the Run object's history.
17
+
18
+ For each Observation record, the observation image is decoded from its base64
19
+ encoded string. If the next record is an Action, its text is drawn onto a
20
+ white panel. The observation image and the white panel are then concatenated
21
+ horizontally to produce a frame.
22
+
23
+ Parameters:
24
+ run (Run): A Run object with its history containing Observation and Action records.
25
+ output_path (str): The path where the GIF will be saved.
26
+ white_panel_width (int): The width of the white panel for displaying text.
27
+ Default increased to 400 for larger images.
28
+ duration (int): Duration between frames in milliseconds.
29
+ Increased here to slow the FPS (default is 1000ms).
30
+ resize_factor (int): The factor to resize the image down by.
31
+ """
32
+ frames = []
33
+ history = run.history
34
+ i = 0
35
+ while i < len(history):
36
+ if isinstance(history[i], Observation):
37
+ observation = history[i]
38
+ image_data = observation.state.image
39
+ if not image_data:
40
+ i += 1
41
+ continue
42
+ # Decode the base64 image
43
+ image_bytes = base64.b64decode(image_data)
44
+ obs_img = Image.open(BytesIO(image_bytes)).convert("RGB")
45
+
46
+ # scale the image down
47
+ obs_img = obs_img.resize((obs_img.width // resize_factor, obs_img.height // resize_factor))
48
+
49
+ # Check if the next record is an Action and extract its text if available
50
+ action_text = ""
51
+ if i + 1 < len(history) and isinstance(history[i + 1], Action):
52
+ action = history[i + 1]
53
+ if action.text:
54
+ action_text = action.text
55
+
56
+ # extract observation and thinking from tags in the action text
57
+ observation_match = re.search(r"<observation>(.*?)</observation>", action_text, re.DOTALL)
58
+ observation_content = observation_match.group(1).strip() if observation_match else None
59
+
60
+ # Extract text between thinking tags if present
61
+ thinking_match = re.search(r"<thinking>(.*?)</thinking>", action_text, re.DOTALL)
62
+ thinking_content = thinking_match.group(1).strip() if thinking_match else None
63
+
64
+ if observation_content and thinking_content:
65
+ action_text = f"**OBSERVATION**\n{observation_content}\n\n**THINKING**\n{thinking_content}"
66
+
67
+ # Create a white panel (same height as the observation image)
68
+ panel = Image.new("RGB", (white_panel_width, obs_img.height), "white")
69
+ draw = ImageDraw.Draw(panel)
70
+ font = ImageFont.load_default()
71
+
72
+ # Wrap the action text if it is too long
73
+ max_chars_per_line = 40 # Adjusted for larger font size
74
+ wrapped_text = textwrap.fill(action_text, width=max_chars_per_line)
75
+
76
+ # Calculate text block size and center it on the panel
77
+ try:
78
+ # Use multiline_textbbox if available (returns bounding box tuple)
79
+ bbox = draw.multiline_textbbox((0, 0), wrapped_text, font=font)
80
+ text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
81
+ except AttributeError:
82
+ # Fallback for older Pillow versions: compute size for each line
83
+ lines = wrapped_text.splitlines() or [wrapped_text]
84
+ line_sizes = [draw.textsize(line, font=font) for line in lines]
85
+ text_width = max(width for width, _ in line_sizes)
86
+ text_height = sum(height for _, height in line_sizes)
87
+ text_x = (white_panel_width - text_width) // 2
88
+ text_y = (obs_img.height - text_height) // 2
89
+ draw.multiline_text((text_x, text_y), wrapped_text, fill="black", font=font, align="center")
90
+
91
+ # Create the combined frame by concatenating the observation image and the panel
92
+ total_width = obs_img.width + white_panel_width
93
+ combined_frame = Image.new("RGB", (total_width, obs_img.height))
94
+ combined_frame.paste(obs_img, (0, 0))
95
+ combined_frame.paste(panel, (obs_img.width, 0))
96
+ frames.append(combined_frame)
97
+
98
+ # Skip the Action record since it has been processed with this Observation
99
+ if i + 1 < len(history) and isinstance(history[i + 1], Action):
100
+ i += 2
101
+ else:
102
+ i += 1
103
+ else:
104
+ i += 1
105
+
106
+ if frames:
107
+ frames[0].save(output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0)
108
+ else:
109
+ raise ValueError("No frames were generated from the Run object's history.")
110
+
111
+
112
+ # Example usage:
113
+ if __name__ == "__main__":
114
+ from proxy_lite.recorder import Run
115
+
116
+ dummy_run = Run.load("0abdb4cb-f289-48b0-ba13-35ed1210f7c1")
117
+
118
+ num_steps = int(len(dummy_run.history) / 2)
119
+ print(f"Number of steps: {num_steps}")
120
+ output_gif_path = "trajectory.gif"
121
+ create_run_gif(dummy_run, output_gif_path, duration=1000)
122
+ print(f"Trajectory GIF saved to {output_gif_path}")
src/proxy_lite/recorder.py CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
3
  import datetime
4
  import json
5
  import os
6
- import sys
7
  import uuid
8
  from pathlib import Path
9
  from typing import Any, Optional, Self
@@ -39,6 +38,11 @@ class Run(BaseModel):
39
  created_at=str(datetime.datetime.now(datetime.UTC)),
40
  )
41
 
 
 
 
 
 
42
  @property
43
  def observations(self) -> list[Observation]:
44
  return [h for h in self.history if isinstance(h, Observation)]
@@ -80,7 +84,7 @@ class DataRecorder:
80
  self.local_folder = local_folder
81
 
82
  def initialise_run(self, task: str) -> Run:
83
- self.local_folder = Path(os.path.abspath(".")) / "local_trajectories"
84
  os.makedirs(self.local_folder, exist_ok=True)
85
  return Run.initialise(task)
86
 
 
3
  import datetime
4
  import json
5
  import os
 
6
  import uuid
7
  from pathlib import Path
8
  from typing import Any, Optional, Self
 
38
  created_at=str(datetime.datetime.now(datetime.UTC)),
39
  )
40
 
41
+ @classmethod
42
+ def load(cls, run_id: str) -> Self:
43
+ with open(Path(__file__).parent.parent.parent / "local_trajectories" / f"{run_id}.json", "r") as f:
44
+ return cls(**json.load(f))
45
+
46
  @property
47
  def observations(self) -> list[Observation]:
48
  return [h for h in self.history if isinstance(h, Observation)]
 
84
  self.local_folder = local_folder
85
 
86
  def initialise_run(self, task: str) -> Run:
87
+ self.local_folder = Path(__file__).parent.parent.parent / "local_trajectories"
88
  os.makedirs(self.local_folder, exist_ok=True)
89
  return Run.initialise(task)
90
 
src/proxy_lite/serializer.py CHANGED
@@ -25,7 +25,7 @@ class BaseSerializer(BaseModel, ABC):
25
  def serialize_tools(self, tools: list[Tool]) -> list[dict]: ...
26
 
27
 
28
- class OpenAISerializer(BaseSerializer):
29
  def serialize_messages(self, message_history: MessageHistory) -> list[dict]:
30
  return message_history.to_dict(exclude={"label"})
31
 
 
25
  def serialize_tools(self, tools: list[Tool]) -> list[dict]: ...
26
 
27
 
28
+ class OpenAICompatibleSerializer(BaseSerializer):
29
  def serialize_messages(self, message_history: MessageHistory) -> list[dict]:
30
  return message_history.to_dict(exclude={"label"})
31