Spaces:
Sleeping
Sleeping
File size: 3,113 Bytes
ca01fa3 0c44583 ca01fa3 a18645a ca01fa3 9e91869 b6d30cb 0c44583 9e91869 b6d30cb a18645a ca01fa3 b6d30cb ca01fa3 b6d30cb ca01fa3 a18645a 0c44583 b6d30cb a18645a b6d30cb a18645a 0c44583 ca01fa3 0c44583 ca01fa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
'''API for implementing LynxKite operations.'''
import dataclasses
import functools
import inspect
import networkx as nx
import pandas as pd
ALL_OPS = {}
@dataclasses.dataclass
class Op:
func: callable
name: str
params: dict
inputs: dict
outputs: dict
type: str
def __call__(self, *inputs, **params):
# Convert parameters.
sig = inspect.signature(self.func)
for p in params:
if p in self.params:
t = sig.parameters[p].annotation
if t is inspect._empty:
t = type(self.params[p])
if t == int:
params[p] = int(params[p])
elif t == float:
params[p] = float(params[p])
# Convert inputs.
inputs = list(inputs)
for i, (x, p) in enumerate(zip(inputs, sig.parameters.values())):
t = p.annotation
if t == nx.Graph and isinstance(x, Bundle):
inputs[i] = o.to_nx()
elif t == Bundle and isinstance(x, nx.Graph):
inputs[i] = Bundle.from_nx(x)
res = self.func(*inputs, **params)
return res
@dataclasses.dataclass
class RelationDefinition:
df: str
source_column: str
target_column: str
source_table: str
target_table: str
source_key: str
target_key: str
@dataclasses.dataclass
class Bundle:
dfs: dict
relations: list[RelationDefinition]
@classmethod
def from_nx(cls, graph: nx.Graph):
edges = nx.to_pandas_edgelist(graph)
d = dict(graph.nodes(data=True))
nodes = pd.DataFrame(d.values(), index=d.keys())
nodes['id'] = nodes.index
return cls(
dfs={'edges': edges, 'nodes': nodes},
relations=[
RelationDefinition(
df='edges',
source_column='source',
target_column='target',
source_table='nodes',
target_table='nodes',
source_key='id',
target_key='id',
)
]
)
def to_nx(self):
graph = nx.from_pandas_edgelist(self.dfs['edges'])
nx.set_node_attributes(graph, self.dfs['nodes'].set_index('id').to_dict('index'))
return graph
def nx_node_attribute_func(name):
'''Decorator for wrapping a function that adds a NetworkX node attribute.'''
def decorator(func):
@functools.wraps(func)
def wrapper(graph: nx.Graph, **kwargs):
graph = graph.copy()
attr = func(graph, **kwargs)
nx.set_node_attributes(graph, attr, name)
return graph
return wrapper
return decorator
def op(name, *, view='basic'):
'''Decorator for defining an operation.'''
def decorator(func):
sig = inspect.signature(func)
# Positional arguments are inputs.
inputs = {
name: param.annotation
for name, param in sig.parameters.items()
if param.kind != param.KEYWORD_ONLY}
params = {
name: param.default if param.default is not inspect._empty else None
for name, param in sig.parameters.items()
if param.kind == param.KEYWORD_ONLY}
outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later.
op = Op(func, name, params=params, inputs=inputs, outputs=outputs, type=view)
ALL_OPS[name] = op
return func
return decorator
|