File size: 3,350 Bytes
6a0e448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import logging
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Optional, Self, Type, cast

from pydantic import BaseModel, Field

from proxy_lite.environments.environment_base import Action, Observation
from proxy_lite.tools import Tool


class BaseSolverConfig(BaseModel):
    pass


class BaseSolver(BaseModel, ABC):
    task: Optional[str] = None
    env_tools: list[Tool] = Field(default_factory=list)
    config: BaseSolverConfig
    logger: logging.Logger | None = None

    class Config:
        arbitrary_types_allowed = True

    async def __aenter__(self) -> Self:
        return self

    async def __aexit__(self, exc_type, exc_value, traceback) -> None:
        pass

    @cached_property
    @abstractmethod
    def tools(self) -> list[Tool]: ...

    @abstractmethod
    async def initialise(
        self,
        task: str,
        env_tools: list[Tool],
        env_info: str,
    ) -> None:
        """
        Initialise the solution with the given task.
        """
        ...

    @abstractmethod
    async def act(self, observation: Observation) -> Action:
        """
        Return an action for interacting with the environment.
        """
        ...

    async def is_complete(self, observation: Observation) -> bool:
        """
        Return a boolean indicating if the task is complete.
        """
        return observation.terminated


class Solvers:
    _solver_registry: dict[str, type[BaseSolver]] = {}
    _solver_config_registry: dict[str, type[BaseSolverConfig]] = {}

    @classmethod
    def register_solver(cls, name: str):
        """
        Decorator to register a Solver class under a given name.

        Example:
            @Solvers.register_solver("my_solver")
            class MySolver(BaseSolver):
                ...
        """

        def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]:
            cls._solver_registry[name] = solver_cls
            return solver_cls

        return decorator

    @classmethod
    def register_solver_config(cls, name: str):
        """
        Decorator to register a Solver configuration class under a given name.

        Example:
            @Solvers.register_solver_config("my_solver")
            class MySolverConfig(BaseSolverConfig):
                ...
        """

        def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]:
            cls._solver_config_registry[name] = config_cls
            return config_cls

        return decorator

    @classmethod
    def get(cls, name: str) -> type[BaseSolver]:
        """
        Retrieve a registered Solver class by its name.

        Raises:
            ValueError: If no such solver is found.
        """
        try:
            return cast(Type[BaseSolver], cls._solver_registry[name])
        except KeyError:
            raise ValueError(f"Solver '{name}' not found.")

    @classmethod
    def get_config(cls, name: str) -> type[BaseSolverConfig]:
        """
        Retrieve a registered Solver configuration class by its name.

        Raises:
            ValueError: If no such config is found.
        """
        try:
            return cast(Type[BaseSolverConfig], cls._solver_config_registry[name])
        except KeyError:
            raise ValueError(f"Solver config for '{name}' not found.")