darabos commited on
Commit
c032595
·
1 Parent(s): 9a36487

More flexible handling for NetworkX return values.

Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py CHANGED
@@ -1,11 +1,14 @@
1
  """Automatically wraps all NetworkX functions as LynxKite operations."""
2
 
 
3
  from lynxkite.core import ops
4
  import functools
5
  import inspect
6
  import networkx as nx
7
  import re
8
 
 
 
9
  ENV = "LynxKite Graph Analytics"
10
 
11
 
@@ -156,10 +159,22 @@ def wrapped(name: str, func):
156
  res = func(*args, **kwargs)
157
  if isinstance(res, nx.Graph):
158
  return res
159
- # Otherwise it's a node attribute.
160
- graph = args[0].copy()
161
- nx.set_node_attributes(graph, values=res, name=name)
162
- return graph
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  return wrapper
165
 
@@ -193,7 +208,7 @@ def register_networkx(env: str):
193
  for name, param in sig.parameters.items()
194
  if name in types and types[name] not in [nx.Graph, nx.DiGraph]
195
  }
196
- nicename = name.replace("_", " ").title()
197
  op = ops.Op(
198
  func=wrapped(name, func),
199
  name=nicename,
 
1
  """Automatically wraps all NetworkX functions as LynxKite operations."""
2
 
3
+ import collections
4
  from lynxkite.core import ops
5
  import functools
6
  import inspect
7
  import networkx as nx
8
  import re
9
 
10
+ import pandas as pd
11
+
12
  ENV = "LynxKite Graph Analytics"
13
 
14
 
 
159
  res = func(*args, **kwargs)
160
  if isinstance(res, nx.Graph):
161
  return res
162
+ # Figure out what the returned value is.
163
+ if isinstance(res, nx.Graph):
164
+ return res
165
+ if isinstance(res, collections.abc.Sized):
166
+ for a in args:
167
+ if isinstance(a, nx.Graph):
168
+ if a.number_of_nodes() == len(res):
169
+ graph = a.copy()
170
+ nx.set_node_attributes(graph, values=res, name=name)
171
+ return graph
172
+ if a.number_of_edges() == len(res):
173
+ graph = a.copy()
174
+ nx.set_edge_attributes(graph, values=res, name=name)
175
+ return graph
176
+ return pd.DataFrame({name: res})
177
+ return pd.DataFrame({name: [res]})
178
 
179
  return wrapper
180
 
 
208
  for name, param in sig.parameters.items()
209
  if name in types and types[name] not in [nx.Graph, nx.DiGraph]
210
  }
211
+ nicename = "NX › " + name.replace("_", " ").title()
212
  op = ops.Op(
213
  func=wrapped(name, func),
214
  name=nicename,