Spaces:
Running
Running
| """For working with LynxKite workspaces.""" | |
| from typing import Optional | |
| import dataclasses | |
| import os | |
| import pydantic | |
| import tempfile | |
| from . import ops | |
| class BaseConfig(pydantic.BaseModel): | |
| model_config = pydantic.ConfigDict( | |
| extra="allow", | |
| ) | |
| class Position(BaseConfig): | |
| x: float | |
| y: float | |
| class WorkspaceNodeData(BaseConfig): | |
| title: str | |
| params: dict | |
| display: Optional[object] = None | |
| error: Optional[str] = None | |
| # Also contains a "meta" field when going out. | |
| # This is ignored when coming back from the frontend. | |
| class WorkspaceNode(BaseConfig): | |
| id: str | |
| type: str | |
| data: WorkspaceNodeData | |
| position: Position | |
| class WorkspaceEdge(BaseConfig): | |
| id: str | |
| source: str | |
| target: str | |
| sourceHandle: str | |
| targetHandle: str | |
| class Workspace(BaseConfig): | |
| env: str = "" | |
| nodes: list[WorkspaceNode] = dataclasses.field(default_factory=list) | |
| edges: list[WorkspaceEdge] = dataclasses.field(default_factory=list) | |
| async def execute(ws: Workspace): | |
| if ws.env in ops.EXECUTORS: | |
| await ops.EXECUTORS[ws.env](ws) | |
| def save(ws: Workspace, path: str): | |
| j = ws.model_dump_json(indent=2) | |
| dirname, basename = os.path.split(path) | |
| # Create temp file in the same directory to make sure it's on the same filesystem. | |
| with tempfile.NamedTemporaryFile( | |
| "w", prefix=f".{basename}.", dir=dirname, delete=False | |
| ) as f: | |
| temp_name = f.name | |
| f.write(j) | |
| os.replace(temp_name, path) | |
| def load(path: str): | |
| with open(path) as f: | |
| j = f.read() | |
| ws = Workspace.model_validate_json(j) | |
| # Metadata is added after loading. This way code changes take effect on old boxes too. | |
| _update_metadata(ws) | |
| return ws | |
| def _update_metadata(ws): | |
| catalog = ops.CATALOGS.get(ws.env, {}) | |
| nodes = {node.id: node for node in ws.nodes} | |
| done = set() | |
| while len(done) < len(nodes): | |
| for node in ws.nodes: | |
| if node.id in done: | |
| continue | |
| data = node.data | |
| op = catalog.get(data.title) | |
| if op: | |
| data.meta = op | |
| node.type = op.type | |
| if data.error == "Unknown operation.": | |
| data.error = None | |
| else: | |
| data.error = "Unknown operation." | |
| done.add(node.id) | |
| return ws | |