darabos commited on
Commit
09b8d25
·
1 Parent(s): 6216561

Getting started on model designer implementation. TMP

Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -1,7 +1,9 @@
1
  """Boxes for defining PyTorch models."""
2
 
3
- from lynxkite.core import ops
4
  from lynxkite.core.ops import Parameter as P
 
 
5
 
6
  ENV = "PyTorch model"
7
 
@@ -22,16 +24,17 @@ def reg(name, inputs=[], outputs=None, params=[]):
22
  )
23
 
24
 
25
- reg("Input: features", outputs=["x"])
26
  reg("Input: graph edges", outputs=["edges"])
27
  reg("Input: label", outputs=["y"])
28
  reg("Input: positive sample", outputs=["x_pos"])
29
  reg("Input: negative sample", outputs=["x_neg"])
30
 
31
- reg("Attention", inputs=["q", "k", "v"], outputs=["x"])
32
  reg("LayerNorm", inputs=["x"])
33
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
34
  reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
 
35
  reg(
36
  "Graph conv",
37
  inputs=["x", "edges"],
@@ -43,8 +46,13 @@ reg(
43
  inputs=["x"],
44
  params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])],
45
  )
46
- reg("Supervised loss", inputs=["x", "y"], outputs=["loss"])
47
- reg("Triplet loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
 
 
 
 
 
48
  reg(
49
  "Optimizer",
50
  inputs=["loss"],
@@ -73,3 +81,84 @@ ops.register_passive_op(
73
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
74
  params=[ops.Parameter.basic("times", 1, int)],
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Boxes for defining PyTorch models."""
2
 
3
+ from lynxkite.core import ops, workspace
4
  from lynxkite.core.ops import Parameter as P
5
+ import torch
6
+ import torch_geometric as pyg
7
 
8
  ENV = "PyTorch model"
9
 
 
24
  )
25
 
26
 
27
+ reg("Input: embedding", outputs=["x"])
28
  reg("Input: graph edges", outputs=["edges"])
29
  reg("Input: label", outputs=["y"])
30
  reg("Input: positive sample", outputs=["x_pos"])
31
  reg("Input: negative sample", outputs=["x_neg"])
32
 
33
+ reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
34
  reg("LayerNorm", inputs=["x"])
35
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
36
  reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")])
37
+ reg("Softmax", inputs=["x"])
38
  reg(
39
  "Graph conv",
40
  inputs=["x", "edges"],
 
46
  inputs=["x"],
47
  params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])],
48
  )
49
+ reg("Concatenate", inputs=["a", "b"], outputs=["x"])
50
+ reg("Add", inputs=["a", "b"], outputs=["x"])
51
+ reg("Subtract", inputs=["a", "b"], outputs=["x"])
52
+ reg("Multiply", inputs=["a", "b"], outputs=["x"])
53
+ reg("MSE loss", inputs=["x", "y"], outputs=["loss"])
54
+ reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
55
+ reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
56
  reg(
57
  "Optimizer",
58
  inputs=["loss"],
 
81
  outputs=[ops.Output(name="output", position="bottom", type="tensor")],
82
  params=[ops.Parameter.basic("times", 1, int)],
83
  )
84
+
85
+
86
+ def build_model(ws: workspace.Workspace, inputs: dict):
87
+ """Builds the model described in the workspace."""
88
+ optimizers = []
89
+ for node in ws.nodes:
90
+ if node.op.name == "Optimizer":
91
+ optimizers.append(node)
92
+ assert optimizers, "No optimizer found."
93
+ assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
94
+ [optimizer] = optimizers
95
+ inputs = {n.id: [] for n in ws.nodes}
96
+ for e in ws.edges:
97
+ inputs[e.target].append(e.source)
98
+ layers = []
99
+
100
+
101
+ def build_model(cfg, device, dropout=None):
102
+ F.triplet_margin_loss
103
+ layers.append((pyg.nn.Linear(E, H), "x -> x"))
104
+ layers.append((torch.nn.LayerNorm(H), "x -> x"))
105
+ for i in range(cfg.attention_layers):
106
+ layers.append(
107
+ (torch.nn.MultiheadAttention(H, 1, batch_first=True), "x, x, x -> x")
108
+ )
109
+ # Pick values, not weights.
110
+ layers.append(lambda res: res[0])
111
+ layers.append(torch.nn.LayerNorm(H))
112
+ # Just take the first token embedding after attention?
113
+ layers.append(lambda res: res[:, 0, :])
114
+ encoder = pyg.nn.Sequential("x", layers).to(device)
115
+ for i in range(cfg.gnn_layers):
116
+ layers.append((cfg.conv(E, H), "x, edge_index -> x"))
117
+ if dropout:
118
+ layers.append(torch.nn.Dropout(dropout))
119
+ layers.append(cfg.activation())
120
+ for i in range(cfg.mlp_layers - 1):
121
+ layers.append((pyg.nn.Linear(E, H), "x -> x"))
122
+ if dropout:
123
+ layers.append(torch.nn.Dropout(dropout))
124
+ layers.append(cfg.activation())
125
+ layers.append((pyg.nn.Linear(E, H), "x -> x"))
126
+ if cfg.predict == "remaining_steps":
127
+ assert cfg.loss_fn != F.triplet_margin_loss, (
128
+ "Triplet loss is only for embedding outputs."
129
+ )
130
+ layers.append((pyg.nn.Linear(E, 1), "x -> x"))
131
+ elif cfg.predict == "tactics":
132
+ assert cfg.loss_fn == F.cross_entropy, (
133
+ "Use cross entropy for tactic prediction."
134
+ )
135
+ layers.append((pyg.nn.Linear(E, len(TACTICS)), "x -> x"))
136
+ elif cfg.predict == "link_likelihood_for_states":
137
+ pass # Just output the embedding.
138
+ elif cfg.embedding["method"] != "learned":
139
+ layers.append((pyg.nn.Linear(E, E), "x -> x"))
140
+ m = pyg.nn.Sequential("x, edge_index", layers).to(device)
141
+ if cfg.predict == "link_likelihood_for_states":
142
+ # The comparator takes two embeddings (state and theorem) and predicts the link.
143
+ layers = []
144
+ layers.append(
145
+ (
146
+ lambda state, theorem: torch.cat([state, theorem], dim=1),
147
+ "state, theorem -> x",
148
+ )
149
+ )
150
+ for i in range(cfg.comparator_layers):
151
+ layers.append((pyg.nn.Linear(E, H), "x -> x"))
152
+ if dropout:
153
+ layers.append(torch.nn.Dropout(dropout))
154
+ layers.append(cfg.activation())
155
+ assert cfg.loss_fn != F.triplet_margin_loss, (
156
+ "Triplet loss is only for embedding outputs."
157
+ )
158
+ layers.append((pyg.nn.Linear(E, 1), "x -> x"))
159
+ # Sigmoid activation at the end to get a probability.
160
+ layers.append((torch.nn.Sigmoid(), "x -> x"))
161
+ m.comparator = pyg.nn.Sequential("state, theorem", layers).to(device)
162
+ if encoder and cfg.predict in ["nodes", "links", "links_for_states"]:
163
+ m.encoder = encoder
164
+ return m