Spaces:
Sleeping
Sleeping
File size: 1,828 Bytes
b3fb4dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv1d(in_channels=args.encoder_dim,
out_channels=args.encoder_dim,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=args.encoder_dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
return self.layers(x).transpose(1, 2)
class ConvBlockDecoder(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv1d(in_channels=args.decoder_dim,
out_channels=args.decoder_dim,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=args.decoder_dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
return self.layers(x).transpose(1, 2)
class ResNetLayer(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.conv_layer = nn.Sequential(
nn.Conv1d(in_channels=args.encoder_dim,
out_channels=args.encoder_dim,
kernel_size=3,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=args.encoder_dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv_layer(x)+x
class ResNetBlock(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layers = nn.Sequential(*[ResNetLayer(args) for _ in range(3)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x) |