Spaces:
Running
Running
File size: 4,151 Bytes
bc2b550 83cc307 bc2b550 83cc307 bc2b550 83cc307 bc2b550 a06b506 bc2b550 b5a8a95 bc2b550 21590fa 5882a26 bc2b550 aa0792f bc2b550 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
'''For working with LynxKite workspaces.'''
from typing import Optional
import dataclasses
import os
import pydantic
import tempfile
import traceback
from . import ops
class BaseConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(
extra='allow',
)
class Position(BaseConfig):
x: float
y: float
class WorkspaceNodeData(BaseConfig):
title: str
params: dict
display: Optional[object] = None
error: Optional[str] = None
class WorkspaceNode(BaseConfig):
id: str
type: str
data: WorkspaceNodeData
position: Position
parentId: Optional[str] = None
class WorkspaceEdge(BaseConfig):
id: str
source: str
target: str
class Workspace(BaseConfig):
nodes: list[WorkspaceNode] = dataclasses.field(default_factory=list)
edges: list[WorkspaceEdge] = dataclasses.field(default_factory=list)
def execute(ws):
# Nodes are responsible for interpreting/executing their child nodes.
nodes = [n for n in ws.nodes if not n.parentId]
children = {}
for n in ws.nodes:
if n.parentId:
children.setdefault(n.parentId, []).append(n)
outputs = {}
failed = 0
while len(outputs) + failed < len(nodes):
for node in nodes:
if node.id in outputs:
continue
inputs = [edge.source for edge in ws.edges if edge.target == node.id]
if all(input in outputs for input in inputs):
inputs = [outputs[input] for input in inputs]
data = node.data
op = ops.ALL_OPS[data.title]
params = {**data.params}
if op.sub_nodes:
sub_nodes = children.get(node.id, [])
sub_node_ids = [node.id for node in sub_nodes]
sub_edges = [edge for edge in ws.edges if edge.source in sub_node_ids]
params['sub_flow'] = {'nodes': sub_nodes, 'edges': sub_edges}
try:
output = op(*inputs, **params)
except Exception as e:
traceback.print_exc()
data.error = str(e)
failed += 1
continue
if len(op.inputs) == 1 and op.inputs.get('multi') == '*':
# It's a flexible input. Create n+1 handles.
data.inputs = {f'input{i}': None for i in range(len(inputs) + 1)}
data.error = None
outputs[node.id] = output
if op.type == 'visualization' or op.type == 'table_view':
data.view = output
def save(ws: Workspace, path: str):
j = ws.model_dump_json(indent=2)
dirname, basename = os.path.split(path)
# Create temp file in the same directory to make sure it's on the same filesystem.
with tempfile.NamedTemporaryFile('w', prefix=f'.{basename}.', dir=dirname, delete_on_close=False) as f:
f.write(j)
f.close()
os.replace(f.name, path)
def load(path: str):
with open(path) as f:
j = f.read()
ws = Workspace.model_validate_json(j)
# Metadata is added after loading. This way code changes take effect on old boxes too.
_update_metadata(ws)
return ws
def _update_metadata(ws):
nodes = {node.id: node for node in ws.nodes}
done = set()
while len(done) < len(nodes):
for node in ws.nodes:
if node.id in done:
continue
data = node.data
if node.parentId is None:
op = ops.ALL_OPS.get(data.title)
elif node.parentId not in nodes:
data.error = f'Parent not found: {node.parentId}'
done.add(node.id)
continue
elif node.parentId in done:
op = nodes[node.parentId].data.meta.sub_nodes[data.title]
else:
continue
if op:
data.meta = op
if data.error == 'Unknown operation.':
data.error = None
else:
data.error = 'Unknown operation.'
done.add(node.id)
return ws
|