darabos commited on
Commit
942065e
·
1 Parent(s): 5882a26

Use Pydantic for op metadata.

Browse files
server/{basic_ops.py → lynxkite_ops.py} RENAMED
@@ -20,6 +20,12 @@ def compute_pagerank(graph: nx.Graph, *, damping=0.85, iterations=100):
20
  return nx.pagerank(graph, alpha=damping, max_iter=iterations)
21
 
22
 
 
 
 
 
 
 
23
  def _map_color(value):
24
  cmap = matplotlib.cm.get_cmap('viridis')
25
  value = (value - value.min()) / (value.max() - value.min())
 
20
  return nx.pagerank(graph, alpha=damping, max_iter=iterations)
21
 
22
 
23
+ @ops.op("Sample graph")
24
+ def create_scale_free_graph(*, nodes: int = 10):
25
+ '''Creates a scale-free graph with the given number of nodes.'''
26
+ return nx.scale_free_graph(nodes)
27
+
28
+
29
  def _map_color(value):
30
  cmap = matplotlib.cm.get_cmap('viridis')
31
  value = (value - value.min()) / (value.max() - value.min())
server/main.py CHANGED
@@ -3,17 +3,17 @@ import fastapi
3
  import pathlib
4
  from . import ops
5
  from . import workspace
6
- from . import basic_ops
7
  # from . import networkx_ops
8
- # from . import pytorch_model_ops
9
- from . import lynxscribe_ops
10
 
11
  app = fastapi.FastAPI()
12
 
13
 
14
  @app.get("/api/catalog")
15
  def get_catalog():
16
- return [op.to_json() for op in ops.ALL_OPS.values()]
17
 
18
 
19
  class SaveRequest(workspace.BaseConfig):
 
3
  import pathlib
4
  from . import ops
5
  from . import workspace
6
+ from . import lynxkite_ops
7
  # from . import networkx_ops
8
+ from . import pytorch_model_ops
9
+ # from . import lynxscribe_ops
10
 
11
  app = fastapi.FastAPI()
12
 
13
 
14
  @app.get("/api/catalog")
15
  def get_catalog():
16
+ return [op.model_dump() for op in ops.ALL_OPS.values()]
17
 
18
 
19
  class SaveRequest(workspace.BaseConfig):
server/networkx_ops.py CHANGED
@@ -26,7 +26,7 @@ for (name, func) in nx.__dict__.items():
26
  sig = inspect.signature(func)
27
  inputs = {k: nx.Graph for k in func.graphs}
