File size: 5,501 Bytes
f11c54d
 
573f5c6
f11c54d
573f5c6
f11c54d
 
 
 
c7f22d1
f11c54d
c7f22d1
 
 
 
 
f11c54d
 
c7f22d1
 
 
 
f11c54d
 
 
c7f22d1
 
f11c54d
 
c7f22d1
 
 
f11c54d
07d94c5
 
c7f22d1
07d94c5
c7f22d1
 
 
 
 
 
 
 
f11c54d
 
c7f22d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f11c54d
573f5c6
 
c7f22d1
 
 
 
573f5c6
 
c7f22d1
 
573f5c6
f11c54d
 
c7f22d1
eee9365
c7f22d1
 
 
 
eee9365
 
c7f22d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from .. import ops
from .. import workspace
import orjson
import pandas as pd
import pydantic
import traceback
import inspect
import typing


class Context(ops.BaseConfig):
    """Passed to operation functions as "_ctx" if they have such a parameter."""

    node: workspace.WorkspaceNode
    last_result: typing.Any = None


class Output(ops.BaseConfig):
    """Return this to send values to specific outputs of a node."""

    output_handle: str
    value: dict


def df_to_list(df):
    return df.to_dict(orient="records")


def has_ctx(op):
    sig = inspect.signature(op.func)
    return "_ctx" in sig.parameters


CACHES = {}


def register(env: str, cache: bool = True):
    """Registers the one-by-one executor."""
    if cache:
        CACHES[env] = {}
        cache = CACHES[env]
    else:
        cache = None
    ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache)


def get_stages(ws, catalog):
    """Inputs on top/bottom are batch inputs. We decompose the graph into a DAG of components along these edges."""
    nodes = {n.id: n for n in ws.nodes}
    batch_inputs = {}
    inputs = {}
    for edge in ws.edges:
        inputs.setdefault(edge.target, []).append(edge.source)
        node = nodes[edge.target]
        op = catalog[node.data.title]
        i = op.inputs[edge.targetHandle]
        if i.position in "top or bottom":
            batch_inputs.setdefault(edge.target, []).append(edge.source)
    stages = []
    for bt, bss in batch_inputs.items():
        upstream = set(bss)
        new = set(bss)
        while new:
            n = new.pop()
            for i in inputs.get(n, []):
                if i not in upstream:
                    upstream.add(i)
                    new.add(i)
        stages.append(upstream)
    stages.sort(key=lambda s: len(s))
    stages.append(set(nodes))
    return stages


def _default_serializer(obj):
    if isinstance(obj, pydantic.BaseModel):
        return obj.dict()
    return {"__nonserializable__": id(obj)}


def make_cache_key(obj):
    return orjson.dumps(obj, default=_default_serializer)


EXECUTOR_OUTPUT_CACHE = {}


async def await_if_needed(obj):
    if inspect.isawaitable(obj):
        return await obj
    return obj


async def execute(ws, catalog, cache=None):
    nodes = {n.id: n for n in ws.nodes}
    contexts = {n.id: Context(node=n) for n in ws.nodes}
    edges = {n.id: [] for n in ws.nodes}
    for e in ws.edges:
        edges[e.source].append(e)
    tasks = {}
    NO_INPUT = object()  # Marker for initial tasks.
    for node in ws.nodes:
        node.data.error = None
        op = catalog.get(node.data.title)
        if op is None:
            node.data.error = f'Operation "{node.data.title}" not found.'
            continue
        # Start tasks for nodes that have no non-batch inputs.
        if all([i.position in "top or bottom" for i in op.inputs.values()]):
            tasks[node.id] = [NO_INPUT]
    batch_inputs = {}
    # Run the rest until we run out of tasks.
    stages = get_stages(ws, catalog)
    for stage in stages:
        next_stage = {}
        while tasks:
            n, ts = tasks.popitem()
            if n not in stage:
                next_stage.setdefault(n, []).extend(ts)
                continue
            node = nodes[n]
            data = node.data
            op = catalog[data.title]
            params = {**data.params}
            if has_ctx(op):
                params["_ctx"] = contexts[node.id]
            results = []
            for task in ts:
                try:
                    inputs = [
                        batch_inputs[(n, i.name)]
                        if i.position in "top or bottom"
                        else task
                        for i in op.inputs.values()
                    ]
                    if cache is not None:
                        key = make_cache_key((inputs, params))
                        if key not in cache:
                            cache[key] = await await_if_needed(op(*inputs, **params))
                        result = cache[key]
                    else:
                        result = await await_if_needed(op(*inputs, **params))
                except Exception as e:
                    traceback.print_exc()
                    data.error = str(e)
                    break
                contexts[node.id].last_result = result
                # Returned lists and DataFrames are considered multiple tasks.
                if isinstance(result, pd.DataFrame):
                    result = df_to_list(result)
                elif not isinstance(result, list):
                    result = [result]
                results.extend(result)
            else:  # Finished all tasks without errors.
                if (
                    op.type == "visualization"
                    or op.type == "table_view"
                    or op.type == "image"
                ):
                    data.display = results[0]
                for edge in edges[node.id]:
                    t = nodes[edge.target]
                    op = catalog[t.data.title]
                    i = op.inputs[edge.targetHandle]
                    if i.position in "top or bottom":
                        batch_inputs.setdefault(
                            (edge.target, edge.targetHandle), []
                        ).extend(results)
                    else:
                        tasks.setdefault(edge.target, []).extend(results)
        tasks = next_stage
    return contexts