Spaces:
Running
Running
from __future__ import annotations | |
import ast | |
import asyncio | |
import inspect | |
import os | |
import traceback | |
import types | |
from collections.abc import AsyncIterator, Callable, Iterator, Mapping | |
from enum import Enum | |
from typing import TYPE_CHECKING, Any | |
import pandas as pd | |
from loguru import logger | |
from langflow.exceptions.component import ComponentBuildError | |
from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData | |
from langflow.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction | |
from langflow.interface import initialize | |
from langflow.interface.listing import lazy_load_dict | |
from langflow.schema.artifact import ArtifactType | |
from langflow.schema.data import Data | |
from langflow.schema.message import Message | |
from langflow.schema.schema import INPUT_FIELD_NAME, OutputValue, build_output_logs | |
from langflow.services.deps import get_storage_service | |
from langflow.utils.constants import DIRECT_TYPES | |
from langflow.utils.schemas import ChatOutputResponse | |
from langflow.utils.util import sync_to_async, unescape_string | |
if TYPE_CHECKING: | |
from uuid import UUID | |
from langflow.custom import Component | |
from langflow.events.event_manager import EventManager | |
from langflow.graph.edge.base import CycleEdge, Edge | |
from langflow.graph.graph.base import Graph | |
from langflow.graph.vertex.schema import NodeData | |
from langflow.services.tracing.schema import Log | |
class VertexStates(str, Enum): | |
"""Vertex are related to it being active, inactive, or in an error state.""" | |
ACTIVE = "ACTIVE" | |
INACTIVE = "INACTIVE" | |
ERROR = "ERROR" | |
class Vertex: | |
def __init__( | |
self, | |
data: NodeData, | |
graph: Graph, | |
*, | |
base_type: str | None = None, | |
is_task: bool = False, | |
params: dict | None = None, | |
) -> None: | |
# is_external means that the Vertex send or receives data from | |
# an external source (e.g the chat) | |
self._lock = asyncio.Lock() | |
self.will_stream = False | |
self.updated_raw_params = False | |
self.id: str = data["id"] | |
self.base_name = self.id.split("-")[0] | |
self.is_state = False | |
self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS) | |
self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS) | |
self.has_session_id = None | |
self.custom_component = None | |
self.has_external_input = False | |
self.has_external_output = False | |
self.graph = graph | |
self.full_data = data.copy() | |
self.base_type: str | None = base_type | |
self.outputs: list[dict] = [] | |
self.parse_data() | |
self.built_object: Any = UnbuiltObject() | |
self.built_result: Any = None | |
self.built = False | |
self._successors_ids: list[str] | None = None | |
self.artifacts: dict[str, Any] = {} | |
self.artifacts_raw: dict[str, Any] = {} | |
self.artifacts_type: dict[str, str] = {} | |
self.steps: list[Callable] = [self._build] | |
self.steps_ran: list[Callable] = [] | |
self.task_id: str | None = None | |
self.is_task = is_task | |
self.params = params or {} | |
self.parent_node_id: str | None = self.full_data.get("parent_node_id") | |
self.load_from_db_fields: list[str] = [] | |
self.parent_is_top_level = False | |
self.layer = None | |
self.result: ResultData | None = None | |
self.results: dict[str, Any] = {} | |
self.outputs_logs: dict[str, OutputValue] = {} | |
self.logs: dict[str, list[Log]] = {} | |
self.has_cycle_edges = False | |
try: | |
self.is_interface_component = self.vertex_type in InterfaceComponentTypes | |
except ValueError: | |
self.is_interface_component = False | |
self.use_result = False | |
self.build_times: list[float] = [] | |
self.state = VertexStates.ACTIVE | |
self.log_transaction_tasks: set[asyncio.Task] = set() | |
def set_input_value(self, name: str, value: Any) -> None: | |
if self.custom_component is None: | |
msg = f"Vertex {self.id} does not have a component instance." | |
raise ValueError(msg) | |
self.custom_component._set_input_value(name, value) | |
def to_data(self): | |
return self.full_data | |
def add_component_instance(self, component_instance: Component) -> None: | |
component_instance.set_vertex(self) | |
self.custom_component = component_instance | |
def add_result(self, name: str, result: Any) -> None: | |
self.results[name] = result | |
def update_graph_state(self, key, new_state, *, append: bool) -> None: | |
if append: | |
self.graph.append_state(key, new_state, caller=self.id) | |
else: | |
self.graph.update_state(key, new_state, caller=self.id) | |
def set_state(self, state: str) -> None: | |
self.state = VertexStates[state] | |
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] <= 1: | |
# If the vertex is inactive and has only one in degree | |
# it means that it is not a merge point in the graph | |
self.graph.inactivated_vertices.add(self.id) | |
elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices: | |
self.graph.inactivated_vertices.remove(self.id) | |
def is_active(self): | |
return self.state == VertexStates.ACTIVE | |
def avg_build_time(self): | |
return sum(self.build_times) / len(self.build_times) if self.build_times else 0 | |
def add_build_time(self, time) -> None: | |
self.build_times.append(time) | |
def set_result(self, result: ResultData) -> None: | |
self.result = result | |
def get_built_result(self): | |
# If the Vertex.type is a power component | |
# then we need to return the built object | |
# instead of the result dict | |
if self.is_interface_component and not isinstance(self.built_object, UnbuiltObject): | |
result = self.built_object | |
# if it is not a dict or a string and hasattr model_dump then | |
# return the model_dump | |
if not isinstance(result, dict | str) and hasattr(result, "content"): | |
return result.content | |
return result | |
if isinstance(self.built_object, str): | |
self.built_result = self.built_object | |
if isinstance(self.built_result, UnbuiltResult): | |
return {} | |
return self.built_result if isinstance(self.built_result, dict) else {"result": self.built_result} | |
def set_artifacts(self) -> None: | |
pass | |
def edges(self) -> list[CycleEdge]: | |
return self.graph.get_vertex_edges(self.id) | |
def outgoing_edges(self) -> list[CycleEdge]: | |
return [edge for edge in self.edges if edge.source_id == self.id] | |
def incoming_edges(self) -> list[CycleEdge]: | |
return [edge for edge in self.edges if edge.target_id == self.id] | |
def edges_source_names(self) -> set[str | None]: | |
return {edge.source_handle.name for edge in self.edges} | |
def predecessors(self) -> list[Vertex]: | |
return self.graph.get_predecessors(self) | |
def successors(self) -> list[Vertex]: | |
return self.graph.get_successors(self) | |
def successors_ids(self) -> list[str]: | |
return self.graph.successor_map.get(self.id, []) | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
state["_lock"] = None # Locks are not serializable | |
state["built_object"] = None if isinstance(self.built_object, UnbuiltObject) else self.built_object | |
state["built_result"] = None if isinstance(self.built_result, UnbuiltResult) else self.built_result | |
return state | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
self._lock = asyncio.Lock() # Reinitialize the lock | |
self.built_object = state.get("built_object") or UnbuiltObject() | |
self.built_result = state.get("built_result") or UnbuiltResult() | |
def set_top_level(self, top_level_vertices: list[str]) -> None: | |
self.parent_is_top_level = self.parent_node_id in top_level_vertices | |
def parse_data(self) -> None: | |
self.data = self.full_data["data"] | |
if self.data["node"]["template"]["_type"] == "Component": | |
if "outputs" not in self.data["node"]: | |
msg = f"Outputs not found for {self.display_name}" | |
raise ValueError(msg) | |
self.outputs = self.data["node"]["outputs"] | |
else: | |
self.outputs = self.data["node"].get("outputs", []) | |
self.output = self.data["node"]["base_classes"] | |
self.display_name: str = self.data["node"].get("display_name", self.id.split("-")[0]) | |
self.icon: str = self.data["node"].get("icon", self.id.split("-")[0]) | |
self.description: str = self.data["node"].get("description", "") | |
self.frozen: bool = self.data["node"].get("frozen", False) | |
self.is_input = self.data["node"].get("is_input") or self.is_input | |
self.is_output = self.data["node"].get("is_output") or self.is_output | |
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} | |
self.has_session_id = "session_id" in template_dicts | |
self.required_inputs: list[str] = [] | |
self.optional_inputs: list[str] = [] | |
for value_dict in template_dicts.values(): | |
list_to_append = self.required_inputs if value_dict.get("required") else self.optional_inputs | |
if "type" in value_dict: | |
list_to_append.append(value_dict["type"]) | |
if "input_types" in value_dict: | |
list_to_append.extend(value_dict["input_types"]) | |
template_dict = self.data["node"]["template"] | |
self.vertex_type = ( | |
self.data["type"] | |
if "Tool" not in [type_ for out in self.outputs for type_ in out["types"]] | |
or template_dict["_type"].islower() | |
else template_dict["_type"] | |
) | |
if self.base_type is None: | |
for base_type, value in lazy_load_dict.all_types_dict.items(): | |
if self.vertex_type in value: | |
self.base_type = base_type | |
break | |
def get_value_from_template_dict(self, key: str): | |
template_dict = self.data.get("node", {}).get("template", {}) | |
if key not in template_dict: | |
msg = f"Key {key} not found in template dict" | |
raise ValueError(msg) | |
return template_dict.get(key, {}).get("value") | |
def get_task(self): | |
# using the task_id, get the task from celery | |
# and return it | |
from celery.result import AsyncResult | |
return AsyncResult(self.task_id) | |
def _set_params_from_normal_edge(self, params: dict, edge: Edge, template_dict: dict): | |
param_key = edge.target_param | |
# If the param_key is in the template_dict and the edge.target_id is the current node | |
# We check this to make sure params with the same name but different target_id | |
# don't get overwritten | |
if param_key in template_dict and edge.target_id == self.id: | |
if template_dict[param_key].get("list"): | |
if param_key not in params: | |
params[param_key] = [] | |
params[param_key].append(self.graph.get_vertex(edge.source_id)) | |
elif edge.target_id == self.id: | |
if isinstance(template_dict[param_key].get("value"), dict): | |
# we don't know the key of the dict but we need to set the value | |
# to the vertex that is the source of the edge | |
param_dict = template_dict[param_key]["value"] | |
if not param_dict or len(param_dict) != 1: | |
params[param_key] = self.graph.get_vertex(edge.source_id) | |
else: | |
params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict} | |
else: | |
params[param_key] = self.graph.get_vertex(edge.source_id) | |
return params | |
def build_params(self) -> None: | |
# sourcery skip: merge-list-append, remove-redundant-if | |
# Some params are required, some are optional | |
# but most importantly, some params are python base classes | |
# like str and others are LangChain objects like LLMChain, BasePromptTemplate | |
# so we need to be able to distinguish between the two | |
# The dicts with "type" == "str" are the ones that are python base classes | |
# and most likely have a "value" key | |
# So for each key besides "_type" in the template dict, we have a dict | |
# with a "type" key. If the type is not "str", then we need to get the | |
# edge that connects to that node and get the Node with the required data | |
# and use that as the value for the param | |
# If the type is "str", then we need to get the value of the "value" key | |
# and use that as the value for the param | |
if self.graph is None: | |
msg = "Graph not found" | |
raise ValueError(msg) | |
if self.updated_raw_params: | |
self.updated_raw_params = False | |
return | |
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} | |
params: dict = {} | |
for edge in self.edges: | |
if not hasattr(edge, "target_param"): | |
continue | |
params = self._set_params_from_normal_edge(params, edge, template_dict) | |
load_from_db_fields = [] | |
for field_name, field in template_dict.items(): | |
if field_name in params: | |
continue | |
# Skip _type and any value that has show == False and is not code | |
# If we don't want to show code but we want to use it | |
if field_name == "_type" or (not field.get("show") and field_name != "code"): | |
continue | |
# If the type is not transformable to a python base class | |
# then we need to get the edge that connects to this node | |
if field.get("type") == "file": | |
# Load the type in value.get('fileTypes') using | |
# what is inside value.get('content') | |
# value.get('value') is the file name | |
if file_path := field.get("file_path"): | |
storage_service = get_storage_service() | |
try: | |
flow_id, file_name = os.path.split(file_path) | |
full_path = storage_service.build_full_path(flow_id, file_name) | |
except ValueError as e: | |
if "too many values to unpack" in str(e): | |
full_path = file_path | |
else: | |
raise | |
params[field_name] = full_path | |
elif field.get("required"): | |
field_display_name = field.get("display_name") | |
logger.warning( | |
f"File path not found for {field_display_name} in component {self.display_name}. " | |
"Setting to None." | |
) | |
params[field_name] = None | |
elif field["list"]: | |
params[field_name] = [] | |
else: | |
params[field_name] = None | |
elif field.get("type") in DIRECT_TYPES and params.get(field_name) is None: | |
val = field.get("value") | |
if field.get("type") == "code": | |
try: | |
if field_name == "code": | |
params[field_name] = val | |
else: | |
params[field_name] = ast.literal_eval(val) if val else None | |
except Exception: # noqa: BLE001 | |
logger.debug(f"Error evaluating code for {field_name}") | |
params[field_name] = val | |
elif field.get("type") in {"dict", "NestedDict"}: | |
# When dict comes from the frontend it comes as a | |
# list of dicts, so we need to convert it to a dict | |
# before passing it to the build method | |
if isinstance(val, list): | |
params[field_name] = {k: v for item in field.get("value", []) for k, v in item.items()} | |
elif isinstance(val, dict): | |
params[field_name] = val | |
elif field.get("type") == "int" and val is not None: | |
try: | |
params[field_name] = int(val) | |
except ValueError: | |
params[field_name] = val | |
elif field.get("type") == "float" and val is not None: | |
try: | |
params[field_name] = float(val) | |
except ValueError: | |
params[field_name] = val | |
params[field_name] = val | |
elif field.get("type") == "str" and val is not None: | |
# val may contain escaped \n, \t, etc. | |
# so we need to unescape it | |
if isinstance(val, list): | |
params[field_name] = [unescape_string(v) for v in val] | |
elif isinstance(val, str): | |
params[field_name] = unescape_string(val) | |
elif isinstance(val, Data): | |
params[field_name] = unescape_string(val.get_text()) | |
elif field.get("type") == "bool" and val is not None: | |
if isinstance(val, bool): | |
params[field_name] = val | |
elif isinstance(val, str): | |
params[field_name] = bool(val) | |
elif field.get("type") == "table" and val is not None: | |
# check if the value is a list of dicts | |
# if it is, create a pandas dataframe from it | |
if isinstance(val, list) and all(isinstance(item, dict) for item in val): | |
params[field_name] = pd.DataFrame(val) | |
else: | |
msg = f"Invalid value type {type(val)} for field {field_name}" | |
raise ValueError(msg) | |
elif val is not None and val != "": | |
params[field_name] = val | |
if field.get("load_from_db"): | |
load_from_db_fields.append(field_name) | |
if not field.get("required") and params.get(field_name) is None: | |
if field.get("default"): | |
params[field_name] = field.get("default") | |
else: | |
params.pop(field_name, None) | |
# Add _type to params | |
self.params = params | |
self.load_from_db_fields = load_from_db_fields | |
self.raw_params = params.copy() | |
def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False) -> None: | |
"""Update the raw parameters of the vertex with the given new parameters. | |
Args: | |
new_params (Dict[str, Any]): The new parameters to update. | |
overwrite (bool, optional): Whether to overwrite the existing parameters. | |
Defaults to False. | |
Raises: | |
ValueError: If any key in new_params is not found in self.raw_params. | |
""" | |
# First check if the input_value in raw_params is not a vertex | |
if not new_params: | |
return | |
if any(isinstance(self.raw_params.get(key), Vertex) for key in new_params): | |
return | |
if not overwrite: | |
for key in new_params.copy(): # type: ignore[attr-defined] | |
if key not in self.raw_params: | |
new_params.pop(key) # type: ignore[attr-defined] | |
self.raw_params.update(new_params) | |
self.params = self.raw_params.copy() | |
self.updated_raw_params = True | |
def instantiate_component(self, user_id=None) -> None: | |
if not self.custom_component: | |
self.custom_component, _ = initialize.loading.instantiate_class( | |
user_id=user_id, | |
vertex=self, | |
) | |
async def _build( | |
self, | |
fallback_to_env_vars, | |
user_id=None, | |
event_manager: EventManager | None = None, | |
) -> None: | |
"""Initiate the build process.""" | |
logger.debug(f"Building {self.display_name}") | |
await self._build_each_vertex_in_params_dict() | |
if self.base_type is None: | |
msg = f"Base type for vertex {self.display_name} not found" | |
raise ValueError(msg) | |
if not self.custom_component: | |
custom_component, custom_params = initialize.loading.instantiate_class( | |
user_id=user_id, vertex=self, event_manager=event_manager | |
) | |
else: | |
custom_component = self.custom_component | |
if hasattr(self.custom_component, "set_event_manager"): | |
self.custom_component.set_event_manager(event_manager) | |
custom_params = initialize.loading.get_params(self.params) | |
await self._build_results( | |
custom_component=custom_component, | |
custom_params=custom_params, | |
fallback_to_env_vars=fallback_to_env_vars, | |
base_type=self.base_type, | |
) | |
self._validate_built_object() | |
self.built = True | |
def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dict]: | |
"""Extracts messages from the artifacts. | |
Args: | |
artifacts (Dict[str, Any]): The artifacts to extract messages from. | |
Returns: | |
List[str]: The extracted messages. | |
""" | |
try: | |
text = artifacts["text"] | |
sender = artifacts.get("sender") | |
sender_name = artifacts.get("sender_name") | |
session_id = artifacts.get("session_id") | |
stream_url = artifacts.get("stream_url") | |
files = [{"path": file} if isinstance(file, str) else file for file in artifacts.get("files", [])] | |
component_id = self.id | |
type_ = self.artifacts_type | |
if isinstance(sender_name, Data | Message): | |
sender_name = sender_name.get_text() | |
messages = [ | |
ChatOutputResponse( | |
message=text, | |
sender=sender, | |
sender_name=sender_name, | |
session_id=session_id, | |
stream_url=stream_url, | |
files=files, | |
component_id=component_id, | |
type=type_, | |
).model_dump(exclude_none=True) | |
] | |
except KeyError: | |
messages = [] | |
return messages | |
def finalize_build(self) -> None: | |
result_dict = self.get_built_result() | |
# We need to set the artifacts to pass information | |
# to the frontend | |
self.set_artifacts() | |
artifacts = self.artifacts_raw | |
messages = self.extract_messages_from_artifacts(artifacts) if isinstance(artifacts, dict) else [] | |
result_dict = ResultData( | |
results=result_dict, | |
artifacts=artifacts, | |
outputs=self.outputs_logs, | |
logs=self.logs, | |
messages=messages, | |
component_display_name=self.display_name, | |
component_id=self.id, | |
) | |
self.set_result(result_dict) | |
async def _build_each_vertex_in_params_dict(self) -> None: | |
"""Iterates over each vertex in the params dictionary and builds it.""" | |
for key, value in self.raw_params.items(): | |
if self._is_vertex(value): | |
if value == self: | |
del self.params[key] | |
continue | |
await self._build_vertex_and_update_params( | |
key, | |
value, | |
) | |
elif isinstance(value, list) and self._is_list_of_vertices(value): | |
await self._build_list_of_vertices_and_update_params(key, value) | |
elif isinstance(value, dict): | |
await self._build_dict_and_update_params( | |
key, | |
value, | |
) | |
elif key not in self.params or self.updated_raw_params: | |
self.params[key] = value | |
async def _build_dict_and_update_params( | |
self, | |
key, | |
vertices_dict: dict[str, Vertex], | |
) -> None: | |
"""Iterates over a dictionary of vertices, builds each and updates the params dictionary.""" | |
for sub_key, value in vertices_dict.items(): | |
if not self._is_vertex(value): | |
self.params[key][sub_key] = value | |
else: | |
result = await value.get_result(self, target_handle_name=key) | |
self.params[key][sub_key] = result | |
def _is_vertex(self, value): | |
"""Checks if the provided value is an instance of Vertex.""" | |
return isinstance(value, Vertex) | |
def _is_list_of_vertices(self, value): | |
"""Checks if the provided value is a list of Vertex instances.""" | |
return all(self._is_vertex(vertex) for vertex in value) | |
async def get_result(self, requester: Vertex, target_handle_name: str | None = None) -> Any: | |
"""Retrieves the result of the vertex. | |
This is a read-only method so it raises an error if the vertex has not been built yet. | |
Returns: | |
The result of the vertex. | |
""" | |
async with self._lock: | |
return await self._get_result(requester, target_handle_name) | |
def _log_transaction_async( | |
self, flow_id: str | UUID, source: Vertex, status, target: Vertex | None = None, error=None | |
) -> None: | |
task = asyncio.create_task(log_transaction(flow_id, source, status, target, error)) | |
self.log_transaction_tasks.add(task) | |
task.add_done_callback(self.log_transaction_tasks.discard) | |
async def _get_result( | |
self, | |
requester: Vertex, | |
target_handle_name: str | None = None, # noqa: ARG002 | |
) -> Any: | |
"""Retrieves the result of the built component. | |
If the component has not been built yet, a ValueError is raised. | |
Returns: | |
The built result if use_result is True, else the built object. | |
""" | |
flow_id = self.graph.flow_id | |
if not self.built: | |
if flow_id: | |
self._log_transaction_async(str(flow_id), source=self, target=requester, status="error") | |
msg = f"Component {self.display_name} has not been built yet" | |
raise ValueError(msg) | |
result = self.built_result if self.use_result else self.built_object | |
if flow_id: | |
self._log_transaction_async(str(flow_id), source=self, target=requester, status="success") | |
return result | |
async def _build_vertex_and_update_params(self, key, vertex: Vertex) -> None: | |
"""Builds a given vertex and updates the params dictionary accordingly.""" | |
result = await vertex.get_result(self, target_handle_name=key) | |
self._handle_func(key, result) | |
if isinstance(result, list): | |
self._extend_params_list_with_result(key, result) | |
self.params[key] = result | |
async def _build_list_of_vertices_and_update_params( | |
self, | |
key, | |
vertices: list[Vertex], | |
) -> None: | |
"""Iterates over a list of vertices, builds each and updates the params dictionary.""" | |
self.params[key] = [] | |
for vertex in vertices: | |
result = await vertex.get_result(self, target_handle_name=key) | |
# Weird check to see if the params[key] is a list | |
# because sometimes it is a Data and breaks the code | |
if not isinstance(self.params[key], list): | |
self.params[key] = [self.params[key]] | |
if isinstance(result, list): | |
self.params[key].extend(result) | |
else: | |
try: | |
if self.params[key] == result: | |
continue | |
self.params[key].append(result) | |
except AttributeError as e: | |
logger.exception(e) | |
msg = ( | |
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}" | |
f"Error building Component {self.display_name}: \n\n{e}" | |
) | |
raise ValueError(msg) from e | |
def _handle_func(self, key, result) -> None: | |
"""Handles 'func' key by checking if the result is a function and setting it as coroutine.""" | |
if key == "func": | |
if not isinstance(result, types.FunctionType): | |
if hasattr(result, "run"): | |
result = result.run | |
elif hasattr(result, "get_function"): | |
result = result.get_function() | |
elif inspect.iscoroutinefunction(result): | |
self.params["coroutine"] = result | |
else: | |
self.params["coroutine"] = sync_to_async(result) | |
def _extend_params_list_with_result(self, key, result) -> None: | |
"""Extends a list in the params dictionary with the given result if it exists.""" | |
if isinstance(self.params[key], list): | |
self.params[key].extend(result) | |
async def _build_results( | |
self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False | |
) -> None: | |
try: | |
result = await initialize.loading.get_instance_results( | |
custom_component=custom_component, | |
custom_params=custom_params, | |
vertex=self, | |
fallback_to_env_vars=fallback_to_env_vars, | |
base_type=base_type, | |
) | |
self.outputs_logs = build_output_logs(self, result) | |
self._update_built_object_and_artifacts(result) | |
except Exception as exc: | |
tb = traceback.format_exc() | |
logger.exception(exc) | |
msg = f"Error building Component {self.display_name}: \n\n{exc}" | |
raise ComponentBuildError(msg, tb) from exc | |
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]) -> None: | |
"""Updates the built object and its artifacts.""" | |
if isinstance(result, tuple): | |
if len(result) == 2: # noqa: PLR2004 | |
self.built_object, self.artifacts = result | |
elif len(result) == 3: # noqa: PLR2004 | |
self.custom_component, self.built_object, self.artifacts = result | |
self.logs = self.custom_component._output_logs | |
self.artifacts_raw = self.artifacts.get("raw", None) | |
self.artifacts_type = { | |
self.outputs[0]["name"]: self.artifacts.get("type", None) or ArtifactType.UNKNOWN.value | |
} | |
self.artifacts = {self.outputs[0]["name"]: self.artifacts} | |
else: | |
self.built_object = result | |
def _validate_built_object(self) -> None: | |
"""Checks if the built object is None and raises a ValueError if so.""" | |
if isinstance(self.built_object, UnbuiltObject): | |
msg = f"{self.display_name}: {self.built_object_repr()}" | |
raise TypeError(msg) | |
if self.built_object is None: | |
message = f"{self.display_name} returned None." | |
if self.base_type == "custom_components": | |
message += " Make sure your build method returns a component." | |
logger.warning(message) | |
elif isinstance(self.built_object, Iterator | AsyncIterator): | |
if self.display_name == "Text Output": | |
msg = f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead." | |
raise ValueError(msg) | |
def _reset(self) -> None: | |
self.built = False | |
self.built_object = UnbuiltObject() | |
self.built_result = UnbuiltResult() | |
self.artifacts = {} | |
self.steps_ran = [] | |
self.build_params() | |
def _is_chat_input(self) -> bool: | |
return False | |
def build_inactive(self) -> None: | |
# Just set the results to None | |
self.built = True | |
self.built_object = None | |
self.built_result = None | |
async def build( | |
self, | |
user_id=None, | |
inputs: dict[str, Any] | None = None, | |
files: list[str] | None = None, | |
requester: Vertex | None = None, | |
event_manager: EventManager | None = None, | |
**kwargs, | |
) -> Any: | |
async with self._lock: | |
if self.state == VertexStates.INACTIVE: | |
# If the vertex is inactive, return None | |
self.build_inactive() | |
return None | |
if self.frozen and self.built: | |
return await self.get_requester_result(requester) | |
if self.built and requester is not None: | |
# This means that the vertex has already been built | |
# and we are just getting the result for the requester | |
return await self.get_requester_result(requester) | |
self._reset() | |
# inject session_id if it is not None | |
if inputs is not None and "session" in inputs and inputs["session"] is not None and self.has_session_id: | |
session_id_value = self.get_value_from_template_dict("session_id") | |
if session_id_value == "": | |
self.update_raw_params({"session_id": inputs["session"]}, overwrite=True) | |
if self._is_chat_input() and (inputs or files): | |
chat_input = {} | |
if ( | |
inputs | |
and isinstance(inputs, dict) | |
and "input_value" in inputs | |
and inputs.get("input_value") is not None | |
): | |
chat_input.update({"input_value": inputs.get(INPUT_FIELD_NAME, "")}) | |
if files: | |
chat_input.update({"files": files}) | |
self.update_raw_params(chat_input, overwrite=True) | |
# Run steps | |
for step in self.steps: | |
if step not in self.steps_ran: | |
await step(user_id=user_id, event_manager=event_manager, **kwargs) | |
self.steps_ran.append(step) | |
self.finalize_build() | |
return await self.get_requester_result(requester) | |
async def get_requester_result(self, requester: Vertex | None): | |
# If the requester is None, this means that | |
# the Vertex is the root of the graph | |
if requester is None: | |
return self.built_object | |
# Get the requester edge | |
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) | |
# Return the result of the requester edge | |
return ( | |
None | |
if requester_edge is None | |
else await requester_edge.get_result_from_source(source=self, target=requester) | |
) | |
def add_edge(self, edge: CycleEdge) -> None: | |
if edge not in self.edges: | |
self.edges.append(edge) | |
def __repr__(self) -> str: | |
return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" | |
def __eq__(self, /, other: object) -> bool: | |
try: | |
if not isinstance(other, Vertex): | |
return False | |
# We should create a more robust comparison | |
# for the Vertex class | |
ids_are_equal = self.id == other.id | |
# self.data is a dict and we need to compare them | |
# to check if they are equal | |
data_are_equal = self.data == other.data | |
except AttributeError: | |
return False | |
else: | |
return ids_are_equal and data_are_equal | |
def __hash__(self) -> int: | |
return id(self) | |
def built_object_repr(self) -> str: | |
# Add a message with an emoji, stars for success, | |
return "Built successfully ✨" if self.built_object is not None else "Failed to build 😵💫" | |
def apply_on_outputs(self, func: Callable[[Any], Any]) -> None: | |
"""Applies a function to the outputs of the vertex.""" | |
if not self.custom_component or not self.custom_component.outputs: | |
return | |
# Apply the function to each output | |
[func(output) for output in self.custom_component._outputs_map.values()] | |