Spaces:
Runtime error
Runtime error
import asyncio | |
from datetime import datetime | |
from typing import Any, Dict, List, Optional, Set | |
import networkx as nx | |
from fastapi import BackgroundTasks | |
from networkx import DiGraph | |
from inference.core import logger | |
from inference.core.managers.base import ModelManager | |
from inference.enterprise.workflows.complier.entities import StepExecutionMode | |
from inference.enterprise.workflows.complier.flow_coordinator import ( | |
ParallelStepExecutionCoordinator, | |
SerialExecutionCoordinator, | |
) | |
from inference.enterprise.workflows.complier.runtime_input_validator import ( | |
prepare_runtime_parameters, | |
) | |
from inference.enterprise.workflows.complier.steps_executors.active_learning_middlewares import ( | |
WorkflowsActiveLearningMiddleware, | |
) | |
from inference.enterprise.workflows.complier.steps_executors.auxiliary import ( | |
run_active_learning_data_collector, | |
run_condition_step, | |
run_crop_step, | |
run_detection_filter, | |
run_detection_offset_step, | |
run_detections_consensus_step, | |
run_static_crop_step, | |
) | |
from inference.enterprise.workflows.complier.steps_executors.constants import ( | |
PARENT_COORDINATES_SUFFIX, | |
) | |
from inference.enterprise.workflows.complier.steps_executors.models import ( | |
run_clip_comparison_step, | |
run_ocr_model_step, | |
run_roboflow_model_step, | |
run_yolo_world_model_step, | |
) | |
from inference.enterprise.workflows.complier.steps_executors.types import OutputsLookup | |
from inference.enterprise.workflows.complier.steps_executors.utils import make_batches | |
from inference.enterprise.workflows.complier.utils import ( | |
get_nodes_of_specific_kind, | |
get_step_selector_from_its_output, | |
is_condition_step, | |
) | |
from inference.enterprise.workflows.constants import OUTPUT_NODE_KIND | |
from inference.enterprise.workflows.entities.outputs import CoordinatesSystem | |
from inference.enterprise.workflows.entities.validators import get_last_selector_chunk | |
from inference.enterprise.workflows.errors import ( | |
ExecutionEngineError, | |
WorkflowsCompilerRuntimeError, | |
) | |
STEP_TYPE2EXECUTOR_MAPPING = { | |
"ClassificationModel": run_roboflow_model_step, | |
"MultiLabelClassificationModel": run_roboflow_model_step, | |
"ObjectDetectionModel": run_roboflow_model_step, | |
"KeypointsDetectionModel": run_roboflow_model_step, | |
"InstanceSegmentationModel": run_roboflow_model_step, | |
"OCRModel": run_ocr_model_step, | |
"Crop": run_crop_step, | |
"Condition": run_condition_step, | |
"DetectionFilter": run_detection_filter, | |
"DetectionOffset": run_detection_offset_step, | |
"AbsoluteStaticCrop": run_static_crop_step, | |
"RelativeStaticCrop": run_static_crop_step, | |
"ClipComparison": run_clip_comparison_step, | |
"DetectionsConsensus": run_detections_consensus_step, | |
"ActiveLearningDataCollector": run_active_learning_data_collector, | |
"YoloWorld": run_yolo_world_model_step, | |
} | |
async def execute_graph( | |
execution_graph: DiGraph, | |
runtime_parameters: Dict[str, Any], | |
model_manager: ModelManager, | |
active_learning_middleware: WorkflowsActiveLearningMiddleware, | |
background_tasks: Optional[BackgroundTasks] = None, | |
api_key: Optional[str] = None, | |
max_concurrent_steps: int = 1, | |
step_execution_mode: StepExecutionMode = StepExecutionMode.LOCAL, | |
) -> dict: | |
runtime_parameters = prepare_runtime_parameters( | |
execution_graph=execution_graph, | |
runtime_parameters=runtime_parameters, | |
) | |
outputs_lookup = {} | |
steps_to_discard = set() | |
if max_concurrent_steps > 1: | |
execution_coordinator = ParallelStepExecutionCoordinator.init( | |
execution_graph=execution_graph | |
) | |
else: | |
execution_coordinator = SerialExecutionCoordinator.init( | |
execution_graph=execution_graph | |
) | |
while True: | |
next_steps = execution_coordinator.get_steps_to_execute_next( | |
steps_to_discard=steps_to_discard | |
) | |
if next_steps is None: | |
break | |
steps_to_discard = await execute_steps( | |
steps=next_steps, | |
max_concurrent_steps=max_concurrent_steps, | |
execution_graph=execution_graph, | |
runtime_parameters=runtime_parameters, | |
outputs_lookup=outputs_lookup, | |
model_manager=model_manager, | |
api_key=api_key, | |
step_execution_mode=step_execution_mode, | |
active_learning_middleware=active_learning_middleware, | |
background_tasks=background_tasks, | |
) | |
return construct_response( | |
execution_graph=execution_graph, outputs_lookup=outputs_lookup | |
) | |
async def execute_steps( | |
steps: List[str], | |
max_concurrent_steps: int, | |
execution_graph: DiGraph, | |
runtime_parameters: Dict[str, Any], | |
outputs_lookup: OutputsLookup, | |
model_manager: ModelManager, | |
api_key: Optional[str], | |
step_execution_mode: StepExecutionMode, | |
active_learning_middleware: WorkflowsActiveLearningMiddleware, | |
background_tasks: Optional[BackgroundTasks], | |
) -> Set[str]: | |
"""outputs_lookup is mutated while execution, only independent steps may be run together""" | |
logger.info(f"Executing steps: {steps}. Execution mode: {step_execution_mode}") | |
nodes_to_discard = set() | |
steps_batches = list(make_batches(iterable=steps, batch_size=max_concurrent_steps)) | |
for steps_batch in steps_batches: | |
logger.info(f"Steps batch: {steps_batch}") | |
coroutines = [ | |
safe_execute_step( | |
step=step, | |
execution_graph=execution_graph, | |
runtime_parameters=runtime_parameters, | |
outputs_lookup=outputs_lookup, | |
model_manager=model_manager, | |
api_key=api_key, | |
step_execution_mode=step_execution_mode, | |
active_learning_middleware=active_learning_middleware, | |
background_tasks=background_tasks, | |
) | |
for step in steps_batch | |
] | |
results = await asyncio.gather(*coroutines) | |
for result in results: | |
nodes_to_discard.update(result) | |
return nodes_to_discard | |
async def safe_execute_step( | |
step: str, | |
execution_graph: DiGraph, | |
runtime_parameters: Dict[str, Any], | |
outputs_lookup: OutputsLookup, | |
model_manager: ModelManager, | |
api_key: Optional[str], | |
step_execution_mode: StepExecutionMode, | |
active_learning_middleware: WorkflowsActiveLearningMiddleware, | |
background_tasks: Optional[BackgroundTasks], | |
) -> Set[str]: | |
try: | |
return await execute_step( | |
step=step, | |
execution_graph=execution_graph, | |
runtime_parameters=runtime_parameters, | |
outputs_lookup=outputs_lookup, | |
model_manager=model_manager, | |
api_key=api_key, | |
step_execution_mode=step_execution_mode, | |
active_learning_middleware=active_learning_middleware, | |
background_tasks=background_tasks, | |
) | |
except Exception as error: | |
raise ExecutionEngineError( | |
f"Error during execution of step: {step}. " | |
f"Type of error: {type(error).__name__}. " | |
f"Cause: {error}" | |
) from error | |
async def execute_step( | |
step: str, | |
execution_graph: DiGraph, | |
runtime_parameters: Dict[str, Any], | |
outputs_lookup: OutputsLookup, | |
model_manager: ModelManager, | |
api_key: Optional[str], | |
step_execution_mode: StepExecutionMode, | |
active_learning_middleware: WorkflowsActiveLearningMiddleware, | |
background_tasks: Optional[BackgroundTasks], | |
) -> Set[str]: | |
logger.info(f"started execution of: {step} - {datetime.now().isoformat()}") | |
nodes_to_discard = set() | |
step_definition = execution_graph.nodes[step]["definition"] | |
executor = STEP_TYPE2EXECUTOR_MAPPING[step_definition.type] | |
additional_args = {} | |
if step_definition.type == "ActiveLearningDataCollector": | |
additional_args["active_learning_middleware"] = active_learning_middleware | |
additional_args["background_tasks"] = background_tasks | |
next_step, outputs_lookup = await executor( | |
step=step_definition, | |
runtime_parameters=runtime_parameters, | |
outputs_lookup=outputs_lookup, | |
model_manager=model_manager, | |
api_key=api_key, | |
step_execution_mode=step_execution_mode, | |
**additional_args, | |
) | |
if is_condition_step(execution_graph=execution_graph, node=step): | |
if execution_graph.nodes[step]["definition"].step_if_true == next_step: | |
nodes_to_discard = get_all_nodes_in_execution_path( | |
execution_graph=execution_graph, | |
source=execution_graph.nodes[step]["definition"].step_if_false, | |
) | |
else: | |
nodes_to_discard = get_all_nodes_in_execution_path( | |
execution_graph=execution_graph, | |
source=execution_graph.nodes[step]["definition"].step_if_true, | |
) | |
logger.info(f"finished execution of: {step} - {datetime.now().isoformat()}") | |
return nodes_to_discard | |
def get_all_nodes_in_execution_path( | |
execution_graph: DiGraph, | |
source: str, | |
) -> Set[str]: | |
nodes = set(nx.descendants(execution_graph, source)) | |
nodes.add(source) | |
return nodes | |
def construct_response( | |
execution_graph: nx.DiGraph, | |
outputs_lookup: Dict[str, Any], | |
) -> Dict[str, Any]: | |
output_nodes = get_nodes_of_specific_kind( | |
execution_graph=execution_graph, kind=OUTPUT_NODE_KIND | |
) | |
result = {} | |
for node in output_nodes: | |
node_definition = execution_graph.nodes[node]["definition"] | |
fallback_selector = None | |
node_selector = node_definition.selector | |
if node_definition.coordinates_system is CoordinatesSystem.PARENT: | |
fallback_selector = node_selector | |
node_selector = f"{node_selector}{PARENT_COORDINATES_SUFFIX}" | |
step_selector = get_step_selector_from_its_output( | |
step_output_selector=node_selector | |
) | |
step_field = get_last_selector_chunk(selector=node_selector) | |
fallback_step_field = ( | |
None | |
if fallback_selector is None | |
else get_last_selector_chunk(selector=fallback_selector) | |
) | |
step_result = outputs_lookup.get(step_selector) | |
if step_result is not None: | |
if issubclass(type(step_result), list): | |
step_result = extract_step_result_from_list( | |
result=step_result, | |
step_field=step_field, | |
fallback_step_field=fallback_step_field, | |
step_selector=step_selector, | |
) | |
else: | |
step_result = extract_step_result_from_dict( | |
result=step_result, | |
step_field=step_field, | |
fallback_step_field=fallback_step_field, | |
step_selector=step_selector, | |
) | |
result[execution_graph.nodes[node]["definition"].name] = step_result | |
return result | |
def extract_step_result_from_list( | |
result: List[Dict[str, Any]], | |
step_field: str, | |
fallback_step_field: Optional[str], | |
step_selector: str, | |
) -> List[Any]: | |
return [ | |
extract_step_result_from_dict( | |
result=element, | |
step_field=step_field, | |
fallback_step_field=fallback_step_field, | |
step_selector=step_selector, | |
) | |
for element in result | |
] | |
def extract_step_result_from_dict( | |
result: Dict[str, Any], | |
step_field: str, | |
fallback_step_field: Optional[str], | |
step_selector: str, | |
) -> Any: | |
step_result = result.get(step_field, result.get(fallback_step_field)) | |
if step_result is None: | |
raise WorkflowsCompilerRuntimeError( | |
f"Cannot find neither field {step_field} nor {fallback_step_field} in result of step {step_selector}" | |
) | |
return step_result | |