πΌ Simple Transformer
Author: Eshan Jayasundara
Last Updated: 2nd of March 2025
Created: 28th of February 2025
About:
βββ Single head transformer (Transformer with self-attention training with teacher-forcing)
Training:
βββ Teacher Forcing (Baseline)
βββ During training, the actual ground-truth tokens (from the dataset) are fed as input to the decoder instead of using the modelβs own predictions.
βββ This makes training faster and ensures the model learns accurate token-to-token mappings.
βββ Drawback: At inference time, the model doesn't see ground-truth inputs, so errors can accumulate (called exposure bias).
Vocabulary dataset (from huggingface):
βββ "yukiarimo/english-vocabulary"
Simple Transformer Architecture:
Encoder
βββ Input text
β βββ Eg: "Hello, how are you?"
βββ Remove punctuation from input text
βββ Input tokenization
βββ Embedding lookup with torch.nn.Embedding
βββ Positional encoding (sin, cosine)
βββ Self-attention
β βββ single-head
β βββ Q = Wq @ Embedding
β βββ K = Wk @ Embedding
β βββ V = Wv @ Embedding
βββ Add and norm
βββ Feed forward layer
β βββ 2 hidden layers
β βββ ReLU as the activation in hidden layer
β βββ No activation at the output layer
β βββ nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim)
βββ Add and norm (again)
βββ Save encoder out to be used in cross attention
Decoder
βββ Decoder teacher text (same as the target text but shifted right)
β βββ Eg: Decoder teacher text - "<SOS> hello, I'm fine."
β βββ Eg: target text - "hello, I'm fine. <EOS>"
βββ Remove punctuation from input text
βββ Input tokenization
βββ Embedding lookup with torch.nn.Embedding
βββ Positional encoding (sin, cosine)
βββ Masked-self-attention (single-head, new class signature for masked self attention introduced)
β βββ single-head
β βββ causal mask with triangular matrix
β βββ Q = Wq @ Embedding
β βββ K = Wk @ Embedding
β βββ V = Wv @ Embedding
βββ Add and norm
βββ Cross attention (same class signature used in the encoder self-attention can be used)
β βββ single-head
β βββ Q = Wq @ Add and normalized output from masked-self-attention
β βββ K = Wk @ Encoder output
β βββ V = Wv @ Encoder output
βββ Add and norm
βββ Feed forward layer
β βββ 2 hidden layers
β βββ ReLU as the activation in hidden layer
β βββ No activation at the output layer
β βββ nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim)
βββ Add and norm (again)
βββ Linear layer (No activation or softmax as in 'Attention is all you need' is used here)
Optimization
βββ Initialize the Adam optimizer with the modelβs parameters and a specified learning rate.
β βββ self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate)
βββ Before computing gradients for the current batch, we reset any existing gradients from the previous iteration.
β βββ self.optimizer.zero_grad()
βββ The model takes in `input_tokens` and `decoder_teacher_tokens` and performs a forward pass to compute `logits`
β βββ logits = self.forward(input_tokens, decoder_teacher_tokens)
βββ The cross-entropy loss
β βββ Measures the difference between the predicted token distribution (logits) and the actual target tokens (decoder_target_tokens).
β βββ It expects logits to have raw scores (not probabilities), and it applies softmax internally.
β βββ loss = F.cross_entropy(logits, decoder_target_tokens)
βββ Compute the gradients of the loss with respect to all trainable parameters in the model using automatic differentiation (backpropagation).
β βββ loss.backward()
βββ Optimizer updates the model's weights using the computed gradients.
βββ self.optimizer.step()
After training, to calculate the output tokens -> text, 'Autoregressive text generation' is used (one word at a time)
βββ Start with <SOS>. (Initial input to the decoder) but input to the encoder is the `prompt`.
βββ Model predicts the next token.
βββ Append the predicted token to the sequence.
βββ Repeat until an <EOS> token or max length is reached.
βββ For illustration let's use words instead of tokens(numerical representation)
<SOS>
<SOS> hello
<SOS> hello I'm
<SOS> hello I'm good
<SOS> hello I'm good <EOS>
Feauter Improvements:
βββ Multi-head attention instead of single-head attention.
βββ Layer normalization instead of simple mean-variance normalization.
βββ Dropout layers for better generalization.
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.