XanderJC's picture
setup
f0f6e5c
from __future__ import annotations
import base64
from collections.abc import Iterator
from enum import Enum
from typing import Any, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, TypeAdapter, field_validator
class MessageLabel(str, Enum):
SYSTEM = "system"
USER_INPUT = "user_input"
SCREENSHOT = "screenshot"
AGENT_MODEL_RESPONSE = "agent_model_response"
MAX_MESSAGES_FOR_CONTEXT_WINDOW = {
MessageLabel.SCREENSHOT: 1,
}
class MessageContent(BaseModel):
pass
class Text(MessageContent):
type: Literal["text"] = Field(default="text", init=False)
text: str
class ImageUrl(BaseModel):
url: str
class Image(MessageContent):
type: Literal["image_url"] = Field(default="image_url", init=False)
image_url: ImageUrl
class Message(BaseModel):
label: Optional[MessageLabel] = None
content: list[Union[Text, Image]] = Field(default_factory=list)
class Config:
use_enum_values = True
@property
def images(self) -> list[Image]:
return [content for content in self.content if isinstance(content, Image)]
@property
def texts(self) -> list[Text]:
return [content for content in self.content if isinstance(content, Text)]
@property
def first_image(self) -> Optional[Image]:
return self.images[0] if self.images else None
@property
def first_text(self) -> Optional[Text]:
return self.texts[0] if self.texts else None
def __len__(self):
return len(self.content)
@classmethod
def from_media(
cls,
text: Optional[str] = None,
image: Optional[bytes | str] = None,
is_base64: bool = False,
) -> Message:
if text is not None:
text = Text(text=text)
if image is not None:
base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8")
data_url = f"data:image/jpeg;base64,{base64_image}"
image = Image(image_url=ImageUrl(url=data_url))
content = [text, image] if text is not None else [image]
else:
content = [text]
return cls(content=content)
class SystemMessage(Message):
role: Literal["system"] = Field(default="system", init=False)
class UserMessage(Message):
role: Literal["user"] = Field(default="user", init=False)
class ToolCall(BaseModel):
id: str
type: str
function: dict[str, Any]
class AssistantMessage(Message):
role: Literal["assistant"] = Field(default="assistant", init=False)
tool_calls: list[ToolCall] = Field(default_factory=list)
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
if not self.tool_calls:
data.pop("tool_calls")
return data
@field_validator("tool_calls", mode="before")
@classmethod
def ensure_list(cls, v):
return [] if v is None else v
class ToolMessage(Message):
role: Literal["tool"] = Field(default="tool", init=False)
tool_call_id: str
MessageTypes = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
MessageAdapter = TypeAdapter(MessageTypes)
class MessageHistory(BaseModel):
messages: list[MessageTypes] = Field(default_factory=list)
def append(self, message: MessageTypes, label: Optional[str] = None):
if label is not None:
message.label = label
self.messages.append(message)
def pop(self) -> MessageTypes:
return self.messages.pop()
def extend(self, history: MessageHistory):
self.messages.extend(history.messages)
def __reversed__(self):
return MessageHistory(messages=self.messages[::-1])
def __getitem__(self, index):
return self.messages[index]
def __len__(self):
return len(self.messages)
def __iter__(self) -> Iterator[MessageTypes]:
return iter(self.messages)
def to_dict(self, exclude: Set[str] | None = None) -> list[dict]:
exclude = exclude or set()
return [message.model_dump(exclude=exclude) for message in self.messages]
def history_view(
self,
limits: dict = MAX_MESSAGES_FOR_CONTEXT_WINDOW,
) -> MessageHistory:
"""Context window management.
Filters messages in reverse order, retaining a limited number of recent screenshots and prompts.
"""
label_counts = {label: 0 for label in limits}
filtered_messages = []
for message in reversed(self.messages):
if message.label in limits:
maximum_count = limits[message.label]
if label_counts[message.label] < maximum_count:
filtered_messages.append(message)
label_counts[message.label] += 1
else:
filtered_messages.append(message)
return MessageHistory(messages=reversed(filtered_messages))
def __add__(self, other: MessageHistory) -> MessageHistory:
new_history = MessageHistory()
new_history.extend(self)
new_history.extend(other)
return new_history
def __iadd__(self, other: MessageHistory) -> MessageHistory:
self.extend(other)
return self