File size: 4,039 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import asyncio
from asyncio import AbstractEventLoop
from typing import Any, Dict, Optional

from fastapi import BackgroundTasks

from inference.core.cache import cache
from inference.core.env import API_KEY, MAX_ACTIVE_MODELS
from inference.core.managers.base import ModelManager
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
from inference.core.registries.roboflow import RoboflowModelRegistry
from inference.enterprise.workflows.complier.entities import StepExecutionMode
from inference.enterprise.workflows.complier.execution_engine import execute_graph
from inference.enterprise.workflows.complier.graph_parser import prepare_execution_graph
from inference.enterprise.workflows.complier.steps_executors.active_learning_middlewares import (
    WorkflowsActiveLearningMiddleware,
)
from inference.enterprise.workflows.complier.validator import (
    validate_workflow_specification,
)
from inference.enterprise.workflows.entities.workflows_specification import (
    WorkflowSpecification,
)
from inference.enterprise.workflows.errors import InvalidSpecificationVersionError
from inference.models.utils import ROBOFLOW_MODEL_TYPES


def compile_and_execute(
    workflow_specification: dict,
    runtime_parameters: Dict[str, Any],
    api_key: Optional[str] = None,
    model_manager: Optional[ModelManager] = None,
    loop: Optional[AbstractEventLoop] = None,
    active_learning_middleware: Optional[WorkflowsActiveLearningMiddleware] = None,
    background_tasks: Optional[BackgroundTasks] = None,
    max_concurrent_steps: int = 1,
    step_execution_mode: StepExecutionMode = StepExecutionMode.LOCAL,
) -> dict:
    if loop is None:
        loop = asyncio.get_event_loop()
    return loop.run_until_complete(
        compile_and_execute_async(
            workflow_specification=workflow_specification,
            runtime_parameters=runtime_parameters,
            model_manager=model_manager,
            api_key=api_key,
            active_learning_middleware=active_learning_middleware,
            background_tasks=background_tasks,
            max_concurrent_steps=max_concurrent_steps,
            step_execution_mode=step_execution_mode,
        )
    )


async def compile_and_execute_async(
    workflow_specification: dict,
    runtime_parameters: Dict[str, Any],
    model_manager: Optional[ModelManager] = None,
    api_key: Optional[str] = None,
    active_learning_middleware: Optional[WorkflowsActiveLearningMiddleware] = None,
    background_tasks: Optional[BackgroundTasks] = None,
    max_concurrent_steps: int = 1,
    step_execution_mode: StepExecutionMode = StepExecutionMode.LOCAL,
) -> dict:
    if api_key is None:
        api_key = API_KEY
    if model_manager is None:
        model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
        model_manager = ModelManager(model_registry=model_registry)
        model_manager = WithFixedSizeCache(model_manager, max_size=MAX_ACTIVE_MODELS)
    if active_learning_middleware is None:
        active_learning_middleware = WorkflowsActiveLearningMiddleware(cache=cache)
    parsed_workflow_specification = WorkflowSpecification.parse_obj(
        workflow_specification
    )
    if parsed_workflow_specification.specification.version != "1.0":
        raise InvalidSpecificationVersionError(
            f"Only version 1.0 of workflow specification is supported."
        )
    validate_workflow_specification(
        workflow_specification=parsed_workflow_specification.specification
    )
    execution_graph = prepare_execution_graph(
        workflow_specification=parsed_workflow_specification.specification
    )
    return await execute_graph(
        execution_graph=execution_graph,
        runtime_parameters=runtime_parameters,
        model_manager=model_manager,
        active_learning_middleware=active_learning_middleware,
        background_tasks=background_tasks,
        api_key=api_key,
        max_concurrent_steps=max_concurrent_steps,
        step_execution_mode=step_execution_mode,
    )