darabos commited on
Commit
3010d5b
·
1 Parent(s): cb1e802

Nodes that can contain flows.

Browse files
server/main.py CHANGED
@@ -9,6 +9,7 @@ import traceback
9
  from . import ops
10
  from . import basic_ops
11
  from . import networkx_ops
 
12
 
13
  class BaseConfig(pydantic.BaseModel):
14
  model_config = pydantic.ConfigDict(
@@ -30,6 +31,7 @@ class WorkspaceNode(BaseConfig):
30
  type: str
31
  data: WorkspaceNodeData
32
  position: Position
 
33
 
34
  class WorkspaceEdge(BaseConfig):
35
  id: str
@@ -46,17 +48,11 @@ app = fastapi.FastAPI()
46
 
47
  @app.get("/api/catalog")
48
  def get_catalog():
49
- return [
50
- {
51
- 'type': op.type,
52
- 'data': { 'title': op.name, 'params': op.params },
53
- 'targetPosition': 'left' if op.inputs else None,
54
- 'sourcePosition': 'right' if op.outputs else None,
55
- }
56
- for op in ops.ALL_OPS.values()]
57
 
58
  def execute(ws):
59
- nodes = ws.nodes
 
60
  outputs = {}
61
  failed = 0
62
  while len(outputs) + failed < len(nodes):
 
9
  from . import ops
10
  from . import basic_ops
11
  from . import networkx_ops
12
+ from . import pytorch_model_ops
13
 
14
  class BaseConfig(pydantic.BaseModel):
15
  model_config = pydantic.ConfigDict(
 
31
  type: str
32
  data: WorkspaceNodeData
33
  position: Position
34
+ parentNode: Optional[str] = None
35
 
36
  class WorkspaceEdge(BaseConfig):
37
  id: str
 
48
 
49
  @app.get("/api/catalog")
50
  def get_catalog():
51
+ return [op.to_json() for op in ops.ALL_OPS.values()]
 
 
 
 
 
 
 
52
 
53
  def execute(ws):
54
+ # Nodes are responsible for interpreting/executing their child nodes.
55
+ nodes = [n for n in ws.nodes if not n.parentNode]
56
  outputs = {}
57
  failed = 0
58
  while len(outputs) + failed < len(nodes):
server/ops.py CHANGED
@@ -11,10 +11,11 @@ ALL_OPS = {}
11
  class Op:
12
  func: callable
13
  name: str
14
- params: dict
15
- inputs: dict
16
- outputs: dict
17
- type: str
 
18
 
19
  def __call__(self, *inputs, **params):
20
  # Convert parameters.
@@ -39,20 +40,37 @@ class Op:
39
  res = self.func(*inputs, **params)
40
  return res
41
 
 
 
 
 
 
 
 
 
 
 
42
  @dataclasses.dataclass
43
  class RelationDefinition:
44
- df: str
45
- source_column: str
46
- target_column: str
47
- source_table: str
48
- target_table: str
49
- source_key: str
50
- target_key: str
 
51
 
52
  @dataclasses.dataclass
53
  class Bundle:
 
 
 
 
 
54
  dfs: dict
55
  relations: list[RelationDefinition]
 
56
 
57
  @classmethod
58
  def from_nx(cls, graph: nx.Graph):
@@ -94,7 +112,7 @@ def nx_node_attribute_func(name):
94
  return decorator
95
 
96
 
97
- def op(name, *, view='basic'):
98
  '''Decorator for defining an operation.'''
99
  def decorator(func):
100
  sig = inspect.signature(func)
@@ -109,6 +127,9 @@ def op(name, *, view='basic'):
109
  if param.kind == param.KEYWORD_ONLY}
110
  outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later.
111
  op = Op(func, name, params=params, inputs=inputs, outputs=outputs, type=view)
 
 
 
112
  ALL_OPS[name] = op
113
  return func
114
  return decorator
 
11
  class Op:
12
  func: callable
13
  name: str
14
+ params: dict # name -> default
15
+ inputs: dict # name -> type
16
+ outputs: dict # name -> type
17
+ type: str # The UI to use for this operation.
18
+ sub_nodes: list = None # If set, these nodes can be placed inside the operation's node.
19
 
20
  def __call__(self, *inputs, **params):
21
  # Convert parameters.
 
40
  res = self.func(*inputs, **params)
41
  return res
42
 
43
+ def to_json(self):
44
+ return {
45
+ 'type': self.type,
46
+ 'data': { 'title': self.name, 'params': self.params },
47
+ 'targetPosition': 'left' if self.inputs else None,
48
+ 'sourcePosition': 'right' if self.outputs else None,
49
+ 'sub_nodes': [sub.to_json() for sub in self.sub_nodes.values()] if self.sub_nodes else None,
50
+ }
51
+
52
+
53
  @dataclasses.dataclass
54
  class RelationDefinition:
55
+ '''Defines a set of edges.'''
56
+ df: str # The DataFrame that contains the edges.
57
+ source_column: str # The column in the edge DataFrame that contains the source node ID.
58
+ target_column: str # The column in the edge DataFrame that contains the target node ID.
59
+ source_table: str # The DataFrame that contains the source nodes.
60
+ target_table: str # The DataFrame that contains the target nodes.
61
+ source_key: str # The column in the source table that contains the node ID.
62
+ target_key: str # The column in the target table that contains the node ID.
63
 
64
  @dataclasses.dataclass
65
  class Bundle:
66
+ '''A collection of DataFrames and other data.
67
+
68
+ Can efficiently represent a knowledge graph (homogeneous or heterogeneous) or tabular data.
69
+ It can also carry other data, such as a trained model.
70
+ '''
71
  dfs: dict
72
  relations: list[RelationDefinition]
73
+ other: dict = None
74
 
75
  @classmethod
76
  def from_nx(cls, graph: nx.Graph):
 
112
  return decorator
113
 
114
 
115
+ def op(name, *, view='basic', sub_nodes=None):
116
  '''Decorator for defining an operation.'''
117
  def decorator(func):
118
  sig = inspect.signature(func)
 
127
  if param.kind == param.KEYWORD_ONLY}
