darabos commited on
Commit
bc2b550
·
1 Parent(s): d994c06

Move workspace functions to workspace.py.

Browse files
Files changed (2) hide show
  1. server/main.py +9 -88
  2. server/workspace.py +92 -0
server/main.py CHANGED
@@ -1,47 +1,12 @@
1
- from typing import Optional
2
  import dataclasses
3
  import fastapi
4
- import os
5
  import pathlib
6
- import pydantic
7
- import tempfile
8
- import traceback
9
  from . import ops
 
10
  from . import basic_ops
11
- from . import networkx_ops
12
- from . import pytorch_model_ops
13
-
14
- class BaseConfig(pydantic.BaseModel):
15
- model_config = pydantic.ConfigDict(
16
- extra='allow',
17
- )
18
-
19
- class Position(BaseConfig):
20
- x: float
21
- y: float
22
-
23
- class WorkspaceNodeData(BaseConfig):
24
- title: str
25
- params: dict
26
- display: Optional[object] = None
27
- error: Optional[str] = None
28
-
29
- class WorkspaceNode(BaseConfig):
30
- id: str
31
- type: str
32
- data: WorkspaceNodeData
33
- position: Position
34
- parentNode: Optional[str] = None
35
-
36
- class WorkspaceEdge(BaseConfig):
37
- id: str
38
- source: str
39
- target: str
40
-
41
- class Workspace(BaseConfig):
42
- nodes: list[WorkspaceNode]
43
- edges: list[WorkspaceEdge]
44
-
45
 
46
  app = fastapi.FastAPI()
47
 
@@ -50,56 +15,15 @@ app = fastapi.FastAPI()
50
  def get_catalog():
51
  return [op.to_json() for op in ops.ALL_OPS.values()]
52
 
53
- def execute(ws):
54
- # Nodes are responsible for interpreting/executing their child nodes.
55
- nodes = [n for n in ws.nodes if not n.parentNode]
56
- children = {}
57
- for n in ws.nodes:
58
- if n.parentNode:
59
- children.setdefault(n.parentNode, []).append(n)
60
- outputs = {}
61
- failed = 0
62
- while len(outputs) + failed < len(nodes):
63
- for node in nodes:
64
- if node.id in outputs:
65
- continue
66
- inputs = [edge.source for edge in ws.edges if edge.target == node.id]
67
- if all(input in outputs for input in inputs):
68
- inputs = [outputs[input] for input in inputs]
69
- data = node.data
70
- op = ops.ALL_OPS[data.title]
71
- params = {**data.params}
72
- if op.sub_nodes:
73
- sub_nodes = children.get(node.id, [])
74
- sub_node_ids = [node.id for node in sub_nodes]
75
- sub_edges = [edge for edge in ws.edges if edge.source in sub_node_ids]
76
- params['sub_flow'] = {'nodes': sub_nodes, 'edges': sub_edges}
77
- try:
78
- output = op(*inputs, **params)
79
- except Exception as e:
80
- traceback.print_exc()
81
- data.error = str(e)
82
- failed += 1
83
- continue
84
- data.error = None
85
- outputs[node.id] = output
86
- if op.type == 'graph_view' or op.type == 'table_view':
87
- data.view = output
88
 
89
-
90
- class SaveRequest(BaseConfig):
91
  path: str
92
- ws: Workspace
93
 
94
  def save(req: SaveRequest):
95
  path = DATA_PATH / req.path
96
  assert path.is_relative_to(DATA_PATH)
97
- j = req.ws.model_dump_json(indent=2)
98
- with tempfile.NamedTemporaryFile('w', delete_on_close=False) as f:
99
- f.write(j)
100
- f.close()
101
- os.replace(f.name, path)
102
-
103
 
104
  @app.post("/api/save")
105
  def save_and_execute(req: SaveRequest):
@@ -113,11 +37,8 @@ def load(path: str):
113
  path = DATA_PATH / path
114
  assert path.is_relative_to(DATA_PATH)
115
  if not path.exists():
116
- return Workspace(nodes=[], edges=[])
117
- with open(path) as f:
118
- j = f.read()
119
- ws = Workspace.model_validate_json(j)
120
- return ws
121
 
122
  DATA_PATH = pathlib.Path.cwd() / 'data'
123
 
 
 
1
  import dataclasses
2
  import fastapi
 
3
  import pathlib
 
 
 
4
  from . import ops
