Spaces:
Running
Running
| import json | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| from functools import cached_property | |
| from typing import Any, Literal, Optional, Self | |
| from pydantic import BaseModel | |
| from proxy_lite.history import ToolCall | |
| from proxy_lite.tools import Tool, ToolExecutionResponse | |
| class EventType(str, Enum): | |
| OBSERVATION = "observation" | |
| ACTION = "action" | |
| MESSAGE = "message" | |
| class Event(BaseModel): | |
| type: EventType | |
| class State(BaseModel): | |
| text: Optional[str] = None | |
| image: Optional[str] = None # base64 encoded image | |
| html: Optional[str] = None | |
| tool_responses: Optional[list[ToolExecutionResponse]] = None | |
| class Observation(Event): | |
| type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION | |
| state: State | |
| terminated: bool | |
| reward: Optional[float] = None | |
| info: Optional[dict[str, Any]] = None | |
| class Action(Event): | |
| type: Literal[EventType.ACTION] = EventType.ACTION | |
| text: Optional[str] = None | |
| tool_calls: Optional[list[ToolCall]] = None | |
| info: Optional[dict[str, Any]] = None | |
| class BaseEnvironmentConfig(BaseModel): ... | |
| class BaseEnvironment(BaseModel, ABC): | |
| config: BaseEnvironmentConfig | |
| logger: logging.Logger | None = None | |
| class Config: | |
| arbitrary_types_allowed = True | |
| async def __aenter__(self) -> Self: | |
| return self | |
| async def __aexit__(self, exc_type, exc_value, traceback): | |
| pass | |
| def info_for_user(self) -> str: ... | |
| def tools(self) -> list[Tool]: ... | |
| async def initialise(self) -> Observation: ... | |
| async def execute_action(self, action: Action) -> Observation: ... | |
| async def observe(self) -> Observation: ... | |
| async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ... | |
| async def execute_tool(self, tool_call: ToolCall) -> None: | |
| function = tool_call.function | |
| for tool in self.tools: | |
| if hasattr(tool, function["name"]): | |
| arguments = json.loads(function["arguments"]) | |
| if type(arguments) == str: | |
| arguments = json.loads(arguments) | |
| return await getattr(tool, function["name"])( | |
| **arguments, | |
| ) | |
| msg = f'No tool function with name "{function["name"]}"' | |
| raise ValueError(msg) | |
| async def get_info(self) -> dict[str, Any]: | |
| return {} | |
| class Environments: | |
| _environment_registry: dict[str, type[BaseEnvironment]] = {} | |
| _environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {} | |
| def register_environment(cls, name: str): | |
| """ | |
| Decorator to register an Environment class under a given name. | |
| Example: | |
| @Environments.register_environment("my_environment") | |
| class MyEnvironment(BaseEnvironment): | |
| ... | |
| """ | |
| def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]: | |
| cls._environment_registry[name] = env_cls | |
| return env_cls | |
| return decorator | |
| def register_environment_config(cls, name: str): | |
| """ | |
| Decorator to register an Environment configuration class under a given name. | |
| Example: | |
| @Environments.register_environment_config("my_environment") | |
| class MyEnvironmentConfig(BaseEnvironmentConfig): | |
| ... | |
| """ | |
| def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]: | |
| cls._environment_config_registry[name] = config_cls | |
| return config_cls | |
| return decorator | |
| def get(cls, name: str) -> type[BaseEnvironment]: | |
| """ | |
| Retrieve a registered Environment class by its name. | |
| Raises: | |
| ValueError: If no such environment is found. | |
| """ | |
| try: | |
| return cls._environment_registry[name] | |
| except KeyError: | |
| raise ValueError(f"Environment '{name}' not found.") | |
| def get_config(cls, name: str) -> type[BaseEnvironmentConfig]: | |
| """ | |
| Retrieve a registered Environment configuration class by its name. | |
| Raises: | |
| ValueError: If no such configuration is found. | |
| """ | |
| try: | |
| return cls._environment_config_registry[name] | |
| except KeyError: | |
| raise ValueError(f"Environment config for '{name}' not found.") | |