Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,503 Bytes
			
			| 9bb0389 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | from torch import nn
import torch
class multiHeadSelfAttentionBlock(nn.Module):
  def __init__(self,embeddingDim=768,numHeads=12,attnDropOut=0):
    super().__init__()
    self.layerNorm = nn.LayerNorm(normalized_shape=embeddingDim)
    self.multiheadAttn = nn.MultiheadAttention(embed_dim=embeddingDim,num_heads=numHeads,dropout=attnDropOut,batch_first=True)
  def forward(self,x):
    layNorm = self.layerNorm(x)
    attnOutPut, _ = self.multiheadAttn(query=layNorm,key=layNorm,value=layNorm)
    return attnOutPut
class MLPBlock(nn.Module):
  def __init__(self,embeddingDim,hiddenLayer,dropOut=0.1):
    super().__init__()
    self.MLP = nn.Sequential(
        nn.LayerNorm(normalized_shape = embeddingDim),
        nn.Linear(embeddingDim, hiddenLayer),
        nn.GELU(),
        nn.Dropout(dropOut),
        nn.Linear(hiddenLayer,embeddingDim),
        nn.Dropout(dropOut)
    )
  def forward(self,x):
    return self.MLP(x)
class transformerEncoderBlock(nn.Module):
  def __init__(self, embeddingDim, hiddenLayer,numHeads,MLPdropOut,attnDropOut=0):
    super().__init__()
    self.MSABlock = multiHeadSelfAttentionBlock(embeddingDim,numHeads,attnDropOut)
    self.MLPBlock = MLPBlock(embeddingDim,hiddenLayer,MLPdropOut)
  def forward(self,x):
    x = self.MSABlock(x) + x
    x = self.MLPBlock(x) + x
    return x
class patchNPositionalEmbeddingMaker(nn.Module):
  def __init__(self,inChannels,outChannels,patchSize,imgSize):
    super().__init__()
    self.outChannels = outChannels
    # outChannels is the same as embeddingDim
    self.patchSize = patchSize
    self.numPatches = int(imgSize**2/patchSize**2)
    self.patchMaker = nn.Conv2d(inChannels,outChannels, kernel_size=patchSize,stride=patchSize,padding=0)
    self.flattener  = nn.Flatten(start_dim=2,end_dim=3)
    self.classEmbedding = nn.Parameter(torch.randn(1,1,self.outChannels),requires_grad=True)
    self.PositionalEmbedding = nn.Parameter(torch.randn(1,self.numPatches+1,self.outChannels), requires_grad=True)
  def forward(self,x):
    batchSize = x.shape[0]
    imgRes = x.shape[-1]
    if(imgRes % self.patchSize ==0):
      pass
    else:
      assert imgRes % self.patchSize ==0, 'Input size must be div by patchSize'
    x = self.patchMaker(x)
    x = self.flattener(x)
    x = x.permute(0,2,1)
    classToken = self.classEmbedding.expand(batchSize,-1,-1)
    x = torch.cat((classToken,x),dim=1)
    x = x + self.PositionalEmbedding
    # batchSize = x.shape[0]
    # embeddingDim = x.shape[-1]
    return x
 |