Spaces:
Running
Running
File size: 2,971 Bytes
c3b473d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""Boxes for defining PyTorch models."""
import enum
from lynxkite.core import ops
from lynxkite.core.ops import Parameter as P
import torch
import torch_geometric.nn as pyg_nn
from .core import op, reg, ENV
reg("Input: tensor", outputs=["output"], params=[P.basic("name")])
reg("Input: graph edges", outputs=["edges"])
reg("Input: sequential", outputs=["y"])
reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
reg(
"Neural ODE",
inputs=["x"],
params=[
P.basic("relative_tolerance"),
P.basic("absolute_tolerance"),
P.options(
"method",
[
"dopri8",
"dopri5",
"bosh3",
"fehlberg2",
"adaptive_heun",
"euler",
"midpoint",
"rk4",
"explicit_adams",
"implicit_adams",
],
),
],
)
reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
reg("LayerNorm", inputs=["x"])
reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
@op("Linear")
def linear(x, *, output_dim=1024):
return pyg_nn.Linear(-1, output_dim)
class ActivationTypes(enum.Enum):
ReLU = "ReLU"
Leaky_ReLU = "Leaky ReLU"
Tanh = "Tanh"
Mish = "Mish"
@op("Activation")
def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU):
return getattr(torch.nn.functional, type.name.lower().replace(" ", "_"))
@op("MSE loss")
def mse_loss(x, y):
return torch.nn.functional.mse_loss
reg("Softmax", inputs=["x"])
reg(
"Graph conv",
inputs=["x", "edges"],
outputs=["x"],
params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])],
)
reg("Concatenate", inputs=["a", "b"], outputs=["x"])
reg("Add", inputs=["a", "b"], outputs=["x"])
reg("Subtract", inputs=["a", "b"], outputs=["x"])
reg("Multiply", inputs=["a", "b"], outputs=["x"])
reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
reg(
"Optimizer",
inputs=["loss"],
outputs=[],
params=[
P.options(
"type",
[
"AdamW",
"Adafactor",
"Adagrad",
"SGD",
"Lion",
"Paged AdamW",
"Galore AdamW",
],
),
P.basic("lr", 0.001),
],
)
ops.register_passive_op(
ENV,
"Repeat",
inputs=[ops.Input(name="input", position="top", type="tensor")],
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
params=[
ops.Parameter.basic("times", 1, int),
ops.Parameter.basic("same_weights", False, bool),
],
)
ops.register_passive_op(
ENV,
"Recurrent chain",
inputs=[ops.Input(name="input", position="top", type="tensor")],
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
params=[],
)
|