Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| class PatchEmbedding(nn.Module): # Done | |
| """ | |
| img_size: 1d size of each image (32 for CIFAR-10) | |
| patch_size: 1d size of each patch (img_size/num_patch_1d, 4 in this experiment) | |
| in_chans: input channel (3 for RGB images) | |
| emb_dim: flattened length for each token (or patch) | |
| """ | |
| def __init__(self, img_size:int, patch_size:int, in_chans:int=3, emb_dim:int=48): | |
| super(PatchEmbedding, self).__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.proj = nn.Conv2d( | |
| in_chans, | |
| emb_dim, | |
| kernel_size = patch_size, | |
| stride = patch_size | |
| ) | |
| def forward(self, x): | |
| with torch.no_grad(): | |
| # x: [batch, in_chans, img_size, img_size] | |
| x = self.proj(x) # [batch, embed_dim, # of patches in a row, # of patches in a col], [batch, 48, 8, 8] in this experiment | |
| x = x.flatten(2) # [batch, embed_dim, total # of patches], [batch, 48, 64] in this experiment | |
| x = x.transpose(1, 2) # [batch, total # of patches, emb_dim] => Transformer encoder requires this dimensions [batch, number of words, word_emb_dim] | |
| return x | |
| class TransformerEncoder(nn.Module): # Done | |
| def __init__(self, input_dim:int, mlp_hidden_dim:int, num_head:int=8, dropout:float=0.): | |
| # input_dim and head for Multi-Head Attention | |
| super(TransformerEncoder, self).__init__() | |
| self.norm1 = nn.LayerNorm(input_dim) # LayerNorm is BatchNorm for NLP | |
| self.msa = MultiHeadSelfAttention(input_dim, n_heads=num_head) | |
| self.norm2 = nn.LayerNorm(input_dim) | |
| # Position-wise Feed-Forward Networks with GELU activation functions | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, mlp_hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(mlp_hidden_dim, input_dim), | |
| nn.GELU(), | |
| ) | |
| def forward(self, x): | |
| out = self.msa(self.norm1(x)) + x # add residual connection | |
| out = self.mlp(self.norm2(out)) + out # add another residual connection | |
| return out | |
| class MultiHeadSelfAttention(nn.Module): | |
| """ | |
| dim: dimension of input and out per token features (emb dim for tokens) | |
| n_heads: number of heads | |
| qkv_bias: whether to have bias in qkv linear layers | |
| attn_p: dropout probability for attention | |
| proj_p: droupout probability last linear layer | |
| scale: scaling factor for attention (1/sqrt(dk)) | |
| qkv: initial linear layer for the query, key, and value | |
| proj: last linear layer | |
| attn_drop, proj_drop: dropout layers for attn and proj | |
| """ | |
| def __init__(self, dim:int, n_heads:int=8, qkv_bias:bool=True, attn_p:float=0.01, proj_p:float=0.01): | |
| super(MultiHeadSelfAttention, self).__init__() | |
| self.n_heads = n_heads | |
| self.dim = dim # embedding dimension for input | |
| self.head_dim = dim // n_heads # d_q, d_k, d_v in the paper (int div needed to preserve input dim = output dim) | |
| self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k) | |
| self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) # lower linear layers in Figure 2 of the paper | |
| self.attn_drop = nn.Dropout(attn_p) | |
| self.proj = nn.Linear(dim, dim) # upper linear layers in Figure 2 of the paper | |
| self.proj_drop = nn.Dropout(proj_p) | |
| def forward(self, x): | |
| """ | |
| Input and Output shape: [batch_size, n_patches + 1, dim] | |
| """ | |
| batch_size, n_tokens, x_dim = x.shape # n_tokens = n_patches + 1 (1 is cls_token), x_dim is input dim | |
| # Sanity Check | |
| if x_dim != self.dim: # make sure input dim is same as concatnated dim (output dim) | |
| raise ValueError | |
| if self.dim != self.head_dim*self.n_heads: # make sure dim is divisible by n_heads | |
| raise ValueError(f"Input & Output dim should be divisible by Number of Heads") | |
| # Linear Layers for Query, Key, Value | |
| qkv = self.qkv(x) # (batch_size, n_patches+1, 3*dim) | |
| qkv = qkv.reshape(batch_size, n_tokens, 3, self.n_heads, self.head_dim) # (batch_size, n_patches+1, 3, n_heads, head_dim) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, n_heads, n_patches+1, head_dim) | |
| q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, n_heads, n_patches+1, head_dim) | |
| # Scaled Dot-Product Attention | |
| k_t = k.transpose(-2, -1) # K Transpose: (batch_size, n_heads, head_dim, n_patches+1) | |
| dot_product = (q @ k_t)*self.scale # Query, Key Dot Product with Scale Factor: (batch_size, n_heads, n_patches+1, n_patches+1) | |
| attn = dot_product.softmax(dim=-1) # Softmax: (batch_size, n_heads, n_patches+1, n_patches+1) | |
| attn = self.attn_drop(attn) # Attention Dropout: (batch_size, n_heads, n_patches+1, n_patches+1) | |
| weighted_avg = attn @ v # (batch_size, n_heads, n_patches+1, head_dim) | |
| weighted_avg = weighted_avg.transpose(1, 2) # (batch_size, n_patches+1, n_heads, head_dim) | |
| # Concat and Last Linear Layer | |
| weighted_avg = weighted_avg.flatten(2) # Concat: (batch_size, n_patches+1, dim) | |
| x = self.proj(weighted_avg) # Last Linear Layer: (batch_size, n_patches+1, dim) | |
| x = self.proj_drop(x) # Last Linear Layer Dropout: (batch_size, n_patches+1, dim) | |
| return x | |
| class ViT(nn.Module): # Done | |
| def __init__( | |
| self, | |
| in_c:int=3, | |
| num_classes:int=10, | |
| img_size:int=32, | |
| num_patch_1d:int=16, | |
| dropout:float=0.1, | |
| num_enc_layers:int=2, | |
| hidden_dim:int=128, | |
| mlp_hidden_dim:int=128//2, | |
| num_head:int=4, | |
| is_cls_token:bool=True | |
| ): | |
| super(ViT, self).__init__() | |
| """ | |
| is_cls_token: are we using class token? | |
| num_patch_1d: number of patches in one row (or col), 3 in Figure 1 of the paper, 8 in this experiment | |
| patch_size: # 1d size (size of row or col) of each patch, 16 for ImageNet in the paper, 4 in this experiment | |
| flattened_patch_dim: Flattened vec length for each patch (4 x 4 x 3, each side is 4 and 3 color scheme), 48 in this experiment | |
| num_tokens: number of total patches + 1 (class token), 10 in Figure 1 of the paper, 65 in this experiment | |
| """ | |
| self.is_cls_token = is_cls_token | |
| self.num_patch_1d = num_patch_1d | |
| self.patch_size = img_size//self.num_patch_1d | |
| num_tokens = (self.num_patch_1d**2)+1 if self.is_cls_token else (self.num_patch_1d**2) | |
| # Divide each image into patches | |
| self.images_to_patches = PatchEmbedding( | |
| img_size=img_size, | |
| patch_size=img_size//num_patch_1d, | |
| emb_dim=num_patch_1d*num_patch_1d | |
| ) | |
| # Linear Projection of Flattened Patches | |
| self.lpfp = nn.Linear(num_patch_1d*num_patch_1d, hidden_dim) # 48 x 384 (384 is the latent vector size D in the paper) | |
| # Patch + Position Embedding (Learnable) | |
| self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) if is_cls_token else None # learnable classification token with dim [1, 1, 384]. 1 in 2nd dim because there is only one class per each image not each patch | |
| self.pos_emb = nn.Parameter(torch.randn(1, num_tokens, hidden_dim)) # learnable positional embedding with dim [1, 65, 384] | |
| # Transformer Encoder | |
| enc_list = [TransformerEncoder(hidden_dim, mlp_hidden_dim=mlp_hidden_dim, dropout=dropout, num_head=num_head) for _ in range(num_enc_layers)] # num_enc_layers is L in Transformer Encoder at Figure 1 | |
| self.enc = nn.Sequential(*enc_list) # * should be adeed if given regular python list to nn.Sequential | |
| # MLP Head (Standard Classifier) | |
| self.mlp_head = nn.Sequential( | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, num_classes) | |
| ) | |
| def forward(self, x): # x: [batch, 3, 32, 32] | |
| # Images into Patches (including flattening) | |
| out = self.images_to_patches(x) # [batch, 64, 48] | |
| # Linear Projection on Flattened Patches | |
| out = self.lpfp(out) # [batch, 64, 384] | |
| # Add Class Token and Positional Embedding | |
| if self.is_cls_token: | |
| out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out], dim=1) # [batch, 65, 384], added as extra learnable embedding | |
| out = out + self.pos_emb # [batch, 65, 384] | |
| # Transformer Encoder | |
| out = self.enc(out) # [batch, 65, 384] | |
| if self.is_cls_token: | |
| out = out[:,0] # [batch, 384] | |
| else: | |
| out = out.mean(1) | |
| # MLP Head | |
| out = self.mlp_head(out) # [batch, 10] | |
| return out | |
| import lightning as pl | |
| from torchmetrics import Accuracy | |
| class ViTLightning(pl.LightningModule): | |
| def __init__(self, learning_rate: float = 1e-3): | |
| super(ViTLightning, self).__init__() | |
| self.vit = ViT( | |
| in_c=3, | |
| num_classes=10, | |
| img_size=32, | |
| num_patch_1d=16, | |
| dropout=0.1, | |
| num_enc_layers=2, | |
| hidden_dim=96, | |
| mlp_hidden_dim=64, | |
| num_head=8, | |
| is_cls_token=True | |
| ) | |
| self.train_acc = Accuracy('multiclass',num_classes=10) | |
| self.val_acc = Accuracy('multiclass',num_classes=10) | |
| self.test_acc = Accuracy('multiclass',num_classes=10) | |
| self.learning_rate = learning_rate | |
| def forward(self, x): | |
| return self.vit(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| preds = self.forward(x) | |
| loss = nn.CrossEntropyLoss()(preds, y) | |
| acc = self.train_acc(preds, y) | |
| self.log('train_loss', loss, prog_bar=True, logger=True) | |
| self.log('train_acc', acc, prog_bar=True, logger=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| preds = self.forward(x) | |
| loss = nn.CrossEntropyLoss()(preds, y) | |
| acc = self.val_acc(preds, y) | |
| self.log('val_loss', loss, prog_bar=True, logger=True) | |
| self.log('val_acc', acc, prog_bar=True, logger=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| preds = self.forward(x) | |
| loss = nn.CrossEntropyLoss()(preds, y) | |
| acc = self.test_acc(preds, y) | |
| self.log('test_loss', loss, prog_bar=True, logger=True) | |
| self.log('test_acc', acc, prog_bar=True, logger=True) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( self.vit.parameters(), ) | |
| num_epochs = self.trainer.max_epochs, | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer=optimizer, | |
| total_steps=self.trainer.estimated_stepping_batches, | |
| epochs=num_epochs, | |
| pct_start= .3, | |
| div_factor= 100, | |
| max_lr= 1e-3, | |
| three_phase= False, | |
| final_div_factor= 100, | |
| anneal_strategy='linear' | |
| ) | |
| return { | |
| 'optimizer':optimizer, | |
| 'lr_scheduler':{ | |
| 'scheduler':scheduler, | |
| 'monitor': "val_loss", | |
| "interval":"step", | |
| "frequency":1 | |
| } | |
| } |