from __future__ import annotations import asyncio import json import time import traceback import typing import uuid from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, BackgroundTasks, Body, HTTPException from fastapi.responses import StreamingResponse from loguru import logger from sqlmodel import select from starlette.background import BackgroundTask from starlette.responses import ContentStream from starlette.types import Receive from langflow.api.utils import ( CurrentActiveUser, DbSession, build_and_cache_graph_from_data, build_graph_from_data, build_graph_from_db, build_graph_from_db_no_cache, format_elapsed_time, format_exception_message, get_top_level_vertices, parse_exception, ) from langflow.api.v1.schemas import ( FlowDataRequest, InputValueRequest, ResultDataResponse, StreamData, VertexBuildResponse, VerticesOrderResponse, ) from langflow.events.event_manager import EventManager, create_default_event_manager from langflow.exceptions.component import ComponentBuildError from langflow.graph.graph.base import Graph from langflow.graph.utils import log_vertex_build from langflow.schema.schema import OutputValue from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService from langflow.services.database.models.flow.model import Flow from langflow.services.deps import async_session_scope, get_chat_service, get_session, get_telemetry_service from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload if TYPE_CHECKING: from langflow.graph.vertex.types import InterfaceVertex router = APIRouter(tags=["Chat"]) async def try_running_celery_task(vertex, user_id): # Try running the task in celery # and set the task_id to the local vertex # if it fails, run the task locally try: from langflow.worker import build_vertex task = build_vertex.delay(vertex) vertex.task_id = task.id except Exception: # noqa: BLE001 logger.opt(exception=True).debug("Error running task in celery") vertex.task_id = None await vertex.build(user_id=user_id) return vertex @router.post("/build/{flow_id}/vertices") async def retrieve_vertices_order( *, flow_id: uuid.UUID, background_tasks: BackgroundTasks, data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None, stop_component_id: str | None = None, start_component_id: str | None = None, session: DbSession, ) -> VerticesOrderResponse: """Retrieve the vertices order for a given flow. Args: flow_id (str): The ID of the flow. background_tasks (BackgroundTasks): The background tasks. data (Optional[FlowDataRequest], optional): The flow data. Defaults to None. stop_component_id (str, optional): The ID of the stop component. Defaults to None. start_component_id (str, optional): The ID of the start component. Defaults to None. session (AsyncSession, optional): The session dependency. Returns: VerticesOrderResponse: The response containing the ordered vertex IDs and the run ID. Raises: HTTPException: If there is an error checking the build status. """ chat_service = get_chat_service() telemetry_service = get_telemetry_service() start_time = time.perf_counter() components_count = None try: flow_id_str = str(flow_id) # First, we need to check if the flow_id is in the cache if not data: graph = await build_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service) else: graph = await build_and_cache_graph_from_data( flow_id=flow_id_str, graph_data=data.model_dump(), chat_service=chat_service ) graph = graph.prepare(stop_component_id, start_component_id) # Now vertices is a list of lists # We need to get the id of each vertex # and return the same structure but only with the ids components_count = len(graph.vertices) vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) await chat_service.set_cache(str(flow_id), graph) background_tasks.add_task( telemetry_service.log_package_playground, PlaygroundPayload( playground_seconds=int(time.perf_counter() - start_time), playground_component_count=components_count, playground_success=True, ), ) return VerticesOrderResponse(ids=graph.first_layer, run_id=graph.run_id, vertices_to_run=vertices_to_run) except Exception as exc: background_tasks.add_task( telemetry_service.log_package_playground, PlaygroundPayload( playground_seconds=int(time.perf_counter() - start_time), playground_component_count=components_count, playground_success=False, playground_error_message=str(exc), ), ) if "stream or streaming set to True" in str(exc): raise HTTPException(status_code=400, detail=str(exc)) from exc logger.exception("Error checking build status") raise HTTPException(status_code=500, detail=str(exc)) from exc @router.post("/build/{flow_id}/flow") async def build_flow( *, background_tasks: BackgroundTasks, flow_id: uuid.UUID, inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None, data: Annotated[FlowDataRequest | None, Body(embed=True)] = None, files: list[str] | None = None, stop_component_id: str | None = None, start_component_id: str | None = None, log_builds: bool | None = True, current_user: CurrentActiveUser, session: DbSession, ): chat_service = get_chat_service() telemetry_service = get_telemetry_service() if not inputs: inputs = InputValueRequest(session=str(flow_id)) async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]: start_time = time.perf_counter() components_count = None try: flow_id_str = str(flow_id) if not data: graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session) else: async with async_session_scope() as new_session: result = await new_session.exec(select(Flow.name).where(Flow.id == flow_id_str)) flow_name = result.first() graph = await build_graph_from_data( flow_id_str, data.model_dump(), user_id=str(current_user.id), flow_name=flow_name ) graph.validate_stream() if stop_component_id or start_component_id: try: first_layer = graph.sort_vertices(stop_component_id, start_component_id) except Exception: # noqa: BLE001 logger.exception("Error sorting vertices") first_layer = graph.sort_vertices() else: first_layer = graph.sort_vertices() if inputs is not None and hasattr(inputs, "session") and inputs.session is not None: graph.session_id = inputs.session for vertex_id in first_layer: graph.run_manager.add_to_vertices_being_run(vertex_id) # Now vertices is a list of lists # We need to get the id of each vertex # and return the same structure but only with the ids components_count = len(graph.vertices) vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) background_tasks.add_task( telemetry_service.log_package_playground, PlaygroundPayload( playground_seconds=int(time.perf_counter() - start_time), playground_component_count=components_count, playground_success=True, ), ) except Exception as exc: background_tasks.add_task( telemetry_service.log_package_playground, PlaygroundPayload( playground_seconds=int(time.perf_counter() - start_time), playground_component_count=components_count, playground_success=False, playground_error_message=str(exc), ), ) if "stream or streaming set to True" in str(exc): raise HTTPException(status_code=400, detail=str(exc)) from exc logger.exception("Error checking build status") raise HTTPException(status_code=500, detail=str(exc)) from exc return first_layer, vertices_to_run, graph async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse: flow_id_str = str(flow_id) next_runnable_vertices = [] top_level_vertices = [] start_time = time.perf_counter() error_message = None try: vertex = graph.get_vertex(vertex_id) try: lock = chat_service.async_cache_locks[flow_id_str] vertex_build_result = await graph.build_vertex( vertex_id=vertex_id, user_id=str(current_user.id), inputs_dict=inputs.model_dump() if inputs else {}, files=files, get_cache=chat_service.get_cache, set_cache=chat_service.set_cache, event_manager=event_manager, ) result_dict = vertex_build_result.result_dict params = vertex_build_result.params valid = vertex_build_result.valid artifacts = vertex_build_result.artifacts next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True) except Exception as exc: # noqa: BLE001 if isinstance(exc, ComponentBuildError): params = exc.message tb = exc.formatted_traceback else: tb = traceback.format_exc() logger.exception("Error building Component") params = format_exception_message(exc) message = {"errorMessage": params, "stackTrace": tb} valid = False error_message = params output_label = vertex.outputs[0]["name"] if vertex.outputs else "output" outputs = {output_label: OutputValue(message=message, type="error")} result_data_response = ResultDataResponse(results={}, outputs=outputs) artifacts = {} background_tasks.add_task(graph.end_all_traces, error=exc) result_data_response.message = artifacts # Log the vertex build if not vertex.will_stream and log_builds: background_tasks.add_task( log_vertex_build, flow_id=flow_id_str, vertex_id=vertex_id, valid=valid, params=params, data=result_data_response, artifacts=artifacts, ) else: await chat_service.set_cache(flow_id_str, graph) timedelta = time.perf_counter() - start_time duration = format_elapsed_time(timedelta) result_data_response.duration = duration result_data_response.timedelta = timedelta vertex.add_build_time(timedelta) inactivated_vertices = list(graph.inactivated_vertices) graph.reset_inactivated_vertices() graph.reset_activated_vertices() # graph.stop_vertex tells us if the user asked # to stop the build of the graph at a certain vertex # if it is in next_vertices_ids, we need to remove other # vertices from next_vertices_ids if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices: next_runnable_vertices = [graph.stop_vertex] if not graph.run_manager.vertices_being_run and not next_runnable_vertices: background_tasks.add_task(graph.end_all_traces) build_response = VertexBuildResponse( inactivated_vertices=list(set(inactivated_vertices)), next_vertices_ids=list(set(next_runnable_vertices)), top_level_vertices=list(set(top_level_vertices)), valid=valid, params=params, id=vertex.id, data=result_data_response, ) background_tasks.add_task( telemetry_service.log_package_component, ComponentPayload( component_name=vertex_id.split("-")[0], component_seconds=int(time.perf_counter() - start_time), component_success=valid, component_error_message=error_message, ), ) except Exception as exc: background_tasks.add_task( telemetry_service.log_package_component, ComponentPayload( component_name=vertex_id.split("-")[0], component_seconds=int(time.perf_counter() - start_time), component_success=False, component_error_message=str(exc), ), ) logger.exception("Error building Component") message = parse_exception(exc) raise HTTPException(status_code=500, detail=message) from exc return build_response async def build_vertices( vertex_id: str, graph: Graph, client_consumed_queue: asyncio.Queue, event_manager: EventManager, ) -> None: build_task = asyncio.create_task(_build_vertex(vertex_id, graph, event_manager)) try: await build_task except asyncio.CancelledError as exc: logger.exception(exc) build_task.cancel() return vertex_build_response: VertexBuildResponse = build_task.result() # send built event or error event try: vertex_build_response_json = vertex_build_response.model_dump_json() build_data = json.loads(vertex_build_response_json) except Exception as exc: msg = f"Error serializing vertex build response: {exc}" raise ValueError(msg) from exc event_manager.on_end_vertex(data={"build_data": build_data}) await client_consumed_queue.get() if vertex_build_response.valid and vertex_build_response.next_vertices_ids: tasks = [] for next_vertex_id in vertex_build_response.next_vertices_ids: task = asyncio.create_task(build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)) tasks.append(task) try: await asyncio.gather(*tasks) except asyncio.CancelledError: for task in tasks: task.cancel() return async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None: if not data: # using another task since the build_graph_and_get_order is now an async function vertices_task = asyncio.create_task(build_graph_and_get_order()) try: await vertices_task except asyncio.CancelledError: vertices_task.cancel() return except Exception as e: if isinstance(e, HTTPException): event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code}) raise event_manager.on_error(data={"error": str(e)}) raise ids, vertices_to_run, graph = vertices_task.result() else: try: ids, vertices_to_run, graph = await build_graph_and_get_order() except Exception as e: if isinstance(e, HTTPException): event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code}) raise event_manager.on_error(data={"error": str(e)}) raise event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run}) await client_consumed_queue.get() tasks = [] for vertex_id in ids: task = asyncio.create_task(build_vertices(vertex_id, graph, client_consumed_queue, event_manager)) tasks.append(task) try: await asyncio.gather(*tasks) except asyncio.CancelledError: background_tasks.add_task(graph.end_all_traces) for task in tasks: task.cancel() return except Exception as e: logger.error(f"Error building vertices: {e}") event_manager.on_error(data={"error": str(e)}) raise event_manager.on_end(data={}) await event_manager.queue.put((None, None, time.time)) async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator: while True: event_id, value, put_time = await queue.get() if value is None: break get_time = time.time() yield value get_time_yield = time.time() client_consumed_queue.put_nowait(event_id) logger.debug( f"consumed event {event_id} " f"(time in queue, {get_time - put_time:.4f}, " f"client {get_time_yield - get_time:.4f})" ) asyncio_queue: asyncio.Queue = asyncio.Queue() asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue() event_manager = create_default_event_manager(queue=asyncio_queue) main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed)) def on_disconnect() -> None: logger.debug("Client disconnected, closing tasks") main_task.cancel() return DisconnectHandlerStreamingResponse( consume_and_yield(asyncio_queue, asyncio_queue_client_consumed), media_type="application/x-ndjson", on_disconnect=on_disconnect, ) class DisconnectHandlerStreamingResponse(StreamingResponse): def __init__( self, content: ContentStream, status_code: int = 200, headers: typing.Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, on_disconnect: typing.Callable | None = None, ): super().__init__(content, status_code, headers, media_type, background) self.on_disconnect = on_disconnect async def listen_for_disconnect(self, receive: Receive) -> None: while True: message = await receive() if message["type"] == "http.disconnect": if self.on_disconnect: coro = self.on_disconnect() if asyncio.iscoroutine(coro): await coro break @router.post("/build/{flow_id}/vertices/{vertex_id}") async def build_vertex( *, flow_id: uuid.UUID, vertex_id: str, background_tasks: BackgroundTasks, inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None, files: list[str] | None = None, current_user: CurrentActiveUser, ) -> VertexBuildResponse: """Build a vertex instead of the entire graph. Args: flow_id (str): The ID of the flow. vertex_id (str): The ID of the vertex to build. background_tasks (BackgroundTasks): The background tasks dependency. inputs (Optional[InputValueRequest], optional): The input values for the vertex. Defaults to None. files (List[str], optional): The files to use. Defaults to None. current_user (Any, optional): The current user dependency. Defaults to Depends(get_current_active_user). Returns: VertexBuildResponse: The response containing the built vertex information. Raises: HTTPException: If there is an error building the vertex. """ chat_service = get_chat_service() telemetry_service = get_telemetry_service() flow_id_str = str(flow_id) next_runnable_vertices = [] top_level_vertices = [] start_time = time.perf_counter() error_message = None try: cache = await chat_service.get_cache(flow_id_str) if isinstance(cache, CacheMiss): # If there's no cache logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}") graph: Graph = await build_graph_from_db( flow_id=flow_id_str, session=await anext(get_session()), chat_service=chat_service ) else: graph = cache.get("result") await graph.initialize_run() vertex = graph.get_vertex(vertex_id) try: lock = chat_service.async_cache_locks[flow_id_str] vertex_build_result = await graph.build_vertex( vertex_id=vertex_id, user_id=str(current_user.id), inputs_dict=inputs.model_dump() if inputs else {}, files=files, get_cache=chat_service.get_cache, set_cache=chat_service.set_cache, ) result_dict = vertex_build_result.result_dict params = vertex_build_result.params valid = vertex_build_result.valid artifacts = vertex_build_result.artifacts next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True) except Exception as exc: # noqa: BLE001 if isinstance(exc, ComponentBuildError): params = exc.message tb = exc.formatted_traceback else: tb = traceback.format_exc() logger.exception("Error building Component") params = format_exception_message(exc) message = {"errorMessage": params, "stackTrace": tb} valid = False error_message = params output_label = vertex.outputs[0]["name"] if vertex.outputs else "output" outputs = {output_label: OutputValue(message=message, type="error")} result_data_response = ResultDataResponse(results={}, outputs=outputs) artifacts = {} background_tasks.add_task(graph.end_all_traces, error=exc) # If there's an error building the vertex # we need to clear the cache await chat_service.clear_cache(flow_id_str) result_data_response.message = artifacts # Log the vertex build if not vertex.will_stream: background_tasks.add_task( log_vertex_build, flow_id=flow_id_str, vertex_id=vertex_id, valid=valid, params=params, data=result_data_response, artifacts=artifacts, ) timedelta = time.perf_counter() - start_time duration = format_elapsed_time(timedelta) result_data_response.duration = duration result_data_response.timedelta = timedelta vertex.add_build_time(timedelta) inactivated_vertices = list(graph.inactivated_vertices) graph.reset_inactivated_vertices() graph.reset_activated_vertices() await chat_service.set_cache(flow_id_str, graph) # graph.stop_vertex tells us if the user asked # to stop the build of the graph at a certain vertex # if it is in next_vertices_ids, we need to remove other # vertices from next_vertices_ids if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices: next_runnable_vertices = [graph.stop_vertex] if not graph.run_manager.vertices_being_run and not next_runnable_vertices: background_tasks.add_task(graph.end_all_traces) build_response = VertexBuildResponse( inactivated_vertices=list(set(inactivated_vertices)), next_vertices_ids=list(set(next_runnable_vertices)), top_level_vertices=list(set(top_level_vertices)), valid=valid, params=params, id=vertex.id, data=result_data_response, ) background_tasks.add_task( telemetry_service.log_package_component, ComponentPayload( component_name=vertex_id.split("-")[0], component_seconds=int(time.perf_counter() - start_time), component_success=valid, component_error_message=error_message, ), ) except Exception as exc: background_tasks.add_task( telemetry_service.log_package_component, ComponentPayload( component_name=vertex_id.split("-")[0], component_seconds=int(time.perf_counter() - start_time), component_success=False, component_error_message=str(exc), ), ) logger.exception("Error building Component") message = parse_exception(exc) raise HTTPException(status_code=500, detail=message) from exc return build_response async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService): graph = None try: try: cache = await chat_service.get_cache(flow_id) except Exception as exc: # noqa: BLE001 logger.exception("Error building Component") yield str(StreamData(event="error", data={"error": str(exc)})) return if isinstance(cache, CacheMiss): # If there's no cache msg = f"No cache found for {flow_id}." logger.error(msg) yield str(StreamData(event="error", data={"error": msg})) return else: graph = cache.get("result") try: vertex: InterfaceVertex = graph.get_vertex(vertex_id) except Exception as exc: # noqa: BLE001 logger.exception("Error building Component") yield str(StreamData(event="error", data={"error": str(exc)})) return if not hasattr(vertex, "stream"): msg = f"Vertex {vertex_id} does not support streaming" logger.error(msg) yield str(StreamData(event="error", data={"error": msg})) return if isinstance(vertex.built_result, str) and vertex.built_result: stream_data = StreamData( event="message", data={"message": f"Streaming vertex {vertex_id}"}, ) yield str(stream_data) stream_data = StreamData( event="message", data={"chunk": vertex.built_result}, ) yield str(stream_data) elif not vertex.frozen or not vertex.built: logger.debug(f"Streaming vertex {vertex_id}") stream_data = StreamData( event="message", data={"message": f"Streaming vertex {vertex_id}"}, ) yield str(stream_data) try: async for chunk in vertex.stream(): stream_data = StreamData( event="message", data={"chunk": chunk}, ) yield str(stream_data) except Exception as exc: # noqa: BLE001 logger.exception("Error building Component") exc_message = parse_exception(exc) if exc_message == "The message must be an iterator or an async iterator.": exc_message = "This stream has already been closed." yield str(StreamData(event="error", data={"error": exc_message})) elif vertex.result is not None: stream_data = StreamData( event="message", data={"chunk": vertex.built_result}, ) yield str(stream_data) else: msg = f"No result found for vertex {vertex_id}" logger.error(msg) yield str(StreamData(event="error", data={"error": msg})) return finally: logger.debug("Closing stream") if graph: await chat_service.set_cache(flow_id, graph) yield str(StreamData(event="close", data={"message": "Stream closed"})) @router.get("/build/{flow_id}/{vertex_id}/stream", response_class=StreamingResponse) async def build_vertex_stream( flow_id: uuid.UUID, vertex_id: str, ): """Build a vertex instead of the entire graph. This function is responsible for building a single vertex instead of the entire graph. It takes the `flow_id` and `vertex_id` as required parameters, and an optional `session_id`. It also depends on the `ChatService` and `SessionService` services. If `session_id` is not provided, it retrieves the graph from the cache using the `chat_service`. If `session_id` is provided, it loads the session data using the `session_service`. Once the graph is obtained, it retrieves the specified vertex using the `vertex_id`. If the vertex does not support streaming, an error is raised. If the vertex has a built result, it sends the result as a chunk. If the vertex is not frozen or not built, it streams the vertex data. If the vertex has a result, it sends the result as a chunk. If none of the above conditions are met, an error is raised. If any exception occurs during the process, an error message is sent. Finally, the stream is closed. Returns: A `StreamingResponse` object with the streamed vertex data in text/event-stream format. Raises: HTTPException: If an error occurs while building the vertex. """ try: return StreamingResponse( _stream_vertex(str(flow_id), vertex_id, get_chat_service()), media_type="text/event-stream" ) except Exception as exc: raise HTTPException(status_code=500, detail="Error building Component") from exc