|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
from sparktts.modules.blocks.layers import ( |
|
Snake1d, |
|
WNConv1d, |
|
ResidualUnit, |
|
WNConvTranspose1d, |
|
init_weights, |
|
) |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int = 16, |
|
output_dim: int = 8, |
|
kernel_size: int = 2, |
|
stride: int = 1, |
|
): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
Snake1d(input_dim), |
|
WNConvTranspose1d( |
|
input_dim, |
|
output_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=(kernel_size - stride) // 2, |
|
), |
|
ResidualUnit(output_dim, dilation=1), |
|
ResidualUnit(output_dim, dilation=3), |
|
ResidualUnit(output_dim, dilation=9), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class WaveGenerator(nn.Module): |
|
def __init__( |
|
self, |
|
input_channel, |
|
channels, |
|
rates, |
|
kernel_sizes, |
|
d_out: int = 1, |
|
): |
|
super().__init__() |
|
|
|
|
|
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] |
|
|
|
|
|
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): |
|
input_dim = channels // 2**i |
|
output_dim = channels // 2 ** (i + 1) |
|
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] |
|
|
|
|
|
layers += [ |
|
Snake1d(output_dim), |
|
WNConv1d(output_dim, d_out, kernel_size=7, padding=3), |
|
nn.Tanh(), |
|
] |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
self.apply(init_weights) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|