Trisha Tomy
init
6a0e448
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Literal, Self
from omegaconf import OmegaConf
from pydantic import BaseModel
from proxy_lite.environments import (
Action,
BaseEnvironment,
EnvironmentConfigTypes,
Environments,
EventType,
Observation,
)
from proxy_lite.logger import create_logger
from proxy_lite.recorder import DataRecorder, Run
from proxy_lite.solvers import (
BaseSolver,
SolverConfigTypes,
Solvers,
)
@asynccontextmanager
async def async_timeout(timeout: float, task_name: str = "timeout"):
try:
async with asyncio.TaskGroup() as tg:
async def timeout_task():
await asyncio.sleep(timeout)
raise TimeoutError(
f"Operation {task_name} timed out after {timeout} seconds",
)
# Create the timeout task
timeout_handle = tg.create_task(timeout_task())
try:
yield
finally:
timeout_handle.cancel()
except* asyncio.TimeoutError as eg:
for e in eg.exceptions:
raise e
except* Exception as eg:
for e in eg.exceptions:
raise e
class RunnerConfig(BaseModel):
environment: EnvironmentConfigTypes
solver: SolverConfigTypes
save_every_step: bool = True
max_steps: int = 50
action_timeout: float = 600.0
environment_timeout: float = 300.0
task_timeout: float = 1800.0
logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
detailed_logger_name: bool = False
@classmethod
def from_dict(cls, config_dict: dict) -> Self:
conf = OmegaConf.create(config_dict)
config_dict = OmegaConf.to_container(conf, resolve=True)
return cls(**config_dict)
@classmethod
def from_yaml(cls, yaml_path: str) -> Self:
conf = OmegaConf.load(yaml_path)
config_dict = OmegaConf.to_container(conf, resolve=True)
return cls(**config_dict)
class Runner(BaseModel):
config: RunnerConfig
recorder: DataRecorder | None = None
environment: type[BaseEnvironment] | None = None
solver: type[BaseSolver] | None = None
logger: logging.Logger | None = None
_run: Run | None = None
class Config:
arbitrary_types_allowed = True
def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
self.environment = Environments.get(self.config.environment.name)
self.solver = Solvers.get(self.config.solver.name)
self.recorder = DataRecorder()
self.logger = create_logger(
name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])",
level=self.config.logger_level,
detailed_name=self.config.detailed_logger_name,
)
async def run_generator(self, task: str) -> AsyncIterator[Run]:
async with (
async_timeout(self.config.task_timeout, "Task"),
):
if self.config.logger_level is not None:
self.logger.setLevel(self.config.logger_level)
run = self.recorder.initialise_run(task)
run.environment = self.config.environment
run.solver = self.config.solver
self.logger.debug(f"Run intialised: {run.run_id}")
event_queue = asyncio.Queue()
async with (
self.environment(
config=self.config.environment,
logger=self.logger,
) as environment,
self.solver(config=self.config.solver, logger=self.logger) as solver,
):
run.env_info = await environment.get_info()
await solver.initialise(
task,
environment.tools,
environment.info_for_user,
)
self.logger.debug("Solver initialised.")
run.solver_history = solver.history
observation: Observation = await environment.initialise()
await event_queue.put(observation)
self.logger.debug("Environment initialised.")
step_count = 0
while step_count < self.config.max_steps:
event = await event_queue.get()
self.logger.debug(f"πŸ€– [bold purple]Processing event:[/] {event.type}")
match event.type:
case EventType.OBSERVATION:
observation: Observation = event
run.record(
observation=observation,
solver_history=solver.history,
)
async with async_timeout(
self.config.action_timeout,
"Action decision",
):
action: Action = await solver.act(observation)
await event_queue.put(action)
case EventType.ACTION:
action: Action = event
self.logger.debug(f"Tool calls: {action.tool_calls}")
run.record(action=action, solver_history=solver.history)
run.complete = await solver.is_complete(observation)
if self.config.save_every_step:
await self.recorder.save(run)
if run.complete:
run.result = action.text
self.logger.info(f"πŸ€– [bold purple]Task complete.[/] ✨ \n{run.result}")
break
self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds")
async with async_timeout(
self.config.environment_timeout,
"Environment response",
):
observation: Observation = await environment.execute_action(action)
step_count += 1
await event_queue.put(observation)
yield run
if not run.complete:
self.logger.warning("πŸ€– [bold purple]Ran out of steps!")
await self.recorder.terminate(run, save=True)
yield run
async def run(self, task: str) -> Run:
async for run in self.run_generator(task):
self._run = run
return run
def run_concurrent(self, tasks: list[str]) -> list[Run]:
async def gather_runs():
return await asyncio.gather(
*[self.run(task) for task in tasks],
return_exceptions=True,
)
return asyncio.run(gather_runs())
@property
def complete(self) -> bool:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.complete
@property
def run_id(self) -> str:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.run_id
@property
def run_result(self) -> str:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.result
if __name__ == "__main__":
from proxy_lite.logger import logger
config = RunnerConfig.from_dict(
{
"environment": {
"name": "webbrowser",
"homepage": "https://www.google.com",
"viewport_width": 1280,
"viewport_height": 1920,
"screenshot_delay": 1,
"headless": False,
},
"solver": {
"name": "simple",
"agent": {
"name": "proxy_lite",
"client": {
"name": "convergence",
"model_id": "convergence-ai/proxy-lite",
"api_base": "https://convergence-ai-demo-api.hf.space/v1",
},
},
},
"max_steps": 150,
"action_timeout": 1800,
"environment_timeout": 1800,
"task_timeout": 18000,
"logger_level": "DEBUG",
},
)
logger.info(f"πŸ€– [bold purple]Config:[/] {config}")
runner = Runner(config=config)
result = asyncio.run(runner.run("Tell me the tesla stock price."))
print(runner.run_result)
print(runner.complete)