import logging from abc import ABC, abstractmethod from functools import cached_property from typing import Optional, Self, Type, cast from pydantic import BaseModel, Field from proxy_lite.environments.environment_base import Action, Observation from proxy_lite.tools import Tool class BaseSolverConfig(BaseModel): pass class BaseSolver(BaseModel, ABC): task: Optional[str] = None env_tools: list[Tool] = Field(default_factory=list) config: BaseSolverConfig logger: logging.Logger | None = None class Config: arbitrary_types_allowed = True async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type, exc_value, traceback) -> None: pass @cached_property @abstractmethod def tools(self) -> list[Tool]: ... @abstractmethod async def initialise( self, task: str, env_tools: list[Tool], env_info: str, ) -> None: """ Initialise the solution with the given task. """ ... @abstractmethod async def act(self, observation: Observation) -> Action: """ Return an action for interacting with the environment. """ ... async def is_complete(self, observation: Observation) -> bool: """ Return a boolean indicating if the task is complete. """ return observation.terminated class Solvers: _solver_registry: dict[str, type[BaseSolver]] = {} _solver_config_registry: dict[str, type[BaseSolverConfig]] = {} @classmethod def register_solver(cls, name: str): """ Decorator to register a Solver class under a given name. Example: @Solvers.register_solver("my_solver") class MySolver(BaseSolver): ... """ def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]: cls._solver_registry[name] = solver_cls return solver_cls return decorator @classmethod def register_solver_config(cls, name: str): """ Decorator to register a Solver configuration class under a given name. Example: @Solvers.register_solver_config("my_solver") class MySolverConfig(BaseSolverConfig): ... """ def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]: cls._solver_config_registry[name] = config_cls return config_cls return decorator @classmethod def get(cls, name: str) -> type[BaseSolver]: """ Retrieve a registered Solver class by its name. Raises: ValueError: If no such solver is found. """ try: return cast(Type[BaseSolver], cls._solver_registry[name]) except KeyError: raise ValueError(f"Solver '{name}' not found.") @classmethod def get_config(cls, name: str) -> type[BaseSolverConfig]: """ Retrieve a registered Solver configuration class by its name. Raises: ValueError: If no such config is found. """ try: return cast(Type[BaseSolverConfig], cls._solver_config_registry[name]) except KeyError: raise ValueError(f"Solver config for '{name}' not found.")