Spaces:
Running
Running
from __future__ import annotations | |
from typing import TYPE_CHECKING, Any, cast | |
from loguru import logger | |
from langflow.graph.edge.schema import EdgeData, SourceHandle, TargetHandle, TargetHandleDict | |
from langflow.schema.schema import INPUT_FIELD_NAME | |
if TYPE_CHECKING: | |
from langflow.graph.vertex.base import Vertex | |
class Edge: | |
def __init__(self, source: Vertex, target: Vertex, edge: EdgeData): | |
self.source_id: str = source.id if source else "" | |
self.target_id: str = target.id if target else "" | |
self.valid_handles: bool = False | |
self.target_param: str | None = None | |
self._target_handle: TargetHandleDict | str | None = None | |
self._data = edge.copy() | |
self.is_cycle = False | |
if data := edge.get("data", {}): | |
self._source_handle = data.get("sourceHandle", {}) | |
self._target_handle = cast("TargetHandleDict", data.get("targetHandle", {})) | |
self.source_handle: SourceHandle = SourceHandle(**self._source_handle) | |
if isinstance(self._target_handle, dict): | |
try: | |
self.target_handle: TargetHandle = TargetHandle(**self._target_handle) | |
except Exception as e: | |
if "inputTypes" in self._target_handle and self._target_handle["inputTypes"] is None: | |
# Check if self._target_handle['fieldName'] | |
if hasattr(target, "custom_component"): | |
display_name = getattr(target.custom_component, "display_name", "") | |
msg = ( | |
f"Component {display_name} field '{self._target_handle['fieldName']}' " | |
"might not be a valid input." | |
) | |
raise ValueError(msg) from e | |
msg = ( | |
f"Field '{self._target_handle['fieldName']}' on {target.display_name} " | |
"might not be a valid input." | |
) | |
raise ValueError(msg) from e | |
raise | |
else: | |
msg = "Target handle is not a dictionary" | |
raise ValueError(msg) | |
self.target_param = self.target_handle.field_name | |
# validate handles | |
self.validate_handles(source, target) | |
else: | |
# Logging here because this is a breaking change | |
logger.error("Edge data is empty") | |
self._source_handle = edge.get("sourceHandle", "") # type: ignore[assignment] | |
self._target_handle = edge.get("targetHandle", "") # type: ignore[assignment] | |
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD' | |
# target_param is documents | |
if isinstance(self._target_handle, str): | |
self.target_param = self._target_handle.split("|")[1] | |
self.source_handle = None | |
self.target_handle = None | |
else: | |
msg = "Target handle is not a string" | |
raise ValueError(msg) | |
# Validate in __init__ to fail fast | |
self.validate_edge(source, target) | |
def to_data(self): | |
return self._data | |
def validate_handles(self, source, target) -> None: | |
if isinstance(self._source_handle, str) or self.source_handle.base_classes: | |
self._legacy_validate_handles(source, target) | |
else: | |
self._validate_handles(source, target) | |
def _validate_handles(self, source, target) -> None: | |
if self.target_handle.input_types is None: | |
self.valid_handles = self.target_handle.type in self.source_handle.output_types | |
elif self.source_handle.output_types is not None: | |
self.valid_handles = ( | |
any(output_type in self.target_handle.input_types for output_type in self.source_handle.output_types) | |
or self.target_handle.type in self.source_handle.output_types | |
) | |
if not self.valid_handles: | |
logger.debug(self.source_handle) | |
logger.debug(self.target_handle) | |
msg = f"Edge between {source.display_name} and {target.display_name} has invalid handles" | |
raise ValueError(msg) | |
def _legacy_validate_handles(self, source, target) -> None: | |
if self.target_handle.input_types is None: | |
self.valid_handles = self.target_handle.type in self.source_handle.base_classes | |
else: | |
self.valid_handles = ( | |
any(baseClass in self.target_handle.input_types for baseClass in self.source_handle.base_classes) | |
or self.target_handle.type in self.source_handle.base_classes | |
) | |
if not self.valid_handles: | |
logger.debug(self.source_handle) | |
logger.debug(self.target_handle) | |
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has invalid handles" | |
raise ValueError(msg) | |
def __setstate__(self, state): | |
self.source_id = state["source_id"] | |
self.target_id = state["target_id"] | |
self.target_param = state["target_param"] | |
self.source_handle = state.get("source_handle") | |
self.target_handle = state.get("target_handle") | |
self._source_handle = state.get("_source_handle") | |
self._target_handle = state.get("_target_handle") | |
self._data = state.get("_data") | |
self.valid_handles = state.get("valid_handles") | |
self.source_types = state.get("source_types") | |
self.target_reqs = state.get("target_reqs") | |
self.matched_type = state.get("matched_type") | |
def validate_edge(self, source, target) -> None: | |
# If the self.source_handle has base_classes, then we are using the legacy | |
# way of defining the source and target handles | |
if isinstance(self._source_handle, str) or self.source_handle.base_classes: | |
self._legacy_validate_edge(source, target) | |
else: | |
self._validate_edge(source, target) | |
def _validate_edge(self, source, target) -> None: | |
# Validate that the outputs of the source node are valid inputs | |
# for the target node | |
# .outputs is a list of Output objects as dictionaries | |
# meaning: check for "types" key in each dictionary | |
self.source_types = [output for output in source.outputs if output["name"] == self.source_handle.name] | |
self.target_reqs = target.required_inputs + target.optional_inputs | |
# Both lists contain strings and sometimes a string contains the value we are | |
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] | |
# so we need to check if any of the strings in source_types is in target_reqs | |
self.valid = any( | |
any(output_type in target_req for output_type in output["types"]) | |
for output in self.source_types | |
for target_req in self.target_reqs | |
) | |
# Get what type of input the target node is expecting | |
# Update the matched type to be the first found match | |
self.matched_type = next( | |
( | |
output_type | |
for output in self.source_types | |
for output_type in output["types"] | |
for target_req in self.target_reqs | |
if output_type in target_req | |
), | |
None, | |
) | |
no_matched_type = self.matched_type is None | |
if no_matched_type: | |
logger.debug(self.source_types) | |
logger.debug(self.target_reqs) | |
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type." | |
raise ValueError(msg) | |
def _legacy_validate_edge(self, source, target) -> None: | |
# Validate that the outputs of the source node are valid inputs | |
# for the target node | |
self.source_types = source.output | |
self.target_reqs = target.required_inputs + target.optional_inputs | |
# Both lists contain strings and sometimes a string contains the value we are | |
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] | |
# so we need to check if any of the strings in source_types is in target_reqs | |
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) | |
# Get what type of input the target node is expecting | |
self.matched_type = next( | |
(output for output in self.source_types if output in self.target_reqs), | |
None, | |
) | |
no_matched_type = self.matched_type is None | |
if no_matched_type: | |
logger.debug(self.source_types) | |
logger.debug(self.target_reqs) | |
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type" | |
raise ValueError(msg) | |
def __repr__(self) -> str: | |
if (hasattr(self, "source_handle") and self.source_handle) and ( | |
hasattr(self, "target_handle") and self.target_handle | |
): | |
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.field_name}]-> {self.target_id}" | |
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" | |
def __hash__(self) -> int: | |
return hash(self.__repr__()) | |
def __eq__(self, /, other: object) -> bool: | |
if not isinstance(other, Edge): | |
return False | |
return ( | |
self._source_handle == other._source_handle | |
and self._target_handle == other._target_handle | |
and self.target_param == other.target_param | |
) | |
def __str__(self) -> str: | |
return self.__repr__() | |
class CycleEdge(Edge): | |
def __init__(self, source: Vertex, target: Vertex, raw_edge: EdgeData): | |
super().__init__(source, target, raw_edge) | |
self.is_fulfilled = False # Whether the contract has been fulfilled. | |
self.result: Any = None | |
self.is_cycle = True | |
source.has_cycle_edges = True | |
target.has_cycle_edges = True | |
async def honor(self, source: Vertex, target: Vertex) -> None: | |
"""Fulfills the contract by setting the result of the source vertex to the target vertex's parameter. | |
If the edge is runnable, the source vertex is run with the message text and the target vertex's | |
root_field param is set to the | |
result. If the edge is not runnable, the target vertex's parameter is set to the result. | |
:param message: The message object to be processed if the edge is runnable. | |
""" | |
if self.is_fulfilled: | |
return | |
if not source.built: | |
# The system should be read-only, so we should not be building vertices | |
# that are not already built. | |
msg = f"Source vertex {source.id} is not built." | |
raise ValueError(msg) | |
if self.matched_type == "Text": | |
self.result = source.built_result | |
else: | |
self.result = source.built_object | |
target.params[self.target_param] = self.result | |
self.is_fulfilled = True | |
async def get_result_from_source(self, source: Vertex, target: Vertex): | |
# Fulfill the contract if it has not been fulfilled. | |
if not self.is_fulfilled: | |
await self.honor(source, target) | |
# If the target vertex is a power component we log messages | |
if ( | |
target.vertex_type == "ChatOutput" | |
and isinstance(target.params.get(INPUT_FIELD_NAME), str | dict) | |
and target.params.get("message") == "" | |
): | |
return self.result | |
return self.result | |
def __repr__(self) -> str: | |
str_repr = super().__repr__() | |
# Add a symbol to show this is a cycle edge | |
return f"{str_repr} 🔄" | |