128
  outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later.
129
  op = Op(func, name, params=params, inputs=inputs, outputs=outputs, type=view)
130
+ if sub_nodes is not None:
131
+ op.sub_nodes = sub_nodes
132
+ op.type = 'sub_flow'
133
  ALL_OPS[name] = op
134
  return func
135
  return decorator
server/pytorch_model_ops.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Boxes for defining and using PyTorch models.'''
2
+ import inspect
3
+ from . import ops
4
+
5
+ LAYERS = {}
6
+
7
+ @ops.op("Define PyTorch model", sub_nodes=LAYERS)
8
+ def define_pytorch_model(*, sub_flow):
9
+ # import torch # Lazy import because it's slow.
10
+ print('sub_flow:', sub_flow)
11
+ return 'hello ' + str(sub_flow)
12
+
13
+ def register_layer(name):
14
+ def decorator(func):
15
+ sig = inspect.signature(func)
16
+ inputs = {
17
+ name: param.annotation
18
+ for name, param in sig.parameters.items()
19
+ if param.kind != param.KEYWORD_ONLY}
20
+ params = {
21
+ name: param.default if param.default is not inspect._empty else None
22
+ for name, param in sig.parameters.items()
23
+ if param.kind == param.KEYWORD_ONLY}
24
+ outputs = {'x': 'tensor'}
25
+ LAYERS[name] = ops.Op(func, name, params=params, inputs=inputs, outputs=outputs, type='vertical')
26
+ return func
27
+ return decorator
28
+
29
+ @register_layer('LayerNorm')
30
+ def normalization():
31
+ return 'LayerNorm'
32
+
33
+ @register_layer('Dropout')
34
+ def dropout(*, p=0.5):
35
+ return f'Dropout ({p})'
36
+
37
+ @register_layer('Linear')
38
+ def linear(*, output_dim: int):
39
+ return f'Linear {output_dim}'
40
+
41
+ @register_layer('Graph Convolution')
42
+ def graph_convolution():
43
+ return 'GraphConv'
44
+
45
+ @register_layer('Nonlinearity')
46
+ def nonlinearity():
47
+ return 'ReLU'
48
+
web/index.html CHANGED
@@ -4,7 +4,7 @@
4
  <meta charset="UTF-8" />
5
  <link rel="icon" type="image/png" href="/public/favicon.ico" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
- <title>Svelte Flow Starter</title>
8
  </head>
9
  <body>
10
  <div id="app"></div>
 
4
  <meta charset="UTF-8" />
5
  <link rel="icon" type="image/png" href="/public/favicon.ico" />
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>LynxKite 2024</title>
8
  </head>
9
  <body>
10
  <div id="app"></div>
web/src/LynxKiteFlow.svelte CHANGED
@@ -14,8 +14,10 @@
14
  type NodeTypes,
15
  } from '@xyflow/svelte';
16
  import NodeWithParams from './NodeWithParams.svelte';
 
17
  import NodeWithGraphView from './NodeWithGraphView.svelte';
18
  import NodeWithTableView from './NodeWithTableView.svelte';
 
19
  import NodeSearch from './NodeSearch.svelte';
20
  import '@xyflow/svelte/dist/style.css';
21
 
@@ -23,8 +25,10 @@
23
 
24
  const nodeTypes: NodeTypes = {
25
  basic: NodeWithParams,
 
26
  graph_view: NodeWithGraphView,
27
  table_view: NodeWithTableView,
 
28
  };
29
 
30
  export let path = '';
@@ -43,23 +47,23 @@
43
  $: fetchWorkspace(path);
44
 
45
  function closeNodeSearch() {
46
- nodeSearchPos = undefined;
47
  }
48
  function toggleNodeSearch({ detail: { event } }) {
49
- if (nodeSearchPos) {
50
  closeNodeSearch();
51
  return;
52
  }
53
  event.preventDefault();
54
- nodeSearchPos = {
55
- top: event.offsetY,
56
- left: event.offsetX - 155,
57
  };
58
  }
59
  function addNode(e) {
60
  const node = {...e.detail};
61
  nodes.update((n) => {
62
- node.position = screenToFlowPosition({x: nodeSearchPos.left, y: nodeSearchPos.top});
63
  const title = node.data.title;
64
  let i = 1;
65
  node.id = `${title} ${i}`;
@@ -67,6 +71,12 @@
67
  i += 1;
68
  node.id = `${title} ${i}`;
69
  }
 
 
 
 
 
 
70
  return [...n, node]
71
  });
72
  closeNodeSearch();
@@ -79,7 +89,11 @@
79
  }
80
  getBoxes();
81
 
82
- let nodeSearchPos: XYPosition | undefined = undefined;
 
 
 
 
83
 
84
  const graph = derived([nodes, edges], ([nodes, edges]) => ({ nodes, edges }));
85
  let backendWorkspace: string;
@@ -120,20 +134,36 @@
120
  return edges.filter((e) => e.source === connection.source || e.target !== connection.target);
121
  });
122
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  </script>
125
 
126
  <div style:height="100%">
127
  <SvelteFlow {nodes} {edges} {nodeTypes} fitView
128
  on:paneclick={toggleNodeSearch}
 
129
  proOptions={{ hideAttribution: true }}
130
  maxZoom={1.5}
131
  minZoom={0.3}
132
  onconnect={onconnect}
133
  >
134
- <Background patternColor="#39bcf3" />
135
  <Controls />
136
  <MiniMap />
137
- {#if nodeSearchPos}<NodeSearch boxes={$boxes} on:cancel={closeNodeSearch} on:add={addNode} pos={nodeSearchPos} />{/if}
 
 
138
  </SvelteFlow>
139
  </div>
 
14
  type NodeTypes,
15
  } from '@xyflow/svelte';
16
  import NodeWithParams from './NodeWithParams.svelte';
17
+ import NodeWithParamsVertical from './NodeWithParamsVertical.svelte';
18
  import NodeWithGraphView from './NodeWithGraphView.svelte';
19
  import NodeWithTableView from './NodeWithTableView.svelte';
20
+ import NodeWithSubFlow from './NodeWithSubFlow.svelte';
21
  import NodeSearch from './NodeSearch.svelte';
22
  import '@xyflow/svelte/dist/style.css';
23
 
 
25
 
26
  const nodeTypes: NodeTypes = {
27
  basic: NodeWithParams,
28
+ vertical: NodeWithParamsVertical,
29
  graph_view: NodeWithGraphView,
30
  table_view: NodeWithTableView,
31
+ sub_flow: NodeWithSubFlow,
32
  };
33
 
34
  export let path = '';
 
47
  $: fetchWorkspace(path);
48
 
49
  function closeNodeSearch() {
50
+ nodeSearchSettings = undefined;
51
  }
52
  function toggleNodeSearch({ detail: { event } }) {
53
+ if (nodeSearchSettings) {
54
  closeNodeSearch();
55
  return;
56
  }
57
  event.preventDefault();
58
+ nodeSearchSettings = {
59
+ pos: { x: event.clientX, y: event.clientY },
60
+ boxes: $boxes,
61
  };
62
  }
63
  function addNode(e) {
64
  const node = {...e.detail};
65
  nodes.update((n) => {
66
+ node.position = screenToFlowPosition({x: nodeSearchSettings.pos.x, y: nodeSearchSettings.pos.y});
67
  const title = node.data.title;
68
  let i = 1;
69
  node.id = `${title} ${i}`;
 
71
  i += 1;
72
  node.id = `${title} ${i}`;
73
  }
74
+ node.parentNode = nodeSearchSettings.parentNode;
75
+ if (node.parentNode) {
76
+ node.extent = 'parent';
77
+ const parent = n.find((x) => x.id === node.parentNode);
78
+ node.position = { x: node.position.x - parent.position.x, y: node.position.y - parent.position.y };
79
+ }
80
  return [...n, node]
81
  });
82
  closeNodeSearch();
 
89
  }
90
  getBoxes();
91
 
92
+ let nodeSearchSettings: {
93
+ pos: XYPosition,
94
+ boxes: any[],
95
+ parentNode: string,
96
+ };
97
 
98
  const graph = derived([nodes, edges], ([nodes, edges]) => ({ nodes, edges }));
99
  let backendWorkspace: string;
 
134
  return edges.filter((e) => e.source === connection.source || e.target !== connection.target);
135
  });
136
  }
137
+ function nodeClick(e) {
138
+ const node = e.detail.node;
139
+ const meta = $boxes.find(m => m.data.title === node.data.title);
140
+ if (!meta) return;
141
+ const sub_nodes = meta.sub_nodes;
142
+ if (!sub_nodes) return;
143
+ const event = e.detail.event;
144
+ if (event.target.classList.contains('title')) return;
145
+ nodeSearchSettings = {
146
+ pos: { x: event.clientX, y: event.clientY },
147
+ boxes: sub_nodes,
148
+ parentNode: node.id,
149
+ };
150
+ }
151
 
152
  </script>
153
 
154
  <div style:height="100%">
155
  <SvelteFlow {nodes} {edges} {nodeTypes} fitView
156
  on:paneclick={toggleNodeSearch}
157
+ on:nodeclick={nodeClick}
158
  proOptions={{ hideAttribution: true }}
159
  maxZoom={1.5}
160
  minZoom={0.3}
161
  onconnect={onconnect}
162
  >
 
163
  <Controls />
164
  <MiniMap />
165
+ {#if nodeSearchSettings}
166
+ <NodeSearch pos={nodeSearchSettings.pos} boxes={nodeSearchSettings.boxes} on:cancel={closeNodeSearch} on:add={addNode} />
167
+ {/if}
168
  </SvelteFlow>
169
  </div>
web/src/LynxKiteNode.svelte CHANGED
@@ -3,6 +3,8 @@
3
 
4
  type $$Props = NodeProps;
5
 
 
 
6
  export let id: $$Props['id']; id;
7
  export let data: $$Props['data'];
8
  export let dragHandle: $$Props['dragHandle'] = undefined; dragHandle;
@@ -17,15 +19,20 @@
17
  export let sourcePosition: $$Props['sourcePosition'] = undefined; sourcePosition;
18
  export let positionAbsoluteX: $$Props['positionAbsoluteX'] = undefined; positionAbsoluteX;
19
  export let positionAbsoluteY: $$Props['positionAbsoluteY'] = undefined; positionAbsoluteY;
 
20
 
21
  let expanded = true;
22
  function titleClicked() {
23
  expanded = !expanded;
 
 
 
 
24
  }
25
  </script>
26
 
27
- <div class="node-container">
28
- <div class="lynxkite-node">
29
  <div class="title" on:click={titleClicked}>
30
  {data.title}
31
  {#if data.error}<span class="error-sign">⚠️</span>{/if}
@@ -56,15 +63,16 @@
56
  }
57
  .node-container {
58
  padding: 8px;
 
 
 
59
  }
60
  .lynxkite-node {
61
  box-shadow: 0px 5px 50px 0px rgba(0, 0, 0, 0.3);
62
  background: white;
63
- min-width: 200px;
64
- max-width: 400px;
65
- max-height: 400px;
66
  overflow-y: auto;
67
  border-radius: 1px;
 
68
  }
69
  .title {
70
  background: #ff8800;
 
3
 
4
  type $$Props = NodeProps;
5
 
6
+ export let nodeStyle = '';
7
+ export let containerStyle = '';
8
  export let id: $$Props['id']; id;
9
  export let data: $$Props['data'];
10
  export let dragHandle: $$Props['dragHandle'] = undefined; dragHandle;
 
19
  export let sourcePosition: $$Props['sourcePosition'] = undefined; sourcePosition;
20
  export let positionAbsoluteX: $$Props['positionAbsoluteX'] = undefined; positionAbsoluteX;
21
  export let positionAbsoluteY: $$Props['positionAbsoluteY'] = undefined; positionAbsoluteY;
22
+ export let onToggle = () => {};
23
 
24
  let expanded = true;
25
  function titleClicked() {
26
  expanded = !expanded;
27
+ onToggle({ expanded });
28
+ }
29
+ function asPx(n: number) {
30
+ return n ? n + 'px' : undefined;
31
  }
32
  </script>
33
 
34
+ <div class="node-container" style:width={asPx(width)} style:height={asPx(height)} style={containerStyle}>
35
+ <div class="lynxkite-node" style={nodeStyle}>
36
  <div class="title" on:click={titleClicked}>
37
  {data.title}
38
  {#if data.error}<span class="error-sign">⚠️</span>{/if}
 
63
  }
64
  .node-container {
65
  padding: 8px;
66
+ min-width: 200px;
67
+ max-width: 400px;
68
+ max-height: 400px;
69
  }
70
  .lynxkite-node {
71
  box-shadow: 0px 5px 50px 0px rgba(0, 0, 0, 0.3);
72
  background: white;
 
 
 
73
  overflow-y: auto;
74
  border-radius: 1px;
75
+ height: 100%;
76
  }
77
  .title {
78
  background: #ff8800;
web/src/NodeSearch.svelte CHANGED
@@ -30,7 +30,8 @@
30
  }
31
  function addSelected() {
32
  const node = {...hits[selectedIndex].item};
33
- node.position = {x: pos.left, y: pos.top};
 
34
  dispatch('add', node);
35
  }
36
  async function lostFocus(e) {
@@ -41,9 +42,7 @@
41
 
42
  </script>
43
 
44
- <div class="node-search"
45
- style="top: {pos.top}px; left: {pos.left}px; right: {pos.right}px; bottom: {pos.bottom}px;">
46
-
47
  <input
48
  bind:this={searchBox}
49
  on:input={onInput}
@@ -83,7 +82,7 @@ style="top: {pos.top}px; left: {pos.left}px; right: {pos.right}px; bottom: {pos.
83
  border-radius: 4px;
84
  }
85
  .node-search {
86
- position: absolute;
87
  width: 300px;
88
  z-index: 5;
89
  padding: 4px;
 
30
  }
31
  function addSelected() {
32
  const node = {...hits[selectedIndex].item};
33
+ delete node.sub_nodes;
34
+ node.position = pos;
35
  dispatch('add', node);
36
  }
37
  async function lostFocus(e) {
 
42
 
43
  </script>
44
 
45
+ <div class="node-search" style="top: {pos.y}px; left: {pos.x}px;">
 
 
46
  <input
47
  bind:this={searchBox}
48
  on:input={onInput}
 
82
  border-radius: 4px;
83
  }
84
  .node-search {
85
+ position: fixed;
86
  width: 300px;
87
  z-index: 5;
88
  padding: 4px;
web/src/NodeWithParamsVertical.svelte ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import { type NodeProps, useSvelteFlow } from '@xyflow/svelte';
3
+ import LynxKiteNode from './LynxKiteNode.svelte';
4
+ type $$Props = NodeProps;
5
+ export let id: $$Props['id'];
6
+ export let data: $$Props['data'];
7
+ const { updateNodeData } = useSvelteFlow();
8
+ </script>
9
+
10
+ <LynxKiteNode {...$$props} sourcePosition="bottom" targetPosition="top">
11
+ {#each Object.entries(data.params) as [name, value]}
12
+ <div class="param">
13
+ <label>
14
+ {name}<br>
15
+ <input
16
+ value={value}
17
+ on:input={(evt) => updateNodeData(id, { params: { ...data.params, [name]: evt.currentTarget.value } })}
18
+ />
19
+ </label>
20
+ </div>
21
+ {/each}
22
+ </LynxKiteNode>
23
+ <style>
24
+ .param {
25
+ padding: 8px;
26
+ }
27
+ .param label {
28
+ font-size: 12px;
29
+ display: block;
30
+ }
31
+ .param input {
32
+ width: calc(100% - 8px);
33
+ }
34
+ </style>
web/src/NodeWithSubFlow.svelte ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import { type NodeProps, useNodes } from '@xyflow/svelte';
3
+ import LynxKiteNode from './LynxKiteNode.svelte';
4
+ type $$Props = NodeProps;
5
+ const nodes = useNodes();
6
+ export let id: $$Props['id'];
7
+ export let data: $$Props['data'];
8
+ let isExpanded = true;
9
+ function onToggle({ expanded }) {
10
+ isExpanded = expanded;
11
+ console.log('onToggle', expanded, height);
12
+ nodes.update((n) =>
13
+ n.map((node) =>
14
+ node.parentNode === id
15
+ ? { ...node, hidden: !expanded }
16
+ : node));
17
+ }
18
+ function computeSize(nodes) {
19
+ let width = 200;
20
+ let height = 200;
21
+ for (const node of nodes) {
22
+ if (node.parentNode === id) {
23
+ width = Math.max(width, node.position.x + 300);
24
+ height = Math.max(height, node.position.y + 200);
25
+ }
26
+ }
27
+ return { width, height };
28
+ }
29
+ $: ({ width, height } = computeSize($nodes));
30
+ </script>
31
+
32
+ <LynxKiteNode
33
+ {...$$props}
34
+ width={isExpanded && width} height={isExpanded && height}
35
+ nodeStyle="background: transparent;" containerStyle="max-width: none; max-height: none;" {onToggle}>
36
+ </LynxKiteNode>
37
+ <style>
38
+ </style>