Spaces:
Running
Running
The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model: | |
Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In | |
a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of | |
the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes | |
in the GNN that will execute the algorithm. | |
A sample of every algorithm is represented as a graph, with each input, output and hint located in | |
either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension, | |
and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of | |
the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar, | |
categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and | |
loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and | |
optimised using mean squared error. | |
Base Model | |
Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS | |
benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder | |
fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as | |
high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the | |
same dimension and are added together; the same happens with hints and inputs located in edges, | |
and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph | |
3 | |
A Generalist Neural Algorithmic Learner | |
embeddings. Thus, at the | |
step for a time-step t of the algorithm, we have a | |
n end of the encoding | |
o | |
(t) (t) | |
(t) | |
single set of embeddings xi , eij , g | |
, shapes n × h, n × n × h, and h, in the nodes, edges and | |
graph, respectively. Note that this is independent of the number and type of the inputs and hints of | |
the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS. | |
Further, note that at each step, the input encoding is fed directly to these embeddings—this recall | |
mechanism significantly improves the model’s robustness over long trajectories [34]. | |
Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed | |
(t) | |
node embeddings, hi . Additionally, the processor uses the processed node embeddings from the | |
(t−1) | |
previous step, hi | |
, as inputs. Importantly, the same processor model can operate on graphs of any | |
size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and | |
passing messages over a fully-connected graph, as our base model. The MPNN computes processed | |
embeddings as follows: | |
(t) | |
(t−1) | |
(t) | |
(t) (t) (t) | |
(t) | |
(t) | |
(t) | |
z(t) = xi khi | |
mi = max fm zi , zj , eij , g(t) | |
hi = fr zi , mi | |
(1) | |
1≤j≤n | |
starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the | |
message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh → | |
Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max | |
aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the | |
neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations | |
(t) | |
where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before | |
using them further. Further details on the MPNN processor may be found in Veličković et al. [5]. | |
Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict | |
the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder | |
relies mainly on a linear decoder for each hint and output, along with a mechanism to compute | |
pairwise node similarities when appropriate. Specifically, the pointer type decoder computes | |
a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the | |
argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used). | |
Loss. The decoded hints and outputs are used to compute the loss during training, according to their | |
type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time, | |
and the output loss is averaged across outputs (most algorithms have a single output, though some | |
have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at | |
each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing | |
is used (see Section 3.2.1). | |
We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution | |
samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution | |
evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only | |
these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can | |
be found in Appendix A. | |
3.2 | |
Model improvements | |
As previously discussed, single-task improvements, especially in terms of learning stability, will | |
empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner, | |
all the changes made to the model, which have lead to an absolute improvement of over 20% on | |
average across all 30 tasks in CLRS. | |
3.2.1 | |
Dataset and training | |
Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints | |
in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes | |
advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint | |
values instead of the network’s own predictions. In the prior model [5], ground-truth hints were | |
4 | |
A Generalist Neural Algorithmic Learner | |
provided during training with probability 0.5, as, without teacher forcing, losses tended to grow | |
unbounded along a trajectory when scalar hints were present, destabilising the training. In this | |
work we incorporate several significant stabilising changes (described in future paragraphs), which | |
allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the | |
network becoming overconfident in always expecting correct hint predictions. With teacher forcing, | |
performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String | |
Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9). | |
Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed | |
CLRS training dataset [5], we augmented the training data in three key ways, without breaking | |
the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new | |
training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly, | |
we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate | |
for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph | |
algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi | |
model [38]); and for string matching algorithms, we varied the length of the pattern to be matched. | |
These both serve to expose the model to different trajectory lengths; for example, in many graph | |
algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and | |
varying the connection probability in the graph generation allows for varying the expected diameter. | |
These changes considerably increase training data variability, compared to the original dataset in | |
Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process | |
in Appendix A. | |
Soft hint propagation. When predicted hints are fed back as inputs during training, gradients | |
may or may not be allowed to flow through them. In previous work, only hints of the scalar type | |
allowed gradients through, as all categoricals were post-processed from logits into the ground-truth | |
format via argmax or thresholding before being fed back. Instead, in this work we use softmax | |
for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without | |
these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing), | |
as well as in Naïve String Matcher (Appendix A, Figs. 7-9). | |
Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common | |
to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of | |
this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the | |
model can overfit to the fixed training values. We therefore turned this fixed hint into an input for | |
these 11 algorithms, eliminating the need for explicitly predicting it. | |
Improving training stability with encoder initialisation and gradient clipping. The scalar | |
hints have unbounded values, in principle, and are optimised using mean-squared error, hence their | |
gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then | |
get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to | |
exploding signals (and consequently gradients), even before any training takes place. | |
To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for | |
scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun | |
initialisation [46] elsewhere. This combination of initialisations proved important for the initial | |
learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw | |
drastic improvements in learning stability, as well as significant increases in validation performance, | |
with gradient clipping [47], which we subsequently employed in all experiments. | |
3.2.2 | |
Encoders and decoders | |
Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar | |
input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node | |
index. To avoid overfitting to these linearly spaced values during training, we replaced them with | |
random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly | |
spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to | |
3 | |
Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity | |
Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44]. | |
5 | |
A Generalist Neural Algorithmic Learner | |
these positions, such as string matching. Namely, the model could learn to base all of its computations | |
on the assumption that it will always be finding a m-character pattern inside an n-character string, | |
even though at test time, m and n will increase fourfold. | |
Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble | |
Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS | |
benchmark, this permutation is encoded as a pointer where each node points to its predecessor in | |
the sorted order (the first node points to itself); this is represented as a n × n matrix P where each | |
row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of | |
pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained | |
decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does | |
not explicitly take advantage of the fact that the pointers encode a permutation, which the model | |
has to learn instead. Our early experiments showed that the model was often failing to predict valid | |
permutations OOD. | |
Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as | |
follows. First, we modify the output representation by rewiring the first node to point to the last one, | |
turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We | |
also augment the representation with a one-hot vector of size n that specifies the first node, so we do | |
not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the | |
permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax | |
with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly | |
stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating | |
and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by: | |
S 0 (Y) = exp(Y) | |
S l (Y) = Tc (Tr (S l−1 (Y))) | |
S(Y) = lim S l (Y), | |
l→∞ | |
(2) | |
where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively. | |
Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix, | |
we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking | |
P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a | |
permutation matrix [52, Theorem 1]. | |
In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a | |
smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and | |
lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good | |
balance between speed of convergence and tie-breaking. We also encode the fact that no node points | |
to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to | |
−∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to | |
applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P, | |
and mask_one pointing to the first element, into the original pointer representation used by CLRS. | |
3.2.3 | |
Processor networks | |
Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping | |
the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it | |
updates all hidden states in each step. Although it is theoretically possible for the network to keep the | |
states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness | |
in NDRs [54], we augment the network with an update gate, biased to be closed by default. We | |
found that the gate stabilizes learning on many of the tasks, and increases the mean performance | |
over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be | |
advantageous in the multi-task case. | |
To add gating to the MPNN model we produce a per-node gating vector from the same inputs that | |
process the embeddings in Equation 1: | |
(t) | |
(t) | |
(t) | |
gi = fg zi , mi | |
(3) | |
where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with | |
ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the | |
final layer bias of fg is initialized to a value of −3, which biases the network for not updating its | |
6 | |
A Generalist Neural Algorithmic Learner | |
Our model | |
Previous SOTA [5] | |
80 | |
60 | |
40 | |
Quickselect | |
Heapsort | |
Knuth-Morris-Pratt | |
Strongly Conn. Comps. | |
DFS | |
Floyd-Warshall | |
Quicksort | |
Bubble Sort | |
Optimal BST | |
Find Max. Subarray | |
Insertion Sort | |
Binary Search | |
LCS Length | |
Naïve String Matcher | |
MST Prim | |
Topological Sort | |
Task Scheduling | |
MST Kruskal | |
Articulation Points | |
Jarvis' March | |
Matrix Chain Order | |
Bridges | |
Graham Scan | |
Dijkstra | |
Activity Selector | |
Bellman-Ford | |
DAG Shortest Paths | |
Segments Intersect | |
0 | |
BFS | |
20 | |
Minimum | |
Average score [%] | |
100 | |
Figure 2: The OOD performance in single-task experiments before and after the improvements | |
presented in this paper, sorted in descending order of current performance. Error bars represent | |
standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current). | |
The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2). | |
b (t) , are computed as follows: | |
representations, unless necessary. The processed gated embeddings, h | |
i | |
b (t) = g(t) | |
h | |
i | |
i | |
and are used instead of | |
(t) | |
hi | |
(t) | |
(t) | |
hi + (1 − gi ) | |
in the subsequent steps, replacing z | |
(t−1) | |
hi | |
(t) | |
(4) | |
in Eq. 1 by z | |
(t) | |
= | |
(t) b (t−1) | |
xi kh | |
. | |
i | |
Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning— | |
where edges store values, and update them based on other edges’ values. An example of this is the | |
Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The | |
update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which | |
roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then | |
from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic | |
programming. Even though there are no node representations in the above update, all our processors | |
are centered on passing messages between node representations hi . | |
To rectify this situation, we augment our processor to perform message passing towards edges. | |
Referring again to the update for dij , we note that the edge representations are updated by choosing | |
an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations | |
over triplets of nodes, then reducing over one node to obtain edge latents: | |
tijk = ψt (hi , hj , hk , eij , eik , ekj , g) | |
hij = φt (max tijk ) | |
(5) | |
k | |
Here, ψt is a triplet message function, mapping all relevant representations to a single vector for | |
each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets | |
for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the | |
max aggregation to obtain edge representations. The computed hij vectors can then be used in any | |
edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks | |
where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree | |
algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily | |
sort the edges by weight, as it selects how to augment the spanning forest at each step. | |
In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only | |
8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions, | |
to make them compatible with the rest of the architecture. Our initial experimentation demonstrated | |
that the output dimensionality of ψt did not significantly affect downstream performance. Note that | |
computing triplet representations has been a useful approach in general GNN design [57]—however, | |
it has predominantly been studied in the context of GNNs over constant input features. Our study is | |
among the first to verify their utility over reasoning tasks with well-specified initial features. | |
3.3 | |
Results | |
By incorporating the changes described in the previous sections we arrived at a single model type, | |
with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance | |
7 | |
A Generalist Neural Algorithmic Learner | |
Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our | |
best model Triplet-GMPNN with all our improvements, after 10,000 training steps. | |
Alg. Type | |
Memnet [5] | |
MPNN [5] | |
PGN [5] | |
Triplet-GMPNN (ours) | |
Div. & C. | |
DP | |
Geometry | |
Graphs | |
Greedy | |
Search | |
Sorting | |
Strings | |
13.05% ± 0.14 | |
67.94% ± 8.20 | |
45.14% ± 11.95 | |
24.12% ± 5.30 | |
53.42% ± 20.82 | |
34.35% ± 21.67 | |
71.53% ± 1.41 | |
1.51% ± 0.46 | |
20.30% ± 0.85 | |
65.10% ± 6.44 | |
73.11% ± 17.19 | |
62.79% ± 8.75 | |
82.39% ± 3.01 | |
41.20% ± 19.87 | |
11.83% ± 2.78 | |
3.21% ± 0.94 | |
65.23% ± 4.44 | |
70.58% ± 6.48 | |
61.19% ± 7.01 | |
60.25% ± 8.42 | |
75.84% ± 6.59 | |
56.11% ± 21.56 | |
15.45% ± 8.46 | |
2.04% ± 0.20 | |
76.36% ± 1.34 | |
81.99% ± 4.98 | |
94.09% ± 2.30 | |
81.41% ± 6.21 | |
91.21% ± 2.95 | |
58.61% ± 24.34 | |
60.37% ± 12.16 | |
49.09% ± 23.49 | |
38.88% | |
44.99% | |
50.84% | |
74.14% | |
0/30 | |
3/30 | |
10/30 | |
6/30 | |
9/30 | |
14/30 | |
3/30 | |
7/30 | |
15/30 | |
11/30 | |
17/30 | |
24/30 | |
Overall avg. | |
> 90% | |
> 80% | |
> 60% | |
on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as | |
Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test | |
set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines | |
include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5]. | |
Figure 2 displays the comparison between the improved model and the best model from Veličković | |
et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher | |
(in absolute terms) compared to the next best model (see Table 1), and to a significant performance | |
improvement in all but one algorithm family, compared to every other model. Further, our stabilising | |
changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient | |
updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We | |
finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to | |
the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%. | |
There are two notable examples of algorithm families with significant OOD performance improvement. | |
The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now | |
solved at approximately 94% OOD, compared to the previous best of about 73%; the second being | |
string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds | |
49% compared to the previous best of approximately 3%. | |
The significant overall performance boost is reflected in the increased number of algorithms we can | |
now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically, | |
we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9 | |
previously) and 90% for 11 algorithms (6 previously). | |