Spaces:
Sleeping
Sleeping
import math | |
from lightning.pytorch.utilities.types import EVAL_DATALOADERS | |
import torch | |
from typing import Dict,Optional,Tuple,Union | |
from dataclasses import dataclass | |
import lightning as pl | |
from torchmetrics import Accuracy | |
# @dataclass | |
# class ViTCfg: | |
# image_size: int | |
# patch_size: int | |
# num_channels: int | |
# model_dim: int | |
# num_attn_heads:int | |
# attn_dropout: int | |
# d_ff: int | |
# number_encoders:int | |
# classification_heads:int | |
class PatchEmbedding(torch.nn.Module): | |
def __init__(self, cfg:Dict) -> None: | |
super().__init__() | |
for k,v in cfg.items(): setattr(self,k,v) | |
assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly" | |
self.num_patchs = (self.image_size // self.patch_size)**2 | |
self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d ( | |
in_channels = self.num_channels, | |
out_channels=self.model_dim, | |
kernel_size = self.patch_size, | |
stride = self.patch_size, | |
bias=False | |
) | |
def forward(self,x:torch.Tensor)->torch.Tensor: | |
# (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim ) | |
return self.img2flattn(x).flatten(2).transpose(1,2) | |
class Embedding(torch.nn.Module): | |
def __init__(self,cfg:Dict ) -> None: | |
super().__init__() | |
for k,v in cfg.items(): setattr(self,k,v) | |
self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg) | |
# single [CLS] token | |
self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) ) | |
self.position_embd:torch.nn.Parameter = torch.nn.Parameter( | |
torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim ) | |
) | |
def forward(self,x:torch.Tensor)->torch.Tensor: | |
x = self.patch_embedding(x) | |
cls_token = self.cls_token.expand( x.shape[0], -1, -1 ) | |
x = torch.cat( (cls_token,x) , dim=1) | |
x = x + self.position_embd | |
return x | |
class AttentionBlock(torch.nn.Module): | |
def __init__(self,cfg:Dict ) -> None: | |
super().__init__() | |
for k,v in cfg.items(): self.__setattr__(k,v) | |
assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads" | |
self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False) | |
self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False) | |
self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout() | |
self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout() | |
# casual mask to ensure that attention is only applied to the left in the input seq | |
# self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) ) | |
''' | |
block_size=10 | |
[[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], | |
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.], | |
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], | |
[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.], | |
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.], | |
[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.], | |
[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.], | |
[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.], | |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 0.], | |
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]] | |
# Batch-1, Seq-1, Mask-(10,10) | |
''' | |
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]: | |
''' | |
input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim) | |
x :: (bs,seq_len,embedding_dim) | |
attn :: (bs, seq_len, 3*embedding_dim) | |
.split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2) | |
# Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get | |
k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads) | |
# Each Heads are responsible for different context of seq_len | |
''' | |
B,T,C = x.size() #(bs, seq_len ,embedding_dim) | |
# calc q,k,v | |
q:torch.Tensor; | |
k:torch.Tensor; | |
v:torch.Tensor; | |
q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2) | |
q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) | |
k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) | |
v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2) | |
attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1))) | |
# attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf')) | |
attn = torch.nn.functional.softmax(attn,dim=-1) | |
attn = self.attn_dropout(attn) | |
y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads ) | |
y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C) | |
return self.resid_dropout(self.out(y)), attn if attention_outputs else None | |
class MLP(torch.nn.Module): | |
def __init__(self,cfg:Dict ) -> None: | |
super().__init__() | |
for k,v in cfg.items(): self.__setattr__(k,v) | |
super().__init__() | |
self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff) | |
self.activation = torch.nn.ReLU() | |
self.layernorm = torch.nn.LayerNorm(self.d_ff) | |
self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim) | |
self.dropout = torch.nn.Dropout(0.2) | |
def forward(self,x:torch.Tensor)->torch.Tensor: | |
return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) ) | |
class EncoderBlock(torch.nn.Module): | |
def __init__(self,cfg:Dict ) -> None: | |
super().__init__() | |
for k,v in cfg.items(): self.__setattr__(k,v) | |
self.attn_block = AttentionBlock(cfg) | |
self.layernorm_1 = torch.nn.LayerNorm(self.model_dim) | |
self.mlp = MLP(cfg) | |
self.layernorm_2 = torch.nn.LayerNorm(self.model_dim) | |
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]: | |
# self-attention | |
attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs ) | |
x = x + attention_op | |
# FC | |
mlp_output = self.mlp( self.layernorm_2(x) ) | |
x = x + mlp_output | |
return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional) | |
class Encoder(torch.nn.Module): | |
""" | |
The transformer encoder module. | |
""" | |
def __init__(self,cfg:Dict ) -> None: | |
super().__init__() | |
for k,v in cfg.items(): self.__setattr__(k,v) | |
# Create a list of transformer blocks | |
self.blocks = torch.nn.ModuleList([]) | |
for _ in range(self.number_encoders): | |
block = EncoderBlock(cfg) | |
self.blocks.append(block) | |
def forward(self,x:torch.Tensor,attention_outputs:bool): | |
# Calculate the transformer block's output for each block | |
all_attn = [] | |
for block in self.blocks: | |
x,attn = block(x,attention_outputs=attention_outputs) | |
all_attn.append(attn) | |
# Return the encoder's output and the attention probabilities (optional) | |
return x,all_attn if attention_outputs==True else None | |
class ViTClassifier(torch.nn.Module): | |
def __init__(self, cfg:Dict ) -> None: | |
super().__init__() | |
for k,v in cfg.items(): self.__setattr__(k,v) | |
self.embed:Embedding = Embedding(cfg) | |
self.encoders:Encoder = Encoder(cfg=cfg) | |
self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False) | |
def forward(self,x:torch.Tensor,attention_outputs=False): | |
x = self.embed(x) | |
x,attn = self.encoders(x,attention_outputs=attention_outputs) | |
return self.classifier(x[:,0]), attn if attention_outputs else None | |