darabos commited on
Commit
dbf89c5
·
1 Parent(s): 56cf2e9

Update node display/error in CRDT one by one during execution.

Browse files
lynxkite-app/src/lynxkite_app/crdt.py CHANGED
@@ -224,17 +224,16 @@ async def execute(
224
  assert path.is_relative_to(config.DATA_PATH), "Provided workspace path is invalid"
225
  # Save user changes before executing, in case the execution fails.
226
  workspace.save(ws_pyd, path)
 
227
  await workspace.execute(ws_pyd)
228
  workspace.save(ws_pyd, path)
229
- # Execution happened on the Python object, we need to replicate
230
- # the results to the CRDT object.
231
- with ws_crdt.doc.transaction():
232
- for nc, np in zip(ws_crdt["nodes"], ws_pyd.nodes):
233
- if "data" not in nc:
234
- nc["data"] = pycrdt.Map()
235
- # Display is added as a non collaborative field.
236
- nc["data"]["display"] = np.data.display
237
- nc["data"]["error"] = np.data.error
238
 
239
 
240
  @contextlib.asynccontextmanager
 
224
  assert path.is_relative_to(config.DATA_PATH), "Provided workspace path is invalid"
225
  # Save user changes before executing, in case the execution fails.
226
  workspace.save(ws_pyd, path)
227
+ add_crdt_bindings(ws_pyd, ws_crdt)
228
  await workspace.execute(ws_pyd)
229
  workspace.save(ws_pyd, path)
230
+
231
+
232
+ def add_crdt_bindings(ws_pyd: workspace.Workspace, ws_crdt: pycrdt.Map):
233
+ for nc, np in zip(ws_crdt["nodes"], ws_pyd.nodes):
234
+ if "data" not in nc:
235
+ nc["data"] = pycrdt.Map()
236
+ np._crdt = nc
 
 
237
 
238
 
239
  @contextlib.asynccontextmanager
lynxkite-app/src/lynxkite_app/main.py CHANGED
@@ -5,6 +5,7 @@ import shutil
5
  import pydantic
6
  import fastapi
7
  import importlib
 
8
  import pathlib
9
  import pkgutil
10
  from fastapi.staticfiles import StaticFiles
@@ -18,6 +19,8 @@ if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
18
 
19
  cudf.pandas.install()
20
 
 
 
21
 
22
  def detect_plugins():
23
  plugins = {}
 
5
  import pydantic
6
  import fastapi
7
  import importlib
8
+ import pandas as pd
9
  import pathlib
10
  import pkgutil
11
  from fastapi.staticfiles import StaticFiles
 
19
 
20
  cudf.pandas.install()
21
 
22
+ pd.options.mode.copy_on_write = True # Prepare for Pandas 3.0.
23
+
24
 
25
  def detect_plugins():
26
  plugins = {}
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -94,8 +94,9 @@ class Result:
94
  JSON-serializable.
95
  """
96
 
97
- output: typing.Any
98
  display: ReadOnlyJSON | None = None
 
99
 
100
 
101
  MULTI_INPUT = Input(name="multi", type="*")
 
94
  JSON-serializable.
95
  """
96
 
97
+ output: typing.Any = None
98
  display: ReadOnlyJSON | None = None
99
+ error: str | None = None
100
 
101
 
102
  MULTI_INPUT = Input(name="multi", type="*")
lynxkite-core/src/lynxkite/core/workspace.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  from typing import Optional
5
  import dataclasses
6
  import os
 
7
  import pydantic
8
  import tempfile
9
  from . import ops
@@ -36,6 +37,16 @@ class WorkspaceNode(BaseConfig):
36
  type: str
37
  data: WorkspaceNodeData
38
  position: Position
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  class WorkspaceEdge(BaseConfig):
 
4
  from typing import Optional
5
  import dataclasses
6
  import os
7
+ import pycrdt
8
  import pydantic
9
  import tempfile
10
  from . import ops
 
37
  type: str
38
  data: WorkspaceNodeData
39
  position: Position
40
+ _crdt: pycrdt.Map
41
+
42
+ def publish_result(self, result: ops.Result):
43
+ """Sends the result to the frontend. Call this in an executor when the result is available."""
44
+ with self._crdt.doc.transaction():
45
+ self._crdt["data"]["display"] = result.display
46
+ self._crdt["data"]["error"] = result.error
47
+
48
+ def publish_error(self, error: Exception | str):
49
+ self.publish_result(ops.Result(error=str(error)))
50
 
51
 
52
  class WorkspaceEdge(BaseConfig):
lynxkite-graph-analytics/src/lynxkite_graph_analytics/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
- from .lynxkite_ops import * # noqa (imported to trigger registration)
 
2
  from . import networkx_ops # noqa (imported to trigger registration)
3
  from . import pytorch_model_ops # noqa (imported to trigger registration)
 
1
+ from .core import * # noqa (easier access for core classes)
2
+ from . import lynxkite_ops # noqa (imported to trigger registration)
3
  from . import networkx_ops # noqa (imported to trigger registration)
4
  from . import pytorch_model_ops # noqa (imported to trigger registration)
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py CHANGED
@@ -158,11 +158,10 @@ async def execute(ws):
158
  if all(input in outputs for input in inputs):
159
  # All inputs for this node are ready, we can compute the output.
160
  inputs = [outputs[input] for input in inputs]
161
- data = node.data
162
- params = {**data.params}
163
- op = catalog.get(data.title)
164
  if not op:
165
- data.error = "Operation not found in catalog"
166
  failed += 1
167
  continue
168
  try:
@@ -177,16 +176,11 @@ async def execute(ws):
177
  result = op(*inputs, **params)
178
  except Exception as e:
179
  traceback.print_exc()
180
- data.error = str(e)
181
  failed += 1
182
  continue
183
- if len(op.inputs) == 1 and op.inputs.get("multi") == "*":
184
- # It's a flexible input. Create n+1 handles.
185
- data.inputs = {f"input{i}": None for i in range(len(inputs) + 1)}
186
- data.error = None
187
  outputs[node.id] = result.output
188
- if result.display:
189
- data.display = result.display
190
 
191
 
192
  def df_for_frontend(df: pd.DataFrame, limit: int) -> pd.DataFrame:
 
158
  if all(input in outputs for input in inputs):
159
  # All inputs for this node are ready, we can compute the output.
160
  inputs = [outputs[input] for input in inputs]
161
+ params = {**node.data.params}
162
+ op = catalog.get(node.data.title)
 
163
  if not op:
164
+ node.publish_error("Operation not found in catalog")
165
  failed += 1
166
  continue
167
  try:
 
176
  result = op(*inputs, **params)
177
  except Exception as e:
178
  traceback.print_exc()
179
+ node.publish_error(e)
180
  failed += 1
181
  continue
 
 
 
 
182
  outputs[node.id] = result.output
183
+ node.publish_result(result)
 
184
 
185
 
186
  def df_for_frontend(df: pd.DataFrame, limit: int) -> pd.DataFrame:
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -152,7 +152,7 @@ def _map_color(value):
152
  if pd.api.types.is_numeric_dtype(value):
153
  cmap = matplotlib.cm.get_cmap("viridis")
154
  value = (value - value.min()) / (value.max() - value.min())
155
- rgba = cmap(value)
156
  return [
157
  "#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255))
158
  for r, g, b in rgba[:, :3]
 
152
  if pd.api.types.is_numeric_dtype(value):
153
  cmap = matplotlib.cm.get_cmap("viridis")
154
  value = (value - value.min()) / (value.max() - value.min())
155
+ rgba = cmap(value.values)
156
  return [
157
  "#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255))
158
  for r, g, b in rgba[:, :3]