import math from lightning.pytorch.utilities.types import EVAL_DATALOADERS import torch from typing import Dict,Optional,Tuple,Union from dataclasses import dataclass import lightning as pl from torchmetrics import Accuracy # @dataclass # class ViTCfg: # image_size: int # patch_size: int # num_channels: int # model_dim: int # num_attn_heads:int # attn_dropout: int # d_ff: int # number_encoders:int # classification_heads:int class PatchEmbedding(torch.nn.Module): def __init__(self, cfg:Dict) -> None: super().__init__() for k,v in cfg.items(): setattr(self,k,v) assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly" self.num_patchs = (self.image_size // self.patch_size)**2 self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d ( in_channels = self.num_channels, out_channels=self.model_dim, kernel_size = self.patch_size, stride = self.patch_size, bias=False ) def forward(self,x:torch.Tensor)->torch.Tensor: # (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim ) return self.img2flattn(x).flatten(2).transpose(1,2) class Embedding(torch.nn.Module): def __init__(self,cfg:Dict ) -> None: super().__init__() for k,v in cfg.items(): setattr(self,k,v) self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg) # single [CLS] token self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) ) self.position_embd:torch.nn.Parameter = torch.nn.Parameter( torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim ) ) def forward(self,x:torch.Tensor)->torch.Tensor: x = self.patch_embedding(x) cls_token = self.cls_token.expand( x.shape[0], -1, -1 ) x = torch.cat( (cls_token,x) , dim=1) x = x + self.position_embd return x class AttentionBlock(torch.nn.Module): def __init__(self,cfg:Dict ) -> None: super().__init__() for k,v in cfg.items(): self.__setattr__(k,v) assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads" self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False) self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False) self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout() self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout() # casual mask to ensure that attention is only applied to the left in the input seq # self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) ) ''' block_size=10 [[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]] # Batch-1, Seq-1, Mask-(10,10) ''' def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]: ''' input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim) x :: (bs,seq_len,embedding_dim) attn :: (bs, seq_len, 3*embedding_dim) .split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2) # Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads) # Each Heads are responsible for different context of seq_len ''' B,T,C = x.size() #(bs, seq_len ,embedding_dim) # calc q,k,v q:torch.Tensor; k:torch.Tensor; v:torch.Tensor; q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2) q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1))) # attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf')) attn = torch.nn.functional.softmax(attn,dim=-1) attn = self.attn_dropout(attn) y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads ) y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C) return self.resid_dropout(self.out(y)), attn if attention_outputs else None class MLP(torch.nn.Module): def __init__(self,cfg:Dict ) -> None: super().__init__() for k,v in cfg.items(): self.__setattr__(k,v) super().__init__() self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff) self.activation = torch.nn.ReLU() self.layernorm = torch.nn.LayerNorm(self.d_ff) self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim) self.dropout = torch.nn.Dropout(0.2) def forward(self,x:torch.Tensor)->torch.Tensor: return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) ) class EncoderBlock(torch.nn.Module): def __init__(self,cfg:Dict ) -> None: super().__init__() for k,v in cfg.items(): self.__setattr__(k,v) self.attn_block = AttentionBlock(cfg) self.layernorm_1 = torch.nn.LayerNorm(self.model_dim) self.mlp = MLP(cfg) self.layernorm_2 = torch.nn.LayerNorm(self.model_dim) def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]: # self-attention attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs ) x = x + attention_op # FC mlp_output = self.mlp( self.layernorm_2(x) ) x = x + mlp_output return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional) class Encoder(torch.nn.Module): """ The transformer encoder module. """ def __init__(self,cfg:Dict ) -> None: super().__init__() for k,v in cfg.items(): self.__setattr__(k,v) # Create a list of transformer blocks self.blocks = torch.nn.ModuleList([]) for _ in range(self.number_encoders): block = EncoderBlock(cfg) self.blocks.append(block) def forward(self,x:torch.Tensor,attention_outputs:bool): # Calculate the transformer block's output for each block all_attn = [] for block in self.blocks: x,attn = block(x,attention_outputs=attention_outputs) all_attn.append(attn) # Return the encoder's output and the attention probabilities (optional) return x,all_attn if attention_outputs==True else None class ViTClassifier(torch.nn.Module): def __init__(self, cfg:Dict ) -> None: super().__init__() for k,v in cfg.items(): self.__setattr__(k,v) self.embed:Embedding = Embedding(cfg) self.encoders:Encoder = Encoder(cfg=cfg) self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False) def forward(self,x:torch.Tensor,attention_outputs=False): x = self.embed(x) x,attn = self.encoders(x,attention_outputs=attention_outputs) return self.classifier(x[:,0]), attn if attention_outputs else None