File size: 1,621 Bytes
0c44583
 
 
 
 
 
 
e7fa7ee
0c44583
 
 
 
 
 
 
 
 
 
5826642
0c44583
 
 
e7fa7ee
 
 
 
 
620531b
e7fa7ee
 
 
 
 
 
 
620531b
e7fa7ee
 
 
 
 
 
620531b
 
 
 
 
 
 
 
e7fa7ee
0c44583
e7fa7ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Automatically wraps all NetworkX functions as LynxKite operations."""
from . import ops
import functools
import inspect
import networkx as nx


def wrapped(name: str, func):
  @functools.wraps(func)
  def wrapper(*args, **kwargs):
    for k, v in kwargs.items():
      if v == 'None':
        kwargs[k] = None
    res = func(*args, **kwargs)
    if isinstance(res, nx.Graph):
      return res
    # Otherwise it's a node attribute.
    graph = args[0].copy()
    nx.set_node_attributes(graph, values=res, name=name)
    return graph
  return wrapper

def register_networkx(env: str):
  ops.CATALOGS.setdefault(env, {})
  for (name, func) in nx.__dict__.items():
    if hasattr(func, 'graphs'):
      sig = inspect.signature(func)
      inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs}
      params = {
        name: ops.Parameter.basic(
            name, str(param.default)
          if type(param.default) in [str, int, float]
          else None,
          param.annotation)
        for name, param in sig.parameters.items()
        if name not in ['G', 'backend', 'backend_kwargs', 'create_using']}
      for p in params.values():
        if not p.type:
          # Guess the type based on the name.
          if len(p.name) == 1:
            p.type = int
      name = "NX › " + name.replace('_', ' ').title()
      op = ops.Op(
        func=wrapped(name, func),
        name=name,
        params=params,
        inputs=inputs,
        outputs={'output': ops.Output(name='output', type=nx.Graph)},
        type='basic',
      )
      ops.CATALOGS[env][name] = op

register_networkx('LynxKite')