Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import Conv1d, ConvTranspose1d | |
from torch.nn.utils import remove_weight_norm, weight_norm | |
LRELU_SLOPE = 0.1 | |
def init_weights(m, mean=0.0, std=0.01): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(mean, std) | |
def get_padding(kernel_size, dilation=1): | |
return (kernel_size * dilation - dilation) // 2 | |
class ResBlock(torch.nn.Module): | |
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): | |
super(ResBlock, self).__init__() | |
self.convs1 = nn.ModuleList( | |
[ | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[0], | |
padding=get_padding(kernel_size, dilation[0]), | |
) | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[1], | |
padding=get_padding(kernel_size, dilation[1]), | |
) | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[2], | |
padding=get_padding(kernel_size, dilation[2]), | |
) | |
), | |
] | |
) | |
self.convs1.apply(init_weights) | |
self.convs2 = nn.ModuleList( | |
[ | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
) | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
) | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
) | |
), | |
] | |
) | |
self.convs2.apply(init_weights) | |
def forward(self, x): | |
for c1, c2 in zip(self.convs1, self.convs2): | |
xt = F.leaky_relu(x, LRELU_SLOPE) | |
xt = c1(xt) | |
xt = F.leaky_relu(xt, LRELU_SLOPE) | |
xt = c2(xt) | |
x = xt + x | |
return x | |
def remove_weight_norm(self): | |
for layer in self.convs1: | |
remove_weight_norm(layer) | |
for layer in self.convs2: | |
remove_weight_norm(layer) | |
class Generator(torch.nn.Module): | |
def __init__(self, cfg): | |
super(Generator, self).__init__() | |
self.num_kernels = len(cfg["resblock_kernel_sizes"]) | |
self.num_upsamples = len(cfg["upsample_rates"]) | |
self.conv_pre = weight_norm( | |
Conv1d( | |
cfg.get("model_in_dim", 80), | |
cfg["upsample_initial_channel"], | |
7, | |
1, | |
padding=3, | |
) | |
) | |
self.ups = nn.ModuleList() | |
for i, (u, k) in enumerate( | |
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) | |
): | |
self.ups.append( | |
weight_norm( | |
ConvTranspose1d( | |
cfg["upsample_initial_channel"] // (2**i), | |
cfg["upsample_initial_channel"] // (2 ** (i + 1)), | |
k, | |
u, | |
padding=(k - u) // 2, | |
) | |
) | |
) | |
self.resblocks = nn.ModuleList() | |
for i in range(len(self.ups)): | |
ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) | |
for k, d in zip( | |
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] | |
): | |
self.resblocks.append(ResBlock(ch, k, d)) | |
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) | |
self.ups.apply(init_weights) | |
self.conv_post.apply(init_weights) | |
def forward(self, x): | |
x = self.conv_pre(x) | |
for i in range(self.num_upsamples): | |
x = F.leaky_relu(x, LRELU_SLOPE) | |
x = self.ups[i](x) | |
xs = None | |
for j in range(self.num_kernels): | |
if xs is None: | |
xs = self.resblocks[i * self.num_kernels + j](x) | |
else: | |
xs += self.resblocks[i * self.num_kernels + j](x) | |
x = xs / self.num_kernels | |
x = F.leaky_relu(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x | |
def remove_weight_norm(self): | |
print("Removing weight norm...") | |
for layer in self.ups: | |
remove_weight_norm(layer) | |
for layer in self.resblocks: | |
layer.remove_weight_norm() | |
remove_weight_norm(self.conv_pre) | |
remove_weight_norm(self.conv_post) | |