Tai Truong
fix readme
d202ada
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from loguru import logger
from langflow.graph.edge.schema import EdgeData, SourceHandle, TargetHandle, TargetHandleDict
from langflow.schema.schema import INPUT_FIELD_NAME
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
class Edge:
def __init__(self, source: Vertex, target: Vertex, edge: EdgeData):
self.source_id: str = source.id if source else ""
self.target_id: str = target.id if target else ""
self.valid_handles: bool = False
self.target_param: str | None = None
self._target_handle: TargetHandleDict | str | None = None
self._data = edge.copy()
self.is_cycle = False
if data := edge.get("data", {}):
self._source_handle = data.get("sourceHandle", {})
self._target_handle = cast("TargetHandleDict", data.get("targetHandle", {}))
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
if isinstance(self._target_handle, dict):
try:
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
except Exception as e:
if "inputTypes" in self._target_handle and self._target_handle["inputTypes"] is None:
# Check if self._target_handle['fieldName']
if hasattr(target, "custom_component"):
display_name = getattr(target.custom_component, "display_name", "")
msg = (
f"Component {display_name} field '{self._target_handle['fieldName']}' "
"might not be a valid input."
)
raise ValueError(msg) from e
msg = (
f"Field '{self._target_handle['fieldName']}' on {target.display_name} "
"might not be a valid input."
)
raise ValueError(msg) from e
raise
else:
msg = "Target handle is not a dictionary"
raise ValueError(msg)
self.target_param = self.target_handle.field_name
# validate handles
self.validate_handles(source, target)
else:
# Logging here because this is a breaking change
logger.error("Edge data is empty")
self._source_handle = edge.get("sourceHandle", "") # type: ignore[assignment]
self._target_handle = edge.get("targetHandle", "") # type: ignore[assignment]
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
# target_param is documents
if isinstance(self._target_handle, str):
self.target_param = self._target_handle.split("|")[1]
self.source_handle = None
self.target_handle = None
else:
msg = "Target handle is not a string"
raise ValueError(msg)
# Validate in __init__ to fail fast
self.validate_edge(source, target)
def to_data(self):
return self._data
def validate_handles(self, source, target) -> None:
if isinstance(self._source_handle, str) or self.source_handle.base_classes:
self._legacy_validate_handles(source, target)
else:
self._validate_handles(source, target)
def _validate_handles(self, source, target) -> None:
if self.target_handle.input_types is None:
self.valid_handles = self.target_handle.type in self.source_handle.output_types
elif self.source_handle.output_types is not None:
self.valid_handles = (
any(output_type in self.target_handle.input_types for output_type in self.source_handle.output_types)
or self.target_handle.type in self.source_handle.output_types
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
msg = f"Edge between {source.display_name} and {target.display_name} has invalid handles"
raise ValueError(msg)
def _legacy_validate_handles(self, source, target) -> None:
if self.target_handle.input_types is None:
self.valid_handles = self.target_handle.type in self.source_handle.base_classes
else:
self.valid_handles = (
any(baseClass in self.target_handle.input_types for baseClass in self.source_handle.base_classes)
or self.target_handle.type in self.source_handle.base_classes
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has invalid handles"
raise ValueError(msg)
def __setstate__(self, state):
self.source_id = state["source_id"]
self.target_id = state["target_id"]
self.target_param = state["target_param"]
self.source_handle = state.get("source_handle")
self.target_handle = state.get("target_handle")
self._source_handle = state.get("_source_handle")
self._target_handle = state.get("_target_handle")
self._data = state.get("_data")
self.valid_handles = state.get("valid_handles")
self.source_types = state.get("source_types")
self.target_reqs = state.get("target_reqs")
self.matched_type = state.get("matched_type")
def validate_edge(self, source, target) -> None:
# If the self.source_handle has base_classes, then we are using the legacy
# way of defining the source and target handles
if isinstance(self._source_handle, str) or self.source_handle.base_classes:
self._legacy_validate_edge(source, target)
else:
self._validate_edge(source, target)
def _validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
# .outputs is a list of Output objects as dictionaries
# meaning: check for "types" key in each dictionary
self.source_types = [output for output in source.outputs if output["name"] == self.source_handle.name]
self.target_reqs = target.required_inputs + target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(
any(output_type in target_req for output_type in output["types"])
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
# Update the matched type to be the first found match
self.matched_type = next(
(
output_type
for output in self.source_types
for output_type in output["types"]
for target_req in self.target_reqs
if output_type in target_req
),
None,
)
no_matched_type = self.matched_type is None
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type."
raise ValueError(msg)
def _legacy_validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
self.source_types = source.output
self.target_reqs = target.required_inputs + target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
# Get what type of input the target node is expecting
self.matched_type = next(
(output for output in self.source_types if output in self.target_reqs),
None,
)
no_matched_type = self.matched_type is None
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type"
raise ValueError(msg)
def __repr__(self) -> str:
if (hasattr(self, "source_handle") and self.source_handle) and (
hasattr(self, "target_handle") and self.target_handle
):
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.field_name}]-> {self.target_id}"
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def __hash__(self) -> int:
return hash(self.__repr__())
def __eq__(self, /, other: object) -> bool:
if not isinstance(other, Edge):
return False
return (
self._source_handle == other._source_handle
and self._target_handle == other._target_handle
and self.target_param == other.target_param
)
def __str__(self) -> str:
return self.__repr__()
class CycleEdge(Edge):
def __init__(self, source: Vertex, target: Vertex, raw_edge: EdgeData):
super().__init__(source, target, raw_edge)
self.is_fulfilled = False # Whether the contract has been fulfilled.
self.result: Any = None
self.is_cycle = True
source.has_cycle_edges = True
target.has_cycle_edges = True
async def honor(self, source: Vertex, target: Vertex) -> None:
"""Fulfills the contract by setting the result of the source vertex to the target vertex's parameter.
If the edge is runnable, the source vertex is run with the message text and the target vertex's
root_field param is set to the
result. If the edge is not runnable, the target vertex's parameter is set to the result.
:param message: The message object to be processed if the edge is runnable.
"""
if self.is_fulfilled:
return
if not source.built:
# The system should be read-only, so we should not be building vertices
# that are not already built.
msg = f"Source vertex {source.id} is not built."
raise ValueError(msg)
if self.matched_type == "Text":
self.result = source.built_result
else:
self.result = source.built_object
target.params[self.target_param] = self.result
self.is_fulfilled = True
async def get_result_from_source(self, source: Vertex, target: Vertex):
# Fulfill the contract if it has not been fulfilled.
if not self.is_fulfilled:
await self.honor(source, target)
# If the target vertex is a power component we log messages
if (
target.vertex_type == "ChatOutput"
and isinstance(target.params.get(INPUT_FIELD_NAME), str | dict)
and target.params.get("message") == ""
):
return self.result
return self.result
def __repr__(self) -> str:
str_repr = super().__repr__()
# Add a symbol to show this is a cycle edge
return f"{str_repr} 🔄"