darabos commited on
Commit
56cf2e9
·
1 Parent(s): e9bf53e

Split lynxkite_ops.py into two.

Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph analytics executor and data types."""
2
+
3
+ from lynxkite.core import ops
4
+ import dataclasses
5
+ import functools
6
+ import networkx as nx
7
+ import pandas as pd
8
+ import polars as pl
9
+ import traceback
10
+ import typing
11
+
12
+
13
+ ENV = "LynxKite Graph Analytics"
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class RelationDefinition:
18
+ """Defines a set of edges."""
19
+
20
+ df: str # The DataFrame that contains the edges.
21
+ source_column: (
22
+ str # The column in the edge DataFrame that contains the source node ID.
23
+ )
24
+ target_column: (
25
+ str # The column in the edge DataFrame that contains the target node ID.
26
+ )
27
+ source_table: str # The DataFrame that contains the source nodes.
28
+ target_table: str # The DataFrame that contains the target nodes.
29
+ source_key: str # The column in the source table that contains the node ID.
30
+ target_key: str # The column in the target table that contains the node ID.
31
+ name: str | None = None # Descriptive name for the relation.
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class Bundle:
36
+ """A collection of DataFrames and other data.
37
+
38
+ Can efficiently represent a knowledge graph (homogeneous or heterogeneous) or tabular data.
39
+ It can also carry other data, such as a trained model.
40
+ """
41
+
42
+ dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
43
+ relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
44
+ other: dict[str, typing.Any] = None
45
+
46
+ @classmethod
47
+ def from_nx(cls, graph: nx.Graph):
48
+ edges = nx.to_pandas_edgelist(graph)
49
+ d = dict(graph.nodes(data=True))
50
+ nodes = pd.DataFrame(d.values(), index=d.keys())
51
+ nodes["id"] = nodes.index
52
+ if "index" in nodes.columns:
53
+ nodes.drop(columns=["index"], inplace=True)
54
+ return cls(
55
+ dfs={"edges": edges, "nodes": nodes},
56
+ relations=[
57
+ RelationDefinition(
58
+ df="edges",
59
+ source_column="source",
60
+ target_column="target",
61
+ source_table="nodes",
62
+ target_table="nodes",
63
+ source_key="id",
64
+ target_key="id",
65
+ )
66
+ ],
67
+ )
68
+
69
+ @classmethod
70
+ def from_df(cls, df: pd.DataFrame):
71
+ return cls(dfs={"df": df})
72
+
73
+ def to_nx(self):
74
+ # TODO: Use relations.
75
+ graph = nx.DiGraph()
76
+ if "nodes" in self.dfs:
77
+ df = self.dfs["nodes"]
78
+ if df.index.name != "id":
79
+ df = df.set_index("id")
80
+ graph.add_nodes_from(df.to_dict("index").items())
81
+ if "edges" in self.dfs:
82
+ edges = self.dfs["edges"]
83
+ graph.add_edges_from(
84
+ [
85
+ (
86
+ e["source"],
87
+ e["target"],
88
+ {
89
+ k: e[k]
90
+ for k in edges.columns
91
+ if k not in ["source", "target"]
92
+ },
93
+ )
94
+ for e in edges.to_records()
95
+ ]
96
+ )
97
+ return graph
98
+
99
+ def copy(self):
100
+ """Returns a medium depth copy of the bundle. The Bundle is completely new, but the DataFrames and RelationDefinitions are shared."""
101
+ return Bundle(
102
+ dfs=dict(self.dfs),
103
+ relations=list(self.relations),
104
+ other=dict(self.other) if self.other else None,
105
+ )
106
+
107
+ def to_dict(self, limit: int = 100):
108
+ return {
109
+ "dataframes": {
110
+ name: {
111
+ "columns": [str(c) for c in df.columns],
112
+ "data": df_for_frontend(df, limit).values.tolist(),
113
+ }
114
+ for name, df in self.dfs.items()
115
+ },
116
+ "relations": [dataclasses.asdict(relation) for relation in self.relations],
117
+ "other": self.other,
118
+ }
119
+
120
+
121
+ def nx_node_attribute_func(name):
122
+ """Decorator for wrapping a function that adds a NetworkX node attribute."""
123
+
124
+ def decorator(func):
125
+ @functools.wraps(func)
126
+ def wrapper(graph: nx.Graph, **kwargs):
127
+ graph = graph.copy()
128
+ attr = func(graph, **kwargs)
129
+ nx.set_node_attributes(graph, attr, name)
130
+ return graph
131
+
132
+ return wrapper
133
+
134
+ return decorator
135
+
136
+
137
+ def disambiguate_edges(ws):
138
+ """If an input plug is connected to multiple edges, keep only the last edge."""
139
+ seen = set()
140
+ for edge in reversed(ws.edges):
141
+ if (edge.target, edge.targetHandle) in seen:
142
+ ws.edges.remove(edge)
143
+ seen.add((edge.target, edge.targetHandle))
144
+
145
+
146
+ @ops.register_executor(ENV)
147
+ async def execute(ws):
148
+ catalog: dict[str, ops.Op] = ops.CATALOGS[ws.env]
149
+ disambiguate_edges(ws)
150
+ outputs = {}
151
+ failed = 0
152
+ while len(outputs) + failed < len(ws.nodes):
153
+ for node in ws.nodes:
154
+ if node.id in outputs:
155
+ continue
156
+ # TODO: Take the input/output handles into account.
157
+ inputs = [edge.source for edge in ws.edges if edge.target == node.id]
158
+ if all(input in outputs for input in inputs):
159
+ # All inputs for this node are ready, we can compute the output.
160
+ inputs = [outputs[input] for input in inputs]
161
+ data = node.data
162
+ params = {**data.params}
163
+ op = catalog.get(data.title)
164
+ if not op:
165
+ data.error = "Operation not found in catalog"
166
+ failed += 1
167
+ continue
168
+ try:
169
+ # Convert inputs types to match operation signature.
170
+ for i, (x, p) in enumerate(zip(inputs, op.inputs.values())):
171
+ if p.type == nx.Graph and isinstance(x, Bundle):
172
+ inputs[i] = x.to_nx()
173
+ elif p.type == Bundle and isinstance(x, nx.Graph):
174
+ inputs[i] = Bundle.from_nx(x)
175
+ elif p.type == Bundle and isinstance(x, pd.DataFrame):
176
+ inputs[i] = Bundle.from_df(x)
177
+ result = op(*inputs, **params)
178
+ except Exception as e:
179
+ traceback.print_exc()
180
+ data.error = str(e)
181
+ failed += 1
182
+ continue
183
+ if len(op.inputs) == 1 and op.inputs.get("multi") == "*":
184
+ # It's a flexible input. Create n+1 handles.
185
+ data.inputs = {f"input{i}": None for i in range(len(inputs) + 1)}
186
+ data.error = None
187
+ outputs[node.id] = result.output
188
+ if result.display:
189
+ data.display = result.display
190
+
191
+
192
+ def df_for_frontend(df: pd.DataFrame, limit: int) -> pd.DataFrame:
193
+ """Returns a DataFrame with values that are safe to send to the frontend."""
194
+ df = df[:limit]
195
+ if isinstance(df, pl.LazyFrame):
196
+ df = df.collect()
197
+ if isinstance(df, pl.DataFrame):
198
+ df = df.to_pandas()
199
+ # Convert non-numeric columns to strings.
200
+ for c in df.columns:
201
+ if not pd.api.types.is_numeric_dtype(df[c]):
202
+ df[c] = df[c].astype(str)
203
+ return df
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -1,201 +1,21 @@
1
- """Graph analytics operations. To be split into separate files when we have more."""
2
 
3
  import os
4
  import fsspec
5
  from lynxkite.core import ops
6
  from collections import deque
7
- import dataclasses
8
- import functools
9
  import grandcypher
10
  import joblib
11
  import matplotlib
12
  import networkx as nx
13
  import pandas as pd
14
  import polars as pl
15
- import traceback
16
- import typing
17
  import json
18
 
19
 
20
  mem = joblib.Memory("../joblib-cache")
21
- ENV = "LynxKite Graph Analytics"
22
- op = ops.op_registration(ENV)
23
-
24
-
25
- @dataclasses.dataclass
26
- class RelationDefinition:
27
- """Defines a set of edges."""
28
-
29
- df: str # The DataFrame that contains the edges.
30
- source_column: (
31
- str # The column in the edge DataFrame that contains the source node ID.
32
- )
33
- target_column: (
34
- str # The column in the edge DataFrame that contains the target node ID.
35
- )
36
- source_table: str # The DataFrame that contains the source nodes.
37
- target_table: str # The DataFrame that contains the target nodes.
38
- source_key: str # The column in the source table that contains the node ID.
39
- target_key: str # The column in the target table that contains the node ID.
40
- name: str | None = None # Descriptive name for the relation.
41
-
42
-
43
- @dataclasses.dataclass
44
- class Bundle:
45
- """A collection of DataFrames and other data.
46
-
47
- Can efficiently represent a knowledge graph (homogeneous or heterogeneous) or tabular data.
48
- It can also carry other data, such as a trained model.
49
- """
50
-
51
- dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
52
- relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
53
- other: dict[str, typing.Any] = None
54
-
55
- @classmethod
56
- def from_nx(cls, graph: nx.Graph):
57
- edges = nx.to_pandas_edgelist(graph)
58
- d = dict(graph.nodes(data=True))
59
- nodes = pd.DataFrame(d.values(), index=d.keys())
60
- nodes["id"] = nodes.index
61
- if "index" in nodes.columns:
62
- nodes.drop(columns=["index"], inplace=True)
63
- return cls(
64
- dfs={"edges": edges, "nodes": nodes},
65
- relations=[
66
- RelationDefinition(
67
- df="edges",
68
- source_column="source",
69
- target_column="target",
70
- source_table="nodes",
71
- target_table="nodes",
72
- source_key="id",
73
- target_key="id",
74
- )
75
- ],
76
- )
77
-
78
- @classmethod
79
- def from_df(cls, df: pd.DataFrame):
80
- return cls(dfs={"df": df})
81
-
82
- def to_nx(self):
83
- # TODO: Use relations.
84
- graph = nx.DiGraph()
85
- if "nodes" in self.dfs:
86
- df = self.dfs["nodes"]
87
- if df.index.name != "id":
88
- df = df.set_index("id")
89
- graph.add_nodes_from(df.to_dict("index").items())
90
- if "edges" in self.dfs:
91
- edges = self.dfs["edges"]
92
- graph.add_edges_from(
93
- [
94
- (
95
- e["source"],
96
- e["target"],
97
- {
98
- k: e[k]
99
- for k in edges.columns
100
- if k not in ["source", "target"]
101
- },
102
- )
103
- for e in edges.to_records()
104
- ]
105
- )
106
- return graph
107
-
108
- def copy(self):
109
- """Returns a medium depth copy of the bundle. The Bundle is completely new, but the DataFrames and RelationDefinitions are shared."""
110
- return Bundle(
111
- dfs=dict(self.dfs),
112
- relations=list(self.relations),
113
- other=dict(self.other) if self.other else None,
114
- )
115
-
116
- def to_dict(self, limit: int = 100):
117
- return {
118
- "dataframes": {
119
- name: {
120
- "columns": [str(c) for c in df.columns],
121
- "data": df_for_frontend(df, limit).values.tolist(),
122
- }
123
- for name, df in self.dfs.items()
124
- },
125
- "relations": [dataclasses.asdict(relation) for relation in self.relations],
126
- "other": self.other,
127
- }
128
-
129
-
130
- def nx_node_attribute_func(name):
131
- """Decorator for wrapping a function that adds a NetworkX node attribute."""
132
-
133
- def decorator(func):
134
- @functools.wraps(func)
135
- def wrapper(graph: nx.Graph, **kwargs):
136
- graph = graph.copy()
137
- attr = func(graph, **kwargs)
138
- nx.set_node_attributes(graph, attr, name)
139
- return graph
140
-
141
- return wrapper
142
-
143
- return decorator
144
-
145
-
146
- def disambiguate_edges(ws):
147
- """If an input plug is connected to multiple edges, keep only the last edge."""
148
- seen = set()
149
- for edge in reversed(ws.edges):
150
- if (edge.target, edge.targetHandle) in seen:
151
- ws.edges.remove(edge)
152
- seen.add((edge.target, edge.targetHandle))
153
-
154
-
155
- @ops.register_executor(ENV)
156
- async def execute(ws):
157
- catalog: dict[str, ops.Op] = ops.CATALOGS[ENV]
158
- disambiguate_edges(ws)
159
- outputs = {}
160
- failed = 0
161
- while len(outputs) + failed < len(ws.nodes):
162
- for node in ws.nodes:
163
- if node.id in outputs:
164
- continue
165
- # TODO: Take the input/output handles into account.
166
- inputs = [edge.source for edge in ws.edges if edge.target == node.id]
167
- if all(input in outputs for input in inputs):
168
- # All inputs for this node are ready, we can compute the output.
169
- inputs = [outputs[input] for input in inputs]
170
- data = node.data
171
- params = {**data.params}
172
- op = catalog.get(data.title)
173
- if not op:
174
- data.error = "Operation not found in catalog"
175
- failed += 1
176
- continue
177
- try:
178
- # Convert inputs types to match operation signature.
179
- for i, (x, p) in enumerate(zip(inputs, op.inputs.values())):
180
- if p.type == nx.Graph and isinstance(x, Bundle):
181
- inputs[i] = x.to_nx()
182
- elif p.type == Bundle and isinstance(x, nx.Graph):
183
- inputs[i] = Bundle.from_nx(x)
184
- elif p.type == Bundle and isinstance(x, pd.DataFrame):
185
- inputs[i] = Bundle.from_df(x)
186
- result = op(*inputs, **params)
187
- except Exception as e:
188
- traceback.print_exc()
189
- data.error = str(e)
190
- failed += 1
191
- continue
192
- if len(op.inputs) == 1 and op.inputs.get("multi") == "*":
193
- # It's a flexible input. Create n+1 handles.
194
- data.inputs = {f"input{i}": None for i in range(len(inputs) + 1)}
195
- data.error = None
196
- outputs[node.id] = result.output
197
- if result.display:
198
- data.display = result.display
199
 
200
 
201
  @op("Import Parquet")
@@ -246,14 +66,14 @@ def create_scale_free_graph(*, nodes: int = 10):
246
 
247
 
248
  @op("Compute PageRank")
249
- @nx_node_attribute_func("pagerank")
250
  def compute_pagerank(graph: nx.Graph, *, damping=0.85, iterations=100):
251
  # TODO: This requires scipy to be installed.
252
  return nx.pagerank(graph, alpha=damping, max_iter=iterations)
253
 
254
 
255
  @op("Compute betweenness centrality")
256
- @nx_node_attribute_func("betweenness_centrality")
257
  def compute_betweenness_centrality(graph: nx.Graph, *, k=10):
258
  return nx.betweenness_centrality(graph, k=k)
259
 
@@ -271,7 +91,7 @@ def discard_parallel_edges(graph: nx.Graph):
271
 
272
 
273
  @op("SQL")
274
- def sql(bundle: Bundle, *, query: ops.LongStr, save_as: str = "result"):
275
  """Run a SQL query on the DataFrames in the bundle. Save the results as a new DataFrame."""
276
  bundle = bundle.copy()
277
  if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
@@ -292,7 +112,7 @@ def sql(bundle: Bundle, *, query: ops.LongStr, save_as: str = "result"):
292
 
293
 
294
  @op("Cypher")
295
- def cypher(bundle: Bundle, *, query: ops.LongStr, save_as: str = "result"):
296
  """Run a Cypher query on the graph in the bundle. Save the results as a new DataFrame."""
297
  bundle = bundle.copy()
298
  graph = bundle.to_nx()
@@ -302,7 +122,7 @@ def cypher(bundle: Bundle, *, query: ops.LongStr, save_as: str = "result"):
302
 
303
 
304
  @op("Organize bundle")
305
- def organize_bundle(bundle: Bundle, *, code: ops.LongStr):
306
  """Lets you rename/copy/delete DataFrames, and modify relations.
