from __future__ import annotations import base64 from collections.abc import Iterator from enum import Enum from typing import Any, Literal, Optional, Set, Union from pydantic import BaseModel, Field, TypeAdapter, field_validator class MessageLabel(str, Enum): SYSTEM = "system" USER_INPUT = "user_input" SCREENSHOT = "screenshot" AGENT_MODEL_RESPONSE = "agent_model_response" MAX_MESSAGES_FOR_CONTEXT_WINDOW = { MessageLabel.SCREENSHOT: 1, } class MessageContent(BaseModel): pass class Text(MessageContent): type: Literal["text"] = Field(default="text", init=False) text: str class ImageUrl(BaseModel): url: str class Image(MessageContent): type: Literal["image_url"] = Field(default="image_url", init=False) image_url: ImageUrl class Message(BaseModel): label: Optional[MessageLabel] = None content: list[Union[Text, Image]] = Field(default_factory=list) class Config: use_enum_values = True @property def images(self) -> list[Image]: return [content for content in self.content if isinstance(content, Image)] @property def texts(self) -> list[Text]: return [content for content in self.content if isinstance(content, Text)] @property def first_image(self) -> Optional[Image]: return self.images[0] if self.images else None @property def first_text(self) -> Optional[Text]: return self.texts[0] if self.texts else None def __len__(self): return len(self.content) @classmethod def from_media( cls, text: Optional[str] = None, image: Optional[bytes | str] = None, is_base64: bool = False, ) -> Message: if text is not None: text = Text(text=text) if image is not None: base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8") data_url = f"data:image/jpeg;base64,{base64_image}" image = Image(image_url=ImageUrl(url=data_url)) content = [text, image] if text is not None else [image] else: content = [text] return cls(content=content) class SystemMessage(Message): role: Literal["system"] = Field(default="system", init=False) class UserMessage(Message): role: Literal["user"] = Field(default="user", init=False) class ToolCall(BaseModel): id: str type: str function: dict[str, Any] class AssistantMessage(Message): role: Literal["assistant"] = Field(default="assistant", init=False) tool_calls: list[ToolCall] = Field(default_factory=list) def model_dump(self, **kwargs): data = super().model_dump(**kwargs) if not self.tool_calls: data.pop("tool_calls") return data @field_validator("tool_calls", mode="before") @classmethod def ensure_list(cls, v): return [] if v is None else v class ToolMessage(Message): role: Literal["tool"] = Field(default="tool", init=False) tool_call_id: str MessageTypes = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] MessageAdapter = TypeAdapter(MessageTypes) class MessageHistory(BaseModel): messages: list[MessageTypes] = Field(default_factory=list) def append(self, message: MessageTypes, label: Optional[str] = None): if label is not None: message.label = label self.messages.append(message) def pop(self) -> MessageTypes: return self.messages.pop() def extend(self, history: MessageHistory): self.messages.extend(history.messages) def __reversed__(self): return MessageHistory(messages=self.messages[::-1]) def __getitem__(self, index): return self.messages[index] def __len__(self): return len(self.messages) def __iter__(self) -> Iterator[MessageTypes]: return iter(self.messages) def to_dict(self, exclude: Set[str] | None = None) -> list[dict]: exclude = exclude or set() return [message.model_dump(exclude=exclude) for message in self.messages] def history_view( self, limits: dict = MAX_MESSAGES_FOR_CONTEXT_WINDOW, ) -> MessageHistory: """Context window management. Filters messages in reverse order, retaining a limited number of recent screenshots and prompts. """ label_counts = {label: 0 for label in limits} filtered_messages = [] for message in reversed(self.messages): if message.label in limits: maximum_count = limits[message.label] if label_counts[message.label] < maximum_count: filtered_messages.append(message) label_counts[message.label] += 1 else: filtered_messages.append(message) return MessageHistory(messages=reversed(filtered_messages)) def __add__(self, other: MessageHistory) -> MessageHistory: new_history = MessageHistory() new_history.extend(self) new_history.extend(other) return new_history def __iadd__(self, other: MessageHistory) -> MessageHistory: self.extend(other) return self