Tai Truong
fix readme
d202ada
import json
from collections.abc import Sequence
from uuid import UUID
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from langflow.schema.message import Message
from langflow.services.database.models.message.model import MessageRead, MessageTable
from langflow.services.deps import async_session_scope, session_scope
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
def _get_variable_query(
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
):
stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712
if sender:
stmt = stmt.where(MessageTable.sender == sender)
if sender_name:
stmt = stmt.where(MessageTable.sender_name == sender_name)
if session_id:
stmt = stmt.where(MessageTable.session_id == session_id)
if flow_id:
stmt = stmt.where(MessageTable.flow_id == flow_id)
if order_by:
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc()
stmt = stmt.order_by(col)
if limit:
stmt = stmt.limit(limit)
return stmt
def get_messages(
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
) -> list[Message]:
"""Retrieves messages from the monitor service based on the provided filters.
Args:
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User")
sender_name (Optional[str]): The name of the sender.
session_id (Optional[str]): The session ID associated with the messages.
order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp".
order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC".
flow_id (Optional[UUID]): The flow ID associated with the messages.
limit (Optional[int]): The maximum number of messages to retrieve.
Returns:
List[Data]: A list of Data objects representing the retrieved messages.
"""
with session_scope() as session:
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
messages = session.exec(stmt)
return [Message(**d.model_dump()) for d in messages]
async def aget_messages(
sender: str | None = None,
sender_name: str | None = None,
session_id: str | None = None,
order_by: str | None = "timestamp",
order: str | None = "DESC",
flow_id: UUID | None = None,
limit: int | None = None,
) -> list[Message]:
"""Retrieves messages from the monitor service based on the provided filters.
Args:
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User")
sender_name (Optional[str]): The name of the sender.
session_id (Optional[str]): The session ID associated with the messages.
order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp".
order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC".
flow_id (Optional[UUID]): The flow ID associated with the messages.
limit (Optional[int]): The maximum number of messages to retrieve.
Returns:
List[Data]: A list of Data objects representing the retrieved messages.
"""
async with async_session_scope() as session:
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit)
messages = await session.exec(stmt)
return [await Message.create(**d.model_dump()) for d in messages]
def add_messages(messages: Message | list[Message], flow_id: str | None = None):
"""Add a message to the monitor service."""
if not isinstance(messages, list):
messages = [messages]
if not all(isinstance(message, Message) for message in messages):
types = ", ".join([str(type(message)) for message in messages])
msg = f"The messages must be instances of Message. Found: {types}"
raise ValueError(msg)
try:
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
with session_scope() as session:
messages_models = add_messagetables(messages_models, session)
return [Message(**message.model_dump()) for message in messages_models]
except Exception as e:
logger.exception(e)
raise
async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None):
"""Add a message to the monitor service."""
if not isinstance(messages, list):
messages = [messages]
if not all(isinstance(message, Message) for message in messages):
types = ", ".join([str(type(message)) for message in messages])
msg = f"The messages must be instances of Message. Found: {types}"
raise ValueError(msg)
try:
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages]
async with async_session_scope() as session:
messages_models = await aadd_messagetables(messages_models, session)
return [await Message.create(**message.model_dump()) for message in messages_models]
except Exception as e:
logger.exception(e)
raise
def update_messages(messages: Message | list[Message]) -> list[Message]:
if not isinstance(messages, list):
messages = [messages]
with session_scope() as session:
updated_messages: list[MessageTable] = []
for message in messages:
msg = session.get(MessageTable, message.id)
if msg:
msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True))
session.add(msg)
session.commit()
session.refresh(msg)
updated_messages.append(msg)
else:
logger.warning(f"Message with id {message.id} not found")
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
async def aupdate_messages(messages: Message | list[Message]) -> list[Message]:
if not isinstance(messages, list):
messages = [messages]
async with async_session_scope() as session:
updated_messages: list[MessageTable] = []
for message in messages:
msg = await session.get(MessageTable, message.id)
if msg:
msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True))
session.add(msg)
await session.commit()
await session.refresh(msg)
updated_messages.append(msg)
else:
logger.warning(f"Message with id {message.id} not found")
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages]
def add_messagetables(messages: list[MessageTable], session: Session):
for message in messages:
try:
session.add(message)
session.commit()
session.refresh(message)
except Exception as e:
logger.exception(e)
raise
new_messages = []
for msg in messages:
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type]
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type]
msg.category = msg.category or ""
new_messages.append(msg)
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession):
try:
for message in messages:
session.add(message)
await session.commit()
for message in messages:
await session.refresh(message)
except Exception as e:
logger.exception(e)
raise
new_messages = []
for msg in messages:
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type]
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type]
msg.category = msg.category or ""
new_messages.append(msg)
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages]
def delete_messages(session_id: str) -> None:
"""Delete messages from the monitor service based on the provided session ID.
Args:
session_id (str): The session ID associated with the messages to delete.
"""
with session_scope() as session:
session.exec(
delete(MessageTable)
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
async def adelete_messages(session_id: str) -> None:
"""Delete messages from the monitor service based on the provided session ID.
Args:
session_id (str): The session ID associated with the messages to delete.
"""
async with async_session_scope() as session:
stmt = (
delete(MessageTable)
.where(col(MessageTable.session_id) == session_id)
.execution_options(synchronize_session="fetch")
)
await session.exec(stmt)
async def delete_message(id_: str) -> None:
"""Delete a message from the monitor service based on the provided ID.
Args:
id_ (str): The ID of the message to delete.
"""
async with async_session_scope() as session:
message = await session.get(MessageTable, id_)
if message:
await session.delete(message)
await session.commit()
def store_message(
message: Message,
flow_id: str | None = None,
) -> list[Message]:
"""Stores a message in the memory.
Args:
message (Message): The message to store.
flow_id (Optional[str]): The flow ID associated with the message.
When running from the CustomComponent you can access this using `self.graph.flow_id`.
Returns:
List[Message]: A list of data containing the stored message.
Raises:
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided.
"""
if not message:
logger.warning("No message provided.")
return []
required_fields = ["session_id", "sender", "sender_name"]
missing_fields = [field for field in required_fields if not getattr(message, field)]
if missing_fields:
missing_descriptions = {
"session_id": "session_id (unique conversation identifier)",
"sender": f"sender (e.g., '{MESSAGE_SENDER_USER}' or '{MESSAGE_SENDER_AI}')",
"sender_name": "sender_name (display name, e.g., 'User' or 'Assistant')",
}
missing = ", ".join(missing_descriptions[field] for field in missing_fields)
msg = (
f"It looks like we're missing some important information: {missing}. "
"Please ensure that your message includes all the required fields."
)
raise ValueError(msg)
if hasattr(message, "id") and message.id:
return update_messages([message])
return add_messages([message], flow_id=flow_id)
async def astore_message(
message: Message,
flow_id: str | None = None,
) -> list[Message]:
"""Stores a message in the memory.
Args:
message (Message): The message to store.
flow_id (Optional[str]): The flow ID associated with the message.
When running from the CustomComponent you can access this using `self.graph.flow_id`.
Returns:
List[Message]: A list of data containing the stored message.
Raises:
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided.
"""
if not message:
logger.warning("No message provided.")
return []
if not message.session_id or not message.sender or not message.sender_name:
msg = "All of session_id, sender, and sender_name must be provided."
raise ValueError(msg)
if hasattr(message, "id") and message.id:
return await aupdate_messages([message])
return await aadd_messages([message], flow_id=flow_id)
class LCBuiltinChatMemory(BaseChatMessageHistory):
def __init__(
self,
flow_id: str,
session_id: str,
) -> None:
self.flow_id = flow_id
self.session_id = session_id
@property
def messages(self) -> list[BaseMessage]:
messages = get_messages(
session_id=self.session_id,
)
return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages
async def aget_messages(self) -> list[BaseMessage]:
messages = await aget_messages(
session_id=self.session_id,
)
return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
for lc_message in messages:
message = Message.from_lc_message(lc_message)
message.session_id = self.session_id
store_message(message, flow_id=self.flow_id)
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
for lc_message in messages:
message = Message.from_lc_message(lc_message)
message.session_id = self.session_id
await astore_message(message, flow_id=self.flow_id)
def clear(self) -> None:
delete_messages(self.session_id)
async def aclear(self) -> None:
await adelete_messages(self.session_id)