Spaces:
Running
Running
File size: 1,965 Bytes
83b1026 5eed07a 83b1026 5eed07a 83b1026 da1ea6b 83b1026 5eed07a 83b1026 5eed07a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
"""Automatically wraps all NetworkX functions as LynxKite operations."""
from lynxkite.core import ops
import functools
import inspect
import networkx as nx
ENV = "LynxKite Graph Analytics"
def wrapped(name: str, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for k, v in kwargs.items():
if v == "None":
kwargs[k] = None
res = func(*args, **kwargs)
if isinstance(res, nx.Graph):
return res
# Otherwise it's a node attribute.
graph = args[0].copy()
nx.set_node_attributes(graph, values=res, name=name)
return graph
return wrapper
def register_networkx(env: str):
cat = ops.CATALOGS.setdefault(env, {})
for name, func in nx.__dict__.items():
if hasattr(func, "graphs"):
sig = inspect.signature(func)
inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs}
params = {
name: ops.Parameter.basic(
name,
str(param.default)
if type(param.default) in [str, int, float]
else None,
param.annotation,
)
for name, param in sig.parameters.items()
if name not in ["G", "backend", "backend_kwargs", "create_using"]
}
for p in params.values():
if not p.type:
# Guess the type based on the name.
if len(p.name) == 1:
p.type = int
name = "NX › " + name.replace("_", " ").title()
op = ops.Op(
func=wrapped(name, func),
name=name,
params=params,
inputs=inputs,
outputs={"output": ops.Output(name="output", type=nx.Graph)},
type="basic",
)
cat[name] = op
register_networkx(ENV)
|