AttnViz / src /vit.py
Muthukamalan's picture
init gradio
197f827
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module): # Done
"""
img_size: 1d size of each image (32 for CIFAR-10)
patch_size: 1d size of each patch (img_size/num_patch_1d, 4 in this experiment)
in_chans: input channel (3 for RGB images)
emb_dim: flattened length for each token (or patch)
"""
def __init__(self, img_size:int, patch_size:int, in_chans:int=3, emb_dim:int=48):
super(PatchEmbedding, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans,
emb_dim,
kernel_size = patch_size,
stride = patch_size
)
def forward(self, x):
with torch.no_grad():
# x: [batch, in_chans, img_size, img_size]
x = self.proj(x) # [batch, embed_dim, # of patches in a row, # of patches in a col], [batch, 48, 8, 8] in this experiment
x = x.flatten(2) # [batch, embed_dim, total # of patches], [batch, 48, 64] in this experiment
x = x.transpose(1, 2) # [batch, total # of patches, emb_dim] => Transformer encoder requires this dimensions [batch, number of words, word_emb_dim]
return x
class TransformerEncoder(nn.Module): # Done
def __init__(self, input_dim:int, mlp_hidden_dim:int, num_head:int=8, dropout:float=0.):
# input_dim and head for Multi-Head Attention
super(TransformerEncoder, self).__init__()
self.norm1 = nn.LayerNorm(input_dim) # LayerNorm is BatchNorm for NLP
self.msa = MultiHeadSelfAttention(input_dim, n_heads=num_head)
self.norm2 = nn.LayerNorm(input_dim)
# Position-wise Feed-Forward Networks with GELU activation functions
self.mlp = nn.Sequential(
nn.Linear(input_dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, input_dim),
nn.GELU(),
)
def forward(self, x):
out = self.msa(self.norm1(x)) + x # add residual connection
out = self.mlp(self.norm2(out)) + out # add another residual connection
return out
class MultiHeadSelfAttention(nn.Module):
"""
dim: dimension of input and out per token features (emb dim for tokens)
n_heads: number of heads
qkv_bias: whether to have bias in qkv linear layers
attn_p: dropout probability for attention
proj_p: droupout probability last linear layer
scale: scaling factor for attention (1/sqrt(dk))
qkv: initial linear layer for the query, key, and value
proj: last linear layer
attn_drop, proj_drop: dropout layers for attn and proj
"""
def __init__(self, dim:int, n_heads:int=8, qkv_bias:bool=True, attn_p:float=0.01, proj_p:float=0.01):
super(MultiHeadSelfAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim # embedding dimension for input
self.head_dim = dim // n_heads # d_q, d_k, d_v in the paper (int div needed to preserve input dim = output dim)
self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k)
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) # lower linear layers in Figure 2 of the paper
self.attn_drop = nn.Dropout(attn_p)
self.proj = nn.Linear(dim, dim) # upper linear layers in Figure 2 of the paper
self.proj_drop = nn.Dropout(proj_p)
def forward(self, x):
"""
Input and Output shape: [batch_size, n_patches + 1, dim]
"""
batch_size, n_tokens, x_dim = x.shape # n_tokens = n_patches + 1 (1 is cls_token), x_dim is input dim
# Sanity Check
if x_dim != self.dim: # make sure input dim is same as concatnated dim (output dim)
raise ValueError
if self.dim != self.head_dim*self.n_heads: # make sure dim is divisible by n_heads
raise ValueError(f"Input & Output dim should be divisible by Number of Heads")
# Linear Layers for Query, Key, Value
qkv = self.qkv(x) # (batch_size, n_patches+1, 3*dim)
qkv = qkv.reshape(batch_size, n_tokens, 3, self.n_heads, self.head_dim) # (batch_size, n_patches+1, 3, n_heads, head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, n_heads, n_patches+1, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # (batch_size, n_heads, n_patches+1, head_dim)
# Scaled Dot-Product Attention
k_t = k.transpose(-2, -1) # K Transpose: (batch_size, n_heads, head_dim, n_patches+1)
dot_product = (q @ k_t)*self.scale # Query, Key Dot Product with Scale Factor: (batch_size, n_heads, n_patches+1, n_patches+1)
attn = dot_product.softmax(dim=-1) # Softmax: (batch_size, n_heads, n_patches+1, n_patches+1)
attn = self.attn_drop(attn) # Attention Dropout: (batch_size, n_heads, n_patches+1, n_patches+1)
weighted_avg = attn @ v # (batch_size, n_heads, n_patches+1, head_dim)
weighted_avg = weighted_avg.transpose(1, 2) # (batch_size, n_patches+1, n_heads, head_dim)
# Concat and Last Linear Layer
weighted_avg = weighted_avg.flatten(2) # Concat: (batch_size, n_patches+1, dim)
x = self.proj(weighted_avg) # Last Linear Layer: (batch_size, n_patches+1, dim)
x = self.proj_drop(x) # Last Linear Layer Dropout: (batch_size, n_patches+1, dim)
return x
class ViT(nn.Module): # Done
def __init__(
self,
in_c:int=3,
num_classes:int=10,
img_size:int=32,
num_patch_1d:int=16,
dropout:float=0.1,
num_enc_layers:int=2,
hidden_dim:int=128,
mlp_hidden_dim:int=128//2,
num_head:int=4,
is_cls_token:bool=True
):
super(ViT, self).__init__()
"""
is_cls_token: are we using class token?
num_patch_1d: number of patches in one row (or col), 3 in Figure 1 of the paper, 8 in this experiment
patch_size: # 1d size (size of row or col) of each patch, 16 for ImageNet in the paper, 4 in this experiment
flattened_patch_dim: Flattened vec length for each patch (4 x 4 x 3, each side is 4 and 3 color scheme), 48 in this experiment
num_tokens: number of total patches + 1 (class token), 10 in Figure 1 of the paper, 65 in this experiment
"""
self.is_cls_token = is_cls_token
self.num_patch_1d = num_patch_1d
self.patch_size = img_size//self.num_patch_1d
num_tokens = (self.num_patch_1d**2)+1 if self.is_cls_token else (self.num_patch_1d**2)
# Divide each image into patches
self.images_to_patches = PatchEmbedding(
img_size=img_size,
patch_size=img_size//num_patch_1d,
emb_dim=num_patch_1d*num_patch_1d
)
# Linear Projection of Flattened Patches
self.lpfp = nn.Linear(num_patch_1d*num_patch_1d, hidden_dim) # 48 x 384 (384 is the latent vector size D in the paper)
# Patch + Position Embedding (Learnable)
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim)) if is_cls_token else None # learnable classification token with dim [1, 1, 384]. 1 in 2nd dim because there is only one class per each image not each patch
self.pos_emb = nn.Parameter(torch.randn(1, num_tokens, hidden_dim)) # learnable positional embedding with dim [1, 65, 384]
# Transformer Encoder
enc_list = [TransformerEncoder(hidden_dim, mlp_hidden_dim=mlp_hidden_dim, dropout=dropout, num_head=num_head) for _ in range(num_enc_layers)] # num_enc_layers is L in Transformer Encoder at Figure 1
self.enc = nn.Sequential(*enc_list) # * should be adeed if given regular python list to nn.Sequential
# MLP Head (Standard Classifier)
self.mlp_head = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, x): # x: [batch, 3, 32, 32]
# Images into Patches (including flattening)
out = self.images_to_patches(x) # [batch, 64, 48]
# Linear Projection on Flattened Patches
out = self.lpfp(out) # [batch, 64, 384]
# Add Class Token and Positional Embedding
if self.is_cls_token:
out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out], dim=1) # [batch, 65, 384], added as extra learnable embedding
out = out + self.pos_emb # [batch, 65, 384]
# Transformer Encoder
out = self.enc(out) # [batch, 65, 384]
if self.is_cls_token:
out = out[:,0] # [batch, 384]
else:
out = out.mean(1)
# MLP Head
out = self.mlp_head(out) # [batch, 10]
return out
import lightning as pl
from torchmetrics import Accuracy
class ViTLightning(pl.LightningModule):
def __init__(self, learning_rate: float = 1e-3):
super(ViTLightning, self).__init__()
self.vit = ViT(
in_c=3,
num_classes=10,
img_size=32,
num_patch_1d=16,
dropout=0.1,
num_enc_layers=2,
hidden_dim=96,
mlp_hidden_dim=64,
num_head=8,
is_cls_token=True
)
self.train_acc = Accuracy('multiclass',num_classes=10)
self.val_acc = Accuracy('multiclass',num_classes=10)
self.test_acc = Accuracy('multiclass',num_classes=10)
self.learning_rate = learning_rate
def forward(self, x):
return self.vit(x)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = nn.CrossEntropyLoss()(preds, y)
acc = self.train_acc(preds, y)
self.log('train_loss', loss, prog_bar=True, logger=True)
self.log('train_acc', acc, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = nn.CrossEntropyLoss()(preds, y)
acc = self.val_acc(preds, y)
self.log('val_loss', loss, prog_bar=True, logger=True)
self.log('val_acc', acc, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = nn.CrossEntropyLoss()(preds, y)
acc = self.test_acc(preds, y)
self.log('test_loss', loss, prog_bar=True, logger=True)
self.log('test_acc', acc, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam( self.vit.parameters(), )
num_epochs = self.trainer.max_epochs,
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
total_steps=self.trainer.estimated_stepping_batches,
epochs=num_epochs,
pct_start= .3,
div_factor= 100,
max_lr= 1e-3,
three_phase= False,
final_div_factor= 100,
anneal_strategy='linear'
)
return {
'optimizer':optimizer,
'lr_scheduler':{
'scheduler':scheduler,
'monitor': "val_loss",
"interval":"step",
"frequency":1
}
}