File size: 2,503 Bytes
9bb0389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from torch import nn
import torch

class multiHeadSelfAttentionBlock(nn.Module):
  def __init__(self,embeddingDim=768,numHeads=12,attnDropOut=0):
    super().__init__()
    self.layerNorm = nn.LayerNorm(normalized_shape=embeddingDim)
    self.multiheadAttn = nn.MultiheadAttention(embed_dim=embeddingDim,num_heads=numHeads,dropout=attnDropOut,batch_first=True)

  def forward(self,x):
    layNorm = self.layerNorm(x)
    attnOutPut, _ = self.multiheadAttn(query=layNorm,key=layNorm,value=layNorm)
    return attnOutPut

class MLPBlock(nn.Module):
  def __init__(self,embeddingDim,hiddenLayer,dropOut=0.1):
    super().__init__()
    self.MLP = nn.Sequential(
        nn.LayerNorm(normalized_shape = embeddingDim),
        nn.Linear(embeddingDim, hiddenLayer),
        nn.GELU(),
        nn.Dropout(dropOut),
        nn.Linear(hiddenLayer,embeddingDim),
        nn.Dropout(dropOut)
    )

  def forward(self,x):
    return self.MLP(x)

class transformerEncoderBlock(nn.Module):
  def __init__(self, embeddingDim, hiddenLayer,numHeads,MLPdropOut,attnDropOut=0):
    super().__init__()
    self.MSABlock = multiHeadSelfAttentionBlock(embeddingDim,numHeads,attnDropOut)
    self.MLPBlock = MLPBlock(embeddingDim,hiddenLayer,MLPdropOut)

  def forward(self,x):
    x = self.MSABlock(x) + x
    x = self.MLPBlock(x) + x
    return x

class patchNPositionalEmbeddingMaker(nn.Module):
  def __init__(self,inChannels,outChannels,patchSize,imgSize):
    super().__init__()
    self.outChannels = outChannels

    # outChannels is the same as embeddingDim
    self.patchSize = patchSize
    self.numPatches = int(imgSize**2/patchSize**2)
    self.patchMaker = nn.Conv2d(inChannels,outChannels, kernel_size=patchSize,stride=patchSize,padding=0)
    self.flattener  = nn.Flatten(start_dim=2,end_dim=3)
    self.classEmbedding = nn.Parameter(torch.randn(1,1,self.outChannels),requires_grad=True)
    self.PositionalEmbedding = nn.Parameter(torch.randn(1,self.numPatches+1,self.outChannels), requires_grad=True)

  def forward(self,x):
    batchSize = x.shape[0]
    imgRes = x.shape[-1]
    if(imgRes % self.patchSize ==0):
      pass
    else:
      assert imgRes % self.patchSize ==0, 'Input size must be div by patchSize'
    x = self.patchMaker(x)
    x = self.flattener(x)
    x = x.permute(0,2,1)
    classToken = self.classEmbedding.expand(batchSize,-1,-1)
    x = torch.cat((classToken,x),dim=1)
    x = x + self.PositionalEmbedding
    # batchSize = x.shape[0]
    # embeddingDim = x.shape[-1]
    return x