MLR-Copilot / benchmarks /CLRS /env /baseline_model_description.txt
Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
23.2 kB
The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model:
Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In
a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of
the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes
in the GNN that will execute the algorithm.
A sample of every algorithm is represented as a graph, with each input, output and hint located in
either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension,
and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of
the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar,
categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and
loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and
optimised using mean squared error.
Base Model
Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS
benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder
fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as
high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the
same dimension and are added together; the same happens with hints and inputs located in edges,
and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph
3
A Generalist Neural Algorithmic Learner
embeddings. Thus, at the
step for a time-step t of the algorithm, we have a
n end of the encoding
o
(t) (t)
(t)
single set of embeddings xi , eij , g
, shapes n × h, n × n × h, and h, in the nodes, edges and
graph, respectively. Note that this is independent of the number and type of the inputs and hints of
the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS.
Further, note that at each step, the input encoding is fed directly to these embeddings—this recall
mechanism significantly improves the model’s robustness over long trajectories [34].
Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed
(t)
node embeddings, hi . Additionally, the processor uses the processed node embeddings from the
(t−1)
previous step, hi
, as inputs. Importantly, the same processor model can operate on graphs of any
size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and
passing messages over a fully-connected graph, as our base model. The MPNN computes processed
embeddings as follows:




(t)
(t−1)
(t)
(t) (t) (t)
(t)
(t)
(t)
z(t) = xi khi
mi = max fm zi , zj , eij , g(t)
hi = fr zi , mi
(1)
1≤j≤n
starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the
message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh →
Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max
aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the
neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations
(t)
where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before
using them further. Further details on the MPNN processor may be found in Veličković et al. [5].
Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict
the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder
relies mainly on a linear decoder for each hint and output, along with a mechanism to compute
pairwise node similarities when appropriate. Specifically, the pointer type decoder computes
a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the
argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used).
Loss. The decoded hints and outputs are used to compute the loss during training, according to their
type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time,
and the output loss is averaged across outputs (most algorithms have a single output, though some
have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at
each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing
is used (see Section 3.2.1).
We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution
samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution
evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only
these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can
be found in Appendix A.
3.2
Model improvements
As previously discussed, single-task improvements, especially in terms of learning stability, will
empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner,
all the changes made to the model, which have lead to an absolute improvement of over 20% on
average across all 30 tasks in CLRS.
3.2.1
Dataset and training
Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints
in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes
advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint
values instead of the network’s own predictions. In the prior model [5], ground-truth hints were
4
A Generalist Neural Algorithmic Learner
provided during training with probability 0.5, as, without teacher forcing, losses tended to grow
unbounded along a trajectory when scalar hints were present, destabilising the training. In this
work we incorporate several significant stabilising changes (described in future paragraphs), which
allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the
network becoming overconfident in always expecting correct hint predictions. With teacher forcing,
performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String
Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9).
Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed
CLRS training dataset [5], we augmented the training data in three key ways, without breaking
the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new
training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly,
we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate
for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph
algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi
model [38]); and for string matching algorithms, we varied the length of the pattern to be matched.
These both serve to expose the model to different trajectory lengths; for example, in many graph
algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and
varying the connection probability in the graph generation allows for varying the expected diameter.
These changes considerably increase training data variability, compared to the original dataset in
Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process
in Appendix A.
Soft hint propagation. When predicted hints are fed back as inputs during training, gradients
may or may not be allowed to flow through them. In previous work, only hints of the scalar type
allowed gradients through, as all categoricals were post-processed from logits into the ground-truth
format via argmax or thresholding before being fed back. Instead, in this work we use softmax
for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without
these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing),
as well as in Naïve String Matcher (Appendix A, Figs. 7-9).
Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common
to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of
this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the
model can overfit to the fixed training values. We therefore turned this fixed hint into an input for
these 11 algorithms, eliminating the need for explicitly predicting it.
Improving training stability with encoder initialisation and gradient clipping. The scalar
hints have unbounded values, in principle, and are optimised using mean-squared error, hence their
gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then
get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to
exploding signals (and consequently gradients), even before any training takes place.
To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for
scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun
initialisation [46] elsewhere. This combination of initialisations proved important for the initial
learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw
drastic improvements in learning stability, as well as significant increases in validation performance,
with gradient clipping [47], which we subsequently employed in all experiments.
3.2.2
Encoders and decoders
Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar
input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node
index. To avoid overfitting to these linearly spaced values during training, we replaced them with
random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly
spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to
3
Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity
Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44].
5
A Generalist Neural Algorithmic Learner
these positions, such as string matching. Namely, the model could learn to base all of its computations
on the assumption that it will always be finding a m-character pattern inside an n-character string,
even though at test time, m and n will increase fourfold.
Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble
Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS
benchmark, this permutation is encoded as a pointer where each node points to its predecessor in
the sorted order (the first node points to itself); this is represented as a n × n matrix P where each
row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of
pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained
decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does
not explicitly take advantage of the fact that the pointers encode a permutation, which the model
has to learn instead. Our early experiments showed that the model was often failing to predict valid
permutations OOD.
Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as
follows. First, we modify the output representation by rewiring the first node to point to the last one,
turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We
also augment the representation with a one-hot vector of size n that specifies the first node, so we do
not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the
permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax
with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly
stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating
and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by:
S 0 (Y) = exp(Y)
S l (Y) = Tc (Tr (S l−1 (Y)))
S(Y) = lim S l (Y),
l→∞
(2)
where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively.
Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix,
we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking
P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a
permutation matrix [52, Theorem 1].
In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a
smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and
lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good
balance between speed of convergence and tie-breaking. We also encode the fact that no node points
to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to
−∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to
applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P,
and mask_one pointing to the first element, into the original pointer representation used by CLRS.
3.2.3
Processor networks
Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping
the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it
updates all hidden states in each step. Although it is theoretically possible for the network to keep the
states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness
in NDRs [54], we augment the network with an update gate, biased to be closed by default. We
found that the gate stabilizes learning on many of the tasks, and increases the mean performance
over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be
advantageous in the multi-task case.
To add gating to the MPNN model we produce a per-node gating vector from the same inputs that
process the embeddings in Equation 1:


