"""Boxes for defining PyTorch models.""" import enum from lynxkite.core import ops from lynxkite.core.ops import Parameter as P import torch import torch_geometric.nn as pyg_nn from .core import op, reg, ENV reg("Input: tensor", outputs=["output"], params=[P.basic("name")]) reg("Input: graph edges", outputs=["edges"]) reg("Input: sequential", outputs=["y"]) reg("LSTM", inputs=["x", "h"], outputs=["x", "h"]) reg( "Neural ODE", inputs=["x"], params=[ P.basic("relative_tolerance"), P.basic("absolute_tolerance"), P.options( "method", [ "dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun", "euler", "midpoint", "rk4", "explicit_adams", "implicit_adams", ], ), ], ) reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"]) reg("LayerNorm", inputs=["x"]) reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)]) @op("Linear") def linear(x, *, output_dim=1024): return pyg_nn.Linear(-1, output_dim) class ActivationTypes(enum.Enum): ReLU = "ReLU" Leaky_ReLU = "Leaky ReLU" Tanh = "Tanh" Mish = "Mish" @op("Activation") def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU): return getattr(torch.nn.functional, type.name.lower().replace(" ", "_")) @op("MSE loss") def mse_loss(x, y): return torch.nn.functional.mse_loss reg("Softmax", inputs=["x"]) reg( "Graph conv", inputs=["x", "edges"], outputs=["x"], params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])], ) reg("Concatenate", inputs=["a", "b"], outputs=["x"]) reg("Add", inputs=["a", "b"], outputs=["x"]) reg("Subtract", inputs=["a", "b"], outputs=["x"]) reg("Multiply", inputs=["a", "b"], outputs=["x"]) reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"]) reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"]) reg( "Optimizer", inputs=["loss"], outputs=[], params=[ P.options( "type", [ "AdamW", "Adafactor", "Adagrad", "SGD", "Lion", "Paged AdamW", "Galore AdamW", ], ), P.basic("lr", 0.001), ], ) ops.register_passive_op( ENV, "Repeat", inputs=[ops.Input(name="input", position="top", type="tensor")], outputs=[ops.Output(name="output", position="bottom", type="tensor")], params=[ ops.Parameter.basic("times", 1, int), ops.Parameter.basic("same_weights", False, bool), ], ) ops.register_passive_op( ENV, "Recurrent chain", inputs=[ops.Input(name="input", position="top", type="tensor")], outputs=[ops.Output(name="output", position="bottom", type="tensor")], params=[], )