darabos commited on
Commit
0280171
·
1 Parent(s): 10c9dc3

Take NX parameter types from the docstrings.

Browse files
lynxkite-core/src/lynxkite/core/ops.py CHANGED
@@ -4,6 +4,7 @@ from __future__ import annotations
4
  import enum
5
  import functools
6
  import inspect
 
7
  import pydantic
8
  import typing
9
  from dataclasses import dataclass
@@ -123,6 +124,26 @@ def basic_outputs(*names):
123
  return {name: Output(name=name, type=None) for name in names}
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  class Op(BaseConfig):
127
  func: typing.Callable = pydantic.Field(exclude=True)
128
  name: str
@@ -136,12 +157,8 @@ class Op(BaseConfig):
136
  # Convert parameters.
137
  for p in params:
138
  if p in self.params:
139
- if self.params[p].type is int:
140
- params[p] = int(params[p])
141
- elif self.params[p].type is float:
142
- params[p] = float(params[p])
143
- elif isinstance(self.params[p].type, enum.EnumMeta):
144
- params[p] = self.params[p].type[params[p]]
145
  res = self.func(*inputs, **params)
146
  if not isinstance(res, Result):
147
  # Automatically wrap the result in a Result object, if it isn't already.
 
4
  import enum
5
  import functools
6
  import inspect
7
+ import types
8
  import pydantic
9
  import typing
10
  from dataclasses import dataclass
 
124
  return {name: Output(name=name, type=None) for name in names}
125
 
126
 
127
+ def _param_to_type(name, value, type):
128
+ value = value or ""
129
+ print(f'Converting "{name}" {value} to {type}')
130
+ if type is int:
131
+ assert value != "", f"{name} is unset."
132
+ return int(value)
133
+ if type is float:
134
+ assert value != "", f"{name} is unset."
135
+ return float(value)
136
+ if isinstance(type, enum.EnumMeta):
137
+ return type[value]
138
+ if isinstance(type, types.UnionType):
139
+ match type.__args__:
140
+ case (None, type):
141
+ return None if value == "" else _param_to_type(name, value, type)
142
+ case (type, None):
143
+ return None if value == "" else _param_to_type(name, value, type)
144
+ return value
145
+
146
+
147
  class Op(BaseConfig):
148
  func: typing.Callable = pydantic.Field(exclude=True)
149
  name: str
 
157
  # Convert parameters.
158
  for p in params:
159
  if p in self.params:
160
+ params[p] = _param_to_type(p, params[p], self.params[p].type)
161
+ print(self.name, p, params[p])
 
 
 
 
162
  res = self.func(*inputs, **params)
163
  if not isinstance(res, Result):
164
  # Automatically wrap the result in a Result object, if it isn't already.
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py CHANGED
@@ -4,10 +4,149 @@ from lynxkite.core import ops
4
  import functools
5
  import inspect
6
  import networkx as nx
 
7
 
8
  ENV = "LynxKite Graph Analytics"
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def wrapped(name: str, func):
12
  @functools.wraps(func)
13
  def wrapper(*args, **kwargs):
@@ -27,36 +166,45 @@ def wrapped(name: str, func):
27
 
28
  def register_networkx(env: str):
29
  cat = ops.CATALOGS.setdefault(env, {})
 
30
  for name, func in nx.__dict__.items():
31
  if hasattr(func, "graphs"):
32
  sig = inspect.signature(func)
 
 
 
 
 
 
 
 
 
 
 
33
  inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs}
34
  params = {
35
  name: ops.Parameter.basic(
36
- name,
37
- str(param.default)
38
  if type(param.default) in [str, int, float]
39
  else None,
40
- param.annotation,
41
  )
42
  for name, param in sig.parameters.items()
43
- if name not in ["G", "backend", "backend_kwargs", "create_using"]
44
  }
45
- for p in params.values():
46
- if not p.type:
47
- # Guess the type based on the name.
48
- if len(p.name) == 1:
49
- p.type = int
50
- name = "NX › " + name.replace("_", " ").title()
51
  op = ops.Op(
52
  func=wrapped(name, func),
53
- name=name,
54
  params=params,
55
  inputs=inputs,
56
  outputs={"output": ops.Output(name="output", type=nx.Graph)},
57
  type="basic",
58
  )
59
- cat[name] = op
 
 
60
 
61
 
62
  register_networkx(ENV)
 
4
  import functools
5
  import inspect
6
  import networkx as nx
7
+ import re
8
 
9
  ENV = "LynxKite Graph Analytics"
10
 
11
 
