darabos commited on
Commit
d1d859f
·
1 Parent(s): f98d0ad

Fix node with image.

Browse files
server/executors/one_by_one.py CHANGED
@@ -7,142 +7,166 @@ import traceback
7
  import inspect
8
  import typing
9
 
 
10
  class Context(ops.BaseConfig):
11
- '''Passed to operation functions as "_ctx" if they have such a parameter.'''
12
- node: workspace.WorkspaceNode
13
- last_result: typing.Any = None
 
 
14
 
15
  class Output(ops.BaseConfig):
16
- '''Return this to send values to specific outputs of a node.'''
17
- output_handle: str
18
- value: dict
 
19
 
20
 
21
  def df_to_list(df):
22
- return df.to_dict(orient='records')
 
23
 
24
  def has_ctx(op):
25
- sig = inspect.signature(op.func)
26
- return '_ctx' in sig.parameters
 
27
 
28
  CACHES = {}
29
 
 
30
  def register(env: str, cache: bool = True):
31
- '''Registers the one-by-one executor.'''
32
- if cache:
33
- CACHES[env] = {}
34
- cache = CACHES[env]
35
- else:
36
- cache = None
37
- ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache)
 
38
 
39
  def get_stages(ws, catalog):
40
- '''Inputs on top/bottom are batch inputs. We decompose the graph into a DAG of components along these edges.'''
41
- nodes = {n.id: n for n in ws.nodes}
42
- batch_inputs = {}
43
- inputs = {}
44
- for edge in ws.edges:
45
- inputs.setdefault(edge.target, []).append(edge.source)
46
- node = nodes[edge.target]
47
- op = catalog[node.data.title]
48
- i = op.inputs[edge.targetHandle]
49
- if i.position in 'top or bottom':
50
- batch_inputs.setdefault(edge.target, []).append(edge.source)
51
- stages = []
52
- for bt, bss in batch_inputs.items():
53
- upstream = set(bss)
54
- new = set(bss)
55
- while new:
56
- n = new.pop()
57
- for i in inputs.get(n, []):
58
- if i not in upstream:
59
- upstream.add(i)
60
- new.add(i)
61
- stages.append(upstream)
62
- stages.sort(key=lambda s: len(s))
63
- stages.append(set(nodes))
64
- return stages
65
 
66
 
67
  def _default_serializer(obj):
68
- if isinstance(obj, pydantic.BaseModel):
69
- return obj.dict()
70
- return {"__nonserializable__": id(obj)}
 
71
 
72
  def make_cache_key(obj):
73
- return orjson.dumps(obj, default=_default_serializer)
 
74
 
75
  EXECUTOR_OUTPUT_CACHE = {}
76
 
 
77
  async def await_if_needed(obj):
78
- if inspect.isawaitable(obj):
79
- return await obj
80
- return obj
 
81
 
82
  async def execute(ws, catalog, cache=None):
83
- nodes = {n.id: n for n in ws.nodes}
84
- contexts = {n.id: Context(node=n) for n in ws.nodes}
85
- edges = {n.id: [] for n in ws.nodes}
86
- for e in ws.edges:
87
- edges[e.source].append(e)
88
- tasks = {}
89
- NO_INPUT = object() # Marker for initial tasks.
90
- for node in ws.nodes:
91
- node.data.error = None
92
- op = catalog[node.data.title]
93
- # Start tasks for nodes that have no non-batch inputs.
94
- if all([i.position in 'top or bottom' for i in op.inputs.values()]):
95
- tasks[node.id] = [NO_INPUT]
96
- batch_inputs = {}
97
- # Run the rest until we run out of tasks.
98
- stages = get_stages(ws, catalog)
99
- for stage in stages:
100
- next_stage = {}
101
- while tasks:
102
- n, ts = tasks.popitem()
103
- if n not in stage:
104
- next_stage.setdefault(n, []).extend(ts)
105
- continue
106
- node = nodes[n]
107
- data = node.data
108
- op = catalog[data.title]
109
- params = {**data.params}
110
- if has_ctx(op):
111
- params['_ctx'] = contexts[node.id]
112
- results = []
113
- for task in ts:
114
- try:
115
- inputs = [
116
- batch_inputs[(n, i.name)] if i.position in 'top or bottom' else task
117
- for i in op.inputs.values()]
118
- if cache is not None:
119
- key = make_cache_key((inputs, params))
120
- if key not in cache:
121
- cache[key] = await await_if_needed(op(*inputs, **params))
122
- result = cache[key]
123
- else:
124
- result = await await_if_needed(op(*inputs, **params))
125
- except Exception as e:
126
- traceback.print_exc()
127
- data.error = str(e)
128
- break
129
- contexts[node.id].last_result = result
130
- # Returned lists and DataFrames are considered multiple tasks.
131
- if isinstance(result, pd.DataFrame):
132
- result = df_to_list(result)
133
- elif not isinstance(result, list):
134
- result = [result]
135
- results.extend(result)
136
- else: # Finished all tasks without errors.
137
- if op.type == 'visualization' or op.type == 'table_view' or op.type == 'image':
138
- data.display = results[0]
139
- for edge in edges[node.id]:
140
- t = nodes[edge.target]
141
- op = catalog[t.data.title]
142
- i = op.inputs[edge.targetHandle]
143
- if i.position in 'top or bottom':
144
- batch_inputs.setdefault((edge.target, edge.targetHandle), []).extend(results)
145
- else:
146
- tasks.setdefault(edge.target, []).extend(results)
147
- tasks = next_stage
148
- return contexts
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import inspect
8
  import typing
