Spaces:
Sleeping
Sleeping
File size: 2,110 Bytes
3010d5b d994c06 3010d5b e7fa7ee 3010d5b d994c06 e7fa7ee d994c06 3010d5b 6a24dfe 3010d5b 942065e 3010d5b 6a24dfe 3010d5b b7a4f8b 3010d5b b7a4f8b 3010d5b d994c06 3010d5b b7a4f8b 3010d5b d994c06 3010d5b b7a4f8b 3010d5b b7a4f8b 942065e b7a4f8b 942065e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
'''Boxes for defining and using PyTorch models.'''
from enum import Enum
import inspect
from . import ops
LAYERS = {}
op = ops.op_registration('LynxKite')
@op("Define PyTorch model", sub_nodes=LAYERS)
def define_pytorch_model(*, sub_flow):
print('sub_flow:', sub_flow)
return ops.Bundle(other={'model': str(sub_flow)})
@op("Train PyTorch model")
def train_pytorch_model(model, graph):
# import torch # Lazy import because it's slow.
return 'hello ' + str(model)
def register_layer(name):
def decorator(func):
sig = inspect.signature(func)
inputs = {
name: ops.Input(name=name, type=param.annotation, position='bottom')
for name, param in sig.parameters.items()
if param.kind != param.KEYWORD_ONLY}
params = {
name: ops.Parameter.basic(name, param.default, param.annotation)
for name, param in sig.parameters.items()
if param.kind == param.KEYWORD_ONLY}
outputs = {'x': ops.Output(name='x', type='tensor', position='top')}
LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs)
return func
return decorator
@register_layer('LayerNorm')
def layernorm(x):
return 'LayerNorm'
@register_layer('Dropout')
def dropout(x, *, p=0.5):
return f'Dropout ({p})'
@register_layer('Linear')
def linear(*, output_dim: int):
return f'Linear {output_dim}'
class GraphConv(Enum):
GCNConv = 'GCNConv'
GATConv = 'GATConv'
GATv2Conv = 'GATv2Conv'
SAGEConv = 'SAGEConv'
@register_layer('Graph Convolution')
def graph_convolution(x, edges, *, type: GraphConv):
return 'GraphConv'
class Nonlinearity(Enum):
Mish = 'Mish'
ReLU = 'ReLU'
Tanh = 'Tanh'
@register_layer('Nonlinearity')
def nonlinearity(x, *, type: Nonlinearity):
return 'ReLU'
def register_area(name, params=[]):
'''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
LAYERS[name] = op
register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)])
|