Spaces:
Running
Running
File size: 3,350 Bytes
6a0e448 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import logging
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Optional, Self, Type, cast
from pydantic import BaseModel, Field
from proxy_lite.environments.environment_base import Action, Observation
from proxy_lite.tools import Tool
class BaseSolverConfig(BaseModel):
pass
class BaseSolver(BaseModel, ABC):
task: Optional[str] = None
env_tools: list[Tool] = Field(default_factory=list)
config: BaseSolverConfig
logger: logging.Logger | None = None
class Config:
arbitrary_types_allowed = True
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass
@cached_property
@abstractmethod
def tools(self) -> list[Tool]: ...
@abstractmethod
async def initialise(
self,
task: str,
env_tools: list[Tool],
env_info: str,
) -> None:
"""
Initialise the solution with the given task.
"""
...
@abstractmethod
async def act(self, observation: Observation) -> Action:
"""
Return an action for interacting with the environment.
"""
...
async def is_complete(self, observation: Observation) -> bool:
"""
Return a boolean indicating if the task is complete.
"""
return observation.terminated
class Solvers:
_solver_registry: dict[str, type[BaseSolver]] = {}
_solver_config_registry: dict[str, type[BaseSolverConfig]] = {}
@classmethod
def register_solver(cls, name: str):
"""
Decorator to register a Solver class under a given name.
Example:
@Solvers.register_solver("my_solver")
class MySolver(BaseSolver):
...
"""
def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]:
cls._solver_registry[name] = solver_cls
return solver_cls
return decorator
@classmethod
def register_solver_config(cls, name: str):
"""
Decorator to register a Solver configuration class under a given name.
Example:
@Solvers.register_solver_config("my_solver")
class MySolverConfig(BaseSolverConfig):
...
"""
def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]:
cls._solver_config_registry[name] = config_cls
return config_cls
return decorator
@classmethod
def get(cls, name: str) -> type[BaseSolver]:
"""
Retrieve a registered Solver class by its name.
Raises:
ValueError: If no such solver is found.
"""
try:
return cast(Type[BaseSolver], cls._solver_registry[name])
except KeyError:
raise ValueError(f"Solver '{name}' not found.")
@classmethod
def get_config(cls, name: str) -> type[BaseSolverConfig]:
"""
Retrieve a registered Solver configuration class by its name.
Raises:
ValueError: If no such config is found.
"""
try:
return cast(Type[BaseSolverConfig], cls._solver_config_registry[name])
except KeyError:
raise ValueError(f"Solver config for '{name}' not found.")
|