Spaces:
Running
Running
Take NX parameter types from the docstrings.
Browse files
lynxkite-core/src/lynxkite/core/ops.py
CHANGED
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
4 |
import enum
|
5 |
import functools
|
6 |
import inspect
|
|
|
7 |
import pydantic
|
8 |
import typing
|
9 |
from dataclasses import dataclass
|
@@ -123,6 +124,26 @@ def basic_outputs(*names):
|
|
123 |
return {name: Output(name=name, type=None) for name in names}
|
124 |
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
class Op(BaseConfig):
|
127 |
func: typing.Callable = pydantic.Field(exclude=True)
|
128 |
name: str
|
@@ -136,12 +157,8 @@ class Op(BaseConfig):
|
|
136 |
# Convert parameters.
|
137 |
for p in params:
|
138 |
if p in self.params:
|
139 |
-
|
140 |
-
|
141 |
-
elif self.params[p].type is float:
|
142 |
-
params[p] = float(params[p])
|
143 |
-
elif isinstance(self.params[p].type, enum.EnumMeta):
|
144 |
-
params[p] = self.params[p].type[params[p]]
|
145 |
res = self.func(*inputs, **params)
|
146 |
if not isinstance(res, Result):
|
147 |
# Automatically wrap the result in a Result object, if it isn't already.
|
|
|
4 |
import enum
|
5 |
import functools
|
6 |
import inspect
|
7 |
+
import types
|
8 |
import pydantic
|
9 |
import typing
|
10 |
from dataclasses import dataclass
|
|
|
124 |
return {name: Output(name=name, type=None) for name in names}
|
125 |
|
126 |
|
127 |
+
def _param_to_type(name, value, type):
|
128 |
+
value = value or ""
|
129 |
+
print(f'Converting "{name}" {value} to {type}')
|
130 |
+
if type is int:
|
131 |
+
assert value != "", f"{name} is unset."
|
132 |
+
return int(value)
|
133 |
+
if type is float:
|
134 |
+
assert value != "", f"{name} is unset."
|
135 |
+
return float(value)
|
136 |
+
if isinstance(type, enum.EnumMeta):
|
137 |
+
return type[value]
|
138 |
+
if isinstance(type, types.UnionType):
|
139 |
+
match type.__args__:
|
140 |
+
case (None, type):
|
141 |
+
return None if value == "" else _param_to_type(name, value, type)
|
142 |
+
case (type, None):
|
143 |
+
return None if value == "" else _param_to_type(name, value, type)
|
144 |
+
return value
|
145 |
+
|
146 |
+
|
147 |
class Op(BaseConfig):
|
148 |
func: typing.Callable = pydantic.Field(exclude=True)
|
149 |
name: str
|
|
|
157 |
# Convert parameters.
|
158 |
for p in params:
|
159 |
if p in self.params:
|
160 |
+
params[p] = _param_to_type(p, params[p], self.params[p].type)
|
161 |
+
print(self.name, p, params[p])
|
|
|
|
|
|
|
|
|
162 |
res = self.func(*inputs, **params)
|
163 |
if not isinstance(res, Result):
|
164 |
# Automatically wrap the result in a Result object, if it isn't already.
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py
CHANGED
@@ -4,10 +4,149 @@ from lynxkite.core import ops
|
|
4 |
import functools
|
5 |
import inspect
|
6 |
import networkx as nx
|
|
|
7 |
|
8 |
ENV = "LynxKite Graph Analytics"
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def wrapped(name: str, func):
|
12 |
@functools.wraps(func)
|
13 |
def wrapper(*args, **kwargs):
|
@@ -27,36 +166,45 @@ def wrapped(name: str, func):
|
|
27 |
|
28 |
def register_networkx(env: str):
|
29 |
cat = ops.CATALOGS.setdefault(env, {})
|
|
|
30 |
for name, func in nx.__dict__.items():
|
31 |
if hasattr(func, "graphs"):
|
32 |
sig = inspect.signature(func)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs}
|
34 |
params = {
|
35 |
name: ops.Parameter.basic(
|
36 |
-
name,
|
37 |
-
str(param.default)
|
38 |
if type(param.default) in [str, int, float]
|
39 |
else None,
|
40 |
-
|
41 |
)
|
42 |
for name, param in sig.parameters.items()
|
43 |
-
if name not in [
|
44 |
}
|
45 |
-
|
46 |
-
if not p.type:
|
47 |
-
# Guess the type based on the name.
|
48 |
-
if len(p.name) == 1:
|
49 |
-
p.type = int
|
50 |
-
name = "NX › " + name.replace("_", " ").title()
|
51 |
op = ops.Op(
|
52 |
func=wrapped(name, func),
|
53 |
-
name=
|
54 |
params=params,
|
55 |
inputs=inputs,
|
56 |
outputs={"output": ops.Output(name="output", type=nx.Graph)},
|
57 |
type="basic",
|
58 |
)
|
59 |
-
cat[
|
|
|
|
|
60 |
|
61 |
|
62 |
register_networkx(ENV)
|
|
|
4 |
import functools
|
5 |
import inspect
|
6 |
import networkx as nx
|
7 |
+
import re
|
8 |
|
9 |
ENV = "LynxKite Graph Analytics"
|
10 |
|
11 |
|
12 |
+
class UnsupportedType(Exception):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
nx.ladder_graph
|
17 |
+
|
18 |
+
|
19 |
+
def doc_to_type(name: str, t: str) -> type:
|
20 |
+
t = t.lower()
|
21 |
+
t = re.sub("[(][^)]+[)]", "", t).strip().strip(".")
|
22 |
+
if " " in name or "http" in name:
|
23 |
+
return None # Not a parameter type.
|
24 |
+
if t.endswith(", optional"):
|
25 |
+
w = doc_to_type(name, t.removesuffix(", optional").strip())
|
26 |
+
if w is None:
|
27 |
+
return None
|
28 |
+
return w | None
|
29 |
+
if t in [
|
30 |
+
"a digraph or multidigraph",
|
31 |
+
"a graph g",
|
32 |
+
"graph",
|
33 |
+
"graphs",
|
34 |
+
"networkx graph instance",
|
35 |
+
"networkx graph",
|
36 |
+
"networkx undirected graph",
|
37 |
+
"nx.graph",
|
38 |
+
"undirected graph",
|
39 |
+
"undirected networkx graph",
|
40 |
+
] or t.startswith("networkx graph"):
|
41 |
+
return nx.Graph
|
42 |
+
elif t in [
|
43 |
+
"digraph-like",
|
44 |
+
"digraph",
|
45 |
+
"directed graph",
|
46 |
+
"networkx digraph",
|
47 |
+
"networkx directed graph",
|
48 |
+
"nx.digraph",
|
49 |
+
]:
|
50 |
+
return nx.DiGraph
|
51 |
+
elif t == "node":
|
52 |
+
raise UnsupportedType(t)
|
53 |
+
elif t == '"node (optional)"':
|
54 |
+
return None
|
55 |
+
elif t == '"edge"':
|
56 |
+
raise UnsupportedType(t)
|
57 |
+
elif t == '"edge (optional)"':
|
58 |
+
return None
|
59 |
+
elif t in ["class", "data type"]:
|
60 |
+
raise UnsupportedType(t)
|
61 |
+
elif t in ["string", "str", "node label"]:
|
62 |
+
return str
|
63 |
+
elif t in ["string or none", "none or string", "string, or none"]:
|
64 |
+
return str | None
|
65 |
+
elif t in ["int", "integer"]:
|
66 |
+
return int
|
67 |
+
elif t in ["bool", "boolean"]:
|
68 |
+
return bool
|
69 |
+
elif t == "tuple":
|
70 |
+
raise UnsupportedType(t)
|
71 |
+
elif t == "set":
|
72 |
+
raise UnsupportedType(t)
|
73 |
+
elif t == "list of floats":
|
74 |
+
raise UnsupportedType(t)
|
75 |
+
elif t == "list of floats or float":
|
76 |
+
return float
|
77 |
+
elif t in ["dict", "dictionary"]:
|
78 |
+
raise UnsupportedType(t)
|
79 |
+
elif t == "scalar or dictionary":
|
80 |
+
return float
|
81 |
+
elif t == "none or dict":
|
82 |
+
return None
|
83 |
+
elif t in ["function", "callable"]:
|
84 |
+
raise UnsupportedType(t)
|
85 |
+
elif t in [
|
86 |
+
"collection",
|
87 |
+
"container of nodes",
|
88 |
+
"list of nodes",
|
89 |
+
]:
|
90 |
+
raise UnsupportedType(t)
|
91 |
+
elif t in [
|
92 |
+
"container",
|
93 |
+
"generator",
|
94 |
+
"iterable",
|
95 |
+
"iterator",
|
96 |
+
"list or iterable container",
|
97 |
+
"list or iterable",
|
98 |
+
"list or set",
|
99 |
+
"list or tuple",
|
100 |
+
"list",
|
101 |
+
]:
|
102 |
+
raise UnsupportedType(t)
|
103 |
+
elif t == "generator of sets":
|
104 |
+
raise UnsupportedType(t)
|
105 |
+
elif t == "dict or a set of 2 or 3 tuples":
|
106 |
+
raise UnsupportedType(t)
|
107 |
+
elif t == "set of 2 or 3 tuples":
|
108 |
+
raise UnsupportedType(t)
|
109 |
+
elif t == "none, string or function":
|
110 |
+
return str | None
|
111 |
+
elif t == "string or function" and name == "weight":
|
112 |
+
return str
|
113 |
+
elif t == "integer, float, or none":
|
114 |
+
return float | None
|
115 |
+
elif t in [
|
116 |
+
"float",
|
117 |
+
"int or float",
|
118 |
+
"integer or float",
|
119 |
+
"integer, float",
|
120 |
+
"number",
|
121 |
+
"numeric",
|
122 |
+
"real",
|
123 |
+
"scalar",
|
124 |
+
]:
|
125 |
+
return float
|
126 |
+
elif t in ["integer or none", "int or none"]:
|
127 |
+
return int | None
|
128 |
+
elif name == "seed":
|
129 |
+
return int | None
|
130 |
+
elif name == "weight":
|
131 |
+
return str
|
132 |
+
elif t == "object":
|
133 |
+
raise UnsupportedType(t)
|
134 |
+
return None
|
135 |
+
|
136 |
+
|
137 |
+
def types_from_doc(doc: str) -> dict[str, type]:
|
138 |
+
types = {}
|
139 |
+
for line in doc.splitlines():
|
140 |
+
if ":" in line:
|
141 |
+
a, b = line.split(":", 1)
|
142 |
+
for a in a.split(","):
|
143 |
+
a = a.strip()
|
144 |
+
t = doc_to_type(a, b)
|
145 |
+
if t is not None:
|
146 |
+
types[a] = t
|
147 |
+
return types
|
148 |
+
|
149 |
+
|
150 |
def wrapped(name: str, func):
|
151 |
@functools.wraps(func)
|
152 |
def wrapper(*args, **kwargs):
|
|
|
166 |
|
167 |
def register_networkx(env: str):
|
168 |
cat = ops.CATALOGS.setdefault(env, {})
|
169 |
+
counter = 0
|
170 |
for name, func in nx.__dict__.items():
|
171 |
if hasattr(func, "graphs"):
|
172 |
sig = inspect.signature(func)
|
173 |
+
try:
|
174 |
+
types = types_from_doc(func.__doc__)
|
175 |
+
except UnsupportedType:
|
176 |
+
continue
|
177 |
+
for k, param in sig.parameters.items():
|
178 |
+
if k in types:
|
179 |
+
continue
|
180 |
+
if param.annotation is not param.empty:
|
181 |
+
types[k] = param.annotation
|
182 |
+
if k in ["i", "j", "n"]:
|
183 |
+
types[k] = int
|
184 |
inputs = {k: ops.Input(name=k, type=nx.Graph) for k in func.graphs}
|
185 |
params = {
|
186 |
name: ops.Parameter.basic(
|
187 |
+
name=name,
|
188 |
+
default=str(param.default)
|
189 |
if type(param.default) in [str, int, float]
|
190 |
else None,
|
191 |
+
type=types[name],
|
192 |
)
|
193 |
for name, param in sig.parameters.items()
|
194 |
+
if name in types and types[name] not in [nx.Graph, nx.DiGraph]
|
195 |
}
|
196 |
+
nicename = name.replace("_", " ").title()
|
|
|
|
|
|
|
|
|
|
|
197 |
op = ops.Op(
|
198 |
func=wrapped(name, func),
|
199 |
+
name=nicename,
|
200 |
params=params,
|
201 |
inputs=inputs,
|
202 |
outputs={"output": ops.Output(name="output", type=nx.Graph)},
|
203 |
type="basic",
|
204 |
)
|
205 |
+
cat[nicename] = op
|
206 |
+
counter += 1
|
207 |
+
print(f"Registered {counter} NetworkX operations.")
|
208 |
|
209 |
|
210 |
register_networkx(ENV)
|