Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class PatchEmbedding(nn.Module): # Done | |
""" | |
img_size: 1d size of each image (32 for CIFAR-10) | |
patch_size: 1d size of each patch (img_size/num_patch_1d, 4 in this experiment) | |
in_chans: input channel (3 for RGB images) | |
emb_dim: flattened length for each token (or patch) | |
""" | |
def __init__(self, img_size:int, patch_size:int, in_chans:int=3, emb_dim:int=48): | |
super(PatchEmbedding, self).__init__() | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.proj = nn.Conv2d( | |
in_chans, | |
emb_dim, | |
kernel_size = patch_size, | |
stride = patch_size | |
) | |
def forward(self, x): | |
with torch.no_grad(): | |
# x: [batch, in_chans, img_size, img_size] | |
x = self.proj(x) # [batch, embed_dim, # of patches in a row, # of patches in a col], [batch, 48, 8, 8] in this experiment | |
x = x.flatten(2) # [batch, embed_dim, total # of patches], [batch, 48, 64] in this experiment | |
x = x.transpose(1, 2) # [batch, total # of patches, emb_dim] => Transformer encoder requires this dimensions [batch, number of words, word_emb_dim] | |
return x | |
class TransformerEncoder(nn.Module): # Done | |
def __init__(self, input_dim:int, mlp_hidden_dim:int, num_head:int=8, dropout:float=0.): | |
# input_dim and head for Multi-Head Attention | |
super(TransformerEncoder, self).__init__() | |
self.norm1 = nn.LayerNorm(input_dim) # LayerNorm is BatchNorm for NLP | |
self.msa = MultiHeadSelfAttention(input_dim, n_heads=num_head) | |
self.norm2 = nn.LayerNorm(input_dim) | |
# Position-wise Feed-Forward Networks with GELU activation functions | |
self.mlp = nn.Sequential( | |
nn.Linear(input_dim, mlp_hidden_dim), | |
nn.GELU(), | |
nn.Linear(mlp_hidden_dim, input_dim), | |
nn.GELU(), | |
) | |
def forward(self, x): | |
out = self.msa(self.norm1(x)) + x # add residual connection | |
out = self.mlp(self.norm2(out)) + out # add another residual connection | |
return out | |
class MultiHeadSelfAttention(nn.Module): | |
""" | |
dim: dimension of input and out per token features (emb dim for tokens) | |
n_heads: number of heads | |
qkv_bias: whether to have bias in qkv linear layers | |
attn_p: dropout probability for attention | |
proj_p: droupout probability last linear layer | |
scale: scaling factor for attention (1/sqrt(dk)) | |
qkv: initial linear layer for the query, key, and value | |
proj: last linear layer | |
attn_drop, proj_drop: dropout layers for attn and proj | |
""" | |
def __init__(self, dim:int, n_heads:int=8, qkv_bias:bool=True, attn_p:float=0.01, proj_p:float=0.01): | |
super(MultiHeadSelfAttention, self).__init__() | |
self.n_heads = n_heads | |
self.dim = dim # embedding dimension for input | |
self.head_dim = dim // n_heads # d_q, d_k, d_v in the paper (int div needed to preserve input dim = output dim) | |
self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k) | |
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) # lower linear layers in Figure 2 of the paper | |
self.attn_drop = nn.Dropout(attn_p) | |
self.proj = nn.Linear(dim, dim) # upper linear layers in Figure 2 of the paper | |
self.proj_drop = nn.Dropout(proj_p) | |
def forward(self, x): | |
""" | |
Input and Output shape: [batch_size, n_patches + 1, dim] | |
""" | |
batch_size, n_tokens, x_dim = x.shape # n_tokens = n_patches + 1 (1 is cls_token), x_dim is input dim | |
# Sanity Check | |
if x_dim != self.dim: # make sure input dim is same as concatnated dim (output dim) | |
raise ValueError | |
if self.dim != self.head_dim*self.n_heads: # make sure dim is divisible by n_heads | |
raise ValueError(f"Input & Output dim should be divisible by Number of Heads") | |
# Linear Layers for Query, Key, Value | |
qkv = self.qkv(x) # (batch_size, n_patches+1, 3*dim) | |
qkv = qkv.reshape(batch_size, n_tokens, 3, self.n_heads, self.head_dim) # (batch_size, n_patches+1, 3, n_heads, head_dim) | |
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, n_heads, n_patches+1, head_dim) | |
q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, n_heads, n_patches+1, head_dim) | |
# Scaled Dot-Product Attention | |
k_t = k.transpose(-2, -1) # K Transpose: (batch_size, n_heads, head_dim, n_patches+1) | |
dot_product = (q @ k_t)*self.scale # Query, Key Dot Product with Scale Factor: (batch_size, n_heads, n_patches+1, n_patches+1) | |
attn = dot_product.softmax(dim=-1) # Softmax: (batch_size, n_heads, n_patches+1, n_patches+1) | |
attn = self.attn_drop(attn) # Attention Dropout: (batch_size, n_heads, n_patches+1, n_patches+1) | |
weighted_avg = attn @ v # (batch_size, n_heads, n_patches+1, head_dim) | |
weighted_avg = weighted_avg.transpose(1, 2) # (batch_size, n_patches+1, n_heads, head_dim) | |
# Concat and Last Linear Layer | |
weighted_avg = weighted_avg.flatten(2) # Concat: (batch_size, n_patches+1, dim) | |
x = self.proj(weighted_avg) # Last Linear Layer: (batch_size, n_patches+1, dim) | |
x = self.proj_drop(x) # Last Linear Layer Dropout: (batch_size, n_patches+1, dim) | |
return x | |
class ViT(nn.Module): # Done | |
def __init__( | |
self, | |
in_c:int=3, | |
num_classes:int=10, | |
img_size:int=32, | |
num_patch_1d:int=16, | |
dropout:float=0.1, | |
num_enc_layers:int=2, | |
hidden_dim:int=128, | |
mlp_hidden_dim:int=128//2, | |
num_head:int=4, | |
is_cls_token:bool=True | |
): | |
super(ViT, self).__init__() | |
""" | |
is_cls_token: are we using class token? | |
num_patch_1d: number of patches in one row (or col), 3 in Figure 1 of the paper, 8 in this experiment | |
patch_size: # 1d size (size of row or col) of each patch, 16 for ImageNet in the paper, 4 in this experiment | |
flattened_patch_dim: Flattened vec length for each patch (4 x 4 x 3, each side is 4 and 3 color scheme), 48 in this experiment | |
num_tokens: number of total patches + 1 (class token), 10 in Figure 1 of the paper, 65 in this experiment | |
""" | |
self.is_cls_token = is_cls_token | |
self.num_patch_1d = num_patch_1d | |
self.patch_size = img_size//self.num_patch_1d | |
num_tokens = (self.num_patch_1d**2)+1 if self.is_cls_token else (self.num_patch_1d**2) | |
# Divide each image into patches | |
self.images_to_patches = PatchEmbedding( | |
img_size=img_size, | |
patch_size=img_size//num_patch_1d, | |
emb_dim=num_patch_1d*num_patch_1d | |
) | |
# Linear Projection of Flattened Patches | |
self.lpfp = nn.Linear(num_patch_1d*num_patch_1d, hidden_dim) # 48 x 384 (384 is the latent vector size D in the paper) | |
# Patch + Position Embedding (Learnable) | |
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) if is_cls_token else None # learnable classification token with dim [1, 1, 384]. 1 in 2nd dim because there is only one class per each image not each patch | |
self.pos_emb = nn.Parameter(torch.randn(1, num_tokens, hidden_dim)) # learnable positional embedding with dim [1, 65, 384] | |
# Transformer Encoder | |
enc_list = [TransformerEncoder(hidden_dim, mlp_hidden_dim=mlp_hidden_dim, dropout=dropout, num_head=num_head) for _ in range(num_enc_layers)] # num_enc_layers is L in Transformer Encoder at Figure 1 | |
self.enc = nn.Sequential(*enc_list) # * should be adeed if given regular python list to nn.Sequential | |
# MLP Head (Standard Classifier) | |
self.mlp_head = nn.Sequential( | |
nn.LayerNorm(hidden_dim), | |
nn.Linear(hidden_dim, num_classes) | |
) | |
def forward(self, x): # x: [batch, 3, 32, 32] | |
# Images into Patches (including flattening) | |
out = self.images_to_patches(x) # [batch, 64, 48] | |
# Linear Projection on Flattened Patches | |
out = self.lpfp(out) # [batch, 64, 384] | |
# Add Class Token and Positional Embedding | |
if self.is_cls_token: | |
out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out], dim=1) # [batch, 65, 384], added as extra learnable embedding | |
out = out + self.pos_emb # [batch, 65, 384] | |
# Transformer Encoder | |
out = self.enc(out) # [batch, 65, 384] | |
if self.is_cls_token: | |
out = out[:,0] # [batch, 384] | |
else: | |
out = out.mean(1) | |
# MLP Head | |
out = self.mlp_head(out) # [batch, 10] | |
return out | |
import lightning as pl | |
from torchmetrics import Accuracy | |
class ViTLightning(pl.LightningModule): | |
def __init__(self, learning_rate: float = 1e-3): | |
super(ViTLightning, self).__init__() | |
self.vit = ViT( | |
in_c=3, | |
num_classes=10, | |
img_size=32, | |
num_patch_1d=16, | |
dropout=0.1, | |
num_enc_layers=2, | |
hidden_dim=96, | |
mlp_hidden_dim=64, | |
num_head=8, | |
is_cls_token=True | |
) | |
self.train_acc = Accuracy('multiclass',num_classes=10) | |
self.val_acc = Accuracy('multiclass',num_classes=10) | |
self.test_acc = Accuracy('multiclass',num_classes=10) | |
self.learning_rate = learning_rate | |
def forward(self, x): | |
return self.vit(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
preds = self.forward(x) | |
loss = nn.CrossEntropyLoss()(preds, y) | |
acc = self.train_acc(preds, y) | |
self.log('train_loss', loss, prog_bar=True, logger=True) | |
self.log('train_acc', acc, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
preds = self.forward(x) | |
loss = nn.CrossEntropyLoss()(preds, y) | |
acc = self.val_acc(preds, y) | |
self.log('val_loss', loss, prog_bar=True, logger=True) | |
self.log('val_acc', acc, prog_bar=True, logger=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
preds = self.forward(x) | |
loss = nn.CrossEntropyLoss()(preds, y) | |
acc = self.test_acc(preds, y) | |
self.log('test_loss', loss, prog_bar=True, logger=True) | |
self.log('test_acc', acc, prog_bar=True, logger=True) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam( self.vit.parameters(), ) | |
num_epochs = self.trainer.max_epochs, | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer=optimizer, | |
total_steps=self.trainer.estimated_stepping_batches, | |
epochs=num_epochs, | |
pct_start= .3, | |
div_factor= 100, | |
max_lr= 1e-3, | |
three_phase= False, | |
final_div_factor= 100, | |
anneal_strategy='linear' | |
) | |
return { | |
'optimizer':optimizer, | |
'lr_scheduler':{ | |
'scheduler':scheduler, | |
'monitor': "val_loss", | |
"interval":"step", | |
"frequency":1 | |
} | |
} |