Spaces:
Running
Running
File size: 4,587 Bytes
6a0e448 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import json
import logging
from abc import ABC, abstractmethod
from enum import Enum
from functools import cached_property
from typing import Any, Literal, Optional, Self
from pydantic import BaseModel
from proxy_lite.history import ToolCall
from proxy_lite.tools import Tool, ToolExecutionResponse
class EventType(str, Enum):
OBSERVATION = "observation"
ACTION = "action"
MESSAGE = "message"
class Event(BaseModel):
type: EventType
class State(BaseModel):
text: Optional[str] = None
image: Optional[str] = None # base64 encoded image
html: Optional[str] = None
tool_responses: Optional[list[ToolExecutionResponse]] = None
class Observation(Event):
type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION
state: State
terminated: bool
reward: Optional[float] = None
info: Optional[dict[str, Any]] = None
class Action(Event):
type: Literal[EventType.ACTION] = EventType.ACTION
text: Optional[str] = None
tool_calls: Optional[list[ToolCall]] = None
info: Optional[dict[str, Any]] = None
class BaseEnvironmentConfig(BaseModel): ...
class BaseEnvironment(BaseModel, ABC):
config: BaseEnvironmentConfig
logger: logging.Logger | None = None
class Config:
arbitrary_types_allowed = True
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
@property
@abstractmethod
def info_for_user(self) -> str: ...
@cached_property
@abstractmethod
def tools(self) -> list[Tool]: ...
@abstractmethod
async def initialise(self) -> Observation: ...
@abstractmethod
async def execute_action(self, action: Action) -> Observation: ...
@abstractmethod
async def observe(self) -> Observation: ...
@abstractmethod
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ...
async def execute_tool(self, tool_call: ToolCall) -> None:
function = tool_call.function
for tool in self.tools:
if hasattr(tool, function["name"]):
arguments = json.loads(function["arguments"])
if isinstance(arguments, str):
arguments = json.loads(arguments)
return await getattr(tool, function["name"])(
**arguments,
)
msg = f'No tool function with name "{function["name"]}"'
raise ValueError(msg)
async def get_info(self) -> dict[str, Any]:
return {}
class Environments:
_environment_registry: dict[str, type[BaseEnvironment]] = {}
_environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {}
@classmethod
def register_environment(cls, name: str):
"""
Decorator to register an Environment class under a given name.
Example:
@Environments.register_environment("my_environment")
class MyEnvironment(BaseEnvironment):
...
"""
def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]:
cls._environment_registry[name] = env_cls
return env_cls
return decorator
@classmethod
def register_environment_config(cls, name: str):
"""
Decorator to register an Environment configuration class under a given name.
Example:
@Environments.register_environment_config("my_environment")
class MyEnvironmentConfig(BaseEnvironmentConfig):
...
"""
def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]:
cls._environment_config_registry[name] = config_cls
return config_cls
return decorator
@classmethod
def get(cls, name: str) -> type[BaseEnvironment]:
"""
Retrieve a registered Environment class by its name.
Raises:
ValueError: If no such environment is found.
"""
try:
return cls._environment_registry[name]
except KeyError:
raise ValueError(f"Environment '{name}' not found.")
@classmethod
def get_config(cls, name: str) -> type[BaseEnvironmentConfig]:
"""
Retrieve a registered Environment configuration class by its name.
Raises:
ValueError: If no such configuration is found.
"""
try:
return cls._environment_config_registry[name]
except KeyError:
raise ValueError(f"Environment config for '{name}' not found.")
|