28
  params = {
29
- name: ops.Parameter(
30
  name, str(param.default)
31
  if type(param.default) in [str, int, float]
32
  else None,
 
26
  sig = inspect.signature(func)
27
  inputs = {k: nx.Graph for k in func.graphs}
28
  params = {
29
+ name: ops.Parameter.basic(
30
  name, str(param.default)
31
  if type(param.default) in [str, int, float]
32
  else None,
server/ops.py CHANGED
@@ -1,60 +1,64 @@
1
  '''API for implementing LynxKite operations.'''
 
2
  import dataclasses
3
  import enum
4
  import functools
5
  import inspect
6
  import networkx as nx
7
  import pandas as pd
 
8
  import typing
 
9
 
10
  ALL_OPS = {}
11
- PARAM_TYPE = type[typing.Any]
12
  typeof = type # We have some arguments called "type".
13
-
14
- @dataclasses.dataclass
15
- class Parameter:
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  '''Defines a parameter for an operation.'''
17
  name: str
18
  default: any
19
- type: PARAM_TYPE = None
20
 
21
  @staticmethod
22
  def options(name, options, default=None):
23
  e = enum.Enum(f'OptionsFor_{name}', options)
24
- return Parameter(name, e[default or options[0]], e)
25
 
26
  @staticmethod
27
  def collapsed(name, default, type=None):
28
- return Parameter(name, default, ('collapsed', type or typeof(default)))
29
-
30
- def __post_init__(self):
31
- if self.default is inspect._empty:
32
- self.default = None
33
- if self.type is None or self.type is inspect._empty:
34
- self.type = type(self.default)
35
- def to_json(self):
36
- t = str(self.type)
37
- default = self.default
38
- if isinstance(self.type, type) and issubclass(self.type, enum.Enum):
39
- t = {'enum': list(self.type.__members__.keys())}
40
- default = self.default.name if self.default else t['enum'][0]
41
- if isinstance(self.type, tuple) and self.type[0] == 'collapsed':
42
- t = {'collapsed': str(self.type[1])}
43
- return {
44
- 'name': self.name,
45
- 'default': default,
46
- 'type': t,
47
- }
48
 
49
- @dataclasses.dataclass
50
- class Op:
51
- func: callable
 
 
 
 
 
 
 
 
52
  name: str
53
  params: dict[str, Parameter]
54
- inputs: dict # name -> type
55
- outputs: dict # name -> type
56
  type: str # The UI to use for this operation.
57
- sub_nodes: list = None # If set, these nodes can be placed inside the operation's node.
58
 
59
  def __call__(self, *inputs, **params):
60
  # Convert parameters.
@@ -74,18 +78,6 @@ class Op:
74
  res = self.func(*inputs, **params)
75
  return res
76
 
77
- def to_json(self):
78
- return {
79
- 'type': self.type,
80
- 'data': {
81
- 'title': self.name,
82
- 'inputs': {i: str(type) for i, type in self.inputs.items()},
83
- 'outputs': {o: str(type) for o, type in self.outputs.items()},
84
- 'params': [p.to_json() for p in self.params.values()],
85
- },
86
- 'sub_nodes': [sub.to_json() for sub in self.sub_nodes.values()] if self.sub_nodes else None,
87
- }
88
-
89
 
90
  @dataclasses.dataclass
91
  class RelationDefinition:
@@ -105,9 +97,9 @@ class Bundle:
105
  Can efficiently represent a knowledge graph (homogeneous or heterogeneous) or tabular data.
106
  It can also carry other data, such as a trained model.
107
  '''
108
- dfs: dict = dataclasses.field(default_factory=dict) # name -> DataFrame
109
  relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
110
- other: dict = None
111
 
112
  @classmethod
113
  def from_nx(cls, graph: nx.Graph):
@@ -161,9 +153,9 @@ def op(name, *, view='basic', sub_nodes=None):
161
  params = {}
162
  for n, param in sig.parameters.items():
163
  if param.kind == param.KEYWORD_ONLY:
164
- params[n] = Parameter(n, param.default, param.annotation)
165
  outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later.
166
- op = Op(func, name, params=params, inputs=inputs, outputs=outputs, type=view)
167
  if sub_nodes is not None:
168
  op.sub_nodes = sub_nodes
169
  op.type = 'sub_flow'
 
1
  '''API for implementing LynxKite operations.'''
2
+ from __future__ import annotations
3
  import dataclasses
4
  import enum
5
  import functools
6
  import inspect
7
  import networkx as nx
8
  import pandas as pd
9
+ import pydantic
10
  import typing
11
+ from typing_extensions import Annotated
12
 
13
  ALL_OPS = {}
 
14
  typeof = type # We have some arguments called "type".
15
+ def type_to_json(t):
16
+ if isinstance(t, type) and issubclass(t, enum.Enum):
17
+ return {'enum': list(t.__members__.keys())}
18
+ if isinstance(t, tuple) and t[0] == 'collapsed':
19
+ return {'collapsed': str(t[1])}
20
+ return {'type': str(t)}
21
+ Type = Annotated[
22
+ typing.Any, pydantic.PlainSerializer(type_to_json, return_type=dict)
23
+ ]
24
+ class BaseConfig(pydantic.BaseModel):
25
+ model_config = pydantic.ConfigDict(
26
+ arbitrary_types_allowed=True,
27
+ )
28
+
29
+
30
+ class Parameter(BaseConfig):
31
  '''Defines a parameter for an operation.'''
32
  name: str
33
  default: any
34
+ type: Type = None
35
 
36
  @staticmethod
37
  def options(name, options, default=None):
38
  e = enum.Enum(f'OptionsFor_{name}', options)
39
+ return Parameter.basic(name, e[default or options[0]], e)
40
 
41
  @staticmethod
42
  def collapsed(name, default, type=None):
43
+ return Parameter.basic(name, default, ('collapsed', type or typeof(default)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ @staticmethod
46
+ def basic(name, default=None, type=None):
47
+ if default is inspect._empty:
48
+ default = None
49
+ if type is None or type is inspect._empty:
50
+ type = typeof(default) if default else None
51
+ return Parameter(name=name, default=default, type=type)
52
+
53
+
54
+ class Op(BaseConfig):
55
+ func: callable = pydantic.Field(exclude=True)
56
  name: str
57
  params: dict[str, Parameter]
58
+ inputs: dict[str, Type] # name -> type
59
+ outputs: dict[str, Type] # name -> type
60
  type: str # The UI to use for this operation.
61
+ sub_nodes: list[Op] = None # If set, these nodes can be placed inside the operation's node.
62
 
63
  def __call__(self, *inputs, **params):
64
  # Convert parameters.
 
78
  res = self.func(*inputs, **params)
79
  return res
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  @dataclasses.dataclass
83
  class RelationDefinition:
 
97
  Can efficiently represent a knowledge graph (homogeneous or heterogeneous) or tabular data.
98
  It can also carry other data, such as a trained model.
99
  '''
100
+ dfs: dict[str, pd.DataFrame] = dataclasses.field(default_factory=dict)
101
  relations: list[RelationDefinition] = dataclasses.field(default_factory=list)
102
+ other: dict[str, typing.Any] = None
103
 
104
  @classmethod
105
  def from_nx(cls, graph: nx.Graph):
 
153
  params = {}
154
  for n, param in sig.parameters.items():
155
  if param.kind == param.KEYWORD_ONLY:
156
+ params[n] = Parameter.basic(n, param.default, param.annotation)
157
  outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later.
158
+ op = Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs, type=view)
159
  if sub_nodes is not None:
160
  op.sub_nodes = sub_nodes
161
  op.type = 'sub_flow'
server/pytorch_model_ops.py CHANGED
@@ -23,11 +23,11 @@ def register_layer(name):
23
  for name, param in sig.parameters.items()
24
  if param.kind != param.KEYWORD_ONLY}
25
  params = {
26
- name: ops.Parameter(name, param.default, param.annotation)
27
  for name, param in sig.parameters.items()
28
  if param.kind == param.KEYWORD_ONLY}
29
  outputs = {'x': 'tensor'}
30
- LAYERS[name] = ops.Op(func, name, params=params, inputs=inputs, outputs=outputs, type='vertical')
31
  return func
32
  return decorator
33
 
@@ -64,7 +64,7 @@ def nonlinearity(x, *, type: Nonlinearity):
64
 
65
  def register_area(name, params=[]):
66
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
67
- op = ops.Op(ops.no_op, name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
68
  LAYERS[name] = op
69
 
70
- register_area('Repeat', params=[ops.Parameter('times', 1, int)])
 
23
  for name, param in sig.parameters.items()
24
  if param.kind != param.KEYWORD_ONLY}
25
  params = {
26
+ name: ops.Parameter.basic(name, param.default, param.annotation)
27
  for name, param in sig.parameters.items()
28
  if param.kind == param.KEYWORD_ONLY}
29
  outputs = {'x': 'tensor'}
30
+ LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs, type='vertical')
31
  return func
32
  return decorator
33
 
 
64
 
65
  def register_area(name, params=[]):
66
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
67
+ op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
68
  LAYERS[name] = op
69
 
70
+ register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)])
web/src/LynxKiteFlow.svelte CHANGED
@@ -31,7 +31,7 @@
31
  const backendWorkspace = useQuery(['workspace', path], async () => {
32
  const res = await fetch(`/api/load?path=${path}`);
33
  return res.json();
34
- }, {staleTime: 10000});
35
  const mutation = useMutation(async(update) => {
36
  const res = await fetch('/api/save', {
37
  method: 'POST',
@@ -79,13 +79,20 @@
79
  };
80
  }
81
  function addNode(e) {
82
- const node = {...e.detail};
83
  nodes.update((n) => {
 
 
 
 
 
 
 
 
 
 
84
  node.position = screenToFlowPosition({x: nodeSearchSettings.pos.x, y: nodeSearchSettings.pos.y});
85
- node.data = { ...node.data };
86
  const title = node.data.title;
87
- node.data.params = Object.fromEntries(
88
- node.data.params.map((p) => [p.name, p.default]));
89
  let i = 1;
90
  node.id = `${title} ${i}`;
91
  while (n.find((x) => x.id === node.id)) {
 
31
  const backendWorkspace = useQuery(['workspace', path], async () => {
32
  const res = await fetch(`/api/load?path=${path}`);
33
  return res.json();
34
+ }, {staleTime: 10000, retry: false});
35
  const mutation = useMutation(async(update) => {
36
  const res = await fetch('/api/save', {
37
  method: 'POST',
 
79
  };
80
  }
81
  function addNode(e) {
82
+ const meta = {...e.detail};
83
  nodes.update((n) => {
84
+ const node = {
85
+ type: meta.type,
86
+ data: {
87
+ title: meta.name,
88
+ params: Object.fromEntries(
89
+ Object.values(meta.params).map((p) => [p.name, p.default])),
90
+ inputs: meta.inputs,
91
+ outputs: meta.outputs,
92
+ },
93
+ };
94
  node.position = screenToFlowPosition({x: nodeSearchSettings.pos.x, y: nodeSearchSettings.pos.y});
 
95
  const title = node.data.title;
 
 
96
  let i = 1;
97
  node.id = `${title} ${i}`;
98
  while (n.find((x) => x.id === node.id)) {
web/src/NodeSearch.svelte CHANGED
@@ -9,7 +9,7 @@
9
  let selectedIndex = 0;
10
  onMount(() => searchBox.focus());
11
  $: fuse = new Fuse(boxes, {
12
- keys: ['data.title']
13
  })
14
  function onInput() {
15
  hits = fuse.search(searchBox.value);
@@ -58,7 +58,7 @@
58
  on:click={addSelected}
59
  class="search-result"
60
  class:selected={index == selectedIndex}>
61
- {box.item.data.title}
62
  </div>
63
  {/each}
64
  </div>
 
9
  let selectedIndex = 0;
10
  onMount(() => searchBox.focus());
11
  $: fuse = new Fuse(boxes, {
12
+ keys: ['name']
13
  })
14
  function onInput() {
15
  hits = fuse.search(searchBox.value);
 
58
  on:click={addSelected}
59
  class="search-result"
60
  class:selected={index == selectedIndex}>
61
+ {box.item.name}
62
  </div>
63
  {/each}
64
  </div>