Spaces:
Runtime error
Runtime error
from torch import nn | |
import torch | |
class multiHeadSelfAttentionBlock(nn.Module): | |
def __init__(self,embeddingDim=768,numHeads=12,attnDropOut=0): | |
super().__init__() | |
self.layerNorm = nn.LayerNorm(normalized_shape=embeddingDim) | |
self.multiheadAttn = nn.MultiheadAttention(embed_dim=embeddingDim,num_heads=numHeads,dropout=attnDropOut,batch_first=True) | |
def forward(self,x): | |
layNorm = self.layerNorm(x) | |
attnOutPut, _ = self.multiheadAttn(query=layNorm,key=layNorm,value=layNorm) | |
return attnOutPut | |
class MLPBlock(nn.Module): | |
def __init__(self,embeddingDim,hiddenLayer,dropOut=0.1): | |
super().__init__() | |
self.MLP = nn.Sequential( | |
nn.LayerNorm(normalized_shape = embeddingDim), | |
nn.Linear(embeddingDim, hiddenLayer), | |
nn.GELU(), | |
nn.Dropout(dropOut), | |
nn.Linear(hiddenLayer,embeddingDim), | |
nn.Dropout(dropOut) | |
) | |
def forward(self,x): | |
return self.MLP(x) | |
class transformerEncoderBlock(nn.Module): | |
def __init__(self, embeddingDim, hiddenLayer,numHeads,MLPdropOut,attnDropOut=0): | |
super().__init__() | |
self.MSABlock = multiHeadSelfAttentionBlock(embeddingDim,numHeads,attnDropOut) | |
self.MLPBlock = MLPBlock(embeddingDim,hiddenLayer,MLPdropOut) | |
def forward(self,x): | |
x = self.MSABlock(x) + x | |
x = self.MLPBlock(x) + x | |
return x | |
class patchNPositionalEmbeddingMaker(nn.Module): | |
def __init__(self,inChannels,outChannels,patchSize,imgSize): | |
super().__init__() | |
self.outChannels = outChannels | |
# outChannels is the same as embeddingDim | |
self.patchSize = patchSize | |
self.numPatches = int(imgSize**2/patchSize**2) | |
self.patchMaker = nn.Conv2d(inChannels,outChannels, kernel_size=patchSize,stride=patchSize,padding=0) | |
self.flattener = nn.Flatten(start_dim=2,end_dim=3) | |
self.classEmbedding = nn.Parameter(torch.randn(1,1,self.outChannels),requires_grad=True) | |
self.PositionalEmbedding = nn.Parameter(torch.randn(1,self.numPatches+1,self.outChannels), requires_grad=True) | |
def forward(self,x): | |
batchSize = x.shape[0] | |
imgRes = x.shape[-1] | |
if(imgRes % self.patchSize ==0): | |
pass | |
else: | |
assert imgRes % self.patchSize ==0, 'Input size must be div by patchSize' | |
x = self.patchMaker(x) | |
x = self.flattener(x) | |
x = x.permute(0,2,1) | |
classToken = self.classEmbedding.expand(batchSize,-1,-1) | |
x = torch.cat((classToken,x),dim=1) | |
x = x + self.PositionalEmbedding | |
# batchSize = x.shape[0] | |
# embeddingDim = x.shape[-1] | |
return x | |