File size: 1,828 Bytes
b3fb4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
  def __init__(self, args) -> None:
    super().__init__()
    self.layers = nn.Sequential(
        nn.Conv1d(in_channels=args.encoder_dim,
                out_channels=args.encoder_dim,
                kernel_size=args.kernel_size,
                stride=1, padding='same', bias=False),
        nn.BatchNorm1d(num_features=args.encoder_dim),
        nn.SiLU(),
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:  
    x = x.transpose(1, 2)
    return self.layers(x).transpose(1, 2)

class ConvBlockDecoder(nn.Module):
  def __init__(self, args) -> None:
    super().__init__()
    self.layers = nn.Sequential(
        nn.Conv1d(in_channels=args.decoder_dim,
                out_channels=args.decoder_dim,
                kernel_size=args.kernel_size,
                stride=1, padding='same', bias=False),
        nn.BatchNorm1d(num_features=args.decoder_dim),
        nn.SiLU(),
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:  
    x = x.transpose(1, 2)
    return self.layers(x).transpose(1, 2)

class ResNetLayer(nn.Module):
  def __init__(self, args) -> None:
    super().__init__()
    self.conv_layer = nn.Sequential(
        nn.Conv1d(in_channels=args.encoder_dim,
                out_channels=args.encoder_dim,
                kernel_size=3,
                stride=1, padding='same', bias=False),
        nn.BatchNorm1d(num_features=args.encoder_dim),
        nn.SiLU(),
    )
      
  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.conv_layer(x)+x
        
        
class ResNetBlock(nn.Module):
  def __init__(self, args) -> None:
    super().__init__()
    self.layers = nn.Sequential(*[ResNetLayer(args) for _ in range(3)])

  def forward(self, x: torch.Tensor) -> torch.Tensor:  
    return self.layers(x)