Tai Truong
fix readme
d202ada
import enum
from langchain_core.messages import BaseMessage
from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import TypedDict
from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
class File(TypedDict):
"""File schema."""
path: str
name: str
type: str
class ChatOutputResponse(BaseModel):
"""Chat output response schema."""
message: str | list[str | dict]
sender: str | None = MESSAGE_SENDER_AI
sender_name: str | None = MESSAGE_SENDER_NAME_AI
session_id: str | None = None
stream_url: str | None = None
component_id: str | None = None
files: list[File] = []
type: str
@field_validator("files", mode="before")
@classmethod
def validate_files(cls, files):
"""Validate files."""
if not files:
return files
for file in files:
if not isinstance(file, dict):
msg = "Files must be a list of dictionaries."
raise ValueError(msg) # noqa: TRY004
if not all(key in file for key in ["path", "name", "type"]):
# If any of the keys are missing, we should extract the
# values from the file path
path = file.get("path")
if not path:
msg = "File path is required."
raise ValueError(msg)
name = file.get("name")
if not name:
name = path.split("/")[-1]
file["name"] = name
type_ = file.get("type")
if not type_:
# get the file type from the path
extension = path.split(".")[-1]
file_types = set(TEXT_FILE_TYPES + IMG_FILE_TYPES)
if extension and extension in file_types:
type_ = extension
else:
for file_type in file_types:
if file_type in path:
type_ = file_type
break
if not type_:
msg = "File type is required."
raise ValueError(msg)
file["type"] = type_
return files
@classmethod
def from_message(
cls,
message: BaseMessage,
sender: str | None = MESSAGE_SENDER_AI,
sender_name: str | None = MESSAGE_SENDER_NAME_AI,
):
"""Build chat output response from message."""
content = message.content
return cls(message=content, sender=sender, sender_name=sender_name)
@model_validator(mode="after")
def validate_message(self):
"""Validate message."""
# The idea here is ensure the \n in message
# is compliant with markdown if sender is machine
# so, for example:
# \n\n -> \n\n
# \n -> \n\n
if self.sender != MESSAGE_SENDER_AI:
return self
# We need to make sure we don't duplicate \n
# in the message
message = self.message.replace("\n\n", "\n")
self.message = message.replace("\n", "\n\n")
return self
class DataOutputResponse(BaseModel):
"""Data output response schema."""
data: list[dict | None]
class ContainsEnumMeta(enum.EnumMeta):
def __contains__(cls, item) -> bool:
try:
cls(item)
except ValueError:
return False
else:
return True