307
 
308
  TODO: Use a declarative solution instead of Python code. Add UI.
@@ -351,13 +171,13 @@ def _map_color(value):
351
 
352
  @op("Visualize graph", view="visualization")
353
  def visualize_graph(
354
- graph: Bundle,
355
  *,
356
  color_nodes_by: ops.NodeAttribute = None,
357
  label_by: ops.NodeAttribute = None,
358
  color_edges_by: ops.EdgeAttribute = None,
359
  ):
360
- nodes = df_for_frontend(graph.dfs["nodes"], 10_000)
361
  if color_nodes_by:
362
  nodes["color"] = _map_color(nodes[color_nodes_by])
363
  for cols in ["x y", "long lat"]:
@@ -387,7 +207,7 @@ def visualize_graph(
387
  )
388
  curveness = 0.3
389
  nodes = nodes.to_records()
390
- edges = df_for_frontend(
391
  graph.dfs["edges"].drop_duplicates(["source", "target"]), 10_000
392
  )
393
  if color_edges_by:
@@ -446,22 +266,8 @@ def visualize_graph(
446
  return v
447
 
448
 
449
- def df_for_frontend(df: pd.DataFrame, limit: int) -> pd.DataFrame:
450
- """Returns a DataFrame with values that are safe to send to the frontend."""
451
- df = df[:limit]
452
- if isinstance(df, pl.LazyFrame):
453
- df = df.collect()
454
- if isinstance(df, pl.DataFrame):
455
- df = df.to_pandas()
456
- # Convert non-numeric columns to strings.
457
- for c in df.columns:
458
- if not pd.api.types.is_numeric_dtype(df[c]):
459
- df[c] = df[c].astype(str)
460
- return df
461
-
462
-
463
  @op("View tables", view="table_view")
464
- def view_tables(bundle: Bundle, *, limit: int = 100):
465
  return bundle.to_dict(limit=limit)
466
 
467
 
@@ -470,7 +276,7 @@ def view_tables(bundle: Bundle, *, limit: int = 100):
470
  view="graph_creation_view",
471
  outputs=["output"],
472
  )
