File size: 2,080 Bytes
3010d5b
d994c06
3010d5b
 
 
 
 
 
 
 
d994c06
 
 
 
 
 
3010d5b
 
 
 
 
6a24dfe
3010d5b
 
 
942065e
3010d5b
 
6a24dfe
 
3010d5b
 
 
 
b7a4f8b
3010d5b
 
 
b7a4f8b
3010d5b
 
 
 
 
 
d994c06
 
 
 
 
 
3010d5b
b7a4f8b
3010d5b
 
d994c06
 
 
 
 
3010d5b
b7a4f8b
3010d5b
 
b7a4f8b
 
942065e
b7a4f8b
 
942065e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''Boxes for defining and using PyTorch models.'''
from enum import Enum
import inspect
from . import ops

LAYERS = {}

@ops.op("Define PyTorch model", sub_nodes=LAYERS)
def define_pytorch_model(*, sub_flow):
  print('sub_flow:', sub_flow)
  return ops.Bundle(other={'model': str(sub_flow)})

@ops.op("Train PyTorch model")
def train_pytorch_model(model, graph):
  # import torch # Lazy import because it's slow.
  return 'hello ' + str(model)

def register_layer(name):
  def decorator(func):
    sig = inspect.signature(func)
    inputs = {
      name: ops.Input(name=name, type=param.annotation, position='bottom')
      for name, param in sig.parameters.items()
      if param.kind != param.KEYWORD_ONLY}
    params = {
      name: ops.Parameter.basic(name, param.default, param.annotation)
      for name, param in sig.parameters.items()
      if param.kind == param.KEYWORD_ONLY}
    outputs = {'x': ops.Output(name='x', type='tensor', position='top')}
    LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs)
    return func
  return decorator

@register_layer('LayerNorm')
def layernorm(x):
  return 'LayerNorm'

@register_layer('Dropout')
def dropout(x, *, p=0.5):
  return f'Dropout ({p})'

@register_layer('Linear')
def linear(*, output_dim: int):
  return f'Linear {output_dim}'

class GraphConv(Enum):
  GCNConv = 'GCNConv'
  GATConv = 'GATConv'
  GATv2Conv = 'GATv2Conv'
  SAGEConv = 'SAGEConv'

@register_layer('Graph Convolution')
def graph_convolution(x, edges, *, type: GraphConv):
  return 'GraphConv'

class Nonlinearity(Enum):
  Mish = 'Mish'
  ReLU = 'ReLU'
  Tanh = 'Tanh'

@register_layer('Nonlinearity')
def nonlinearity(x, *, type: Nonlinearity):
  return 'ReLU'

def register_area(name, params=[]):
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
  op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
  LAYERS[name] = op

register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)])