File size: 5,628 Bytes
e8a8341
 
0213da5
e8a8341
0213da5
e8a8341
 
 
 
d1d859f
e8a8341
d1d859f
 
 
 
 
e8a8341
 
d1d859f
 
 
 
e8a8341
 
 
d1d859f
 
e8a8341
 
d1d859f
 
 
e8a8341
a509341
 
d1d859f
a509341
d1d859f
 
 
 
 
 
 
 
e8a8341
 
d1d859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8a8341
0213da5
 
d1d859f
 
 
 
0213da5
 
d1d859f
 
0213da5
e8a8341
 
d1d859f
a0194e7
d1d859f
 
 
 
a0194e7
 
d1d859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03b7855
 
 
 
 
 
 
d1d859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from .. import ops
from .. import workspace
import orjson
import pandas as pd
import pydantic
import traceback
import inspect
import typing


class Context(ops.BaseConfig):
    """Passed to operation functions as "_ctx" if they have such a parameter."""

    node: workspace.WorkspaceNode
    last_result: typing.Any = None


class Output(ops.BaseConfig):
    """Return this to send values to specific outputs of a node."""

    output_handle: str
    value: dict


def df_to_list(df):
    return df.to_dict(orient="records")


def has_ctx(op):
    sig = inspect.signature(op.func)
    return "_ctx" in sig.parameters


CACHES = {}


def register(env: str, cache: bool = True):
    """Registers the one-by-one executor."""
    if cache:
        CACHES[env] = {}
        cache = CACHES[env]
    else:
        cache = None
    ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache)


def get_stages(ws, catalog):
    """Inputs on top/bottom are batch inputs. We decompose the graph into a DAG of components along these edges."""
    nodes = {n.id: n for n in ws.nodes}
    batch_inputs = {}
    inputs = {}
    for edge in ws.edges:
        inputs.setdefault(edge.target, []).append(edge.source)
        node = nodes[edge.target]
        op = catalog[node.data.title]
        i = op.inputs[edge.targetHandle]
        if i.position in "top or bottom":
            batch_inputs.setdefault(edge.target, []).append(edge.source)
    stages = []
    for bt, bss in batch_inputs.items():
        upstream = set(bss)
        new = set(bss)
        while new:
            n = new.pop()
            for i in inputs.get(n, []):
                if i not in upstream:
                    upstream.add(i)
                    new.add(i)
        stages.append(upstream)
    stages.sort(key=lambda s: len(s))
    stages.append(set(nodes))
    return stages


def _default_serializer(obj):
    if isinstance(obj, pydantic.BaseModel):
        return obj.dict()
    return {"__nonserializable__": id(obj)}


def make_cache_key(obj):
    return orjson.dumps(obj, default=_default_serializer)


EXECUTOR_OUTPUT_CACHE = {}


async def await_if_needed(obj):
    if inspect.isawaitable(obj):
        return await obj
    return obj


async def execute(ws, catalog, cache=None):
    nodes = {n.id: n for n in ws.nodes}
    contexts = {n.id: Context(node=n) for n in ws.nodes}
    edges = {n.id: [] for n in ws.nodes}
    for e in ws.edges:
        edges[e.source].append(e)
    tasks = {}
    NO_INPUT = object()  # Marker for initial tasks.
    for node in ws.nodes:
        node.data.error = None
        op = catalog.get(node.data.title)
        if op is None:
            node.data.error = f'Operation "{node.data.title}" not found.'
            continue
        # Start tasks for nodes that have no non-batch inputs.
        if all([i.position in "top or bottom" for i in op.inputs.values()]):
            tasks[node.id] = [NO_INPUT]
    batch_inputs = {}
    # Run the rest until we run out of tasks.
    stages = get_stages(ws, catalog)
    for stage in stages:
        next_stage = {}
        while tasks:
            n, ts = tasks.popitem()
            if n not in stage:
                next_stage.setdefault(n, []).extend(ts)
                continue
            node = nodes[n]
            data = node.data
            op = catalog[data.title]
            params = {**data.params}
            if has_ctx(op):
                params["_ctx"] = contexts[node.id]
            results = []
            for task in ts:
                try:
                    inputs = []
                    for i in op.inputs.values():
                        if i.position in "top or bottom":
                            assert (n, i.name) in batch_inputs, f"{i.name} is missing"
                            inputs.append(batch_inputs[(n, i.name)])
                        else:
                            inputs.append(task)
                    if cache is not None:
                        key = make_cache_key((inputs, params))
                        if key not in cache:
                            cache[key] = await await_if_needed(op(*inputs, **params))
                        result = cache[key]
                    else:
                        result = await await_if_needed(op(*inputs, **params))
                except Exception as e:
                    traceback.print_exc()
                    data.error = str(e)
                    break
                contexts[node.id].last_result = result
                # Returned lists and DataFrames are considered multiple tasks.
                if isinstance(result, pd.DataFrame):
                    result = df_to_list(result)
                elif not isinstance(result, list):
                    result = [result]
                results.extend(result)
            else:  # Finished all tasks without errors.
                if (
                    op.type == "visualization"
                    or op.type == "table_view"
                    or op.type == "image"
                ):
                    data.display = results[0]
                for edge in edges[node.id]:
                    t = nodes[edge.target]
                    op = catalog[t.data.title]
                    i = op.inputs[edge.targetHandle]
                    if i.position in "top or bottom":
                        batch_inputs.setdefault(
                            (edge.target, edge.targetHandle), []
                        ).extend(results)
                    else:
                        tasks.setdefault(edge.target, []).extend(results)
        tasks = next_stage
    return contexts