Spaces:
Running
Running
File size: 8,828 Bytes
f0f6e5c fe73d4f b8f3e5d f0f6e5c 7263162 f0f6e5c 701b92f f0f6e5c 2cedb9d f0f6e5c 2cedb9d f0f6e5c 2cedb9d f0f6e5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Literal, Self
from omegaconf import OmegaConf
from pydantic import BaseModel
from proxy_lite.environments import (
Action,
BaseEnvironment,
EnvironmentConfigTypes,
Environments,
EventType,
Observation,
)
from proxy_lite.logger import create_logger
from proxy_lite.recorder import DataRecorder, Run
from proxy_lite.solvers import (
BaseSolver,
SolverConfigTypes,
Solvers,
)
@asynccontextmanager
async def async_timeout(timeout: float, task_name: str = "timeout"):
try:
async with asyncio.TaskGroup() as tg:
async def timeout_task():
await asyncio.sleep(timeout)
raise TimeoutError(
f"Operation {task_name} timed out after {timeout} seconds",
)
# Create the timeout task
timeout_handle = tg.create_task(timeout_task())
try:
yield
finally:
timeout_handle.cancel()
except* asyncio.TimeoutError as eg:
for e in eg.exceptions:
raise e
except* Exception as eg:
for e in eg.exceptions:
raise e
class RunnerConfig(BaseModel):
environment: EnvironmentConfigTypes
solver: SolverConfigTypes
save_every_step: bool = True
max_steps: int = 50
action_timeout: float = 600.0
environment_timeout: float = 300.0
task_timeout: float = 1800.0
logger_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO"
detailed_logger_name: bool = False
@classmethod
def from_dict(cls, config_dict: dict) -> Self:
conf = OmegaConf.create(config_dict)
config_dict = OmegaConf.to_container(conf, resolve=True)
return cls(**config_dict)
@classmethod
def from_yaml(cls, yaml_path: str) -> Self:
conf = OmegaConf.load(yaml_path)
config_dict = OmegaConf.to_container(conf, resolve=True)
return cls(**config_dict)
class Runner(BaseModel):
config: RunnerConfig
recorder: DataRecorder | None = None
environment: type[BaseEnvironment] | None = None
solver: type[BaseSolver] | None = None
logger: logging.Logger | None = None
_run: Run | None = None
class Config:
arbitrary_types_allowed = True
def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
self.environment = Environments.get(self.config.environment.name)
self.solver = Solvers.get(self.config.solver.name)
self.recorder = DataRecorder()
self.logger = create_logger(
name=f"([bold purple]{self.config.solver.name}[/]-[bold blue]{self.config.environment.name}[/])",
level=self.config.logger_level,
detailed_name=self.config.detailed_logger_name,
)
async def run_generator(self, task: str) -> AsyncIterator[Run]:
async with (
async_timeout(self.config.task_timeout, "Task"),
):
if self.config.logger_level is not None:
self.logger.setLevel(self.config.logger_level)
run = self.recorder.initialise_run(task)
run.environment = self.config.environment
run.solver = self.config.solver
self.logger.debug(f"Run intialised: {run.run_id}")
event_queue = asyncio.Queue()
async with (
self.environment(
config=self.config.environment,
logger=self.logger,
) as environment,
self.solver(config=self.config.solver, logger=self.logger) as solver,
):
run.env_info = await environment.get_info()
await solver.initialise(
task,
environment.tools,
environment.info_for_user,
)
self.logger.debug("Solver initialised.")
run.solver_history = solver.history
observation: Observation = await environment.initialise()
await event_queue.put(observation)
self.logger.debug("Environment initialised.")
step_count = 0
while step_count < self.config.max_steps:
event = await event_queue.get()
self.logger.debug(f"π€ [bold purple]Processing event:[/] {event.type}")
match event.type:
case EventType.OBSERVATION:
observation: Observation = event
run.record(
observation=observation,
solver_history=solver.history,
)
async with async_timeout(
self.config.action_timeout,
"Action decision",
):
action: Action = await solver.act(observation)
await event_queue.put(action)
case EventType.ACTION:
action: Action = event
self.logger.debug(f"Tool calls: {action.tool_calls}")
run.record(action=action, solver_history=solver.history)
run.complete = await solver.is_complete(observation)
if self.config.save_every_step:
await self.recorder.save(run)
if run.complete:
run.result = action.text
self.logger.info(f"π€ [bold purple]Task complete.[/] β¨ \n{run.result}")
break
self.logger.debug(f"DEBUG: Using environment_timeout: {self.config.environment_timeout} seconds")
async with async_timeout(
self.config.environment_timeout,
"Environment response",
):
observation: Observation = await environment.execute_action(action)
step_count += 1
await event_queue.put(observation)
yield run
if not run.complete:
self.logger.warning("π€ [bold purple]Ran out of steps!")
await self.recorder.terminate(run, save=True)
yield run
async def run(self, task: str) -> Run:
async for run in self.run_generator(task):
self._run = run
return run
def run_concurrent(self, tasks: list[str]) -> list[Run]:
async def gather_runs():
return await asyncio.gather(
*[self.run(task) for task in tasks],
return_exceptions=True,
)
return asyncio.run(gather_runs())
@property
def complete(self) -> bool:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.complete
@property
def run_id(self) -> str:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.run_id
@property
def run_result(self) -> str:
if self._run is None:
raise RuntimeError("Run not initialised")
return self._run.result
if __name__ == "__main__":
from proxy_lite.logger import logger
config = RunnerConfig.from_dict(
{
"environment": {
"name": "webbrowser",
"homepage": "https://www.google.com",
"viewport_width": 1280,
"viewport_height": 1920,
"screenshot_delay": 1,
"headless": False,
},
"solver": {
"name": "simple",
"agent": {
"name": "proxy_lite",
"client": {
"name": "convergence",
"model_id": "convergence-ai/proxy-lite",
"api_base": "https://convergence-ai-demo-api.hf.space/v1",
},
},
},
"max_steps": 150,
"action_timeout": 1800,
"environment_timeout": 1800,
"task_timeout": 18000,
"logger_level": "DEBUG",
},
)
logger.info(f"π€ [bold purple]Config:[/] {config}")
runner = Runner(config=config)
result = asyncio.run(runner.run("Tell me the tesla stock price."))
print(runner.run_result)
print(runner.complete)
|