Spaces:
Running
Running
import warnings | |
from collections.abc import AsyncIterator, Iterator | |
from typing import Any, TypeAlias, get_args | |
from pandas import DataFrame | |
from pydantic import Field, field_validator | |
from langflow.inputs.validators import CoalesceBool | |
from langflow.schema.data import Data | |
from langflow.schema.message import Message | |
from langflow.services.database.models.message.model import MessageBase | |
from langflow.template.field.base import Input | |
from .input_mixin import ( | |
BaseInputMixin, | |
DatabaseLoadMixin, | |
DropDownMixin, | |
FieldTypes, | |
FileMixin, | |
InputTraceMixin, | |
LinkMixin, | |
ListableInputMixin, | |
MetadataTraceMixin, | |
MultilineMixin, | |
RangeMixin, | |
SerializableFieldTypes, | |
SliderMixin, | |
TableMixin, | |
ToolModeMixin, | |
) | |
class TableInput(BaseInputMixin, MetadataTraceMixin, TableMixin, ListableInputMixin): | |
field_type: SerializableFieldTypes = FieldTypes.TABLE | |
is_list: bool = True | |
def validate_value(cls, v: Any, _info): | |
# Check if value is a list of dicts | |
if isinstance(v, DataFrame): | |
v = v.to_dict(orient="records") | |
if not isinstance(v, list): | |
msg = f"TableInput value must be a list of dictionaries or Data. Value '{v}' is not a list." | |
raise ValueError(msg) # noqa: TRY004 | |
for item in v: | |
if not isinstance(item, dict | Data): | |
msg = ( | |
"TableInput value must be a list of dictionaries or Data. " | |
f"Item '{item}' is not a dictionary or Data." | |
) | |
raise ValueError(msg) # noqa: TRY004 | |
return v | |
class HandleInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin): | |
"""Represents an Input that has a Handle to a specific type (e.g. BaseLanguageModel, BaseRetriever, etc.). | |
This class inherits from the `BaseInputMixin` and `ListableInputMixin` classes. | |
Attributes: | |
input_types (list[str]): A list of input types. | |
field_type (SerializableFieldTypes): The field type of the input. | |
""" | |
input_types: list[str] = Field(default_factory=list) | |
field_type: SerializableFieldTypes = FieldTypes.OTHER | |
class DataInput(HandleInput, InputTraceMixin, ListableInputMixin, ToolModeMixin): | |
"""Represents an Input that has a Handle that receives a Data object. | |
Attributes: | |
input_types (list[str]): A list of input types supported by this data input. | |
""" | |
input_types: list[str] = ["Data"] | |
class DataFrameInput(HandleInput, InputTraceMixin, ListableInputMixin, ToolModeMixin): | |
input_types: list[str] = ["DataFrame"] | |
class PromptInput(BaseInputMixin, ListableInputMixin, InputTraceMixin, ToolModeMixin): | |
field_type: SerializableFieldTypes = FieldTypes.PROMPT | |
class CodeInput(BaseInputMixin, ListableInputMixin, InputTraceMixin): | |
field_type: SerializableFieldTypes = FieldTypes.CODE | |
# Applying mixins to a specific input type | |
class StrInput(BaseInputMixin, ListableInputMixin, DatabaseLoadMixin, MetadataTraceMixin): | |
field_type: SerializableFieldTypes = FieldTypes.TEXT | |
load_from_db: CoalesceBool = False | |
"""Defines if the field will allow the user to open a text editor. Default is False.""" | |
def _validate_value(v: Any, info): | |
"""Validates the given value and returns the processed value. | |
Args: | |
v (Any): The value to be validated. | |
info: Additional information about the input. | |
Returns: | |
The processed value. | |
Raises: | |
ValueError: If the value is not of a valid type or if the input is missing a required key. | |
""" | |
if not isinstance(v, str) and v is not None: | |
# Keep the warning for now, but we should change it to an error | |
if info.data.get("input_types") and v.__class__.__name__ not in info.data.get("input_types"): | |
warnings.warn( | |
f"Invalid value type {type(v)} for input {info.data.get('name')}. " | |
f"Expected types: {info.data.get('input_types')}", | |
stacklevel=4, | |
) | |
else: | |
warnings.warn( | |
f"Invalid value type {type(v)} for input {info.data.get('name')}.", | |
stacklevel=4, | |
) | |
return v | |
def validate_value(cls, v: Any, info): | |
"""Validates the given value and returns the processed value. | |
Args: | |
v (Any): The value to be validated. | |
info: Additional information about the input. | |
Returns: | |
The processed value. | |
Raises: | |
ValueError: If the value is not of a valid type or if the input is missing a required key. | |
""" | |
is_list = info.data["is_list"] | |
return [cls._validate_value(vv, info) for vv in v] if is_list else cls._validate_value(v, info) | |
class MessageInput(StrInput, InputTraceMixin): | |
input_types: list[str] = ["Message"] | |
def _validate_value(v: Any, _info): | |
# If v is a instance of Message, then its fine | |
if isinstance(v, dict): | |
return Message(**v) | |
if isinstance(v, Message): | |
return v | |
if isinstance(v, str | AsyncIterator | Iterator): | |
return Message(text=v) | |
if isinstance(v, MessageBase): | |
return Message(**v.model_dump()) | |
msg = f"Invalid value type {type(v)}" | |
raise ValueError(msg) | |
class MessageTextInput(StrInput, MetadataTraceMixin, InputTraceMixin, ToolModeMixin): | |
"""Represents a text input component for the Langflow system. | |
This component is used to handle text inputs in the Langflow system. | |
It provides methods for validating and processing text values. | |
Attributes: | |
input_types (list[str]): A list of input types that this component supports. | |
In this case, it supports the "Message" input type. | |
""" | |
input_types: list[str] = ["Message"] | |
def _validate_value(v: Any, info): | |
"""Validates the given value and returns the processed value. | |
Args: | |
v (Any): The value to be validated. | |
info: Additional information about the input. | |
Returns: | |
The processed value. | |
Raises: | |
ValueError: If the value is not of a valid type or if the input is missing a required key. | |
""" | |
value: str | AsyncIterator | Iterator | None = None | |
if isinstance(v, dict): | |
v = Message(**v) | |
if isinstance(v, str): | |
value = v | |
elif isinstance(v, Message): | |
value = v.text | |
elif isinstance(v, Data): | |
if v.text_key in v.data: | |
value = v.data[v.text_key] | |
else: | |
keys = ", ".join(v.data.keys()) | |
input_name = info.data["name"] | |
msg = ( | |
f"The input to '{input_name}' must contain the key '{v.text_key}'." | |
f"You can set `text_key` to one of the following keys: {keys} " | |
"or set the value using another Component." | |
) | |
raise ValueError(msg) | |
elif isinstance(v, AsyncIterator | Iterator): | |
value = v | |
else: | |
msg = f"Invalid value type {type(v)}" | |
raise ValueError(msg) # noqa: TRY004 | |
return value | |
class MultilineInput(MessageTextInput, MultilineMixin, InputTraceMixin, ToolModeMixin): | |
"""Represents a multiline input field. | |
Attributes: | |
field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT. | |
multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.TEXT | |
multiline: CoalesceBool = True | |
class MultilineSecretInput(MessageTextInput, MultilineMixin, InputTraceMixin): | |
"""Represents a multiline input field. | |
Attributes: | |
field_type (SerializableFieldTypes): The type of the field. Defaults to FieldTypes.TEXT. | |
multiline (CoalesceBool): Indicates whether the input field should support multiple lines. Defaults to True. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.PASSWORD | |
multiline: CoalesceBool = True | |
password: CoalesceBool = Field(default=True) | |
class SecretStrInput(BaseInputMixin, DatabaseLoadMixin): | |
"""Represents a field with password field type. | |
This class inherits from `BaseInputMixin` and `DatabaseLoadMixin`. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to `FieldTypes.PASSWORD`. | |
password (CoalesceBool): A boolean indicating whether the input is a password. Defaults to `True`. | |
input_types (list[str]): A list of input types associated with this input. Defaults to an empty list. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.PASSWORD | |
password: CoalesceBool = Field(default=True) | |
input_types: list[str] = ["Message"] | |
load_from_db: CoalesceBool = True | |
def validate_value(cls, v: Any, info): | |
"""Validates the given value and returns the processed value. | |
Args: | |
v (Any): The value to be validated. | |
info: Additional information about the input. | |
Returns: | |
The processed value. | |
Raises: | |
ValueError: If the value is not of a valid type or if the input is missing a required key. | |
""" | |
value: str | AsyncIterator | Iterator | None = None | |
if isinstance(v, str): | |
value = v | |
elif isinstance(v, Message): | |
value = v.text | |
elif isinstance(v, Data): | |
if v.text_key in v.data: | |
value = v.data[v.text_key] | |
else: | |
keys = ", ".join(v.data.keys()) | |
input_name = info.data["name"] | |
msg = ( | |
f"The input to '{input_name}' must contain the key '{v.text_key}'." | |
f"You can set `text_key` to one of the following keys: {keys} " | |
"or set the value using another Component." | |
) | |
raise ValueError(msg) | |
elif isinstance(v, AsyncIterator | Iterator): | |
value = v | |
elif v is None: | |
value = None | |
else: | |
msg = f"Invalid value type `{type(v)}` for input `{info.data['name']}`" | |
raise ValueError(msg) | |
return value | |
class IntInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixin): | |
"""Represents an integer field. | |
This class represents an integer input and provides functionality for handling integer values. | |
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.INTEGER. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.INTEGER | |
def validate_value(cls, v: Any, info): | |
"""Validates the given value and returns the processed value. | |
Args: | |
v (Any): The value to be validated. | |
info: Additional information about the input. | |
Returns: | |
The processed value. | |
Raises: | |
ValueError: If the value is not of a valid type or if the input is missing a required key. | |
""" | |
if v and not isinstance(v, int | float): | |
msg = f"Invalid value type {type(v)} for input {info.data.get('name')}." | |
raise ValueError(msg) | |
if isinstance(v, float): | |
v = int(v) | |
return v | |
class FloatInput(BaseInputMixin, ListableInputMixin, RangeMixin, MetadataTraceMixin): | |
"""Represents a float field. | |
This class represents a float input and provides functionality for handling float values. | |
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `RangeMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FLOAT. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.FLOAT | |
def validate_value(cls, v: Any, info): | |
"""Validates the given value and returns the processed value. | |
Args: | |
v (Any): The value to be validated. | |
info: Additional information about the input. | |
Returns: | |
The processed value. | |
Raises: | |
ValueError: If the value is not of a valid type or if the input is missing a required key. | |
""" | |
if v and not isinstance(v, int | float): | |
msg = f"Invalid value type {type(v)} for input {info.data.get('name')}." | |
raise ValueError(msg) | |
if isinstance(v, int): | |
v = float(v) | |
return v | |
class BoolInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin): | |
"""Represents a boolean field. | |
This class represents a boolean input and provides functionality for handling boolean values. | |
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.BOOLEAN. | |
value (CoalesceBool): The value of the boolean input. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.BOOLEAN | |
value: CoalesceBool = False | |
class NestedDictInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin, InputTraceMixin): | |
"""Represents a nested dictionary field. | |
This class represents a nested dictionary input and provides functionality for handling dictionary values. | |
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.NESTED_DICT. | |
value (Optional[dict]): The value of the input. Defaults to an empty dictionary. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.NESTED_DICT | |
value: dict | Data | None = {} | |
class DictInput(BaseInputMixin, ListableInputMixin, InputTraceMixin): | |
"""Represents a dictionary field. | |
This class represents a dictionary input and provides functionality for handling dictionary values. | |
It inherits from the `BaseInputMixin` and `ListableInputMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.DICT. | |
value (Optional[dict]): The value of the dictionary input. Defaults to an empty dictionary. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.DICT | |
value: dict | None = {} | |
class DropdownInput(BaseInputMixin, DropDownMixin, MetadataTraceMixin, ToolModeMixin): | |
"""Represents a dropdown input field. | |
This class represents a dropdown input field and provides functionality for handling dropdown values. | |
It inherits from the `BaseInputMixin` and `DropDownMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT. | |
options (Optional[Union[list[str], Callable]]): List of options for the field. | |
Default is None. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.TEXT | |
options: list[str] = Field(default_factory=list) | |
combobox: CoalesceBool = False | |
class MultiselectInput(BaseInputMixin, ListableInputMixin, DropDownMixin, MetadataTraceMixin): | |
"""Represents a multiselect input field. | |
This class represents a multiselect input field and provides functionality for handling multiselect values. | |
It inherits from the `BaseInputMixin`, `ListableInputMixin` and `DropDownMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.TEXT. | |
options (Optional[Union[list[str], Callable]]): List of options for the field. Only used when is_list=True. | |
Default is None. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.TEXT | |
options: list[str] = Field(default_factory=list) | |
is_list: bool = Field(default=True, serialization_alias="list") | |
combobox: CoalesceBool = False | |
def validate_value(cls, v: Any, _info): | |
# Check if value is a list of dicts | |
if not isinstance(v, list): | |
msg = f"MultiselectInput value must be a list. Value: '{v}'" | |
raise ValueError(msg) # noqa: TRY004 | |
for item in v: | |
if not isinstance(item, str): | |
msg = f"MultiselectInput value must be a list of strings. Item: '{item}' is not a string" | |
raise ValueError(msg) # noqa: TRY004 | |
return v | |
class FileInput(BaseInputMixin, ListableInputMixin, FileMixin, MetadataTraceMixin): | |
"""Represents a file field. | |
This class represents a file input and provides functionality for handling file values. | |
It inherits from the `BaseInputMixin`, `ListableInputMixin`, and `FileMixin` classes. | |
Attributes: | |
field_type (SerializableFieldTypes): The field type of the input. Defaults to FieldTypes.FILE. | |
""" | |
field_type: SerializableFieldTypes = FieldTypes.FILE | |
class LinkInput(BaseInputMixin, LinkMixin): | |
field_type: SerializableFieldTypes = FieldTypes.LINK | |
class SliderInput(BaseInputMixin, RangeMixin, SliderMixin): | |
field_type: SerializableFieldTypes = FieldTypes.SLIDER | |
DEFAULT_PROMPT_INTUT_TYPES = ["Message", "Text"] | |
class DefaultPromptField(Input): | |
name: str | |
display_name: str | None = None | |
field_type: str = "str" | |
advanced: bool = False | |
multiline: bool = True | |
input_types: list[str] = DEFAULT_PROMPT_INTUT_TYPES | |
value: Any = "" # Set the value to empty string | |
InputTypes: TypeAlias = ( | |
Input | |
| DefaultPromptField | |
| BoolInput | |
| DataInput | |
| DictInput | |
| DropdownInput | |
| MultiselectInput | |
| FileInput | |
| FloatInput | |
| HandleInput | |
| IntInput | |
| MultilineInput | |
| MultilineSecretInput | |
| NestedDictInput | |
| PromptInput | |
| CodeInput | |
| SecretStrInput | |
| StrInput | |
| MessageTextInput | |
| MessageInput | |
| TableInput | |
| LinkInput | |
| SliderInput | |
| DataFrameInput | |
) | |
InputTypesMap: dict[str, type[InputTypes]] = {t.__name__: t for t in get_args(InputTypes)} | |
def instantiate_input(input_type: str, data: dict) -> InputTypes: | |
input_type_class = InputTypesMap.get(input_type) | |
if "type" in data: | |
# Replate with field_type | |
data["field_type"] = data.pop("type") | |
if input_type_class: | |
return input_type_class(**data) | |
msg = f"Invalid input type: {input_type}" | |
raise ValueError(msg) | |