9
 
10
+
11
  class Context(ops.BaseConfig):
12
+ """Passed to operation functions as "_ctx" if they have such a parameter."""
13
+
14
+ node: workspace.WorkspaceNode
15
+ last_result: typing.Any = None
16
+
17
 
18
  class Output(ops.BaseConfig):
19
+ """Return this to send values to specific outputs of a node."""
20
+
21
+ output_handle: str
22
+ value: dict
23
 
24
 
25
  def df_to_list(df):
26
+ return df.to_dict(orient="records")
27
+
28
 
29
  def has_ctx(op):
30
+ sig = inspect.signature(op.func)
31
+ return "_ctx" in sig.parameters
32
+
33
 
34
  CACHES = {}
35
 
36
+
37
  def register(env: str, cache: bool = True):
38
+ """Registers the one-by-one executor."""
39
+ if cache:
40
+ CACHES[env] = {}
41
+ cache = CACHES[env]
42
+ else:
43
+ cache = None
44
+ ops.EXECUTORS[env] = lambda ws: execute(ws, ops.CATALOGS[env], cache=cache)
45
+
46
 
47
  def get_stages(ws, catalog):
48
+ """Inputs on top/bottom are batch inputs. We decompose the graph into a DAG of components along these edges."""
49
+ nodes = {n.id: n for n in ws.nodes}
50
+ batch_inputs = {}
51
+ inputs = {}
52
+ for edge in ws.edges:
53
+ inputs.setdefault(edge.target, []).append(edge.source)
54
+ node = nodes[edge.target]
55
+ op = catalog[node.data.title]
56
+ i = op.inputs[edge.targetHandle]
57
+ if i.position in "top or bottom":
58
+ batch_inputs.setdefault(edge.target, []).append(edge.source)
59
+ stages = []
60
+ for bt, bss in batch_inputs.items():
61
+ upstream = set(bss)
62
+ new = set(bss)
63
+ while new:
64
+ n = new.pop()
65
+ for i in inputs.get(n, []):
66
+ if i not in upstream:
67
+ upstream.add(i)
68
+ new.add(i)
69
+ stages.append(upstream)
70
+ stages.sort(key=lambda s: len(s))
71
+ stages.append(set(nodes))
72
+ return stages
73
 
74
 
75
  def _default_serializer(obj):
76
+ if isinstance(obj, pydantic.BaseModel):
77
+ return obj.dict()
78
+ return {"__nonserializable__": id(obj)}
79
+
80
 
81
  def make_cache_key(obj):
82
+ return orjson.dumps(obj, default=_default_serializer)
83
+
84
 
85
  EXECUTOR_OUTPUT_CACHE = {}
86
 
87
+
88
  async def await_if_needed(obj):
89
+ if inspect.isawaitable(obj):
90
+ return await obj
91
+ return obj
92
+
93
 
94
  async def execute(ws, catalog, cache=None):