5
+ from . import workspace
6
  from . import basic_ops
7
+ # from . import networkx_ops
8
+ # from . import pytorch_model_ops
9
+ from . import lynxscribe_ops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  app = fastapi.FastAPI()
12
 
 
15
  def get_catalog():
16
  return [op.to_json() for op in ops.ALL_OPS.values()]
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ class SaveRequest(workspace.BaseConfig):
 
20
  path: str
21
+ ws: workspace.Workspace
22
 
23
  def save(req: SaveRequest):
24
  path = DATA_PATH / req.path
25
  assert path.is_relative_to(DATA_PATH)
26
+ workspace.save(req.ws, path)
 
 
 
 
 
27
 
28
  @app.post("/api/save")
29
  def save_and_execute(req: SaveRequest):
 
37
  path = DATA_PATH / path
38
  assert path.is_relative_to(DATA_PATH)
39
  if not path.exists():
40
+ return workspace.Workspace()
41
+ return workspace.load(path)
 
 
 
42
 
43
  DATA_PATH = pathlib.Path.cwd() / 'data'
44
 
server/workspace.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''For working with LynxKite workspaces.'''
2
+ from typing import Optional
3
+ import dataclasses
4
+ import os
5
+ import pydantic
6
+ import tempfile
7
+ import traceback
8
+ from . import ops
9
+
10
+ class BaseConfig(pydantic.BaseModel):
11
+ model_config = pydantic.ConfigDict(
12
+ extra='allow',
13
+ )
14
+
15
+ class Position(BaseConfig):
16
+ x: float
17
+ y: float
18
+
19
+ class WorkspaceNodeData(BaseConfig):
20
+ title: str
21
+ params: dict
22
+ display: Optional[object] = None
23
+ error: Optional[str] = None
24
+
25
+ class WorkspaceNode(BaseConfig):
26
+ id: str
27
+ type: str
28
+ data: WorkspaceNodeData
29
+ position: Position
30
+ parentNode: Optional[str] = None
31
+
32
+ class WorkspaceEdge(BaseConfig):
33
+ id: str
34
+ source: str
35
+ target: str
36
+
37
+ class Workspace(BaseConfig):
38
+ nodes: list[WorkspaceNode] = dataclasses.field(default_factory=list)
39
+ edges: list[WorkspaceEdge] = dataclasses.field(default_factory=list)
40
+
41
+
42
+ def execute(ws):
43
+ # Nodes are responsible for interpreting/executing their child nodes.
44
+ nodes = [n for n in ws.nodes if not n.parentNode]
45
+ print(nodes)
46
+ children = {}
47
+ for n in ws.nodes:
48
+ if n.parentNode:
49
+ children.setdefault(n.parentNode, []).append(n)
50
+ outputs = {}
51
+ failed = 0
52
+ while len(outputs) + failed < len(nodes):
53
+ for node in nodes:
54
+ if node.id in outputs:
55
+ continue
56
+ inputs = [edge.source for edge in ws.edges if edge.target == node.id]
57
+ if all(input in outputs for input in inputs):
58
+ inputs = [outputs[input] for input in inputs]
59
+ data = node.data
60
+ op = ops.ALL_OPS[data.title]
61
+ params = {**data.params}
62
+ if op.sub_nodes:
63
+ sub_nodes = children.get(node.id, [])
64
+ sub_node_ids = [node.id for node in sub_nodes]
65
+ sub_edges = [edge for edge in ws.edges if edge.source in sub_node_ids]
66
+ params['sub_flow'] = {'nodes': sub_nodes, 'edges': sub_edges}
67
+ try:
68
+ output = op(*inputs, **params)
69
+ except Exception as e:
70
+ traceback.print_exc()
71
+ data.error = str(e)
72
+ failed += 1
73
+ continue
74
+ data.error = None
75
+ outputs[node.id] = output
76
+ if op.type == 'graph_view' or op.type == 'table_view':
77
+ data.view = output
78
+
79
+
80
+ def save(ws: Workspace, path: str):
81
+ j = ws.model_dump_json(indent=2)
82
+ with tempfile.NamedTemporaryFile('w', delete_on_close=False) as f:
83
+ f.write(j)
84
+ f.close()
85
+ os.replace(f.name, path)
86
+
87
+
88
+ def load(path: str):
89
+ with open(path) as f:
90
+ j = f.read()
91
+ ws = Workspace.model_validate_json(j)
92
+ return ws