Spaces:
Running
Running
Getting started on model designer implementation. TMP
Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
"""Boxes for defining PyTorch models."""
|
2 |
|
3 |
-
from lynxkite.core import ops
|
4 |
from lynxkite.core.ops import Parameter as P
|
|
|
|
|
5 |
|
6 |
ENV = "PyTorch model"
|
7 |
|
@@ -22,16 +24,17 @@ def reg(name, inputs=[], outputs=None, params=[]):
|
|
22 |
)
|
23 |
|
24 |
|
25 |
-
reg("Input:
|
26 |
reg("Input: graph edges", outputs=["edges"])
|
27 |
reg("Input: label", outputs=["y"])
|
28 |
reg("Input: positive sample", outputs=["x_pos"])
|
29 |
reg("Input: negative sample", outputs=["x_neg"])
|
30 |
|
31 |
-
reg("Attention", inputs=["q", "k", "v"], outputs=["x"])
|
32 |
reg("LayerNorm", inputs=["x"])
|
33 |
reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
|
34 |
reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
|
|
|
35 |
reg(
|
36 |
"Graph conv",
|
37 |
inputs=["x", "edges"],
|
@@ -43,8 +46,13 @@ reg(
|
|
43 |
inputs=["x"],
|
44 |
params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])],
|
45 |
)
|
46 |
-
reg("
|
47 |
-
reg("
|
|
|
|
|
|
|
|
|
|
|
48 |
reg(
|
49 |
"Optimizer",
|
50 |
inputs=["loss"],
|
@@ -73,3 +81,84 @@ ops.register_passive_op(
|
|
73 |
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
|
74 |
params=[ops.Parameter.basic("times", 1, int)],
|
75 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Boxes for defining PyTorch models."""
|
2 |
|
3 |
+
from lynxkite.core import ops, workspace
|
4 |
from lynxkite.core.ops import Parameter as P
|
5 |
+
import torch
|
6 |
+
import torch_geometric as pyg
|
7 |
|
8 |
ENV = "PyTorch model"
|
9 |
|
|
|
24 |
)
|
25 |
|
26 |
|
27 |
+
reg("Input: embedding", outputs=["x"])
|
28 |
reg("Input: graph edges", outputs=["edges"])
|
29 |
reg("Input: label", outputs=["y"])
|
30 |
reg("Input: positive sample", outputs=["x_pos"])
|
31 |
reg("Input: negative sample", outputs=["x_neg"])
|
32 |
|
33 |
+
reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
|
34 |
reg("LayerNorm", inputs=["x"])
|
35 |
reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
|
36 |
reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
|
37 |
+
reg("Softmax", inputs=["x"])
|
38 |
reg(
|
39 |
"Graph conv",
|
40 |
inputs=["x", "edges"],
|
|
|
46 |
inputs=["x"],
|
47 |
params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])],
|
48 |
)
|
49 |
+
reg("Concatenate", inputs=["a", "b"], outputs=["x"])
|
50 |
+
reg("Add", inputs=["a", "b"], outputs=["x"])
|
51 |
+
reg("Subtract", inputs=["a", "b"], outputs=["x"])
|
52 |
+
reg("Multiply", inputs=["a", "b"], outputs=["x"])
|
53 |
+
reg("MSE loss", inputs=["x", "y"], outputs=["loss"])
|
54 |
+
reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
|
55 |
+
reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
|
56 |
reg(
|
57 |
"Optimizer",
|
58 |
inputs=["loss"],
|
|
|
81 |
outputs=[ops.Output(name="output", position="bottom", type="tensor")],
|
82 |
params=[ops.Parameter.basic("times", 1, int)],
|
83 |
)
|
84 |
+
|
85 |
+
|
86 |
+
def build_model(ws: workspace.Workspace, inputs: dict):
|
87 |
+
"""Builds the model described in the workspace."""
|
88 |
+
optimizers = []
|
89 |
+
for node in ws.nodes:
|
90 |
+
if node.op.name == "Optimizer":
|
91 |
+
optimizers.append(node)
|
92 |
+
assert optimizers, "No optimizer found."
|
93 |
+
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
94 |
+
[optimizer] = optimizers
|
95 |
+
inputs = {n.id: [] for n in ws.nodes}
|
96 |
+
for e in ws.edges:
|
97 |
+
inputs[e.target].append(e.source)
|
98 |
+
layers = []
|
99 |
+
|
100 |
+
|
101 |
+
def build_model(cfg, device, dropout=None):
|
102 |
+
F.triplet_margin_loss
|
103 |
+
layers.append((pyg.nn.Linear(E, H), "x -> x"))
|
104 |
+
layers.append((torch.nn.LayerNorm(H), "x -> x"))
|
105 |
+
for i in range(cfg.attention_layers):
|
106 |
+
layers.append(
|
107 |
+
(torch.nn.MultiheadAttention(H, 1, batch_first=True), "x, x, x -> x")
|
108 |
+
)
|
109 |
+
# Pick values, not weights.
|
110 |
+
layers.append(lambda res: res[0])
|
111 |
+
layers.append(torch.nn.LayerNorm(H))
|
112 |
+
# Just take the first token embedding after attention?
|
113 |
+
layers.append(lambda res: res[:, 0, :])
|
114 |
+
encoder = pyg.nn.Sequential("x", layers).to(device)
|
115 |
+
for i in range(cfg.gnn_layers):
|
116 |
+
layers.append((cfg.conv(E, H), "x, edge_index -> x"))
|
117 |
+
if dropout:
|
118 |
+
layers.append(torch.nn.Dropout(dropout))
|
119 |
+
layers.append(cfg.activation())
|
120 |
+
for i in range(cfg.mlp_layers - 1):
|
121 |
+
layers.append((pyg.nn.Linear(E, H), "x -> x"))
|
122 |
+
if dropout:
|
123 |
+
layers.append(torch.nn.Dropout(dropout))
|
124 |
+
layers.append(cfg.activation())
|
125 |
+
layers.append((pyg.nn.Linear(E, H), "x -> x"))
|
126 |
+
if cfg.predict == "remaining_steps":
|
127 |
+
assert cfg.loss_fn != F.triplet_margin_loss, (
|
128 |
+
"Triplet loss is only for embedding outputs."
|
129 |
+
)
|
130 |
+
layers.append((pyg.nn.Linear(E, 1), "x -> x"))
|
131 |
+
elif cfg.predict == "tactics":
|
132 |
+
assert cfg.loss_fn == F.cross_entropy, (
|
133 |
+
"Use cross entropy for tactic prediction."
|
134 |
+
)
|
135 |
+
layers.append((pyg.nn.Linear(E, len(TACTICS)), "x -> x"))
|
136 |
+
elif cfg.predict == "link_likelihood_for_states":
|
137 |
+
pass # Just output the embedding.
|
138 |
+
elif cfg.embedding["method"] != "learned":
|
139 |
+
layers.append((pyg.nn.Linear(E, E), "x -> x"))
|
140 |
+
m = pyg.nn.Sequential("x, edge_index", layers).to(device)
|
141 |
+
if cfg.predict == "link_likelihood_for_states":
|
142 |
+
# The comparator takes two embeddings (state and theorem) and predicts the link.
|
143 |
+
layers = []
|
144 |
+
layers.append(
|
145 |
+
(
|
146 |
+
lambda state, theorem: torch.cat([state, theorem], dim=1),
|
147 |
+
"state, theorem -> x",
|
148 |
+
)
|
149 |
+
)
|
150 |
+
for i in range(cfg.comparator_layers):
|
151 |
+
layers.append((pyg.nn.Linear(E, H), "x -> x"))
|
152 |
+
if dropout:
|
153 |
+
layers.append(torch.nn.Dropout(dropout))
|
154 |
+
layers.append(cfg.activation())
|
155 |
+
assert cfg.loss_fn != F.triplet_margin_loss, (
|
156 |
+
"Triplet loss is only for embedding outputs."
|
157 |
+
)
|
158 |
+
layers.append((pyg.nn.Linear(E, 1), "x -> x"))
|
159 |
+
# Sigmoid activation at the end to get a probability.
|
160 |
+
layers.append((torch.nn.Sigmoid(), "x -> x"))
|
161 |
+
m.comparator = pyg.nn.Sequential("state, theorem", layers).to(device)
|
162 |
+
if encoder and cfg.predict in ["nodes", "links", "links_for_states"]:
|
163 |
+
m.encoder = encoder
|
164 |
+
return m
|