Spaces:
Running
Running
File size: 3,435 Bytes
ca01fa3 b8b73b2 ca01fa3 cb1e802 b8b73b2 ca01fa3 cb1e802 9e91869 ca01fa3 0c44583 ca01fa3 9e91869 ca01fa3 9e91869 ca01fa3 a18645a ca01fa3 05acf81 cb1e802 05acf81 cb1e802 05acf81 ca01fa3 05acf81 ca01fa3 b8b73b2 05acf81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from typing import Optional
import dataclasses
import fastapi
import os
import pathlib
import pydantic
import tempfile
import traceback
from . import ops
from . import basic_ops
from . import networkx_ops
class BaseConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(
extra='allow',
)
class Position(BaseConfig):
x: float
y: float
class WorkspaceNodeData(BaseConfig):
title: str
params: dict
display: Optional[object] = None
error: Optional[str] = None
class WorkspaceNode(BaseConfig):
id: str
type: str
data: WorkspaceNodeData
position: Position
class WorkspaceEdge(BaseConfig):
id: str
source: str
target: str
class Workspace(BaseConfig):
nodes: list[WorkspaceNode]
edges: list[WorkspaceEdge]
app = fastapi.FastAPI()
@app.get("/api/catalog")
def get_catalog():
return [
{
'type': op.type,
'data': { 'title': op.name, 'params': op.params },
'targetPosition': 'left' if op.inputs else None,
'sourcePosition': 'right' if op.outputs else None,
}
for op in ops.ALL_OPS.values()]
def execute(ws):
nodes = ws.nodes
outputs = {}
failed = 0
while len(outputs) + failed < len(nodes):
for node in nodes:
if node.id in outputs:
continue
inputs = [edge.source for edge in ws.edges if edge.target == node.id]
if all(input in outputs for input in inputs):
inputs = [outputs[input] for input in inputs]
data = node.data
op = ops.ALL_OPS[data.title]
try:
output = op(*inputs, **data.params)
except Exception as e:
traceback.print_exc()
data.error = str(e)
failed += 1
continue
data.error = None
outputs[node.id] = output
if op.type == 'graph_view' or op.type == 'table_view':
data.view = output
class SaveRequest(BaseConfig):
path: str
ws: Workspace
def save(req: SaveRequest):
path = DATA_PATH / req.path
assert path.is_relative_to(DATA_PATH)
j = req.ws.model_dump_json(indent=2)
with tempfile.NamedTemporaryFile('w', delete_on_close=False) as f:
f.write(j)
f.close()
os.replace(f.name, path)
@app.post("/api/save")
def save_and_execute(req: SaveRequest):
save(req)
execute(req.ws)
save(req)
return req.ws
@app.get("/api/load")
def load(path: str):
path = DATA_PATH / path
assert path.is_relative_to(DATA_PATH)
if not path.exists():
return Workspace(nodes=[], edges=[])
with open(path) as f:
j = f.read()
ws = Workspace.model_validate_json(j)
return ws
DATA_PATH = pathlib.Path.cwd() / 'data'
@dataclasses.dataclass(order=True)
class DirectoryEntry:
name: str
type: str
@app.get("/api/dir/list")
def list_dir(path: str):
path = DATA_PATH / path
assert path.is_relative_to(DATA_PATH)
return sorted([
DirectoryEntry(p.relative_to(DATA_PATH), 'directory' if p.is_dir() else 'workspace')
for p in path.iterdir()])
@app.post("/api/dir/mkdir")
def make_dir(req: dict):
path = DATA_PATH / req['path']
assert path.is_relative_to(DATA_PATH)
assert not path.exists()
path.mkdir()
return list_dir(path.parent)
|