File size: 2,110 Bytes
3010d5b
d994c06
3010d5b
 
 
 
 
e7fa7ee
 
 
3010d5b
 
d994c06
 
e7fa7ee
d994c06
 
 
3010d5b
 
 
 
 
6a24dfe
3010d5b
 
 
942065e
3010d5b
 
6a24dfe
 
3010d5b
 
 
 
b7a4f8b
3010d5b
 
 
b7a4f8b
3010d5b
 
 
 
 
 
d994c06
 
 
 
 
 
3010d5b
b7a4f8b
3010d5b
 
d994c06
 
 
 
 
3010d5b
b7a4f8b
3010d5b
 
b7a4f8b
 
942065e
b7a4f8b
 
942065e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''Boxes for defining and using PyTorch models.'''
from enum import Enum
import inspect
from . import ops

LAYERS = {}

op = ops.op_registration('LynxKite')

@op("Define PyTorch model", sub_nodes=LAYERS)
def define_pytorch_model(*, sub_flow):
  print('sub_flow:', sub_flow)
  return ops.Bundle(other={'model': str(sub_flow)})

@op("Train PyTorch model")
def train_pytorch_model(model, graph):
  # import torch # Lazy import because it's slow.
  return 'hello ' + str(model)

def register_layer(name):
  def decorator(func):
    sig = inspect.signature(func)
    inputs = {
      name: ops.Input(name=name, type=param.annotation, position='bottom')
      for name, param in sig.parameters.items()
      if param.kind != param.KEYWORD_ONLY}
    params = {
      name: ops.Parameter.basic(name, param.default, param.annotation)
      for name, param in sig.parameters.items()
      if param.kind == param.KEYWORD_ONLY}
    outputs = {'x': ops.Output(name='x', type='tensor', position='top')}
    LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs)
    return func
  return decorator

@register_layer('LayerNorm')
def layernorm(x):
  return 'LayerNorm'

@register_layer('Dropout')
def dropout(x, *, p=0.5):
  return f'Dropout ({p})'

@register_layer('Linear')
def linear(*, output_dim: int):
  return f'Linear {output_dim}'

class GraphConv(Enum):
  GCNConv = 'GCNConv'
  GATConv = 'GATConv'
  GATv2Conv = 'GATv2Conv'
  SAGEConv = 'SAGEConv'

@register_layer('Graph Convolution')
def graph_convolution(x, edges, *, type: GraphConv):
  return 'GraphConv'

class Nonlinearity(Enum):
  Mish = 'Mish'
  ReLU = 'ReLU'
  Tanh = 'Tanh'

@register_layer('Nonlinearity')
def nonlinearity(x, *, type: Nonlinearity):
  return 'ReLU'

def register_area(name, params=[]):
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
  op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
  LAYERS[name] = op

register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)])