File size: 19,118 Bytes
a3ab6c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- # Understanding Graph Attention Networks (GAT) -->\n",
"<h1><center>Understanding Graph Attention Networks (GAT)</center></h1>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!--  -->\n",
"<img src=\"img/GAT_Cover.jpg\" width=700x/>\n",
"\n",
"This is 4th in the series of blogs <font color=\"green\">*Explained: Graph Representation Learning*</font>. Let's dive right in, assuming you have read the first three. GAT (Graph Attention Network), is a novel neural network architectures that operate on graph-structured data, leveraging masked self-attentional layers to address the shortcomings of prior methods based on graph convolutions or their approximations. By stacking layers in which nodes are able to attend over their neighborhoods’ features, the method enables (implicitly) specifying different weights to different nodes in a neighborhood, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront. In this way, GAT addresses several key challenges of spectral-based graph neural networks simultaneously, and make the model readily applicable to inductive as well as transductive problems.\n",
"\n",
"Analyzing and Visualizing the learned attentional weights also lead to a more interpretable model in terms of importance of neighbors."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But before getting into the meat of this method, I want you to be familiar and thorough with the Attention Mechanism, because we'll be building GATs on the concept of <b>Self Attention</b> and <b>Multi-Head Attention</b> introduced by <b><i>Vaswani et al.</i></b>\n",
"If not, you may read this blog, [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) by Jay Alamar.\n",
"\n",
"<hr/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2><center>Can we do better than GCNs?</center></h2>\n",
"<!-- ## Can we do better than GCNs? -->\n",
"\n",
"From Graph Convolutional Network (GCN), we learnt that combining local graph structure and node-level features yields good performance on node classification task. However, the way GCN aggregates messages is <b>structure-dependent</b>, which may hurt its generalizability.\n",
"\n",
"The fundamental novelty that GAT brings to the table is how the information from the one-hop neighborhood is aggregated. For GCN, a graph convolution operation produces the normalized sum of neighbors' node features as follows:\n",
"\n",
"$$h_i^{(l+1)}=\\sigma\\left(\\sum_{j\\in \\mathcal{N}(i)} {\\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\\right)$$\n",
"\n",
"where $\\mathcal{N}(i)$ is the set of its one-hop neighbors (to include $v_{i}$ in the set, we simply added a self-loop to each node), $c_{ij}=\\sqrt{|\\mathcal{N}(i)|}\\sqrt{|\\mathcal{N}(j)|}$ is a normalization constant based on graph structure, $\\sigma$ is an activation function (GCN uses ReLU), and $W^{l}$ is a shared weight matrix for node-wise feature transformation.\n",
"\n",
"GAT introduces the attention mechanism as a substitute for the statically normalized convolution operation. The figure below clearly illustrates the key difference.\n",
"\n",
"<img src=\"img/GCN_vs_GAT.jpg\" width=800x/>\n",
"\n",
"<hr/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- ## How does the attention work in GAT layer? -->\n",
"<h1><center>How does the GAT layer work?</center></h1>\n",
"\n",
"The particular attentional setup utilized by GAT closely follows the work of `Bahdanau et al. (2015)` i.e *Additive Attention*, but the framework is agnostic to the particular choice of attention mechanism.\n",
"\n",
"The input to the layer is a set of node features, $\\mathbf{h} = \\{\\vec{h}_1,\\vec{h}_2,...,\\vec{h}_N\\}, \\vec{h}_i ∈ \\mathbb{R}^{F}$ , where $N$ is the\n",
"number of nodes, and $F$ is the number of features in each node. The layer produces a new set of node\n",
"features (of potentially different cardinality $F'$ ), $\\mathbf{h} = \\{\\vec{h'}_1,\\vec{h'}_2,...,\\vec{h'}_N\\}, \\vec{h'}_i ∈ \\mathbb{R}^{F'}$, as its output.\n",
"\n",
"\n",
"<h3><font color=\"black\" >The Attentional Layer broken into 4 separate parts:</font></h3>\n",
"\n",
"<hr/>\n",
"\n",
"**1)** <font color=\"red\">**Simple linear transformation:**</font> In order to obtain sufficient expressive power to transform the input features into higher level features, atleast one learnable linear transformation is required. To that end, as an initial step, a shared linear transformation, parametrized by a weight matrix, $W ∈ \\mathbb{R}^{F′×F}$ , is applied to every node.\n",
"\n",
"$$\\begin{split}\\begin{align}\n",
"z_i^{(l)}&=W^{(l)}h_i^{(l)},&(1) \\\\\n",
"\\end{align}\\end{split}$$\n",
"\n",
"<div style=\"float: right\">\n",
" <img src=\"img/Attentional_Layer.jpg\" width=400x/>\n",
"</div>\n",
"\n",
"\n",
"<hr/>\n",
"\n",
"**2)** <font color=\"red\">**Attention Coefficients:**</font> We then compute a pair-wise <font color=\"blue\">**un-normalized**</font> attention score between two neighbors. Here, it first concatenates the $z$ embeddings of the two nodes, where $||$ denotes concatenation, then takes a dot product of it and a learnable weight vector $\\vec a^{(l)}$, and applies a LeakyReLU in the end. This form of attention is usually called additive attention, in contrast with the dot-product attention used for the Transformer model. We then perform self-attention on the nodes, a shared attentional mechanism $a$ : $\\mathbb{R}^{F′} × \\mathbb{R}^{F′} → \\mathbb{R}$ to compute attention coefficients \n",
"$$\\begin{split}\\begin{align}\n",
"e_{ij}^{(l)}&=\\text{LeakyReLU}(\\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),&(2)\\\\\n",
"\\end{align}\\end{split}$$\n",
"\n",
"**Q. Is this step the most important step?** \n",
"\n",
"**Ans.** Yes! This indicates the importance of node $j’s$ features to node $i$. This step allows every node to attend on every other node, dropping all structural information.\n",
"\n",
"**NOTE:** The graph structure is injected into the mechanism by performing <b>*masked attention*</b>, we only compute $e_{ij}$ for nodes $j$ ∈ $N_{i}$, where $N_{i}$ is some neighborhood of node $i$ in the graph. In all the experiments, these will be exactly the first-order neighbors of $i$ (including $i$).\n",
"\n",
"\n",
"<hr/>\n",
"\n",
"**3)** <font color=\"red\">**Softmax:**</font> This makes coefficients easily comparable across different nodes, we normalize them across all choices of $j$ using the softmax function\n",
"\n",
"$$\\begin{split}\\begin{align}\n",
"\\alpha_{ij}^{(l)}&=\\frac{\\exp(e_{ij}^{(l)})}{\\sum_{k\\in \\mathcal{N}(i)}^{}\\exp(e_{ik}^{(l)})},&(3)\\\\\n",
"\\end{align}\\end{split}$$\n",
"\n",
"\n",
"<hr/>\n",
"\n",
"**4)** <font color=\"red\">**Aggregation:**</font> This step is similar to GCN. The embeddings from neighbors are aggregated together, scaled by the attention scores. \n",
"\n",
"$$\\begin{split}\\begin{align}\n",
"h_i^{(l+1)}&=\\sigma\\left(\\sum_{j\\in \\mathcal{N}(i)} {\\alpha^{(l)}_{ij} z^{(l)}_j }\\right),&(4)\n",
"\\end{align}\\end{split}$$\n",
"\n",
"<hr/>\n",
"\n",
"<!-- {:style=\"float: right;margin-right: 7px;margin-top: 7px;\"} -->\n",
"\n",
"\n",
"<!--  -->"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- ### Multi-head Attention -->\n",
"<h2><center>Multi-head Attention</center></h2>\n",
"\n",
"<p>\n",
"<img src=\"img/MultiHead_Attention.jpeg\" width=500x/>\n",
"*An illustration of multi-head attention (with K = 3 heads) by node 1 on its neighborhood. Different arrow styles and colors denote independent attention computations. The aggregated features from each head are concatenated or averaged to obtain $\\vec{h'}_{1}$.*\n",
"</p>\n",
"\n",
"Analogous to multiple channels in ConvNet, GAT introduces multi-head attention to enrich the model capacity and to stabilize the learning process. Specifically, K independent attention mechanisms execute the transformation of Equation 4, and then their outputs can be combined in 2 ways depending on the use:\n",
"\n",
"* <b>Concatenation</b>\n",
" \n",
" $$\\textbf{Concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\\sigma\\left(\\sum_{j\\in \\mathcal{N}(i)}\\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\\right)$$\n",
" \n",
" * As can be seen in this settingthe final returned output, $h′$, will consist of $KF′$ features (rather than F′) for each node.\n",
"\n",
"\n",
"* <b>Averaging</b>\n",
" * If we perform multi-head attention on the final (prediction) layer of the network, concatenation is no longer sensible and instead, averaging is employed, and delay applying the final nonlinearity (usually a softmax or logistic sigmoid for classification problems) until then:\n",
" \n",
" $$\\textbf{Average}: h_{i}^{(l+1)}=\\sigma\\left(\\frac{1}{K}\\sum_{k=1}^{K}\\sum_{j\\in\\mathcal{N}(i)}\\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\\right)$$\n",
" \n",
"Thus concatenation for intermediary layers and average for the final layer are used."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- ## Implementing GAT Layer in PyTorch -->\n",
"<h1><center>Implementing GAT Layer in PyTorch</center></h1>\n",
"\n",
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x1108d6810>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"torch.manual_seed(2020) # seed for reproducible numbers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class GATLayer(nn.Module):\n",
" \"\"\"\n",
" Simple PyTorch Implementation of the Graph Attention layer.\n",
" \"\"\"\n",
"\n",
" def __init__(self, in_features, out_features, dropout, alpha, concat=True):\n",
" super(GATLayer, self).__init__()\n",
" self.dropout = dropout # drop prob = 0.6\n",
" self.in_features = in_features # \n",
" self.out_features = out_features # \n",
" self.alpha = alpha # LeakyReLU with negative input slope, alpha = 0.2\n",
" self.concat = concat # conacat = True for all layers except the output layer.\n",
"\n",
" # Xavier Initialization of Weights\n",
" # Alternatively use weights_init to apply weights of choice \n",
" self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))\n",
" nn.init.xavier_uniform_(self.W.data, gain=1.414)\n",
" self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))\n",
" nn.init.xavier_uniform_(self.a.data, gain=1.414)\n",
" \n",
" # LeakyReLU\n",
" self.leakyrelu = nn.LeakyReLU(self.alpha)\n",
"\n",
" def forward(self, input, adj):\n",
" # Linear Transformation\n",
" h = torch.mm(input, self.W)\n",
" N = h.size()[0]\n",
"\n",
" # Attention Mechanism\n",
" a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)\n",
" e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))\n",
"\n",
" # Masked Attention\n",
" zero_vec = -9e15*torch.ones_like(e)\n",
" attention = torch.where(adj > 0, e, zero_vec)\n",
" \n",
" attention = F.softmax(attention, dim=1)\n",
" attention = F.dropout(attention, self.dropout, training=self.training)\n",
" h_prime = torch.matmul(attention, h)\n",
"\n",
" if self.concat:\n",
" return F.elu(h_prime)\n",
" else:\n",
" return h_prime"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# # Alternate approach to applying weights of choice using weights_init()\n",
"# def weights_init(m):\n",
"# if isinstance(m, nn.Linear):\n",
"# torch.nn.init.xavier_uniform_(m.weight)\n",
"\n",
"# # Applying just after calling the model class\n",
"# model.apply(weights_init)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- ## Implementing GAT on Citation Datasets using PyTorch Geometric -->\n",
"<h1><center>Implementing GAT on Citation Datasets using PyTorch Geometric</center></h1>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PyG Imports"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from torch_geometric.data import Data\n",
"from torch_geometric.nn import GATConv\n",
"from torch_geometric.datasets import Planetoid\n",
"import torch_geometric.transforms as T\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib notebook\n",
"\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of Classes in Cora: 7\n",
"Number of Node Features in Cora: 1433\n"
]
}
],
"source": [
"name_data = 'Cora'\n",
"dataset = Planetoid(root= '/tmp/' + name_data, name = name_data)\n",
"dataset.transform = T.NormalizeFeatures()\n",
"\n",
"print(f\"Number of Classes in {name_data}:\", dataset.num_classes)\n",
"print(f\"Number of Node Features in {name_data}:\", dataset.num_node_features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class GAT(torch.nn.Module):\n",
" def __init__(self):\n",
" super(GAT, self).__init__()\n",
" self.hid = 8\n",
" self.in_head = 8\n",
" self.out_head = 1\n",
" \n",
" self.conv1 = GATConv(dataset.num_features, self.hid, heads=self.in_head, dropout=0.6)\n",
" self.conv2 = GATConv(self.hid*self.in_head, dataset.num_classes, concat=False,\n",
" heads=self.out_head, dropout=0.6)\n",
"\n",
" def forward(self, data):\n",
" x, edge_index = data.x, data.edge_index\n",
" \n",
" # Dropout before the GAT layer is used to avoid overfitting in small datasets like Cora.\n",
" # One can skip them if the dataset is sufficiently large.\n",
" \n",
" x = F.dropout(x, p=0.6, training=self.training)\n",
" x = self.conv1(x, edge_index)\n",
" x = F.elu(x)\n",
" x = F.dropout(x, p=0.6, training=self.training)\n",
" x = self.conv2(x, edge_index)\n",
" \n",
" return F.log_softmax(x, dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(1.9467, grad_fn=<NllLossBackward>)\n",
"tensor(0.6551, grad_fn=<NllLossBackward>)\n",
"tensor(0.5155, grad_fn=<NllLossBackward>)\n",
"tensor(0.6176, grad_fn=<NllLossBackward>)\n",
"tensor(0.6120, grad_fn=<NllLossBackward>)\n"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"model = GAT().to(device)\n",
"\n",
"data = dataset[0].to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n",
"\n",
"model.train()\n",
"for epoch in range(1000):\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" out = model(data)\n",
" loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n",
" \n",
" if epoch%200 == 0:\n",
" print(loss)\n",
" \n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8200\n"
]
}
],
"source": [
"model.eval()\n",
"_, pred = model(data).max(dim=1)\n",
"correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())\n",
"acc = correct / data.test_mask.sum().item()\n",
"print('Accuracy: {:.4f}'.format(acc))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
"[Graph Attention Networks](https://arxiv.org/abs/1710.10903)\n",
"\n",
"[Graph attention network, DGL by Zhang et al.](https://docs.dgl.ai/tutorials/models/1_gnn/9_gat.html)\n",
"\n",
"[Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)\n",
"\n",
"[The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)\n",
"\n",
"[Mechanics of Seq2seq Models With Attention](https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/)\n",
"\n",
"[Attention? Attention!](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|