File size: 3,240 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Any, List, Set

from networkx import DiGraph

from inference.enterprise.workflows.entities.outputs import JsonField
from inference.enterprise.workflows.entities.validators import is_selector
from inference.enterprise.workflows.entities.workflows_specification import (
    InputType,
    StepType,
)


def get_input_parameters_selectors(inputs: List[InputType]) -> Set[str]:
    return {
        construct_input_selector(input_name=input_definition.name)
        for input_definition in inputs
    }


def construct_input_selector(input_name: str) -> str:
    return f"$inputs.{input_name}"


def get_steps_selectors(steps: List[StepType]) -> Set[str]:
    return {construct_step_selector(step_name=step.name) for step in steps}


def construct_step_selector(step_name: str) -> str:
    return f"$steps.{step_name}"


def get_steps_input_selectors(steps: List[StepType]) -> Set[str]:
    result = set()
    for step in steps:
        result.update(get_step_input_selectors(step=step))
    return result


def get_step_input_selectors(step: StepType) -> Set[str]:
    result = set()
    for step_input_name in step.get_input_names():
        step_input = getattr(step, step_input_name)
        if not issubclass(type(step_input), list):
            step_input = [step_input]
        for element in step_input:
            if not is_selector(selector_or_value=element):
                continue
            result.add(element)
    return result


def get_steps_output_selectors(steps: List[StepType]) -> Set[str]:
    result = set()
    for step in steps:
        for output_name in step.get_output_names():
            result.add(f"$steps.{step.name}.{output_name}")
    return result


def get_output_names(outputs: List[JsonField]) -> Set[str]:
    return {construct_output_name(name=output.name) for output in outputs}


def construct_output_name(name: str) -> str:
    return f"$outputs.{name}"


def get_output_selectors(outputs: List[JsonField]) -> Set[str]:
    return {output.selector for output in outputs}


def is_input_selector(selector_or_value: Any) -> bool:
    if not is_selector(selector_or_value=selector_or_value):
        return False
    return selector_or_value.startswith("$inputs")


def construct_selector_pointing_step_output(selector: str, new_output: str) -> str:
    if is_step_output_selector(selector_or_value=selector):
        selector = get_step_selector_from_its_output(step_output_selector=selector)
    return f"{selector}.{new_output}"


def is_step_output_selector(selector_or_value: Any) -> bool:
    if not is_selector(selector_or_value=selector_or_value):
        return False
    return (
        selector_or_value.startswith("$steps.")
        and len(selector_or_value.split(".")) == 3
    )


def get_step_selector_from_its_output(step_output_selector: str) -> str:
    return ".".join(step_output_selector.split(".")[:2])


def get_nodes_of_specific_kind(execution_graph: DiGraph, kind: str) -> Set[str]:
    return {
        node[0]
        for node in execution_graph.nodes(data=True)
        if node[1].get("kind") == kind
    }


def is_condition_step(execution_graph: DiGraph, node: str) -> bool:
    return execution_graph.nodes[node]["definition"].type == "Condition"