95
+ nodes = {n.id: n for n in ws.nodes}
96
+ contexts = {n.id: Context(node=n) for n in ws.nodes}
97
+ edges = {n.id: [] for n in ws.nodes}
98
+ for e in ws.edges:
99
+ edges[e.source].append(e)
100
+ tasks = {}
101
+ NO_INPUT = object() # Marker for initial tasks.
102
+ for node in ws.nodes:
103
+ node.data.error = None
104
+ op = catalog.get(node.data.title)
105
+ if op is None:
106
+ node.data.error = f'Operation "{node.data.title}" not found.'
107
+ continue
108
+ # Start tasks for nodes that have no non-batch inputs.
109
+ if all([i.position in "top or bottom" for i in op.inputs.values()]):
110
+ tasks[node.id] = [NO_INPUT]
111
+ batch_inputs = {}
112
+ # Run the rest until we run out of tasks.
113
+ stages = get_stages(ws, catalog)
114
+ for stage in stages:
115
+ next_stage = {}
116
+ while tasks:
117
+ n, ts = tasks.popitem()
118
+ if n not in stage:
119
+ next_stage.setdefault(n, []).extend(ts)
120
+ continue
121
+ node = nodes[n]
122
+ data = node.data
123
+ op = catalog[data.title]
124
+ params = {**data.params}
125
+ if has_ctx(op):
126
+ params["_ctx"] = contexts[node.id]
127
+ results = []
128
+ for task in ts:
129
+ try:
130
+ inputs = [
131
+ batch_inputs[(n, i.name)]
132
+ if i.position in "top or bottom"
133
+ else task
134
+ for i in op.inputs.values()
135
+ ]
136
+ if cache is not None:
137
+ key = make_cache_key((inputs, params))
138
+ if key not in cache:
139
+ cache[key] = await await_if_needed(op(*inputs, **params))
140
+ result = cache[key]
141
+ else:
142
+ result = await await_if_needed(op(*inputs, **params))
143
+ except Exception as e:
144
+ traceback.print_exc()
145
+ data.error = str(e)
146
+ break
147
+ contexts[node.id].last_result = result
148
+ # Returned lists and DataFrames are considered multiple tasks.
149
+ if isinstance(result, pd.DataFrame):
150
+ result = df_to_list(result)
151
+ elif not isinstance(result, list):
152
+ result = [result]
153
+ results.extend(result)
154
+ else: # Finished all tasks without errors.
155
+ if (
156
+ op.type == "visualization"
157
+ or op.type == "table_view"
158
+ or op.type == "image"
159
+ ):
160
+ data.display = results[0]
161
+ for edge in edges[node.id]:
162
+ t = nodes[edge.target]
163
+ op = catalog[t.data.title]
164
+ i = op.inputs[edge.targetHandle]
165
+ if i.position in "top or bottom":
166
+ batch_inputs.setdefault(
167
+ (edge.target, edge.targetHandle), []
168
+ ).extend(results)
169
+ else:
170
+ tasks.setdefault(edge.target, []).extend(results)
171
+ tasks = next_stage
172
+ return contexts
web/src/workspace/Workspace.tsx CHANGED
@@ -15,7 +15,6 @@ import {
15
  type Node,
16
  type Edge,
17
  type Connection,
18
- type NodeTypes,
19
  useReactFlow,
20
  MiniMap,
21
  } from '@xyflow/react';
@@ -28,17 +27,14 @@ import Atom from '~icons/tabler/atom.jsx';
28
  import { syncedStore, getYjsDoc } from "@syncedstore/core";
29
  import { WebsocketProvider } from "y-websocket";
30
  import NodeWithParams from './nodes/NodeWithParams';
31
- // import NodeWithVisualization from './NodeWithVisualization';
32
- // import NodeWithImage from './NodeWithImage';
33
  // import NodeWithTableView from './NodeWithTableView';
34
- // import NodeWithSubFlow from './NodeWithSubFlow';
35
- // import NodeWithArea from './NodeWithArea';
36
- // import NodeSearch from './NodeSearch';
37
  import EnvironmentSelector from './EnvironmentSelector';
38
  import { LynxKiteState } from './LynxKiteState';
39
  import '@xyflow/react/dist/style.css';
40
  import { Workspace, WorkspaceNode } from "../apiTypes.ts";
41
  import NodeSearch, { OpsOp, Catalog, Catalogs } from "./NodeSearch.tsx";
 
 
42
 