(t)
(t)
(t)
gi = fg zi , mi
(3)
where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with
ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the
final layer bias of fg is initialized to a value of −3, which biases the network for not updating its
6
A Generalist Neural Algorithmic Learner
Our model
Previous SOTA [5]
80
60
40
Quickselect
Heapsort
Knuth-Morris-Pratt
Strongly Conn. Comps.
DFS
Floyd-Warshall
Quicksort
Bubble Sort
Optimal BST
Find Max. Subarray
Insertion Sort
Binary Search
LCS Length
Naïve String Matcher
MST Prim
Topological Sort
Task Scheduling
MST Kruskal
Articulation Points
Jarvis' March
Matrix Chain Order
Bridges
Graham Scan
Dijkstra
Activity Selector
Bellman-Ford
DAG Shortest Paths
Segments Intersect
0
BFS
20
Minimum
Average score [%]
100
Figure 2: The OOD performance in single-task experiments before and after the improvements
presented in this paper, sorted in descending order of current performance. Error bars represent
standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current).
The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2).
b (t) , are computed as follows:
representations, unless necessary. The processed gated embeddings, h
i
b (t) = g(t)
h
i
i
and are used instead of
(t)
hi
(t)
(t)
hi + (1 − gi )
in the subsequent steps, replacing z
(t−1)
hi
(t)
(4)
in Eq. 1 by z
(t)
=
(t) b (t−1)
xi kh
.
i
Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning—
where edges store values, and update them based on other edges’ values. An example of this is the
Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The
update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which
roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then
from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic
programming. Even though there are no node representations in the above update, all our processors
are centered on passing messages between node representations hi .
To rectify this situation, we augment our processor to perform message passing towards edges.
Referring again to the update for dij , we note that the edge representations are updated by choosing
an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations
over triplets of nodes, then reducing over one node to obtain edge latents:
tijk = ψt (hi , hj , hk , eij , eik , ekj , g)
hij = φt (max tijk )
(5)
k
Here, ψt is a triplet message function, mapping all relevant representations to a single vector for
each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets
for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the
max aggregation to obtain edge representations. The computed hij vectors can then be used in any
edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks
where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree
algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily
sort the edges by weight, as it selects how to augment the spanning forest at each step.
In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only
8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions,
to make them compatible with the rest of the architecture. Our initial experimentation demonstrated
that the output dimensionality of ψt did not significantly affect downstream performance. Note that
computing triplet representations has been a useful approach in general GNN design [57]—however,
it has predominantly been studied in the context of GNNs over constant input features. Our study is
among the first to verify their utility over reasoning tasks with well-specified initial features.
3.3
Results
By incorporating the changes described in the previous sections we arrived at a single model type,
with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance
7
A Generalist Neural Algorithmic Learner
Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our
best model Triplet-GMPNN with all our improvements, after 10,000 training steps.
Alg. Type
Memnet [5]
MPNN [5]
PGN [5]
Triplet-GMPNN (ours)
Div. & C.
DP
Geometry
Graphs
Greedy
Search
Sorting
Strings
13.05% ± 0.14
67.94% ± 8.20
45.14% ± 11.95
24.12% ± 5.30
53.42% ± 20.82
34.35% ± 21.67
71.53% ± 1.41
1.51% ± 0.46
20.30% ± 0.85
65.10% ± 6.44
73.11% ± 17.19
62.79% ± 8.75
82.39% ± 3.01
41.20% ± 19.87
11.83% ± 2.78
3.21% ± 0.94
65.23% ± 4.44
70.58% ± 6.48
61.19% ± 7.01
60.25% ± 8.42
75.84% ± 6.59
56.11% ± 21.56
15.45% ± 8.46
2.04% ± 0.20
76.36% ± 1.34
81.99% ± 4.98
94.09% ± 2.30
81.41% ± 6.21
91.21% ± 2.95
58.61% ± 24.34
60.37% ± 12.16
49.09% ± 23.49
38.88%
44.99%
50.84%
74.14%
0/30
3/30
10/30
6/30
9/30
14/30
3/30
7/30
15/30
11/30
17/30
24/30
Overall avg.
> 90%
> 80%
> 60%
on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as
Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test
set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines
include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5].
Figure 2 displays the comparison between the improved model and the best model from Veličković
et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher
(in absolute terms) compared to the next best model (see Table 1), and to a significant performance
improvement in all but one algorithm family, compared to every other model. Further, our stabilising
changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient
updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We
finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to
the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%.
There are two notable examples of algorithm families with significant OOD performance improvement.
The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now
solved at approximately 94% OOD, compared to the previous best of about 73%; the second being
string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds
49% compared to the previous best of approximately 3%.
The significant overall performance boost is reflected in the increased number of algorithms we can
now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically,
we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9
previously) and 90% for 11 algorithms (6 previously).