"""Boxes for defining PyTorch models.""" from lynxkite.core import ops, workspace from lynxkite.core.ops import Parameter as P import torch import torch_geometric as pyg ENV = "PyTorch model" def reg(name, inputs=[], outputs=None, params=[]): if outputs is None: outputs = inputs return ops.register_passive_op( ENV, name, inputs=[ ops.Input(name=name, position="bottom", type="tensor") for name in inputs ], outputs=[ ops.Output(name=name, position="top", type="tensor") for name in outputs ], params=params, ) reg("Input: embedding", outputs=["x"]) reg("Input: graph edges", outputs=["edges"]) reg("Input: label", outputs=["y"]) reg("Input: positive sample", outputs=["x_pos"]) reg("Input: negative sample", outputs=["x_neg"]) reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"]) reg("LayerNorm", inputs=["x"]) reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)]) reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")]) reg("Softmax", inputs=["x"]) reg( "Graph conv", inputs=["x", "edges"], outputs=["x"], params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])], ) reg( "Activation", inputs=["x"], params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])], ) reg("Concatenate", inputs=["a", "b"], outputs=["x"]) reg("Add", inputs=["a", "b"], outputs=["x"]) reg("Subtract", inputs=["a", "b"], outputs=["x"]) reg("Multiply", inputs=["a", "b"], outputs=["x"]) reg("MSE loss", inputs=["x", "y"], outputs=["loss"]) reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"]) reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"]) reg( "Optimizer", inputs=["loss"], outputs=[], params=[ P.options( "type", [ "AdamW", "Adafactor", "Adagrad", "SGD", "Lion", "Paged AdamW", "Galore AdamW", ], ), P.basic("lr", 0.001), ], ) ops.register_passive_op( ENV, "Repeat", inputs=[ops.Input(name="input", position="top", type="tensor")], outputs=[ops.Output(name="output", position="bottom", type="tensor")], params=[ops.Parameter.basic("times", 1, int)], ) def build_model(ws: workspace.Workspace, inputs: dict): """Builds the model described in the workspace.""" optimizers = [] for node in ws.nodes: if node.op.name == "Optimizer": optimizers.append(node) assert optimizers, "No optimizer found." assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}" [optimizer] = optimizers inputs = {n.id: [] for n in ws.nodes} for e in ws.edges: inputs[e.target].append(e.source) layers = [] def build_model(cfg, device, dropout=None): F.triplet_margin_loss layers.append((pyg.nn.Linear(E, H), "x -> x")) layers.append((torch.nn.LayerNorm(H), "x -> x")) for i in range(cfg.attention_layers): layers.append( (torch.nn.MultiheadAttention(H, 1, batch_first=True), "x, x, x -> x") ) # Pick values, not weights. layers.append(lambda res: res[0]) layers.append(torch.nn.LayerNorm(H)) # Just take the first token embedding after attention? layers.append(lambda res: res[:, 0, :]) encoder = pyg.nn.Sequential("x", layers).to(device) for i in range(cfg.gnn_layers): layers.append((cfg.conv(E, H), "x, edge_index -> x")) if dropout: layers.append(torch.nn.Dropout(dropout)) layers.append(cfg.activation()) for i in range(cfg.mlp_layers - 1): layers.append((pyg.nn.Linear(E, H), "x -> x")) if dropout: layers.append(torch.nn.Dropout(dropout)) layers.append(cfg.activation()) layers.append((pyg.nn.Linear(E, H), "x -> x")) if cfg.predict == "remaining_steps": assert cfg.loss_fn != F.triplet_margin_loss, ( "Triplet loss is only for embedding outputs." ) layers.append((pyg.nn.Linear(E, 1), "x -> x")) elif cfg.predict == "tactics": assert cfg.loss_fn == F.cross_entropy, ( "Use cross entropy for tactic prediction." ) layers.append((pyg.nn.Linear(E, len(TACTICS)), "x -> x")) elif cfg.predict == "link_likelihood_for_states": pass # Just output the embedding. elif cfg.embedding["method"] != "learned": layers.append((pyg.nn.Linear(E, E), "x -> x")) m = pyg.nn.Sequential("x, edge_index", layers).to(device) if cfg.predict == "link_likelihood_for_states": # The comparator takes two embeddings (state and theorem) and predicts the link. layers = [] layers.append( ( lambda state, theorem: torch.cat([state, theorem], dim=1), "state, theorem -> x", ) ) for i in range(cfg.comparator_layers): layers.append((pyg.nn.Linear(E, H), "x -> x")) if dropout: layers.append(torch.nn.Dropout(dropout)) layers.append(cfg.activation()) assert cfg.loss_fn != F.triplet_margin_loss, ( "Triplet loss is only for embedding outputs." ) layers.append((pyg.nn.Linear(E, 1), "x -> x")) # Sigmoid activation at the end to get a probability. layers.append((torch.nn.Sigmoid(), "x -> x")) m.comparator = pyg.nn.Sequential("state, theorem", layers).to(device) if encoder and cfg.predict in ["nodes", "links", "links_for_states"]: m.encoder = encoder return m