AttnViz / src /old.py.old
Muthukamalan's picture
init gradio
197f827
import math
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
import torch
from typing import Dict,Optional,Tuple,Union
from dataclasses import dataclass
import lightning as pl
from torchmetrics import Accuracy
# @dataclass
# class ViTCfg:
# image_size: int
# patch_size: int
# num_channels: int
# model_dim: int
# num_attn_heads:int
# attn_dropout: int
# d_ff: int
# number_encoders:int
# classification_heads:int
class PatchEmbedding(torch.nn.Module):
def __init__(self, cfg:Dict) -> None:
super().__init__()
for k,v in cfg.items(): setattr(self,k,v)
assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly"
self.num_patchs = (self.image_size // self.patch_size)**2
self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d (
in_channels = self.num_channels,
out_channels=self.model_dim,
kernel_size = self.patch_size,
stride = self.patch_size,
bias=False
)
def forward(self,x:torch.Tensor)->torch.Tensor:
# (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim )
return self.img2flattn(x).flatten(2).transpose(1,2)
class Embedding(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): setattr(self,k,v)
self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg)
# single [CLS] token
self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) )
self.position_embd:torch.nn.Parameter = torch.nn.Parameter(
torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim )
)
def forward(self,x:torch.Tensor)->torch.Tensor:
x = self.patch_embedding(x)
cls_token = self.cls_token.expand( x.shape[0], -1, -1 )
x = torch.cat( (cls_token,x) , dim=1)
x = x + self.position_embd
return x
class AttentionBlock(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads"
self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False)
self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False)
self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout()
self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout()
# casual mask to ensure that attention is only applied to the left in the input seq
# self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) )
'''
block_size=10
[[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]
# Batch-1, Seq-1, Mask-(10,10)
'''
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
'''
input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim)
x :: (bs,seq_len,embedding_dim)
attn :: (bs, seq_len, 3*embedding_dim)
.split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2)
# Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get
k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads)
# Each Heads are responsible for different context of seq_len
'''
B,T,C = x.size() #(bs, seq_len ,embedding_dim)
# calc q,k,v
q:torch.Tensor;
k:torch.Tensor;
v:torch.Tensor;
q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2)
q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1)))
# attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf'))
attn = torch.nn.functional.softmax(attn,dim=-1)
attn = self.attn_dropout(attn)
y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads )
y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C)
return self.resid_dropout(self.out(y)), attn if attention_outputs else None
class MLP(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
super().__init__()
self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff)
self.activation = torch.nn.ReLU()
self.layernorm = torch.nn.LayerNorm(self.d_ff)
self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim)
self.dropout = torch.nn.Dropout(0.2)
def forward(self,x:torch.Tensor)->torch.Tensor:
return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) )
class EncoderBlock(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
self.attn_block = AttentionBlock(cfg)
self.layernorm_1 = torch.nn.LayerNorm(self.model_dim)
self.mlp = MLP(cfg)
self.layernorm_2 = torch.nn.LayerNorm(self.model_dim)
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
# self-attention
attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs )
x = x + attention_op
# FC
mlp_output = self.mlp( self.layernorm_2(x) )
x = x + mlp_output
return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional)
class Encoder(torch.nn.Module):
"""
The transformer encoder module.
"""
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
# Create a list of transformer blocks
self.blocks = torch.nn.ModuleList([])
for _ in range(self.number_encoders):
block = EncoderBlock(cfg)
self.blocks.append(block)
def forward(self,x:torch.Tensor,attention_outputs:bool):
# Calculate the transformer block's output for each block
all_attn = []
for block in self.blocks:
x,attn = block(x,attention_outputs=attention_outputs)
all_attn.append(attn)
# Return the encoder's output and the attention probabilities (optional)
return x,all_attn if attention_outputs==True else None
class ViTClassifier(torch.nn.Module):
def __init__(self, cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
self.embed:Embedding = Embedding(cfg)
self.encoders:Encoder = Encoder(cfg=cfg)
self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False)
def forward(self,x:torch.Tensor,attention_outputs=False):
x = self.embed(x)
x,attn = self.encoders(x,attention_outputs=attention_outputs)
return self.classifier(x[:,0]), attn if attention_outputs else None