Spaces:
Running
Running
| '''For working with LynxKite workspaces.''' | |
| from typing import Optional | |
| import dataclasses | |
| import os | |
| import pydantic | |
| import tempfile | |
| import traceback | |
| from . import ops | |
| class BaseConfig(pydantic.BaseModel): | |
| model_config = pydantic.ConfigDict( | |
| extra='allow', | |
| ) | |
| class Position(BaseConfig): | |
| x: float | |
| y: float | |
| class WorkspaceNodeData(BaseConfig): | |
| title: str | |
| params: dict | |
| display: Optional[object] = None | |
| error: Optional[str] = None | |
| class WorkspaceNode(BaseConfig): | |
| id: str | |
| type: str | |
| data: WorkspaceNodeData | |
| position: Position | |
| parentId: Optional[str] = None | |
| class WorkspaceEdge(BaseConfig): | |
| id: str | |
| source: str | |
| target: str | |
| class Workspace(BaseConfig): | |
| nodes: list[WorkspaceNode] = dataclasses.field(default_factory=list) | |
| edges: list[WorkspaceEdge] = dataclasses.field(default_factory=list) | |
| def execute(ws): | |
| # Nodes are responsible for interpreting/executing their child nodes. | |
| nodes = [n for n in ws.nodes if not n.parentId] | |
| children = {} | |
| for n in ws.nodes: | |
| if n.parentId: | |
| children.setdefault(n.parentId, []).append(n) | |
| outputs = {} | |
| failed = 0 | |
| while len(outputs) + failed < len(nodes): | |
| for node in nodes: | |
| if node.id in outputs: | |
| continue | |
| inputs = [edge.source for edge in ws.edges if edge.target == node.id] | |
| if all(input in outputs for input in inputs): | |
| inputs = [outputs[input] for input in inputs] | |
| data = node.data | |
| op = ops.ALL_OPS[data.title] | |
| params = {**data.params} | |
| if op.sub_nodes: | |
| sub_nodes = children.get(node.id, []) | |
| sub_node_ids = [node.id for node in sub_nodes] | |
| sub_edges = [edge for edge in ws.edges if edge.source in sub_node_ids] | |
| params['sub_flow'] = {'nodes': sub_nodes, 'edges': sub_edges} | |
| try: | |
| output = op(*inputs, **params) | |
| except Exception as e: | |
| traceback.print_exc() | |
| data.error = str(e) | |
| failed += 1 | |
| continue | |
| if len(op.inputs) == 1 and op.inputs.get('multi') == '*': | |
| # It's a flexible input. Create n+1 handles. | |
| data.inputs = {f'input{i}': None for i in range(len(inputs) + 1)} | |
| data.error = None | |
| outputs[node.id] = output | |
| if op.type == 'visualization' or op.type == 'table_view': | |
| data.view = output | |
| def save(ws: Workspace, path: str): | |
| j = ws.model_dump_json(indent=2) | |
| dirname, basename = os.path.split(path) | |
| # Create temp file in the same directory to make sure it's on the same filesystem. | |
| with tempfile.NamedTemporaryFile('w', prefix=f'.{basename}.', dir=dirname, delete_on_close=False) as f: | |
| f.write(j) | |
| f.close() | |
| os.replace(f.name, path) | |
| def load(path: str): | |
| with open(path) as f: | |
| j = f.read() | |
| ws = Workspace.model_validate_json(j) | |
| # Metadata is added after loading. This way code changes take effect on old boxes too. | |
| _update_metadata(ws) | |
| return ws | |
| def _update_metadata(ws): | |
| nodes = {node.id: node for node in ws.nodes} | |
| done = set() | |
| while len(done) < len(nodes): | |
| for node in ws.nodes: | |
| if node.id in done: | |
| continue | |
| data = node.data | |
| if node.parentId is None: | |
| op = ops.ALL_OPS.get(data.title) | |
| elif node.parentId not in nodes: | |
| data.error = f'Parent not found: {node.parentId}' | |
| done.add(node.id) | |
| continue | |
| elif node.parentId in done: | |
| op = nodes[node.parentId].data.meta.sub_nodes[data.title] | |
| else: | |
| continue | |
| if op: | |
| data.meta = op | |
| if data.error == 'Unknown operation.': | |
| data.error = None | |
| else: | |
| data.error = 'Unknown operation.' | |
| done.add(node.id) | |
| return ws | |