lynxkite / server /pytorch_model_ops.py
darabos's picture
Control plug position for each input/output.
6a24dfe
raw
history blame
2.08 kB
'''Boxes for defining and using PyTorch models.'''
from enum import Enum
import inspect
from . import ops
LAYERS = {}
@ops.op("Define PyTorch model", sub_nodes=LAYERS)
def define_pytorch_model(*, sub_flow):
print('sub_flow:', sub_flow)
return ops.Bundle(other={'model': str(sub_flow)})
@ops.op("Train PyTorch model")
def train_pytorch_model(model, graph):
# import torch # Lazy import because it's slow.
return 'hello ' + str(model)
def register_layer(name):
def decorator(func):
sig = inspect.signature(func)
inputs = {
name: ops.Input(name=name, type=param.annotation, position='bottom')
for name, param in sig.parameters.items()
if param.kind != param.KEYWORD_ONLY}
params = {
name: ops.Parameter.basic(name, param.default, param.annotation)
for name, param in sig.parameters.items()
if param.kind == param.KEYWORD_ONLY}
outputs = {'x': ops.Output(name='x', type='tensor', position='top')}
LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs)
return func
return decorator
@register_layer('LayerNorm')
def layernorm(x):
return 'LayerNorm'
@register_layer('Dropout')
def dropout(x, *, p=0.5):
return f'Dropout ({p})'
@register_layer('Linear')
def linear(*, output_dim: int):
return f'Linear {output_dim}'
class GraphConv(Enum):
GCNConv = 'GCNConv'
GATConv = 'GATConv'
GATv2Conv = 'GATv2Conv'
SAGEConv = 'SAGEConv'
@register_layer('Graph Convolution')
def graph_convolution(x, edges, *, type: GraphConv):
return 'GraphConv'
class Nonlinearity(Enum):
Mish = 'Mish'
ReLU = 'ReLU'
Tanh = 'Tanh'
@register_layer('Nonlinearity')
def nonlinearity(x, *, type: Nonlinearity):
return 'ReLU'
def register_area(name, params=[]):
'''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
LAYERS[name] = op
register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)])