XanderJC's picture
setup
f0f6e5c
import json
import logging
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from functools import cached_property
from typing import Any, Optional, Type, cast
from pydantic import BaseModel, Field
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
from proxy_lite.client import BaseClient, ClientConfigTypes, OpenAIClientConfig
from proxy_lite.history import (
AssistantMessage,
MessageHistory,
MessageLabel,
SystemMessage,
Text,
ToolCall,
ToolMessage,
UserMessage,
)
from proxy_lite.logger import logger
from proxy_lite.tools import Tool
# if TYPE_CHECKING:
# from proxy_lite.tools import Tool
class BaseAgentConfig(BaseModel):
client: ClientConfigTypes = Field(default_factory=OpenAIClientConfig)
history_messages_limit: dict[MessageLabel, int] = Field(default_factory=lambda: dict())
history_messages_include: Optional[dict[MessageLabel, int]] = Field(
default=None,
description="If set, overrides history_messages_limit by setting all message types to 0 except those specified",
)
def model_post_init(self, __context: Any) -> None:
if self.history_messages_include is not None:
self.history_messages_limit = {label: 0 for label in MessageLabel}
self.history_messages_limit.update(self.history_messages_include)
class BaseAgent(BaseModel, ABC):
config: BaseAgentConfig
temperature: float = Field(default=0.7, ge=0, le=2)
history: MessageHistory = Field(default_factory=MessageHistory)
client: Optional[BaseClient] = None
env_tools: list[Tool] = Field(default_factory=list)
task: Optional[str] = Field(default=None)
seed: Optional[int] = Field(default=None)
class Config:
arbitrary_types_allowed = True
def __init__(self, **data) -> None:
super().__init__(**data)
self._exit_stack = AsyncExitStack()
self._tools_init_task = None
def model_post_init(self, __context: Any) -> None:
super().model_post_init(__context)
self.client = BaseClient.create(self.config.client)
@property
@abstractmethod
def system_prompt(self) -> str: ...
@cached_property
@abstractmethod
def tools(self) -> list[Tool]: ...
@cached_property
def tool_descriptions(self) -> str:
tool_descriptions = []
for tool in self.tools:
func_descriptions = "\n".join("- {name}: {description}".format(**schema) for schema in tool.schema)
tool_title = f"{tool.__class__.__name__}:\n" if len(self.tools) > 1 else ""
tool_descriptions.append(f"{tool_title}{func_descriptions}")
return "\n\n".join(tool_descriptions)
async def get_history_view(self) -> MessageHistory:
return MessageHistory(
messages=[SystemMessage(content=[Text(text=self.system_prompt)])],
) + self.history.history_view(
limits=self.config.history_messages_limit,
)
@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
reraise=True,
before_sleep=before_sleep_log(logger, logging.ERROR),
)
async def generate_output(
self,
use_tool: bool = False,
response_format: Optional[type[BaseModel]] = None,
append_assistant_message: bool = True,
) -> AssistantMessage:
messages: MessageHistory = await self.get_history_view()
response_content = (
await self.client.create_completion(
messages=messages,
temperature=self.temperature,
seed=self.seed,
response_format=response_format,
tools=self.tools if use_tool else None,
)
).model_dump()
response_content = response_content["choices"][0]["message"]
assistant_message = AssistantMessage(
role=response_content["role"],
content=[Text(text=response_content["content"])] if response_content["content"] else [],
tool_calls=response_content["tool_calls"],
)
if append_assistant_message:
self.history.append(message=assistant_message, label=self.message_label)
return assistant_message
def receive_user_message(
self,
text: Optional[str] = None,
image: list[bytes] = None,
label: MessageLabel = None,
is_base64: bool = False,
) -> None:
message = UserMessage.from_media(
text=text,
image=image,
is_base64=is_base64,
)
self.history.append(message=message, label=label)
def receive_system_message(
self,
text: Optional[str] = None,
label: MessageLabel = None,
) -> None:
message = SystemMessage.from_media(text=text)
self.history.append(message=message, label=label)
def receive_assistant_message(
self,
content: Optional[str] = None,
tool_calls: Optional[list[ToolCall]] = None,
label: MessageLabel = None,
) -> None:
message = AssistantMessage(
content=[Text(text=content)] if content else [],
tool_calls=tool_calls,
)
self.history.append(message=message, label=label)
async def use_tool(self, tool_call: ToolCall):
function = tool_call.function
for tool in self.tools:
if hasattr(tool, function["name"]):
return await getattr(tool, function["name"])(
**json.loads(function["arguments"]),
)
msg = f'No tool function with name "{function["name"]}"'
raise ValueError(msg)
async def receive_tool_message(
self,
text: str,
tool_id: str,
label: MessageLabel = None,
) -> None:
self.history.append(
message=ToolMessage(content=[Text(text=text)], tool_call_id=tool_id),
label=label,
)
class Agents:
_agent_registry: dict[str, type[BaseAgent]] = {}
_agent_config_registry: dict[str, type[BaseAgentConfig]] = {}
@classmethod
def register_agent(cls, name: str):
"""
Decorator to register an Agent class under a given name.
Example:
@Agents.register_agent("browser")
class BrowserAgent(BaseAgent):
...
"""
def decorator(agent_cls: type[BaseAgent]) -> type[BaseAgent]:
cls._agent_registry[name] = agent_cls
return agent_cls
return decorator
@classmethod
def register_agent_config(cls, name: str):
"""
Decorator to register a configuration class under a given name.
Example:
@Agents.register_agent_config("browser")
class BrowserAgentConfig(BaseAgentConfig):
...
"""
def decorator(config_cls: type[BaseAgentConfig]) -> type[BaseAgentConfig]:
cls._agent_config_registry[name] = config_cls
return config_cls
return decorator
@classmethod
def get(cls, name: str) -> type[BaseAgent]:
"""
Retrieve a registered Agent class by its name.
Raises:
ValueError: If no such agent is found.
"""
try:
return cast(Type[BaseAgent], cls._agent_registry[name])
except KeyError:
raise ValueError(f"Agent '{name}' not found.")
@classmethod
def get_config(cls, name: str) -> type[BaseAgentConfig]:
"""
Retrieve a registered Agent configuration class by its name.
Raises:
ValueError: If no such config is found.
"""
try:
return cast(type[BaseAgentConfig], cls._agent_config_registry[name])
except KeyError:
raise ValueError(f"Agent config for '{name}' not found.")