Spaces:
Running
Running
'''Boxes for defining and using PyTorch models.''' | |
from enum import Enum | |
import inspect | |
from . import ops | |
LAYERS = {} | |
def define_pytorch_model(*, sub_flow): | |
print('sub_flow:', sub_flow) | |
return ops.Bundle(other={'model': str(sub_flow)}) | |
def train_pytorch_model(model, graph): | |
# import torch # Lazy import because it's slow. | |
return 'hello ' + str(model) | |
def register_layer(name): | |
def decorator(func): | |
sig = inspect.signature(func) | |
inputs = { | |
name: ops.Input(name=name, type=param.annotation, position='bottom') | |
for name, param in sig.parameters.items() | |
if param.kind != param.KEYWORD_ONLY} | |
params = { | |
name: ops.Parameter.basic(name, param.default, param.annotation) | |
for name, param in sig.parameters.items() | |
if param.kind == param.KEYWORD_ONLY} | |
outputs = {'x': ops.Output(name='x', type='tensor', position='top')} | |
LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs) | |
return func | |
return decorator | |
def layernorm(x): | |
return 'LayerNorm' | |
def dropout(x, *, p=0.5): | |
return f'Dropout ({p})' | |
def linear(*, output_dim: int): | |
return f'Linear {output_dim}' | |
class GraphConv(Enum): | |
GCNConv = 'GCNConv' | |
GATConv = 'GATConv' | |
GATv2Conv = 'GATv2Conv' | |
SAGEConv = 'SAGEConv' | |
def graph_convolution(x, edges, *, type: GraphConv): | |
return 'GraphConv' | |
class Nonlinearity(Enum): | |
Mish = 'Mish' | |
ReLU = 'ReLU' | |
Tanh = 'Tanh' | |
def nonlinearity(x, *, type: Nonlinearity): | |
return 'ReLU' | |
def register_area(name, params=[]): | |
'''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.''' | |
op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area') | |
LAYERS[name] = op | |
register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)]) | |