473
- def create_graph(bundle: Bundle, *, relations: str = None) -> Bundle:
474
  """Replace relations of the given bundle
475
 
476
  relations is a stringified JSON, instead of a dict, because complex Yjs types (arrays, maps)
@@ -489,6 +295,6 @@ def create_graph(bundle: Bundle, *, relations: str = None) -> Bundle:
489
  bundle = bundle.copy()
490
  if not (relations is None or relations.strip() == ""):
491
  bundle.relations = [
492
- RelationDefinition(**r) for r in json.loads(relations).values()
493
  ]
494
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
 
1
+ """Graph analytics operations."""
2
 
3
  import os
4
  import fsspec
5
  from lynxkite.core import ops
6
  from collections import deque
7
+ from . import core
 
8
  import grandcypher
9
  import joblib
10
  import matplotlib
11
  import networkx as nx
12
  import pandas as pd
13
  import polars as pl
 
 
14
  import json
15
 
16
 
17
  mem = joblib.Memory("../joblib-cache")
18
+ op = ops.op_registration(core.ENV)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  @op("Import Parquet")
 
66
 
67
 
68
  @op("Compute PageRank")
69
+ @core.nx_node_attribute_func("pagerank")
70
  def compute_pagerank(graph: nx.Graph, *, damping=0.85, iterations=100):
71
  # TODO: This requires scipy to be installed.
72
  return nx.pagerank(graph, alpha=damping, max_iter=iterations)
73
 
74
 
75
  @op("Compute betweenness centrality")
76
+ @core.nx_node_attribute_func("betweenness_centrality")
77
  def compute_betweenness_centrality(graph: nx.Graph, *, k=10):
78
  return nx.betweenness_centrality(graph, k=k)
79
 
 
91
 
92
 
93
  @op("SQL")
94
+ def sql(bundle: core.Bundle, *, query: ops.LongStr, save_as: str = "result"):
95
  """Run a SQL query on the DataFrames in the bundle. Save the results as a new DataFrame."""
96
  bundle = bundle.copy()
97
  if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
 
112
 
113
 
114
  @op("Cypher")
115
+ def cypher(bundle: core.Bundle, *, query: ops.LongStr, save_as: str = "result"):
116
  """Run a Cypher query on the graph in the bundle. Save the results as a new DataFrame."""
117
  bundle = bundle.copy()
118
  graph = bundle.to_nx()
 
122
 
123
 
124
  @op("Organize bundle")
125
+ def organize_bundle(bundle: core.Bundle, *, code: ops.LongStr):
126
  """Lets you rename/copy/delete DataFrames, and modify relations.
