FoodVision101 / MyModel /partViT.py
Chaitanya Garg
wrong dir
60d75eb
from torch import nn
import torch
class multiHeadSelfAttentionBlock(nn.Module):
def __init__(self,embeddingDim=768,numHeads=12,attnDropOut=0):
super().__init__()
self.layerNorm = nn.LayerNorm(normalized_shape=embeddingDim)
self.multiheadAttn = nn.MultiheadAttention(embed_dim=embeddingDim,num_heads=numHeads,dropout=attnDropOut,batch_first=True)
def forward(self,x):
layNorm = self.layerNorm(x)
attnOutPut, _ = self.multiheadAttn(query=layNorm,key=layNorm,value=layNorm)
return attnOutPut
class MLPBlock(nn.Module):
def __init__(self,embeddingDim,hiddenLayer,dropOut=0.1):
super().__init__()
self.MLP = nn.Sequential(
nn.LayerNorm(normalized_shape = embeddingDim),
nn.Linear(embeddingDim, hiddenLayer),
nn.GELU(),
nn.Dropout(dropOut),
nn.Linear(hiddenLayer,embeddingDim),
nn.Dropout(dropOut)
)
def forward(self,x):
return self.MLP(x)
class transformerEncoderBlock(nn.Module):
def __init__(self, embeddingDim, hiddenLayer,numHeads,MLPdropOut,attnDropOut=0):
super().__init__()
self.MSABlock = multiHeadSelfAttentionBlock(embeddingDim,numHeads,attnDropOut)
self.MLPBlock = MLPBlock(embeddingDim,hiddenLayer,MLPdropOut)
def forward(self,x):
x = self.MSABlock(x) + x
x = self.MLPBlock(x) + x
return x
class patchNPositionalEmbeddingMaker(nn.Module):
def __init__(self,inChannels,outChannels,patchSize,imgSize):
super().__init__()
self.outChannels = outChannels
# outChannels is the same as embeddingDim
self.patchSize = patchSize
self.numPatches = int(imgSize**2/patchSize**2)
self.patchMaker = nn.Conv2d(inChannels,outChannels, kernel_size=patchSize,stride=patchSize,padding=0)
self.flattener = nn.Flatten(start_dim=2,end_dim=3)
self.classEmbedding = nn.Parameter(torch.randn(1,1,self.outChannels),requires_grad=True)
self.PositionalEmbedding = nn.Parameter(torch.randn(1,self.numPatches+1,self.outChannels), requires_grad=True)
def forward(self,x):
batchSize = x.shape[0]
imgRes = x.shape[-1]
if(imgRes % self.patchSize ==0):
pass
else:
assert imgRes % self.patchSize ==0, 'Input size must be div by patchSize'
x = self.patchMaker(x)
x = self.flattener(x)
x = x.permute(0,2,1)
classToken = self.classEmbedding.expand(batchSize,-1,-1)
x = torch.cat((classToken,x),dim=1)
x = x + self.PositionalEmbedding
# batchSize = x.shape[0]
# embeddingDim = x.shape[-1]
return x