12
+ class UnsupportedType(Exception):
13
+ pass
14
+
15
+
16
+ nx.ladder_graph
17
+
18
+
19
+ def doc_to_type(name: str, t: str) -> type:
20
+ t = t.lower()
21
+ t = re.sub("[(][^)]+[)]", "", t).strip().strip(".")
22
+ if " " in name or "http" in name:
23
+ return None # Not a parameter type.
24
+ if t.endswith(", optional"):
25
+ w = doc_to_type(name, t.removesuffix(", optional").strip())
26
+ if w is None:
27
+ return None
28
+ return w | None
29
+ if t in [
30
+ "a digraph or multidigraph",
31
+ "a graph g",
32
+ "graph",
33
+ "graphs",
34
+ "networkx graph instance",
35
+ "networkx graph",
36
+ "networkx undirected graph",
37
+ "nx.graph",
38
+ "undirected graph",
39
+ "undirected networkx graph",
40
+ ] or t.startswith("networkx graph"):
41
+ return nx.Graph
42
+ elif t in [
43
+ "digraph-like",
44
+ "digraph",
45
+ "directed graph",
46
+ "networkx digraph",
47
+ "networkx directed graph",
48
+ "nx.digraph",
49
+ ]:
50
+ return nx.DiGraph
51
+ elif t == "node":
52
+ raise UnsupportedType(t)
53
+ elif t == '"node (optional)"':
54
+ return None
55
+ elif t == '"edge"':
56
+ raise UnsupportedType(t)
57
+ elif t == '"edge (optional)"':
58
+ return None
59
+ elif t in ["class", "data type"]:
60
+ raise UnsupportedType(t)
61
+ elif t in ["string", "str", "node label"]:
62
+ return str
63
+ elif t in ["string or none", "none or string", "string, or none"]:
64
+ return str | None
65
+ elif t in ["int", "integer"]:
66
+ return int
67
+ elif t in ["bool", "boolean"]:
68
+ return bool
69
+ elif t == "tuple":
70
+ raise UnsupportedType(t)
71
+ elif t == "set":
72
+ raise UnsupportedType(t)
73
+ elif t == "list of floats":
74
+ raise UnsupportedType(t)
75
+ elif t == "list of floats or float":
76
+ return float
77
+ elif t in ["dict", "dictionary"]:
78
+ raise UnsupportedType(t)
79
+ elif t == "scalar or dictionary":
80
+ return float
81
+ elif t == "none or dict":
82
+ return None
83
+ elif t in ["function", "callable"]:
84
+ raise UnsupportedType(t)
85
+ elif t in [
86
+ "collection",
87
+ "container of nodes",
88
+ "list of nodes",
89
+ ]:
90
+ raise UnsupportedType(t)
91
+ elif t in [
92
+ "container",
93
+ "generator",
94
+ "iterable",
95
+ "iterator",
96
+ "list or iterable container",
97
+ "list or iterable",
98
+ "list or set",
99
+ "list or tuple",
100
+ "list",
101
+ ]:
102
+ raise UnsupportedType(t)
103
+ elif t == "generator of sets":
104
+ raise UnsupportedType(t)
105
+ elif t == "dict or a set of 2 or 3 tuples":
106
+ raise UnsupportedType(t)
107
+ elif t == "set of 2 or 3 tuples":
108
+ raise UnsupportedType(t)
109
+ elif t == "none, string or function":
110
+ return str | None
111
+ elif t == "string or function" and name == "weight":
112
+ return str
113
+ elif t == "integer, float, or none":
114
+ return float | None
115
+ elif t in [
116
+ "float",
117
+ "int or float",
118
+ "integer or float",
119
+ "integer, float",
120
+ "number",
121
+ "numeric",
122
+ "real",
123
+ "scalar",
124
+ ]:
125
+ return float
126
+ elif t in ["integer or none", "int or none"]:
127
+ return int | None
128
+ elif name == "seed":
129
+ return int | None
130
+ elif name == "weight":
131
+ return str
132
+ elif t == "object":
133
+ raise UnsupportedType(t)
134
+ return None
135
+
136
+
137
+ def types_from_doc(doc: str) -> dict[str, type]:
138
+ types = {}
139
+ for line in doc.splitlines():
140
+ if ":" in line:
141
+ a, b = line.split(":", 1)
142
+ for a in a.split(","):
143
+ a = a.strip()
144
+ t = doc_to_type(a, b)
145
+ if t is not None:
146
+ types[a] = t
147
+ return types
148
+
149
+
150
  def wrapped(name: str, func):
151
  @functools.wraps(func)
152
  def wrapper(*args, **kwargs):
 
166
 
167
  def register_networkx(env: str):
168
  cat = ops.CATALOGS.setdefault(env, {})
169
+ counter = 0
170
  for name, func in nx.__dict__.items():
171
  if hasattr(func, "graphs"):
172
  sig = inspect.signature(func)
173
+ try:
174
+ types = types_from_doc(func.__doc__)
175
+ except UnsupportedType:
176
+ continue
177
+ for k, param in sig.parameters.items():
178
+ if k in types:
179
+ continue
180
+ if param.annotation is not param.empty:
181
+ types[k] = param.annotation
182
+ if k in ["i", "j", "n"]:
183
+ types[k] = int
184
  inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs}
185
  params = {
186
  name: ops.Parameter.basic(
187
+ name=name,
188
+ default=str(param.default)
189
  if type(param.default) in [str, int, float]
190
  else None,
191
+ type=types[name],
192
  )
193
  for name, param in sig.parameters.items()
194
+ if name in types and types[name] not in [nx.Graph, nx.DiGraph]
195
  }
196
+ nicename = name.replace("_", " ").title()
 
 
 
 
 
197
  op = ops.Op(
198
  func=wrapped(name, func),
199
+ name=nicename,
200
  params=params,
201
  inputs=inputs,
202
  outputs={"output": ops.Output(name="output", type=nx.Graph)},
203
  type="basic",
204
  )
205
+ cat[nicename] = op
206
+ counter += 1
207
+ print(f"Registered {counter} NetworkX operations.")
208
 
209
 
210
  register_networkx(ENV)