127
 
128
  TODO: Use a declarative solution instead of Python code. Add UI.
 
171
 
172
  @op("Visualize graph", view="visualization")
173
  def visualize_graph(
174
+ graph: core.Bundle,
175
  *,
176
  color_nodes_by: ops.NodeAttribute = None,
177
  label_by: ops.NodeAttribute = None,
178
  color_edges_by: ops.EdgeAttribute = None,
179
  ):
180
+ nodes = core.df_for_frontend(graph.dfs["nodes"], 10_000)
181
  if color_nodes_by:
182
  nodes["color"] = _map_color(nodes[color_nodes_by])
183
  for cols in ["x y", "long lat"]:
 
207
  )
208
  curveness = 0.3
209
  nodes = nodes.to_records()
210
+ edges = core.df_for_frontend(
211
  graph.dfs["edges"].drop_duplicates(["source", "target"]), 10_000
212
  )
213
  if color_edges_by:
 
266
  return v
267
 
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  @op("View tables", view="table_view")
270
+ def view_tables(bundle: core.Bundle, *, limit: int = 100):
271
  return bundle.to_dict(limit=limit)
272
 
273
 
 
276
  view="graph_creation_view",
277
  outputs=["output"],
278
  )
279
+ def create_graph(bundle: core.Bundle, *, relations: str = None) -> core.Bundle:
280
  """Replace relations of the given bundle
281
 
282
  relations is a stringified JSON, instead of a dict, because complex Yjs types (arrays, maps)
 
295
  bundle = bundle.copy()
296
  if not (relations is None or relations.strip() == ""):
297
  bundle.relations = [
298
+ core.RelationDefinition(**r) for r in json.loads(relations).values()
299
  ]
300
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
lynxkite-graph-analytics/tests/test_lynxkite_ops.py CHANGED
@@ -2,12 +2,12 @@ import pandas as pd
2
  import pytest
3
  import networkx as nx
4
 
5
- from lynxkite.core import workspace
6
- from lynxkite_graph_analytics.lynxkite_ops import Bundle, execute, op
7
 
8
 
9
  async def test_execute_operation_not_in_catalog():
10
- ws = workspace.Workspace(env="test")
11
  ws.nodes.append(
12
  workspace.WorkspaceNode(
13
  id="1",
@@ -23,6 +23,8 @@ async def test_execute_operation_not_in_catalog():
23
  async def test_execute_operation_inputs_correct_cast():
24
  # Test that the automatic casting of operation inputs works correctly.
25
 
 
 
26
  @op("Create Bundle")
27
  def create_bundle() -> Bundle:
28
  df = pd.DataFrame({"source": [1, 2, 3], "target": [4, 5, 6]})
 
2
  import pytest
3
  import networkx as nx
4
 
5
+ from lynxkite.core import workspace, ops
6
+ from lynxkite_graph_analytics.core import Bundle, execute, ENV
7
 
8
 
9
  async def test_execute_operation_not_in_catalog():
10
+ ws = workspace.Workspace(env=ENV)
11
  ws.nodes.append(
12
  workspace.WorkspaceNode(
13
  id="1",
 
23
  async def test_execute_operation_inputs_correct_cast():
24
  # Test that the automatic casting of operation inputs works correctly.
25
 
26
+ op = ops.op_registration("test")
27
+
28
  @op("Create Bundle")
29
  def create_bundle() -> Bundle:
30
  df = pd.DataFrame({"source": [1, 2, 3], "target": [4, 5, 6]})