Spaces:
Running
Running
import asyncio | |
import logging | |
from collections.abc import AsyncIterator | |
from contextlib import asynccontextmanager | |
from typing import Any, Literal, Self | |
from omegaconf import OmegaConf | |
from pydantic import BaseModel | |
from proxy_lite.environments import ( | |
Action, | |
BaseEnvironment, | |
EnvironmentConfigTypes, | |
Environments, | |
EventType, | |
Observation, | |
) | |
from proxy_lite.logger import create_logger | |
from proxy_lite.recorder import DataRecorder, Run | |
from proxy_lite.solvers import ( | |
BaseSolver, | |
SolverConfigTypes, | |
Solvers, | |
) | |
async def async_timeout(timeout: float, task_name: str = "timeout"): | |
try: | |
async with asyncio.TaskGroup() as tg: | |
async def timeout_task(): | |
await asyncio.sleep(timeout) | |
raise TimeoutError( | |
f"Operation {task_name} timed out after {timeout} seconds", | |
) | |
# Create the timeout task | |
timeout_handle = tg.create_task(timeout_task()) | |
try: | |
yield | |
finally: | |
timeout_handle.cancel() | |
except* asyncio.TimeoutError as eg: | |
for e in eg.exceptions: | |
raise e | |
except* Exception as eg: | |
for e in eg.exceptions: | |
raise e | |
class RunnerConfig(BaseModel): | |
environment: EnvironmentConfigTypes | |
solver: SolverConfigTypes | |
save_every_step: bool = True | |
max_steps: int = 50 | |
action_timeout: float = 600.0 | |
environment_timeout: float = 300.0 | |
task_timeout: float = 1800.0 | |
logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO" | |
detailed_logger_name: bool = False | |
def from_dict(cls, config_dict: dict) -> Self: | |
conf = OmegaConf.create(config_dict) | |
config_dict = OmegaConf.to_container(conf, resolve=True) | |
return cls(**config_dict) | |
def from_yaml(cls, yaml_path: str) -> Self: | |
conf = OmegaConf.load(yaml_path) | |
config_dict = OmegaConf.to_container(conf, resolve=True) | |
return cls(**config_dict) | |
class Runner(BaseModel): | |
config: RunnerConfig | |
recorder: DataRecorder | None = None | |
environment: type[BaseEnvironment] | None = None | |
solver: type[BaseSolver] | None = None | |
logger: logging.Logger | None = None | |
_run: Run | None = None | |
class Config: | |
arbitrary_types_allowed = True | |
def model_post_init(self, __context: Any) -> None: | |
super().model_post_init(__context) | |
self.environment = Environments.get(self.config.environment.name) | |
self.solver = Solvers.get(self.config.solver.name) | |
self.recorder = DataRecorder() | |
self.logger = create_logger( | |
name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])", | |
level=self.config.logger_level, | |
detailed_name=self.config.detailed_logger_name, | |
) | |
async def run_generator(self, task: str) -> AsyncIterator[Run]: | |
async with ( | |
async_timeout(self.config.task_timeout, "Task"), | |
): | |
if self.config.logger_level is not None: | |
self.logger.setLevel(self.config.logger_level) | |
run = self.recorder.initialise_run(task) | |
run.environment = self.config.environment | |
run.solver = self.config.solver | |
self.logger.debug(f"Run intialised: {run.run_id}") | |
event_queue = asyncio.Queue() | |
async with ( | |
self.environment( | |
config=self.config.environment, | |
logger=self.logger, | |
) as environment, | |
self.solver(config=self.config.solver, logger=self.logger) as solver, | |
): | |
run.env_info = await environment.get_info() | |
await solver.initialise( | |
task, | |
environment.tools, | |
environment.info_for_user, | |
) | |
self.logger.debug("Solver initialised.") | |
run.solver_history = solver.history | |
observation: Observation = await environment.initialise() | |
await event_queue.put(observation) | |
self.logger.debug("Environment initialised.") | |
step_count = 0 | |
while step_count < self.config.max_steps: | |
event = await event_queue.get() | |
self.logger.debug(f"π€ [bold purple]Processing event:[/] {event.type}") | |
match event.type: | |
case EventType.OBSERVATION: | |
observation: Observation = event | |
run.record( | |
observation=observation, | |
solver_history=solver.history, | |
) | |
async with async_timeout( | |
self.config.action_timeout, | |
"Action decision", | |
): | |
action: Action = await solver.act(observation) | |
await event_queue.put(action) | |
case EventType.ACTION: | |
action: Action = event | |
self.logger.debug(f"Tool calls: {action.tool_calls}") | |
run.record(action=action, solver_history=solver.history) | |
run.complete = await solver.is_complete(observation) | |
if self.config.save_every_step: | |
await self.recorder.save(run) | |
if run.complete: | |
run.result = action.text | |
self.logger.info(f"π€ [bold purple]Task complete.[/] β¨ \n{run.result}") | |
break | |
self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds") | |
async with async_timeout( | |
self.config.environment_timeout, | |
"Environment response", | |
): | |
observation: Observation = await environment.execute_action(action) | |
step_count += 1 | |
await event_queue.put(observation) | |
yield run | |
if not run.complete: | |
self.logger.warning("π€ [bold purple]Ran out of steps!") | |
await self.recorder.terminate(run, save=True) | |
yield run | |
async def run(self, task: str) -> Run: | |
async for run in self.run_generator(task): | |
self._run = run | |
return run | |
def run_concurrent(self, tasks: list[str]) -> list[Run]: | |
async def gather_runs(): | |
return await asyncio.gather( | |
*[self.run(task) for task in tasks], | |
return_exceptions=True, | |
) | |
return asyncio.run(gather_runs()) | |
def complete(self) -> bool: | |
if self._run is None: | |
raise RuntimeError("Run not initialised") | |
return self._run.complete | |
def run_id(self) -> str: | |
if self._run is None: | |
raise RuntimeError("Run not initialised") | |
return self._run.run_id | |
def run_result(self) -> str: | |
if self._run is None: | |
raise RuntimeError("Run not initialised") | |
return self._run.result | |
if __name__ == "__main__": | |
from proxy_lite.logger import logger | |
config = RunnerConfig.from_dict( | |
{ | |
"environment": { | |
"name": "webbrowser", | |
"homepage": "https://www.google.com", | |
"viewport_width": 1280, | |
"viewport_height": 1920, | |
"screenshot_delay": 1, | |
"headless": False, | |
}, | |
"solver": { | |
"name": "simple", | |
"agent": { | |
"name": "proxy_lite", | |
"client": { | |
"name": "convergence", | |
"model_id": "convergence-ai/proxy-lite", | |
"api_base": "https://convergence-ai-demo-api.hf.space/v1", | |
}, | |
}, | |
}, | |
"max_steps": 150, | |
"action_timeout": 1800, | |
"environment_timeout": 1800, | |
"task_timeout": 18000, | |
"logger_level": "DEBUG", | |
}, | |
) | |
logger.info(f"π€ [bold purple]Config:[/] {config}") | |
runner = Runner(config=config) | |
result = asyncio.run(runner.run("Tell me the tesla stock price.")) | |
print(runner.run_result) | |
print(runner.complete) | |