Spaces:
Sleeping
Sleeping
File size: 8,361 Bytes
197f827 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import math
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
import torch
from typing import Dict,Optional,Tuple,Union
from dataclasses import dataclass
import lightning as pl
from torchmetrics import Accuracy
# @dataclass
# class ViTCfg:
# image_size: int
# patch_size: int
# num_channels: int
# model_dim: int
# num_attn_heads:int
# attn_dropout: int
# d_ff: int
# number_encoders:int
# classification_heads:int
class PatchEmbedding(torch.nn.Module):
def __init__(self, cfg:Dict) -> None:
super().__init__()
for k,v in cfg.items(): setattr(self,k,v)
assert self.image_size % self.patch_size==0,"patch size is not divide image_size properly"
self.num_patchs = (self.image_size // self.patch_size)**2
self.img2flattn:torch.nn.Conv2d = torch.nn.Conv2d (
in_channels = self.num_channels,
out_channels=self.model_dim,
kernel_size = self.patch_size,
stride = self.patch_size,
bias=False
)
def forward(self,x:torch.Tensor)->torch.Tensor:
# (bs, 3, 32, 32 ) >> (bs, model_dim, img_size//patch_size, img_size//patch_size ) >> ( 1. model_dim, img_size**2 ) >> ( 1, img_size**2, model_dim )
return self.img2flattn(x).flatten(2).transpose(1,2)
class Embedding(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): setattr(self,k,v)
self.patch_embedding:PatchEmbedding = PatchEmbedding(cfg=cfg)
# single [CLS] token
self.cls_token:torch.nn.Parameter = torch.nn.Parameter( torch.randn(1,1, self.model_dim ) )
self.position_embd:torch.nn.Parameter = torch.nn.Parameter(
torch.randn( 1, int( (self.image_size // self.patch_size)**2 + 1), self.model_dim )
)
def forward(self,x:torch.Tensor)->torch.Tensor:
x = self.patch_embedding(x)
cls_token = self.cls_token.expand( x.shape[0], -1, -1 )
x = torch.cat( (cls_token,x) , dim=1)
x = x + self.position_embd
return x
class AttentionBlock(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
assert self.model_dim % self.num_attn_heads ==0, "model dim is not divisible by n heads"
self.attn_layer:torch.nn.Linear = torch.nn.Linear(self.model_dim, 3*self.model_dim, bias=False)
self.out :torch.nn.Linear = torch.nn.Linear(self.model_dim,self.model_dim,bias=False)
self.attn_dropout:torch.nn.Dropout = torch.nn.Dropout()
self.resid_dropout:torch.nn.Dropout= torch.nn.Dropout()
# casual mask to ensure that attention is only applied to the left in the input seq
# self.register_buffer('bias',tensor= torch.tril(torch.ones(self.block_size,self.block_size)).view(1, 1, self.block_size, self.block_size) )
'''
block_size=10
[[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]
# Batch-1, Seq-1, Mask-(10,10)
'''
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
'''
input (bs,seq_len,embedding_dim) >> output (bs,seq_len,embedding_dim)
x :: (bs,seq_len,embedding_dim)
attn :: (bs, seq_len, 3*embedding_dim)
.split:: (bs, seq_len, 3*embedding_dim).split(embedding_dim,dim=2)
# Each chunk (bs,seq_len,embedding) is a view of the original tensor, split across embeddin_dim so, 3 will get
k,q,v >> (bs,seql_len, n_heads, embedding_dim//n_heads) >> (bs,head, seql_len, embedding_dim//n_heads)
# Each Heads are responsible for different context of seq_len
'''
B,T,C = x.size() #(bs, seq_len ,embedding_dim)
# calc q,k,v
q:torch.Tensor;
k:torch.Tensor;
v:torch.Tensor;
q,k,v = self.attn_layer(x).split(split_size=self.model_dim,dim=2)
q = q.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
k = k.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
v = v.view(B,T,self.num_attn_heads, C//self.num_attn_heads).transpose(1,2)
attn = (q @ k.transpose(-2,-1)) * (1/math.sqrt(k.size(-1)))
# attn = attn.masked_fill(self.bias[:,:,:T,:T]==0,float('-inf'))
attn = torch.nn.functional.softmax(attn,dim=-1)
attn = self.attn_dropout(attn)
y:torch.Tensor = attn @ v # (bs, n_heads, T,T) @ (bs, n_heads, T, embding_dm/n_heads ) >> (bs,n_heads, seq_len, embedding_dim/n_heads )
y:torch.Tensor = y.transpose(1,2).contiguous().view(B,T,C)
return self.resid_dropout(self.out(y)), attn if attention_outputs else None
class MLP(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
super().__init__()
self.dense_1 = torch.nn.Linear(self.model_dim, self.d_ff)
self.activation = torch.nn.ReLU()
self.layernorm = torch.nn.LayerNorm(self.d_ff)
self.dense_2 = torch.nn.Linear(self.d_ff, self.model_dim)
self.dropout = torch.nn.Dropout(0.2)
def forward(self,x:torch.Tensor)->torch.Tensor:
return self.dropout( self.dense_2( self.layernorm(self.activation( self.dense_1(x) )) ) )
class EncoderBlock(torch.nn.Module):
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
self.attn_block = AttentionBlock(cfg)
self.layernorm_1 = torch.nn.LayerNorm(self.model_dim)
self.mlp = MLP(cfg)
self.layernorm_2 = torch.nn.LayerNorm(self.model_dim)
def forward(self,x:torch.Tensor, attention_outputs:bool)->Tuple[torch.Tensor, Union[torch.Tensor,None]]:
# self-attention
attention_op, attn = self.attn_block(self.layernorm_1(x), attention_outputs=attention_outputs )
x = x + attention_op
# FC
mlp_output = self.mlp( self.layernorm_2(x) )
x = x + mlp_output
return x, attn if attention_outputs==True else None # Return the transformer block's output and the attention probabilities (optional)
class Encoder(torch.nn.Module):
"""
The transformer encoder module.
"""
def __init__(self,cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
# Create a list of transformer blocks
self.blocks = torch.nn.ModuleList([])
for _ in range(self.number_encoders):
block = EncoderBlock(cfg)
self.blocks.append(block)
def forward(self,x:torch.Tensor,attention_outputs:bool):
# Calculate the transformer block's output for each block
all_attn = []
for block in self.blocks:
x,attn = block(x,attention_outputs=attention_outputs)
all_attn.append(attn)
# Return the encoder's output and the attention probabilities (optional)
return x,all_attn if attention_outputs==True else None
class ViTClassifier(torch.nn.Module):
def __init__(self, cfg:Dict ) -> None:
super().__init__()
for k,v in cfg.items(): self.__setattr__(k,v)
self.embed:Embedding = Embedding(cfg)
self.encoders:Encoder = Encoder(cfg=cfg)
self.classifier:torch.nn.Linear = torch.nn.Linear(self.model_dim ,self.classification_heads,bias=False)
def forward(self,x:torch.Tensor,attention_outputs=False):
x = self.embed(x)
x,attn = self.encoders(x,attention_outputs=attention_outputs)
return self.classifier(x[:,0]), attn if attention_outputs else None
|