Spaces:
Running
Running
from .. import ops | |
from .. import workspace | |
import orjson | |
import pandas as pd | |
import pydantic | |
import traceback | |
import inspect | |
import typing | |
class Context(ops.BaseConfig): | |
"""Passed to operation functions as "_ctx" if they have such a parameter.""" | |
node: workspace.WorkspaceNode | |
last_result: typing.Any = None | |
class Output(ops.BaseConfig): | |
"""Return this to send values to specific outputs of a node.""" | |
output_handle: str | |
value: dict | |
def df_to_list(df): | |
return df.to_dict(orient="records") | |
def has_ctx(op): | |
sig = inspect.signature(op.func) | |
return "_ctx" in sig.parameters | |
CACHES = {} | |
def register(env: str, cache: bool = True): | |
"""Registers the one-by-one executor.""" | |
if cache: | |
CACHES[env] = {} | |
cache = CACHES[env] | |
else: | |
cache = None | |
ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache) | |
def get_stages(ws, catalog): | |
"""Inputs on top/bottom are batch inputs. We decompose the graph into a DAG of components along these edges.""" | |
nodes = {n.id: n for n in ws.nodes} | |
batch_inputs = {} | |
inputs = {} | |
for edge in ws.edges: | |
inputs.setdefault(edge.target, []).append(edge.source) | |
node = nodes[edge.target] | |
op = catalog[node.data.title] | |
i = op.inputs[edge.targetHandle] | |
if i.position in "top or bottom": | |
batch_inputs.setdefault(edge.target, []).append(edge.source) | |
stages = [] | |
for bt, bss in batch_inputs.items(): | |
upstream = set(bss) | |
new = set(bss) | |
while new: | |
n = new.pop() | |
for i in inputs.get(n, []): | |
if i not in upstream: | |
upstream.add(i) | |
new.add(i) | |
stages.append(upstream) | |
stages.sort(key=lambda s: len(s)) | |
stages.append(set(nodes)) | |
return stages | |
def _default_serializer(obj): | |
if isinstance(obj, pydantic.BaseModel): | |
return obj.dict() | |
return {"__nonserializable__": id(obj)} | |
def make_cache_key(obj): | |
return orjson.dumps(obj, default=_default_serializer) | |
EXECUTOR_OUTPUT_CACHE = {} | |
async def await_if_needed(obj): | |
if inspect.isawaitable(obj): | |
return await obj | |
return obj | |
async def execute(ws, catalog, cache=None): | |
nodes = {n.id: n for n in ws.nodes} | |
contexts = {n.id: Context(node=n) for n in ws.nodes} | |
edges = {n.id: [] for n in ws.nodes} | |
for e in ws.edges: | |
edges[e.source].append(e) | |
tasks = {} | |
NO_INPUT = object() # Marker for initial tasks. | |
for node in ws.nodes: | |
node.data.error = None | |
op = catalog.get(node.data.title) | |
if op is None: | |
node.data.error = f'Operation "{node.data.title}" not found.' | |
continue | |
# Start tasks for nodes that have no non-batch inputs. | |
if all([i.position in "top or bottom" for i in op.inputs.values()]): | |
tasks[node.id] = [NO_INPUT] | |
batch_inputs = {} | |
# Run the rest until we run out of tasks. | |
stages = get_stages(ws, catalog) | |
for stage in stages: | |
next_stage = {} | |
while tasks: | |
n, ts = tasks.popitem() | |
if n not in stage: | |
next_stage.setdefault(n, []).extend(ts) | |
continue | |
node = nodes[n] | |
data = node.data | |
op = catalog[data.title] | |
params = {**data.params} | |
if has_ctx(op): | |
params["_ctx"] = contexts[node.id] | |
results = [] | |
for task in ts: | |
try: | |
inputs = [] | |
for i in op.inputs.values(): | |
if i.position in "top or bottom": | |
assert (n, i.name) in batch_inputs, f"{i.name} is missing" | |
inputs.append(batch_inputs[(n, i.name)]) | |
else: | |
inputs.append(task) | |
if cache is not None: | |
key = make_cache_key((inputs, params)) | |
if key not in cache: | |
cache[key] = await await_if_needed(op(*inputs, **params)) | |
result = cache[key] | |
else: | |
result = await await_if_needed(op(*inputs, **params)) | |
except Exception as e: | |
traceback.print_exc() | |
data.error = str(e) | |
break | |
contexts[node.id].last_result = result | |
# Returned lists and DataFrames are considered multiple tasks. | |
if isinstance(result, pd.DataFrame): | |
result = df_to_list(result) | |
elif not isinstance(result, list): | |
result = [result] | |
results.extend(result) | |
else: # Finished all tasks without errors. | |
if ( | |
op.type == "visualization" | |
or op.type == "table_view" | |
or op.type == "image" | |
): | |
data.display = results[0] | |
for edge in edges[node.id]: | |
t = nodes[edge.target] | |
op = catalog[t.data.title] | |
i = op.inputs[edge.targetHandle] | |
if i.position in "top or bottom": | |
batch_inputs.setdefault( | |
(edge.target, edge.targetHandle), [] | |
).extend(results) | |
else: | |
tasks.setdefault(edge.target, []).extend(results) | |
tasks = next_stage | |
return contexts | |