Spaces:
Running
Running
from __future__ import annotations | |
import asyncio | |
import json | |
import time | |
import traceback | |
import typing | |
import uuid | |
from typing import TYPE_CHECKING, Annotated | |
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException | |
from fastapi.responses import StreamingResponse | |
from loguru import logger | |
from sqlmodel import select | |
from starlette.background import BackgroundTask | |
from starlette.responses import ContentStream | |
from starlette.types import Receive | |
from langflow.api.utils import ( | |
CurrentActiveUser, | |
DbSession, | |
build_and_cache_graph_from_data, | |
build_graph_from_data, | |
build_graph_from_db, | |
build_graph_from_db_no_cache, | |
format_elapsed_time, | |
format_exception_message, | |
get_top_level_vertices, | |
parse_exception, | |
) | |
from langflow.api.v1.schemas import ( | |
FlowDataRequest, | |
InputValueRequest, | |
ResultDataResponse, | |
StreamData, | |
VertexBuildResponse, | |
VerticesOrderResponse, | |
) | |
from langflow.events.event_manager import EventManager, create_default_event_manager | |
from langflow.exceptions.component import ComponentBuildError | |
from langflow.graph.graph.base import Graph | |
from langflow.graph.utils import log_vertex_build | |
from langflow.schema.schema import OutputValue | |
from langflow.services.cache.utils import CacheMiss | |
from langflow.services.chat.service import ChatService | |
from langflow.services.database.models.flow.model import Flow | |
from langflow.services.deps import async_session_scope, get_chat_service, get_session, get_telemetry_service | |
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload | |
if TYPE_CHECKING: | |
from langflow.graph.vertex.types import InterfaceVertex | |
router = APIRouter(tags=["Chat"]) | |
async def try_running_celery_task(vertex, user_id): | |
# Try running the task in celery | |
# and set the task_id to the local vertex | |
# if it fails, run the task locally | |
try: | |
from langflow.worker import build_vertex | |
task = build_vertex.delay(vertex) | |
vertex.task_id = task.id | |
except Exception: # noqa: BLE001 | |
logger.opt(exception=True).debug("Error running task in celery") | |
vertex.task_id = None | |
await vertex.build(user_id=user_id) | |
return vertex | |
async def retrieve_vertices_order( | |
*, | |
flow_id: uuid.UUID, | |
background_tasks: BackgroundTasks, | |
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None, | |
stop_component_id: str | None = None, | |
start_component_id: str | None = None, | |
session: DbSession, | |
) -> VerticesOrderResponse: | |
"""Retrieve the vertices order for a given flow. | |
Args: | |
flow_id (str): The ID of the flow. | |
background_tasks (BackgroundTasks): The background tasks. | |
data (Optional[FlowDataRequest], optional): The flow data. Defaults to None. | |
stop_component_id (str, optional): The ID of the stop component. Defaults to None. | |
start_component_id (str, optional): The ID of the start component. Defaults to None. | |
session (AsyncSession, optional): The session dependency. | |
Returns: | |
VerticesOrderResponse: The response containing the ordered vertex IDs and the run ID. | |
Raises: | |
HTTPException: If there is an error checking the build status. | |
""" | |
chat_service = get_chat_service() | |
telemetry_service = get_telemetry_service() | |
start_time = time.perf_counter() | |
components_count = None | |
try: | |
flow_id_str = str(flow_id) | |
# First, we need to check if the flow_id is in the cache | |
if not data: | |
graph = await build_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service) | |
else: | |
graph = await build_and_cache_graph_from_data( | |
flow_id=flow_id_str, graph_data=data.model_dump(), chat_service=chat_service | |
) | |
graph = graph.prepare(stop_component_id, start_component_id) | |
# Now vertices is a list of lists | |
# We need to get the id of each vertex | |
# and return the same structure but only with the ids | |
components_count = len(graph.vertices) | |
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) | |
await chat_service.set_cache(str(flow_id), graph) | |
background_tasks.add_task( | |
telemetry_service.log_package_playground, | |
PlaygroundPayload( | |
playground_seconds=int(time.perf_counter() - start_time), | |
playground_component_count=components_count, | |
playground_success=True, | |
), | |
) | |
return VerticesOrderResponse(ids=graph.first_layer, run_id=graph.run_id, vertices_to_run=vertices_to_run) | |
except Exception as exc: | |
background_tasks.add_task( | |
telemetry_service.log_package_playground, | |
PlaygroundPayload( | |
playground_seconds=int(time.perf_counter() - start_time), | |
playground_component_count=components_count, | |
playground_success=False, | |
playground_error_message=str(exc), | |
), | |
) | |
if "stream or streaming set to True" in str(exc): | |
raise HTTPException(status_code=400, detail=str(exc)) from exc | |
logger.exception("Error checking build status") | |
raise HTTPException(status_code=500, detail=str(exc)) from exc | |
async def build_flow( | |
*, | |
background_tasks: BackgroundTasks, | |
flow_id: uuid.UUID, | |
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None, | |
data: Annotated[FlowDataRequest | None, Body(embed=True)] = None, | |
files: list[str] | None = None, | |
stop_component_id: str | None = None, | |
start_component_id: str | None = None, | |
log_builds: bool | None = True, | |
current_user: CurrentActiveUser, | |
session: DbSession, | |
): | |
chat_service = get_chat_service() | |
telemetry_service = get_telemetry_service() | |
if not inputs: | |
inputs = InputValueRequest(session=str(flow_id)) | |
async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]: | |
start_time = time.perf_counter() | |
components_count = None | |
try: | |
flow_id_str = str(flow_id) | |
if not data: | |
graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session) | |
else: | |
async with async_session_scope() as new_session: | |
result = await new_session.exec(select(Flow.name).where(Flow.id == flow_id_str)) | |
flow_name = result.first() | |
graph = await build_graph_from_data( | |
flow_id_str, data.model_dump(), user_id=str(current_user.id), flow_name=flow_name | |
) | |
graph.validate_stream() | |
if stop_component_id or start_component_id: | |
try: | |
first_layer = graph.sort_vertices(stop_component_id, start_component_id) | |
except Exception: # noqa: BLE001 | |
logger.exception("Error sorting vertices") | |
first_layer = graph.sort_vertices() | |
else: | |
first_layer = graph.sort_vertices() | |
if inputs is not None and hasattr(inputs, "session") and inputs.session is not None: | |
graph.session_id = inputs.session | |
for vertex_id in first_layer: | |
graph.run_manager.add_to_vertices_being_run(vertex_id) | |
# Now vertices is a list of lists | |
# We need to get the id of each vertex | |
# and return the same structure but only with the ids | |
components_count = len(graph.vertices) | |
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run))) | |
background_tasks.add_task( | |
telemetry_service.log_package_playground, | |
PlaygroundPayload( | |
playground_seconds=int(time.perf_counter() - start_time), | |
playground_component_count=components_count, | |
playground_success=True, | |
), | |
) | |
except Exception as exc: | |
background_tasks.add_task( | |
telemetry_service.log_package_playground, | |
PlaygroundPayload( | |
playground_seconds=int(time.perf_counter() - start_time), | |
playground_component_count=components_count, | |
playground_success=False, | |
playground_error_message=str(exc), | |
), | |
) | |
if "stream or streaming set to True" in str(exc): | |
raise HTTPException(status_code=400, detail=str(exc)) from exc | |
logger.exception("Error checking build status") | |
raise HTTPException(status_code=500, detail=str(exc)) from exc | |
return first_layer, vertices_to_run, graph | |
async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse: | |
flow_id_str = str(flow_id) | |
next_runnable_vertices = [] | |
top_level_vertices = [] | |
start_time = time.perf_counter() | |
error_message = None | |
try: | |
vertex = graph.get_vertex(vertex_id) | |
try: | |
lock = chat_service.async_cache_locks[flow_id_str] | |
vertex_build_result = await graph.build_vertex( | |
vertex_id=vertex_id, | |
user_id=str(current_user.id), | |
inputs_dict=inputs.model_dump() if inputs else {}, | |
files=files, | |
get_cache=chat_service.get_cache, | |
set_cache=chat_service.set_cache, | |
event_manager=event_manager, | |
) | |
result_dict = vertex_build_result.result_dict | |
params = vertex_build_result.params | |
valid = vertex_build_result.valid | |
artifacts = vertex_build_result.artifacts | |
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) | |
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) | |
result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True) | |
except Exception as exc: # noqa: BLE001 | |
if isinstance(exc, ComponentBuildError): | |
params = exc.message | |
tb = exc.formatted_traceback | |
else: | |
tb = traceback.format_exc() | |
logger.exception("Error building Component") | |
params = format_exception_message(exc) | |
message = {"errorMessage": params, "stackTrace": tb} | |
valid = False | |
error_message = params | |
output_label = vertex.outputs[0]["name"] if vertex.outputs else "output" | |
outputs = {output_label: OutputValue(message=message, type="error")} | |
result_data_response = ResultDataResponse(results={}, outputs=outputs) | |
artifacts = {} | |
background_tasks.add_task(graph.end_all_traces, error=exc) | |
result_data_response.message = artifacts | |
# Log the vertex build | |
if not vertex.will_stream and log_builds: | |
background_tasks.add_task( | |
log_vertex_build, | |
flow_id=flow_id_str, | |
vertex_id=vertex_id, | |
valid=valid, | |
params=params, | |
data=result_data_response, | |
artifacts=artifacts, | |
) | |
else: | |
await chat_service.set_cache(flow_id_str, graph) | |
timedelta = time.perf_counter() - start_time | |
duration = format_elapsed_time(timedelta) | |
result_data_response.duration = duration | |
result_data_response.timedelta = timedelta | |
vertex.add_build_time(timedelta) | |
inactivated_vertices = list(graph.inactivated_vertices) | |
graph.reset_inactivated_vertices() | |
graph.reset_activated_vertices() | |
# graph.stop_vertex tells us if the user asked | |
# to stop the build of the graph at a certain vertex | |
# if it is in next_vertices_ids, we need to remove other | |
# vertices from next_vertices_ids | |
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices: | |
next_runnable_vertices = [graph.stop_vertex] | |
if not graph.run_manager.vertices_being_run and not next_runnable_vertices: | |
background_tasks.add_task(graph.end_all_traces) | |
build_response = VertexBuildResponse( | |
inactivated_vertices=list(set(inactivated_vertices)), | |
next_vertices_ids=list(set(next_runnable_vertices)), | |
top_level_vertices=list(set(top_level_vertices)), | |
valid=valid, | |
params=params, | |
id=vertex.id, | |
data=result_data_response, | |
) | |
background_tasks.add_task( | |
telemetry_service.log_package_component, | |
ComponentPayload( | |
component_name=vertex_id.split("-")[0], | |
component_seconds=int(time.perf_counter() - start_time), | |
component_success=valid, | |
component_error_message=error_message, | |
), | |
) | |
except Exception as exc: | |
background_tasks.add_task( | |
telemetry_service.log_package_component, | |
ComponentPayload( | |
component_name=vertex_id.split("-")[0], | |
component_seconds=int(time.perf_counter() - start_time), | |
component_success=False, | |
component_error_message=str(exc), | |
), | |
) | |
logger.exception("Error building Component") | |
message = parse_exception(exc) | |
raise HTTPException(status_code=500, detail=message) from exc | |
return build_response | |
async def build_vertices( | |
vertex_id: str, | |
graph: Graph, | |
client_consumed_queue: asyncio.Queue, | |
event_manager: EventManager, | |
) -> None: | |
build_task = asyncio.create_task(_build_vertex(vertex_id, graph, event_manager)) | |
try: | |
await build_task | |
except asyncio.CancelledError as exc: | |
logger.exception(exc) | |
build_task.cancel() | |
return | |
vertex_build_response: VertexBuildResponse = build_task.result() | |
# send built event or error event | |
try: | |
vertex_build_response_json = vertex_build_response.model_dump_json() | |
build_data = json.loads(vertex_build_response_json) | |
except Exception as exc: | |
msg = f"Error serializing vertex build response: {exc}" | |
raise ValueError(msg) from exc | |
event_manager.on_end_vertex(data={"build_data": build_data}) | |
await client_consumed_queue.get() | |
if vertex_build_response.valid and vertex_build_response.next_vertices_ids: | |
tasks = [] | |
for next_vertex_id in vertex_build_response.next_vertices_ids: | |
task = asyncio.create_task(build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)) | |
tasks.append(task) | |
try: | |
await asyncio.gather(*tasks) | |
except asyncio.CancelledError: | |
for task in tasks: | |
task.cancel() | |
return | |
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None: | |
if not data: | |
# using another task since the build_graph_and_get_order is now an async function | |
vertices_task = asyncio.create_task(build_graph_and_get_order()) | |
try: | |
await vertices_task | |
except asyncio.CancelledError: | |
vertices_task.cancel() | |
return | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code}) | |
raise | |
event_manager.on_error(data={"error": str(e)}) | |
raise | |
ids, vertices_to_run, graph = vertices_task.result() | |
else: | |
try: | |
ids, vertices_to_run, graph = await build_graph_and_get_order() | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code}) | |
raise | |
event_manager.on_error(data={"error": str(e)}) | |
raise | |
event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run}) | |
await client_consumed_queue.get() | |
tasks = [] | |
for vertex_id in ids: | |
task = asyncio.create_task(build_vertices(vertex_id, graph, client_consumed_queue, event_manager)) | |
tasks.append(task) | |
try: | |
await asyncio.gather(*tasks) | |
except asyncio.CancelledError: | |
background_tasks.add_task(graph.end_all_traces) | |
for task in tasks: | |
task.cancel() | |
return | |
except Exception as e: | |
logger.error(f"Error building vertices: {e}") | |
event_manager.on_error(data={"error": str(e)}) | |
raise | |
event_manager.on_end(data={}) | |
await event_manager.queue.put((None, None, time.time)) | |
async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator: | |
while True: | |
event_id, value, put_time = await queue.get() | |
if value is None: | |
break | |
get_time = time.time() | |
yield value | |
get_time_yield = time.time() | |
client_consumed_queue.put_nowait(event_id) | |
logger.debug( | |
f"consumed event {event_id} " | |
f"(time in queue, {get_time - put_time:.4f}, " | |
f"client {get_time_yield - get_time:.4f})" | |
) | |
asyncio_queue: asyncio.Queue = asyncio.Queue() | |
asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue() | |
event_manager = create_default_event_manager(queue=asyncio_queue) | |
main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed)) | |
def on_disconnect() -> None: | |
logger.debug("Client disconnected, closing tasks") | |
main_task.cancel() | |
return DisconnectHandlerStreamingResponse( | |
consume_and_yield(asyncio_queue, asyncio_queue_client_consumed), | |
media_type="application/x-ndjson", | |
on_disconnect=on_disconnect, | |
) | |
class DisconnectHandlerStreamingResponse(StreamingResponse): | |
def __init__( | |
self, | |
content: ContentStream, | |
status_code: int = 200, | |
headers: typing.Mapping[str, str] | None = None, | |
media_type: str | None = None, | |
background: BackgroundTask | None = None, | |
on_disconnect: typing.Callable | None = None, | |
): | |
super().__init__(content, status_code, headers, media_type, background) | |
self.on_disconnect = on_disconnect | |
async def listen_for_disconnect(self, receive: Receive) -> None: | |
while True: | |
message = await receive() | |
if message["type"] == "http.disconnect": | |
if self.on_disconnect: | |
coro = self.on_disconnect() | |
if asyncio.iscoroutine(coro): | |
await coro | |
break | |
async def build_vertex( | |
*, | |
flow_id: uuid.UUID, | |
vertex_id: str, | |
background_tasks: BackgroundTasks, | |
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None, | |
files: list[str] | None = None, | |
current_user: CurrentActiveUser, | |
) -> VertexBuildResponse: | |
"""Build a vertex instead of the entire graph. | |
Args: | |
flow_id (str): The ID of the flow. | |
vertex_id (str): The ID of the vertex to build. | |
background_tasks (BackgroundTasks): The background tasks dependency. | |
inputs (Optional[InputValueRequest], optional): The input values for the vertex. Defaults to None. | |
files (List[str], optional): The files to use. Defaults to None. | |
current_user (Any, optional): The current user dependency. Defaults to Depends(get_current_active_user). | |
Returns: | |
VertexBuildResponse: The response containing the built vertex information. | |
Raises: | |
HTTPException: If there is an error building the vertex. | |
""" | |
chat_service = get_chat_service() | |
telemetry_service = get_telemetry_service() | |
flow_id_str = str(flow_id) | |
next_runnable_vertices = [] | |
top_level_vertices = [] | |
start_time = time.perf_counter() | |
error_message = None | |
try: | |
cache = await chat_service.get_cache(flow_id_str) | |
if isinstance(cache, CacheMiss): | |
# If there's no cache | |
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}") | |
graph: Graph = await build_graph_from_db( | |
flow_id=flow_id_str, session=await anext(get_session()), chat_service=chat_service | |
) | |
else: | |
graph = cache.get("result") | |
await graph.initialize_run() | |
vertex = graph.get_vertex(vertex_id) | |
try: | |
lock = chat_service.async_cache_locks[flow_id_str] | |
vertex_build_result = await graph.build_vertex( | |
vertex_id=vertex_id, | |
user_id=str(current_user.id), | |
inputs_dict=inputs.model_dump() if inputs else {}, | |
files=files, | |
get_cache=chat_service.get_cache, | |
set_cache=chat_service.set_cache, | |
) | |
result_dict = vertex_build_result.result_dict | |
params = vertex_build_result.params | |
valid = vertex_build_result.valid | |
artifacts = vertex_build_result.artifacts | |
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) | |
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) | |
result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True) | |
except Exception as exc: # noqa: BLE001 | |
if isinstance(exc, ComponentBuildError): | |
params = exc.message | |
tb = exc.formatted_traceback | |
else: | |
tb = traceback.format_exc() | |
logger.exception("Error building Component") | |
params = format_exception_message(exc) | |
message = {"errorMessage": params, "stackTrace": tb} | |
valid = False | |
error_message = params | |
output_label = vertex.outputs[0]["name"] if vertex.outputs else "output" | |
outputs = {output_label: OutputValue(message=message, type="error")} | |
result_data_response = ResultDataResponse(results={}, outputs=outputs) | |
artifacts = {} | |
background_tasks.add_task(graph.end_all_traces, error=exc) | |
# If there's an error building the vertex | |
# we need to clear the cache | |
await chat_service.clear_cache(flow_id_str) | |
result_data_response.message = artifacts | |
# Log the vertex build | |
if not vertex.will_stream: | |
background_tasks.add_task( | |
log_vertex_build, | |
flow_id=flow_id_str, | |
vertex_id=vertex_id, | |
valid=valid, | |
params=params, | |
data=result_data_response, | |
artifacts=artifacts, | |
) | |
timedelta = time.perf_counter() - start_time | |
duration = format_elapsed_time(timedelta) | |
result_data_response.duration = duration | |
result_data_response.timedelta = timedelta | |
vertex.add_build_time(timedelta) | |
inactivated_vertices = list(graph.inactivated_vertices) | |
graph.reset_inactivated_vertices() | |
graph.reset_activated_vertices() | |
await chat_service.set_cache(flow_id_str, graph) | |
# graph.stop_vertex tells us if the user asked | |
# to stop the build of the graph at a certain vertex | |
# if it is in next_vertices_ids, we need to remove other | |
# vertices from next_vertices_ids | |
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices: | |
next_runnable_vertices = [graph.stop_vertex] | |
if not graph.run_manager.vertices_being_run and not next_runnable_vertices: | |
background_tasks.add_task(graph.end_all_traces) | |
build_response = VertexBuildResponse( | |
inactivated_vertices=list(set(inactivated_vertices)), | |
next_vertices_ids=list(set(next_runnable_vertices)), | |
top_level_vertices=list(set(top_level_vertices)), | |
valid=valid, | |
params=params, | |
id=vertex.id, | |
data=result_data_response, | |
) | |
background_tasks.add_task( | |
telemetry_service.log_package_component, | |
ComponentPayload( | |
component_name=vertex_id.split("-")[0], | |
component_seconds=int(time.perf_counter() - start_time), | |
component_success=valid, | |
component_error_message=error_message, | |
), | |
) | |
except Exception as exc: | |
background_tasks.add_task( | |
telemetry_service.log_package_component, | |
ComponentPayload( | |
component_name=vertex_id.split("-")[0], | |
component_seconds=int(time.perf_counter() - start_time), | |
component_success=False, | |
component_error_message=str(exc), | |
), | |
) | |
logger.exception("Error building Component") | |
message = parse_exception(exc) | |
raise HTTPException(status_code=500, detail=message) from exc | |
return build_response | |
async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService): | |
graph = None | |
try: | |
try: | |
cache = await chat_service.get_cache(flow_id) | |
except Exception as exc: # noqa: BLE001 | |
logger.exception("Error building Component") | |
yield str(StreamData(event="error", data={"error": str(exc)})) | |
return | |
if isinstance(cache, CacheMiss): | |
# If there's no cache | |
msg = f"No cache found for {flow_id}." | |
logger.error(msg) | |
yield str(StreamData(event="error", data={"error": msg})) | |
return | |
else: | |
graph = cache.get("result") | |
try: | |
vertex: InterfaceVertex = graph.get_vertex(vertex_id) | |
except Exception as exc: # noqa: BLE001 | |
logger.exception("Error building Component") | |
yield str(StreamData(event="error", data={"error": str(exc)})) | |
return | |
if not hasattr(vertex, "stream"): | |
msg = f"Vertex {vertex_id} does not support streaming" | |
logger.error(msg) | |
yield str(StreamData(event="error", data={"error": msg})) | |
return | |
if isinstance(vertex.built_result, str) and vertex.built_result: | |
stream_data = StreamData( | |
event="message", | |
data={"message": f"Streaming vertex {vertex_id}"}, | |
) | |
yield str(stream_data) | |
stream_data = StreamData( | |
event="message", | |
data={"chunk": vertex.built_result}, | |
) | |
yield str(stream_data) | |
elif not vertex.frozen or not vertex.built: | |
logger.debug(f"Streaming vertex {vertex_id}") | |
stream_data = StreamData( | |
event="message", | |
data={"message": f"Streaming vertex {vertex_id}"}, | |
) | |
yield str(stream_data) | |
try: | |
async for chunk in vertex.stream(): | |
stream_data = StreamData( | |
event="message", | |
data={"chunk": chunk}, | |
) | |
yield str(stream_data) | |
except Exception as exc: # noqa: BLE001 | |
logger.exception("Error building Component") | |
exc_message = parse_exception(exc) | |
if exc_message == "The message must be an iterator or an async iterator.": | |
exc_message = "This stream has already been closed." | |
yield str(StreamData(event="error", data={"error": exc_message})) | |
elif vertex.result is not None: | |
stream_data = StreamData( | |
event="message", | |
data={"chunk": vertex.built_result}, | |
) | |
yield str(stream_data) | |
else: | |
msg = f"No result found for vertex {vertex_id}" | |
logger.error(msg) | |
yield str(StreamData(event="error", data={"error": msg})) | |
return | |
finally: | |
logger.debug("Closing stream") | |
if graph: | |
await chat_service.set_cache(flow_id, graph) | |
yield str(StreamData(event="close", data={"message": "Stream closed"})) | |
async def build_vertex_stream( | |
flow_id: uuid.UUID, | |
vertex_id: str, | |
): | |
"""Build a vertex instead of the entire graph. | |
This function is responsible for building a single vertex instead of the entire graph. | |
It takes the `flow_id` and `vertex_id` as required parameters, and an optional `session_id`. | |
It also depends on the `ChatService` and `SessionService` services. | |
If `session_id` is not provided, it retrieves the graph from the cache using the `chat_service`. | |
If `session_id` is provided, it loads the session data using the `session_service`. | |
Once the graph is obtained, it retrieves the specified vertex using the `vertex_id`. | |
If the vertex does not support streaming, an error is raised. | |
If the vertex has a built result, it sends the result as a chunk. | |
If the vertex is not frozen or not built, it streams the vertex data. | |
If the vertex has a result, it sends the result as a chunk. | |
If none of the above conditions are met, an error is raised. | |
If any exception occurs during the process, an error message is sent. | |
Finally, the stream is closed. | |
Returns: | |
A `StreamingResponse` object with the streamed vertex data in text/event-stream format. | |
Raises: | |
HTTPException: If an error occurs while building the vertex. | |
""" | |
try: | |
return StreamingResponse( | |
_stream_vertex(str(flow_id), vertex_id, get_chat_service()), media_type="text/event-stream" | |
) | |
except Exception as exc: | |
raise HTTPException(status_code=500, detail="Error building Component") from exc | |