"""Automatically wraps all NetworkX functions as LynxKite operations.""" from lynxkite.core import ops import functools import inspect import networkx as nx import re ENV = "LynxKite Graph Analytics" class UnsupportedType(Exception): pass nx.ladder_graph def doc_to_type(name: str, t: str) -> type: t = t.lower() t = re.sub("[(][^)]+[)]", "", t).strip().strip(".") if " " in name or "http" in name: return None # Not a parameter type. if t.endswith(", optional"): w = doc_to_type(name, t.removesuffix(", optional").strip()) if w is None: return None return w | None if t in [ "a digraph or multidigraph", "a graph g", "graph", "graphs", "networkx graph instance", "networkx graph", "networkx undirected graph", "nx.graph", "undirected graph", "undirected networkx graph", ] or t.startswith("networkx graph"): return nx.Graph elif t in [ "digraph-like", "digraph", "directed graph", "networkx digraph", "networkx directed graph", "nx.digraph", ]: return nx.DiGraph elif t == "node": raise UnsupportedType(t) elif t == '"node (optional)"': return None elif t == '"edge"': raise UnsupportedType(t) elif t == '"edge (optional)"': return None elif t in ["class", "data type"]: raise UnsupportedType(t) elif t in ["string", "str", "node label"]: return str elif t in ["string or none", "none or string", "string, or none"]: return str | None elif t in ["int", "integer"]: return int elif t in ["bool", "boolean"]: return bool elif t == "tuple": raise UnsupportedType(t) elif t == "set": raise UnsupportedType(t) elif t == "list of floats": raise UnsupportedType(t) elif t == "list of floats or float": return float elif t in ["dict", "dictionary"]: raise UnsupportedType(t) elif t == "scalar or dictionary": return float elif t == "none or dict": return None elif t in ["function", "callable"]: raise UnsupportedType(t) elif t in [ "collection", "container of nodes", "list of nodes", ]: raise UnsupportedType(t) elif t in [ "container", "generator", "iterable", "iterator", "list or iterable container", "list or iterable", "list or set", "list or tuple", "list", ]: raise UnsupportedType(t) elif t == "generator of sets": raise UnsupportedType(t) elif t == "dict or a set of 2 or 3 tuples": raise UnsupportedType(t) elif t == "set of 2 or 3 tuples": raise UnsupportedType(t) elif t == "none, string or function": return str | None elif t == "string or function" and name == "weight": return str elif t == "integer, float, or none": return float | None elif t in [ "float", "int or float", "integer or float", "integer, float", "number", "numeric", "real", "scalar", ]: return float elif t in ["integer or none", "int or none"]: return int | None elif name == "seed": return int | None elif name == "weight": return str elif t == "object": raise UnsupportedType(t) return None def types_from_doc(doc: str) -> dict[str, type]: types = {} for line in doc.splitlines(): if ":" in line: a, b = line.split(":", 1) for a in a.split(","): a = a.strip() t = doc_to_type(a, b) if t is not None: types[a] = t return types def wrapped(name: str, func): @functools.wraps(func) def wrapper(*args, **kwargs): for k, v in kwargs.items(): if v == "None": kwargs[k] = None res = func(*args, **kwargs) if isinstance(res, nx.Graph): return res # Otherwise it's a node attribute. graph = args[0].copy() nx.set_node_attributes(graph, values=res, name=name) return graph return wrapper def register_networkx(env: str): cat = ops.CATALOGS.setdefault(env, {}) counter = 0 for name, func in nx.__dict__.items(): if hasattr(func, "graphs"): sig = inspect.signature(func) try: types = types_from_doc(func.__doc__) except UnsupportedType: continue for k, param in sig.parameters.items(): if k in types: continue if param.annotation is not param.empty: types[k] = param.annotation if k in ["i", "j", "n"]: types[k] = int inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs} params = { name: ops.Parameter.basic( name=name, default=str(param.default) if type(param.default) in [str, int, float] else None, type=types[name], ) for name, param in sig.parameters.items() if name in types and types[name] not in [nx.Graph, nx.DiGraph] } nicename = name.replace("_", " ").title() op = ops.Op( func=wrapped(name, func), name=nicename, params=params, inputs=inputs, outputs={"output": ops.Output(name="output", type=nx.Graph)}, type="basic", ) cat[nicename] = op counter += 1 print(f"Registered {counter} NetworkX operations.") register_networkx(ENV)