import json import logging from abc import ABC, abstractmethod from contextlib import AsyncExitStack from functools import cached_property from typing import Any, Optional, Type, cast from pydantic import BaseModel, Field from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential from proxy_lite.client import BaseClient, ClientConfigTypes, OpenAIClientConfig from proxy_lite.history import ( AssistantMessage, MessageHistory, MessageLabel, SystemMessage, Text, ToolCall, ToolMessage, UserMessage, ) from proxy_lite.logger import logger from proxy_lite.tools import Tool # if TYPE_CHECKING: # from proxy_lite.tools import Tool class BaseAgentConfig(BaseModel): client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig) history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict()) history_messages_include: Optional[dict[MessageLabel, int]] = Field( default=None, description="If set, overrides history_messages_limit by setting all message types to 0 except those specified", ) def model_post_init(self, __context: Any) -> None: if self.history_messages_include is not None: self.history_messages_limit = {label: 0 for label in MessageLabel} self.history_messages_limit.update(self.history_messages_include) class BaseAgent(BaseModel, ABC): config: BaseAgentConfig temperature: float = Field(default=0.7, ge=0, le=2) history: MessageHistory = Field(default_factory=MessageHistory) client: Optional[BaseClient] = None env_tools: list[Tool] = Field(default_factory=list) task: Optional[str] = Field(default=None) seed: Optional[int] = Field(default=None) class Config: arbitrary_types_allowed = True def __init__(self, **data) -> None: super().__init__(**data) self._exit_stack = AsyncExitStack() self._tools_init_task = None def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) self.client = BaseClient.create(self.config.client) @property @abstractmethod def system_prompt(self) -> str: ... @cached_property @abstractmethod def tools(self) -> list[Tool]: ... @cached_property def tool_descriptions(self) -> str: tool_descriptions = [] for tool in self.tools: func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema) tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else "" tool_descriptions.append(f"{tool_title}{func_descriptions}") return "\n\n".join(tool_descriptions) async def get_history_view(self) -> MessageHistory: return MessageHistory( messages=[SystemMessage(content=[Text(text=self.system_prompt)])], ) + self.history.history_view( limits=self.config.history_messages_limit, ) @retry( wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3), reraise=True, before_sleep=before_sleep_log(logger, logging.ERROR), ) async def generate_output( self, use_tool: bool = False, response_format: Optional[type[BaseModel]] = None, append_assistant_message: bool = True, ) -> AssistantMessage: messages: MessageHistory = await self.get_history_view() response_content = ( await self.client.create_completion( messages=messages, temperature=self.temperature, seed=self.seed, response_format=response_format, tools=self.tools if use_tool else None, ) ).model_dump() response_content = response_content["choices"][0]["message"] assistant_message = AssistantMessage( role=response_content["role"], content=[Text(text=response_content["content"])] if response_content["content"] else [], tool_calls=response_content["tool_calls"], ) if append_assistant_message: self.history.append(message=assistant_message, label=self.message_label) return assistant_message def receive_user_message( self, text: Optional[str] = None, image: list[bytes] = None, label: MessageLabel = None, is_base64: bool = False, ) -> None: message = UserMessage.from_media( text=text, image=image, is_base64=is_base64, ) self.history.append(message=message, label=label) def receive_system_message( self, text: Optional[str] = None, label: MessageLabel = None, ) -> None: message = SystemMessage.from_media(text=text) self.history.append(message=message, label=label) def receive_assistant_message( self, content: Optional[str] = None, tool_calls: Optional[list[ToolCall]] = None, label: MessageLabel = None, ) -> None: message = AssistantMessage( content=[Text(text=content)] if content else [], tool_calls=tool_calls, ) self.history.append(message=message, label=label) async def use_tool(self, tool_call: ToolCall): function = tool_call.function for tool in self.tools: if hasattr(tool, function["name"]): return await getattr(tool, function["name"])( **json.loads(function["arguments"]), ) msg = f'No tool function with name "{function["name"]}"' raise ValueError(msg) async def receive_tool_message( self, text: str, tool_id: str, label: MessageLabel = None, ) -> None: self.history.append( message=ToolMessage(content=[Text(text=text)], tool_call_id=tool_id), label=label, ) class Agents: _agent_registry: dict[str, type[BaseAgent]] = {} _agent_config_registry: dict[str, type[BaseAgentConfig]] = {} @classmethod def register_agent(cls, name: str): """ Decorator to register an Agent class under a given name. Example: @Agents.register_agent("browser") class BrowserAgent(BaseAgent): ... """ def decorator(agent_cls: type[BaseAgent]) -> type[BaseAgent]: cls._agent_registry[name] = agent_cls return agent_cls return decorator @classmethod def register_agent_config(cls, name: str): """ Decorator to register a configuration class under a given name. Example: @Agents.register_agent_config("browser") class BrowserAgentConfig(BaseAgentConfig): ... """ def decorator(config_cls: type[BaseAgentConfig]) -> type[BaseAgentConfig]: cls._agent_config_registry[name] = config_cls return config_cls return decorator @classmethod def get(cls, name: str) -> type[BaseAgent]: """ Retrieve a registered Agent class by its name. Raises: ValueError: If no such agent is found. """ try: return cast(Type[BaseAgent], cls._agent_registry[name]) except KeyError: raise ValueError(f"Agent '{name}' not found.") @classmethod def get_config(cls, name: str) -> type[BaseAgentConfig]: """ Retrieve a registered Agent configuration class by its name. Raises: ValueError: If no such config is found. """ try: return cast(type[BaseAgentConfig], cls._agent_config_registry[name]) except KeyError: raise ValueError(f"Agent config for '{name}' not found.")