File size: 4,536 Bytes
e8a8341
 
0213da5
e8a8341
0213da5
e8a8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03a6805
e8a8341
 
 
 
 
a509341
 
 
e8a8341
a509341
 
 
 
 
 
e8a8341
 
9cc1fee
e8a8341
 
 
 
 
 
 
 
9cc1fee
e8a8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0213da5
 
 
 
 
 
 
 
 
e8a8341
 
a0194e7
 
 
 
 
 
e8a8341
 
 
 
 
 
 
 
 
 
eda8f97
9cc1fee
e8a8341
 
 
eda8f97
 
e8a8341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cc1fee
e8a8341
0213da5
 
e8a8341
a0194e7
e8a8341
 
a0194e7
e8a8341
 
 
 
 
 
 
 
 
 
 
 
4d72daa
e8a8341
 
 
 
 
9cc1fee
e8a8341
 
 
 
b34d742
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from .. import ops
from .. import workspace
import orjson
import pandas as pd
import pydantic
import traceback
import inspect
import typing

class Context(ops.BaseConfig):
  '''Passed to operation functions as "_ctx" if they have such a parameter.'''
  node: workspace.WorkspaceNode
  last_result: typing.Any = None

class Output(ops.BaseConfig):
  '''Return this to send values to specific outputs of a node.'''
  output_handle: str
  value: dict


def df_to_list(df):
  return df.to_dict(orient='records')

def has_ctx(op):
  sig = inspect.signature(op.func)
  return '_ctx' in sig.parameters

CACHES = {}

def register(env: str, cache: bool = True):
  '''Registers the one-by-one executor.'''
  if cache:
    CACHES[env] = {}
    cache = CACHES[env]
  else:
    cache = None
  ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache)

def get_stages(ws, catalog):
  '''Inputs on top/bottom are batch inputs. We decompose the graph into a DAG of components along these edges.'''
  nodes = {n.id: n for n in ws.nodes}
  batch_inputs = {}
  inputs = {}
  for edge in ws.edges:
    inputs.setdefault(edge.target, []).append(edge.source)
    node = nodes[edge.target]
    op = catalog[node.data.title]
    i = op.inputs[edge.targetHandle]
    if i.position in 'top or bottom':
      batch_inputs.setdefault(edge.target, []).append(edge.source)
  stages = []
  for bt, bss in batch_inputs.items():
    upstream = set(bss)
    new = set(bss)
    while new:
      n = new.pop()
      for i in inputs.get(n, []):
        if i not in upstream:
          upstream.add(i)
          new.add(i)
    stages.append(upstream)
  stages.sort(key=lambda s: len(s))
  stages.append(set(nodes))
  return stages


def _default_serializer(obj):
  if isinstance(obj, pydantic.BaseModel):
    return obj.dict()
  return {"__nonserializable__": id(obj)}

def make_cache_key(obj):
  return orjson.dumps(obj, default=_default_serializer)

EXECUTOR_OUTPUT_CACHE = {}

async def await_if_needed(obj):
  if inspect.isawaitable(obj):
    return await obj
  return obj

async def execute(ws, catalog, cache=None):
  nodes = {n.id: n for n in ws.nodes}
  contexts = {n.id: Context(node=n) for n in ws.nodes}
  edges = {n.id: [] for n in ws.nodes}
  for e in ws.edges:
    edges[e.source].append(e)
  tasks = {}
  NO_INPUT = object() # Marker for initial tasks.
  for node in ws.nodes:
    node.data.error = None
    op = catalog[node.data.title]
    # Start tasks for nodes that have no non-batch inputs.
    if all([i.position in 'top or bottom' for i in op.inputs.values()]):
      tasks[node.id] = [NO_INPUT]
  batch_inputs = {}
  # Run the rest until we run out of tasks.
  stages = get_stages(ws, catalog)
  for stage in stages:
    next_stage = {}
    while tasks:
      n, ts = tasks.popitem()
      if n not in stage:
        next_stage.setdefault(n, []).extend(ts)
        continue
      node = nodes[n]
      data = node.data
      op = catalog[data.title]
      params = {**data.params}
      if has_ctx(op):
        params['_ctx'] = contexts[node.id]
      results = []
      for task in ts:
        try:
          inputs = [
            batch_inputs[(n, i.name)] if i.position in 'top or bottom' else task
            for i in op.inputs.values()]
          if cache is not None:
            key = make_cache_key((inputs, params))
            if key not in cache:
              cache[key] = await await_if_needed(op(*inputs, **params))
            result = cache[key]
          else:
            result = await await_if_needed(op(*inputs, **params))
        except Exception as e:
          traceback.print_exc()
          data.error = str(e)
          break
        contexts[node.id].last_result = result
        # Returned lists and DataFrames are considered multiple tasks.
        if isinstance(result, pd.DataFrame):
          result = df_to_list(result)
        elif not isinstance(result, list):
          result = [result]
        results.extend(result)
      else: # Finished all tasks without errors.
        if op.type == 'visualization' or op.type == 'table_view' or op.type == 'image':
          data.display = results[0]
        for edge in edges[node.id]:
          t = nodes[edge.target]
          op = catalog[t.data.title]
          i = op.inputs[edge.targetHandle]
          if i.position in 'top or bottom':
            batch_inputs.setdefault((edge.target, edge.targetHandle), []).extend(results)
          else:
            tasks.setdefault(edge.target, []).extend(results)
    tasks = next_stage
  return contexts