Spaces:
Running
Running
from __future__ import annotations | |
import base64 | |
from collections.abc import Iterator | |
from enum import Enum | |
from typing import Any, Literal, Optional, Set, Union | |
from pydantic import BaseModel, Field, TypeAdapter, field_validator | |
class MessageLabel(str, Enum): | |
SYSTEM = "system" | |
USER_INPUT = "user_input" | |
SCREENSHOT = "screenshot" | |
AGENT_MODEL_RESPONSE = "agent_model_response" | |
MAX_MESSAGES_FOR_CONTEXT_WINDOW = { | |
MessageLabel.SCREENSHOT: 1, | |
} | |
class MessageContent(BaseModel): | |
pass | |
class Text(MessageContent): | |
type: Literal["text"] = Field(default="text", init=False) | |
text: str | |
class ImageUrl(BaseModel): | |
url: str | |
class Image(MessageContent): | |
type: Literal["image_url"] = Field(default="image_url", init=False) | |
image_url: ImageUrl | |
class Message(BaseModel): | |
label: Optional[MessageLabel] = None | |
content: list[Union[Text, Image]] = Field(default_factory=list) | |
class Config: | |
use_enum_values = True | |
def images(self) -> list[Image]: | |
return [content for content in self.content if isinstance(content, Image)] | |
def texts(self) -> list[Text]: | |
return [content for content in self.content if isinstance(content, Text)] | |
def first_image(self) -> Optional[Image]: | |
return self.images[0] if self.images else None | |
def first_text(self) -> Optional[Text]: | |
return self.texts[0] if self.texts else None | |
def __len__(self): | |
return len(self.content) | |
def from_media( | |
cls, | |
text: Optional[str] = None, | |
image: Optional[bytes | str] = None, | |
is_base64: bool = False, | |
) -> Message: | |
if text is not None: | |
text = Text(text=text) | |
if image is not None: | |
base64_image = image if is_base64 else base64.b64encode(image).decode("utf-8") | |
data_url = f"data:image/jpeg;base64,{base64_image}" | |
image = Image(image_url=ImageUrl(url=data_url)) | |
content = [text, image] if text is not None else [image] | |
else: | |
content = [text] | |
return cls(content=content) | |
class SystemMessage(Message): | |
role: Literal["system"] = Field(default="system", init=False) | |
class UserMessage(Message): | |
role: Literal["user"] = Field(default="user", init=False) | |
class ToolCall(BaseModel): | |
id: str | |
type: str | |
function: dict[str, Any] | |
class AssistantMessage(Message): | |
role: Literal["assistant"] = Field(default="assistant", init=False) | |
tool_calls: list[ToolCall] = Field(default_factory=list) | |
def model_dump(self, **kwargs): | |
data = super().model_dump(**kwargs) | |
if not self.tool_calls: | |
data.pop("tool_calls") | |
return data | |
def ensure_list(cls, v): | |
return [] if v is None else v | |
class ToolMessage(Message): | |
role: Literal["tool"] = Field(default="tool", init=False) | |
tool_call_id: str | |
MessageTypes = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage] | |
MessageAdapter = TypeAdapter(MessageTypes) | |
class MessageHistory(BaseModel): | |
messages: list[MessageTypes] = Field(default_factory=list) | |
def append(self, message: MessageTypes, label: Optional[str] = None): | |
if label is not None: | |
message.label = label | |
self.messages.append(message) | |
def pop(self) -> MessageTypes: | |
return self.messages.pop() | |
def extend(self, history: MessageHistory): | |
self.messages.extend(history.messages) | |
def __reversed__(self): | |
return MessageHistory(messages=self.messages[::-1]) | |
def __getitem__(self, index): | |
return self.messages[index] | |
def __len__(self): | |
return len(self.messages) | |
def __iter__(self) -> Iterator[MessageTypes]: | |
return iter(self.messages) | |
def to_dict(self, exclude: Set[str] | None = None) -> list[dict]: | |
exclude = exclude or set() | |
return [message.model_dump(exclude=exclude) for message in self.messages] | |
def history_view( | |
self, | |
limits: dict = MAX_MESSAGES_FOR_CONTEXT_WINDOW, | |
) -> MessageHistory: | |
"""Context window management. | |
Filters messages in reverse order, retaining a limited number of recent screenshots and prompts. | |
""" | |
label_counts = {label: 0 for label in limits} | |
filtered_messages = [] | |
for message in reversed(self.messages): | |
if message.label in limits: | |
maximum_count = limits[message.label] | |
if label_counts[message.label] < maximum_count: | |
filtered_messages.append(message) | |
label_counts[message.label] += 1 | |
else: | |
filtered_messages.append(message) | |
return MessageHistory(messages=reversed(filtered_messages)) | |
def __add__(self, other: MessageHistory) -> MessageHistory: | |
new_history = MessageHistory() | |
new_history.extend(self) | |
new_history.extend(other) | |
return new_history | |
def __iadd__(self, other: MessageHistory) -> MessageHistory: | |
self.extend(other) | |
return self | |