Spaces:
Running
Running
"""Boxes for defining PyTorch models.""" | |
import enum | |
from lynxkite.core import ops | |
from lynxkite.core.ops import Parameter as P | |
import torch | |
import torch_geometric.nn as pyg_nn | |
from .core import op, reg, ENV | |
reg("Input: tensor", outputs=["output"], params=[P.basic("name")]) | |
reg("Input: graph edges", outputs=["edges"]) | |
reg("Input: sequential", outputs=["y"]) | |
reg("Output", inputs=["x"], outputs=["x"], params=[P.basic("name")]) | |
def lstm(x, *, input_size=1024, hidden_size=1024, dropout=0.0): | |
return torch.nn.LSTM(input_size, hidden_size, dropout=0.0) | |
reg( | |
"Neural ODE", | |
inputs=["x"], | |
params=[ | |
P.basic("relative_tolerance"), | |
P.basic("absolute_tolerance"), | |
P.options( | |
"method", | |
[ | |
"dopri8", | |
"dopri5", | |
"bosh3", | |
"fehlberg2", | |
"adaptive_heun", | |
"euler", | |
"midpoint", | |
"rk4", | |
"explicit_adams", | |
"implicit_adams", | |
], | |
), | |
], | |
) | |
def attention(query, key, value, *, embed_dim=1024, num_heads=1, dropout=0.0): | |
return torch.nn.MultiHeadAttention(embed_dim, num_heads, dropout=dropout, need_weights=True) | |
def layernorm(x, *, normalized_shape=""): | |
normalized_shape = [int(s.strip()) for s in normalized_shape.split(",")] | |
return torch.nn.LayerNorm(normalized_shape) | |
def dropout(x, *, p=0.0): | |
return torch.nn.Dropout(p) | |
def linear(x, *, output_dim=1024): | |
return pyg_nn.Linear(-1, output_dim) | |
class ActivationTypes(enum.Enum): | |
ReLU = "ReLU" | |
Leaky_ReLU = "Leaky ReLU" | |
Tanh = "Tanh" | |
Mish = "Mish" | |
def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU): | |
return getattr(torch.nn.functional, type.name.lower().replace(" ", "_")) | |
def mse_loss(x, y): | |
return torch.nn.functional.mse_loss | |
def constant_vector(*, value=0, size=1): | |
return lambda _: torch.full((size,), value) | |
def softmax(x, *, dim=1): | |
return torch.nn.Softmax(dim=dim) | |
def concatenate(a, b): | |
return lambda a, b: torch.concatenate(*torch.broadcast_tensors(a, b)) | |
reg( | |
"Graph conv", | |
inputs=["x", "edges"], | |
outputs=["x"], | |
params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])], | |
) | |
reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"]) | |
reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"]) | |
reg( | |
"Optimizer", | |
inputs=["loss"], | |
outputs=[], | |
params=[ | |
P.options( | |
"type", | |
[ | |
"AdamW", | |
"Adafactor", | |
"Adagrad", | |
"SGD", | |
"Lion", | |
"Paged AdamW", | |
"Galore AdamW", | |
], | |
), | |
P.basic("lr", 0.001), | |
], | |
) | |
ops.register_passive_op( | |
ENV, | |
"Repeat", | |
inputs=[ops.Input(name="input", position="top", type="tensor")], | |
outputs=[ops.Output(name="output", position="bottom", type="tensor")], | |
params=[ | |
ops.Parameter.basic("times", 1, int), | |
ops.Parameter.basic("same_weights", False, bool), | |
], | |
) | |
ops.register_passive_op( | |
ENV, | |
"Recurrent chain", | |
inputs=[ops.Input(name="input", position="top", type="tensor")], | |
outputs=[ops.Output(name="output", position="bottom", type="tensor")], | |
params=[], | |
) | |
def _set_handle_positions(op): | |
op: ops.Op = op.__op__ | |
for v in op.outputs.values(): | |
v.position = "top" | |
for v in op.inputs.values(): | |
v.position = "bottom" | |
def _register_simple_pytorch_layer(func): | |
op = ops.op(ENV, func.__name__.title())(lambda input: func) | |
_set_handle_positions(op) | |
def _register_two_tensor_function(func): | |
op = ops.op(ENV, func.__name__.title())(lambda a, b: func) | |
_set_handle_positions(op) | |
SIMPLE_FUNCTIONS = [ | |
torch.sin, | |
torch.cos, | |
torch.log, | |
torch.exp, | |
] | |
TWO_TENSOR_FUNCTIONS = [ | |
torch.multiply, | |
torch.add, | |
torch.subtract, | |
] | |
for f in SIMPLE_FUNCTIONS: | |
_register_simple_pytorch_layer(f) | |
for f in TWO_TENSOR_FUNCTIONS: | |
_register_two_tensor_function(f) | |