43
  export default function (props: any) {
44
  return (
@@ -144,6 +140,8 @@ function LynxKiteFlow() {
144
  const nodeTypes = useMemo(() => ({
145
  basic: NodeWithParams,
146
  table_view: NodeWithParams,
 
 
147
  }), []);
148
  const closeNodeSearch = useCallback(() => {
149
  setNodeSearchSettings(undefined);
@@ -160,7 +158,7 @@ function LynxKiteFlow() {
160
  pos: { x: event.clientX, y: event.clientY },
161
  boxes: catalog.data![state.workspace.env!],
162
  });
163
- }, [setNodeSearchSettings, suppressSearchUntil]);
164
  const addNode = useCallback((meta: OpsOp) => {
165
  const node: Partial<WorkspaceNode> = {
166
  type: meta.type,
@@ -184,7 +182,7 @@ function LynxKiteFlow() {
184
  wnodes.push(node as WorkspaceNode);
185
  setNodes([...nodes, node as WorkspaceNode]);
186
  closeNodeSearch();
187
- }, [state, reactFlow, setNodes]);
188
 
189
  const onConnect = useCallback((connection: Connection) => {
190
  setSuppressSearchUntil(Date.now() + 200);
 
15
  type Node,
16
  type Edge,
17
  type Connection,
 
18
  useReactFlow,
19
  MiniMap,
20
  } from '@xyflow/react';
 
27
  import { syncedStore, getYjsDoc } from "@syncedstore/core";
28
  import { WebsocketProvider } from "y-websocket";
29
  import NodeWithParams from './nodes/NodeWithParams';
 
 
30
  // import NodeWithTableView from './NodeWithTableView';
 
 
 
31
  import EnvironmentSelector from './EnvironmentSelector';
32
  import { LynxKiteState } from './LynxKiteState';
33
  import '@xyflow/react/dist/style.css';
34
  import { Workspace, WorkspaceNode } from "../apiTypes.ts";
35
  import NodeSearch, { OpsOp, Catalog, Catalogs } from "./NodeSearch.tsx";
36
+ import NodeWithVisualization from "./nodes/NodeWithVisualization.tsx";
37
+ import NodeWithImage from "./nodes/NodeWithImage.tsx";
38
 
39
  export default function (props: any) {
40
  return (
 
140
  const nodeTypes = useMemo(() => ({
141
  basic: NodeWithParams,
142
  table_view: NodeWithParams,
143
+ visualization: NodeWithVisualization,
144
+ image: NodeWithImage,
145
  }), []);
146
  const closeNodeSearch = useCallback(() => {
147
  setNodeSearchSettings(undefined);
 
158
  pos: { x: event.clientX, y: event.clientY },
159
  boxes: catalog.data![state.workspace.env!],
160
  });
161
+ }, [catalog, state, setNodeSearchSettings, suppressSearchUntil]);
162
  const addNode = useCallback((meta: OpsOp) => {
163
  const node: Partial<WorkspaceNode> = {
164
  type: meta.type,
 
182
  wnodes.push(node as WorkspaceNode);
183
  setNodes([...nodes, node as WorkspaceNode]);
184
  closeNodeSearch();
185
+ }, [nodeSearchSettings, state, reactFlow, setNodes]);
186
 
187
  const onConnect = useCallback((connection: Connection) => {
188
  setSuppressSearchUntil(Date.now() + 200);
web/src/workspace/nodes/NodeWithImage.tsx ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import NodeWithParams from './NodeWithParams';
2
+
3
+ const NodeWithImage = (props: any) => {
4
+ return (
5
+ <NodeWithParams {...props}>
6
+ {props.data.display && <img src={props.data.display} alt="Node Display" />}
7
+ </NodeWithParams>
8
+ );
9
+ };
10
+
11
+ export default NodeWithImage;
web/src/workspace/nodes/NodeWithParams.tsx CHANGED
@@ -23,6 +23,7 @@ function NodeWithParams(props: any) {
23
  onChange={(value: any, opts?: UpdateOptions) => setParam(name, value, opts || {})}
24
  />
25
  )}
 
26
  </LynxKiteNode >
27
  );
28
  }
 
23
  onChange={(value: any, opts?: UpdateOptions) => setParam(name, value, opts || {})}
24
  />
25
  )}
26
+ {props.children}
27
  </LynxKiteNode >
28
  );
29
  }
web/src/workspace/nodes/NodeWithVisualization.tsx ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useEffect } from 'react';
2
+ import NodeWithParams from './NodeWithParams';
3
+ import * as echarts from 'echarts';
4
+
5
+ const NodeWithVisualization = (props: any) => {
6
+ const chartsRef = React.useRef<HTMLDivElement>(null);
7
+ const chartsInstanceRef = React.useRef<echarts.ECharts>();
8
+ useEffect(() => {
9
+ const opts = props.data?.display?.value;
10
+ if (!opts || !chartsRef.current) return;
11
+ console.log(chartsRef.current);
12
+ chartsInstanceRef.current = echarts.init(chartsRef.current, null, { renderer: 'canvas', width: 250, height: 250 });
13
+ chartsInstanceRef.current.setOption(opts);
14
+ const onResize = () => chartsInstanceRef.current?.resize();
15
+ window.addEventListener('resize', onResize);
16
+ return () => {
17
+ window.removeEventListener('resize', onResize);
18
+ chartsInstanceRef.current?.dispose();
19
+ };
20
+ }, [props.data?.display?.value]);
21
+ return (
22
+ <NodeWithParams {...props}>
23
+ <div className="box" draggable={false} ref={chartsRef} />;
24
+ </NodeWithParams>
25
+ );
26
+ };
27
+
28
+ export default NodeWithVisualization;