File size: 21,780 Bytes
ca9f11d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 |
'''
A Decoder-Only Transformer components
-> Word Embedding
-> Position Encoding
-> Masked Self-Attention
-> Residual Connections
-> A fully connected layer
-> Classification Head
'''
import torch
import torch.nn as nn
import math
class WordEmbeddings(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
## d_model: The dimension of the transformer, which is also the number of embedding values per token.
## vocab_size: Get the size of the underlying vocabulary
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(num_embeddings=vocab_size,
embedding_dim=d_model)
def forward(self, x):
# (batch, seq_len) --> (batch, seq_len, d_model)
# multiply by sqrt(d_model) to scale the embeddings
return self.embedding(x) * math.sqrt(self.d_model)
class PositionEncoding(nn.Module):
'''Ref: https://github.com/StatQuest/decoder_transformer_from_scratch/blob/main/decoder_transformers_with_pytorch_and_lightning_v2.ipynb
'''
def __init__(self, d_model: int, seq_len: int, dropout: float):
## d_model = The dimension of the transformer, which is also the number of embedding values per token.
## In the transformer I used in the StatQuest: Transformer Neural Networks Clearly Explained!!!
## d_model=2, so that's what we'll use as a default for now.
## However, in "Attention Is All You Need" d_model=512
## seq_len = maximum number of tokens we allow as input.
## Since we are precomputing the position encoding values and storing them in a lookup table
## we can use d_model and seq_len to determine the number of rows and columns in that
## lookup table.
##
## In this simple example, we are only using short phrases, so we are using
## seq_len=6 as the default setting.
## However, in The Annotated Transformer, they set the default value for seq_len to 5000
## We call the super's init because by creating our own __init__() method, we overwrite the one
## we inherited from nn.Module. So we have to explicity call nn.Module's __init__(), otherwise it
## won't get initialized. NOTE: If we didn't write our own __init__(), then we would not have
## to call super().__init__(). Alternatively, if we didn't want to access any of nn.Module's methods,
## we wouldn't have to call it then either.
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
## Now we create a lookup table, pe, of position encoding values and initialize all of them to 0.
## To do this, we will make a matrix of 0s that has seq_len rows and d_model columns.
## for example...
## torch.zeros(3, 2)
## ...returns a matrix of 0s with 3 rows and 2 columns...
## tensor([[0., 0.],
## [0., 0.],
## [0., 0.]])
pe = torch.zeros(seq_len, d_model)
## Now we create a sequence of numbers for each position that a token can have in the input (or output).
## For example, if the input tokens where "I'm happy today!", then "I'm" would get the first
## position, 0, "happy" would get the second position, 1, and "today!" would get the third position, 2.
## NOTE: Since we are going to be doing math with these position indices to create the
## positional encoding for each one, we need them to be floats rather than ints.
##
## NOTE: Two ways to create floats are...
##
## torch.arange(start=0, end=3, step=1, dtype=torch.float)
##
## ...and...
##
## torch.arange(start=0, end=3, step=1).float()
##
## ...but the latter is just as clear and requires less typing.
##
## Lastly, .unsqueeze(1) converts the single list of numbers that torch.arange creates into a matrix with
## one row for each index, and all of the indices in a single column. So if "seq_len" = 3, then we
## would create a matrix with 3 rows and 1 column like this...
##
## torch.arange(start=0, end=3, step=1, dtype=torch.float).unsqueeze(1)
##
## ...returns...
##
## tensor([[0.],
## [1.],
## [2.]])
position = torch.arange(start=0, end=seq_len, step=1).float().unsqueeze(1)
## Here is where we start doing the math to determine the y-axis coordinates on the
## sine and cosine curves.
##
## The positional encoding equations used in "Attention is all you need" are...
##
## PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
## PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
##
## ...and we see, within the sin() and cos() functions, we divide "pos" by some number that depends
## on the index (i) and total number of PE values we want per token (d_model).
##
## NOTE: When the index, i, is 0 then we are calculating the y-axis coordinates on the **first pair**
## of sine and cosine curves. When i=1, then we are calculating the y-axis coordiantes on the
## **second pair** of sine and cosine curves. etc. etc.
##
## Now, pretty much everyone calculates the term we use to divide "pos" by first, and they do it with
## code that looks like this...
##
## div_term = torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(math.log(10000.0) / d_model))
##
## Now, at least to me, it's not obvious that div_term = 1/(10000^(2i/d_model)) for a few reasons:
##
## 1) div_term wraps everything in a call to torch.exp()
## 2) It uses log()
## 2) The order of the terms is different
##
## The reason for these differences is, presumably, trying to prevent underflow (getting too close to 0).
## So, to show that div_term = 1/(10000^(2i/d_model))...
##
## 1) Swap out math.log() for torch.log() (doing this requires converting 10000.0 to a tensor, which is my
## guess for why they used math.log() instead of torch.log())...
##
## torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
##
## 2) Rearrange the terms...
##
## torch.exp(-1 * (torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model))
##
## 3) Pull out the -1 with exp(-1 * x) = 1/exp(x)
##
## 1/torch.exp(torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model)
##
## 4) Use exp(a * b) = exp(a)^b to pull out the 2i/d_model term...
##
## 1/torch.exp(torch.log(torch.tensor(10000.0)))^(torch.arange(start=0, end=d_model, step=2).float() / d_model)
##
## 5) Use exp(log(x)) = x to get the original form of the denominator...
##
## 1/(torch.tensor(10000.0)^(torch.arange(start=0, end=d_model, step=2).float() / d_model))
##
## 6) Bam.
##
## So, that being said, I don't think underflow is actually that big an issue. In fact, some coder at Hugging Face
## also doesn't think so, and their code for positional encoding in DistilBERT (a streamlined version of BERT, which
## is a transformer model)
## calculates the values directly - using the form of the equation found in original Attention is all you need
## manuscript. See...
## https://github.com/huggingface/transformers/blob/455c6390938a5c737fa63e78396cedae41e4e87e/src/transformers/modeling_distilbert.py#L53
## So I think we can simplify the code, but I'm also writing all these comments to show that it is equivalent to what
## you'll see in the wild...
##
## Now let's create an index for the embedding positions to simplify the code a little more...
embedding_index = torch.arange(start=0, end=d_model, step=2).float()
## NOTE: Setting step=2 results in the same sequence numbers that we would get if we multiplied i by 2.
## So we can save ourselves a little math by just setting step=2.
## And now, finally, let's create div_term...
div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)
## Now we calculate the actual positional encoding values. Remember 'pe' was initialized as a matrix of 0s
## with seq_len (max number of input tokens) rows and d_model (number of embedding values per token) columns.
pe[:, 0::2] = torch.sin(position * div_term) ## every other column, starting with the 1st, has sin() values
pe[:, 1::2] = torch.cos(position * div_term) ## every other column, starting with the 2nd, has cos() values
## NOTE: If the notation for indexing 'pe[]' looks cryptic to you, read on...
##
## First, let's look at the general indexing notation:
##
## For each row or column in matrix we can select elements in that
## row or column with the following indexs...
##
## i:j:k = select elements between i and j with stepsize = k.
##
## ...where...
##
## i defaults to 0
## j defaults to the number of elements in the row, column or whatever.
## k defaults to 1
##
## Now that we have looked at the general notation, let's look at specific
## examples so that we can understand it.
##
## We'll start with: pe[:, 0::2]
##
## The stuff that comes before the comma (in this case ':') refers to the rows we want to select.
## The ':' before the comma means "select all rows" because we are not providing specific
## values for i, j and k and, instead, just using the default values.
##
## The stuff after the comma refers to the columns we want to select.
## In this case, we have '0::2', and that means we start with
## the first column (column = 0) and go to the end (using the default value for j)
## and we set the stepsize to 2, which means we skip every other column.
##
## Now to understand pe[:, 1::2]
##
## Again, the stuff before the comma refers to the rows, and, just like before
## we use default values for i,j and k, so we select all rows.
##
## The stuff that comes after the comma refers to the columns.
## In this case, we start with the 2nd column (column = 1), and go to the end
## (using the default value for 'j') and we set the stepsize to 2, which
## means we skip every other column.
##
## NOTE: using this ':' based notation is called "indexing" and also called "slicing"
## Add a batch dimension to the positional encoding
pe = pe.unsqueeze(0) # (1, seq_len, d_model)
## Now we "register 'pe'.
self.register_buffer('pe', pe) ## "register_buffer()" ensures that
## 'pe' will be moved to wherever the model gets
## moved to. So if the model is moved to a GPU, then,
## even though we don't need to optimize 'pe', it will
## also be moved to that GPU. This, in turn, means
## that accessing 'pe' will be relatively fast copared
## to having a GPU have to get the data from a CPU.
def forward(self, word_embeddings):
## Because this class, PositionEncoding, inherits from nn.Module, the forward() method
## is called by default when we use a PositionEncoding() object.
## In other words, after we create a PositionEncoding() object, pe = PositionEncoding(),
## then pe(word_embeddings) will call forward() and so this is where
## we will add the position encoding values to the word embedding values
## (batch, seq_len, d_model)
x = word_embeddings + (self.pe[:,:word_embeddings.shape[1], :]).requires_grad_(False)
return self.dropout(x)
class LayerNormalization(nn.Module):
def __init__(self, features: int, eps:float=10**-6) -> None:
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter
def forward(self, x):
# x: (batch, seq_len, hidden_size)
# Keep the dimension for broadcasting
mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
# Keep the dimension for broadcasting
std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
# eps is to prevent dividing by zero or when std is very small
return self.alpha * (x - mean) / (std + self.eps) + self.bias
class MultiHeadAttentionBlock(nn.Module):
def __init__(self, d_model: int, h: int, dropout: float) -> None:
super().__init__()
# Make sure d_model is divisible by h
assert d_model % h == 0, "d_model is not divisible by h"
self.d_model = d_model # Embedding vector size
self.h = h # Number of heads
self.d_k = d_model // h # Dimension of vector seen by each head
self.w_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wq
self.w_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wk
self.w_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wv
self.w_o = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wo
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(query, key, value, mask, dropout: nn.Dropout):
d_k = query.shape[-1]
## (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
## Compute attention scores, the equation is (q * k^T)/sqrt(d_model)
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
## Here we are masking out things we don't want to pay attention to,
## like tokens that come after the current token.
## We can also use masking to block out the <PAD> token,
## which is used when we have a batch of inputs sequences
## and they are not all the exact same length. Because the batch is passed
## in as a matrix, each input sequence has to have the same length, so we
## add <PAD> to the shorter sequences so that they are all as long ast the
## longest sequence.
##
## We replace <PAD>, or tokens that come after the current token
## with a very large negative number so that the SoftMax() function
## will give all masked elements an output value (or "probability") of 0.
## Write a very low value (indicating -inf) to the positions where mask == 0
attention_scores.masked_fill_(mask == 0, -1e9)
## Apply softmax to determine what percent of each token's value to
## use in the final attention values.
## (batch, h, seq_len, seq_len)
attention_scores = attention_scores.softmax(dim=-1)
if dropout is not None:
attention_scores = dropout(attention_scores)
## (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
## return attention scores which can be used for visualization
return (attention_scores @ value), attention_scores
def forward(self, q, k, v, mask):
query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
# (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
# Calculate attention
x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
# Combine all the heads together
# (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
# Multiply by Wo
# (batch, seq_len, d_model) --> (batch, seq_len, d_model)
return self.w_o(x)
class ResidualConnection(nn.Module):
def __init__(self, features: int, dropout: float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization(features)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
class FeedForwardBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
def forward(self, x):
# (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
class DecoderBlock(nn.Module):
def __init__(self,
features: int,
self_attention_block: MultiHeadAttentionBlock,
feed_forward_block: FeedForwardBlock,
dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
def forward(self, x, mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, mask))
x = self.residual_connections[1](x, self.feed_forward_block)
return x
class Decoder(nn.Module):
def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class ProjectionLayer(nn.Module):
def __init__(self, d_model, vocab_size):
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, x) -> None:
# (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
return self.proj(x)
class DecoderOnlyTransformer(nn.Module):
def __init__(self,
word_embedding: WordEmbeddings,
position_embedding: PositionEncoding,
decoder: Decoder,
projection_layer: ProjectionLayer):
super().__init__()
self.word_embedding = word_embedding
self.position_embedding = position_embedding
self.decoder = decoder
self.projection_layer = projection_layer
def decode(self, x: torch.Tensor, mask: torch.Tensor):
# x shape (batch, seq_len)
x = self.word_embedding(x)
x = self.position_embedding(x)
# x shape (batch, seq_len, d_model)
return self.decoder(x, mask)
def project(self, x):
# (batch, seq_len, vocab_size)
return self.projection_layer(x)
def build_transformer(vocab_size: int,
seq_len: int,
d_model: int=512,
N: int=6,
h: int=8,
dropout: float=0.1,
d_ff: int=2048) -> DecoderOnlyTransformer:
# Create the embedding layers
word_embedding = WordEmbeddings(d_model, vocab_size)
# Create the positional encoding layers
position_encoding = PositionEncoding(d_model, seq_len, dropout)
# Create the decoder blocks
decoder_blocks = []
for _ in range(N):
multi_head_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(d_model, multi_head_self_attention_block, feed_forward_block, dropout)
decoder_blocks.append(decoder_block)
# Create the encoder and decoder
decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
# Create the projection layer
projection_layer = ProjectionLayer(d_model, vocab_size)
# Create the transformer
transformer = DecoderOnlyTransformer(word_embedding,
position_encoding,
decoder,
projection_layer)
# Initialize the parameters
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return transformer
|