Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from typing import Optional | |
| from .vits_config import VitsConfig | |
| #............................................. | |
| def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): | |
| in_act = input_a + input_b | |
| t_act = torch.tanh(in_act[:, :num_channels, :]) | |
| s_act = torch.sigmoid(in_act[:, num_channels:, :]) | |
| acts = t_act * s_act | |
| return acts | |
| #............................................. | |
| class VitsWaveNet(torch.nn.Module): | |
| def __init__(self, config: VitsConfig, num_layers: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_layers = num_layers | |
| self.speaker_embedding_size = config.speaker_embedding_size | |
| self.in_layers = torch.nn.ModuleList() | |
| self.res_skip_layers = torch.nn.ModuleList() | |
| self.dropout = nn.Dropout(config.wavenet_dropout) | |
| if hasattr(nn.utils.parametrizations, "weight_norm"): | |
| weight_norm = nn.utils.parametrizations.weight_norm | |
| else: | |
| weight_norm = nn.utils.weight_norm | |
| if config.speaker_embedding_size != 0: | |
| cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) | |
| self.cond_layer = weight_norm(cond_layer, name="weight") | |
| for i in range(num_layers): | |
| dilation = config.wavenet_dilation_rate**i | |
| padding = (config.wavenet_kernel_size * dilation - dilation) // 2 | |
| in_layer = torch.nn.Conv1d( | |
| in_channels=config.hidden_size, | |
| out_channels=2 * config.hidden_size, | |
| kernel_size=config.wavenet_kernel_size, | |
| dilation=dilation, | |
| padding=padding, | |
| ) | |
| in_layer = weight_norm(in_layer, name="weight") | |
| self.in_layers.append(in_layer) | |
| # last one is not necessary | |
| if i < num_layers - 1: | |
| res_skip_channels = 2 * config.hidden_size | |
| else: | |
| res_skip_channels = config.hidden_size | |
| res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) | |
| res_skip_layer = weight_norm(res_skip_layer, name="weight") | |
| self.res_skip_layers.append(res_skip_layer) | |
| def forward(self, inputs, padding_mask, global_conditioning=None): | |
| outputs = torch.zeros_like(inputs) | |
| num_channels_tensor = torch.IntTensor([self.hidden_size]) | |
| if global_conditioning is not None: | |
| global_conditioning = self.cond_layer(global_conditioning) | |
| for i in range(self.num_layers): | |
| hidden_states = self.in_layers[i](inputs) | |
| if global_conditioning is not None: | |
| cond_offset = i * 2 * self.hidden_size | |
| global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :] | |
| else: | |
| global_states = torch.zeros_like(hidden_states) | |
| acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) | |
| acts = self.dropout(acts) | |
| res_skip_acts = self.res_skip_layers[i](acts) | |
| if i < self.num_layers - 1: | |
| res_acts = res_skip_acts[:, : self.hidden_size, :] | |
| inputs = (inputs + res_acts) * padding_mask | |
| outputs = outputs + res_skip_acts[:, self.hidden_size :, :] | |
| else: | |
| outputs = outputs + res_skip_acts | |
| return outputs * padding_mask | |
| def remove_weight_norm(self): | |
| if self.speaker_embedding_size != 0: | |
| torch.nn.utils.remove_weight_norm(self.cond_layer) | |
| for layer in self.in_layers: | |
| torch.nn.utils.remove_weight_norm(layer) | |
| for layer in self.res_skip_layers: | |
| torch.nn.utils.remove_weight_norm(layer) | |
| def apply_weight_norm(self): | |
| if hasattr(nn.utils.parametrizations, "weight_norm"): | |
| weight_norm = nn.utils.parametrizations.weight_norm | |
| else: | |
| weight_norm = nn.utils.weight_norm | |
| if self.speaker_embedding_size != 0: | |
| weight_norm(self.cond_layer) | |
| for layer in self.in_layers: | |
| weight_norm(layer) | |
| for layer in self.res_skip_layers: | |
| weight_norm(layer) | |
| #............................................................................................. | |
| class VitsResidualCouplingLayer(nn.Module): | |
| def __init__(self, config: VitsConfig): | |
| super().__init__() | |
| self.half_channels = config.flow_size // 2 | |
| self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) | |
| self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) | |
| self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) | |
| def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): | |
| first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) | |
| hidden_states = self.conv_pre(first_half) * padding_mask | |
| hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning) | |
| mean = self.conv_post(hidden_states) * padding_mask | |
| log_stddev = torch.zeros_like(mean) | |
| if not reverse: | |
| second_half = mean + second_half * torch.exp(log_stddev) * padding_mask | |
| outputs = torch.cat([first_half, second_half], dim=1) | |
| log_determinant = torch.sum(log_stddev, [1, 2]) | |
| return outputs, log_determinant | |
| else: | |
| second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask | |
| outputs = torch.cat([first_half, second_half], dim=1) | |
| return outputs, None | |
| def apply_weight_norm(self): | |
| nn.utils.weight_norm(self.conv_pre) | |
| self.wavenet.apply_weight_norm() | |
| nn.utils.weight_norm(self.conv_post) | |
| def remove_weight_norm(self): | |
| nn.utils.remove_weight_norm(self.conv_pre) | |
| self.wavenet.remove_weight_norm() | |
| nn.utils.remove_weight_norm(self.conv_post) | |
| #............................................................................................. | |
| class VitsResidualCouplingBlock(nn.Module): | |
| def __init__(self, config: VitsConfig): | |
| super().__init__() | |
| self.flows = nn.ModuleList() | |
| for _ in range(config.prior_encoder_num_flows): | |
| self.flows.append(VitsResidualCouplingLayer(config)) | |
| def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): | |
| if not reverse: | |
| for flow in self.flows: | |
| inputs, _ = flow(inputs, padding_mask, global_conditioning) | |
| inputs = torch.flip(inputs, [1]) | |
| else: | |
| for flow in reversed(self.flows): | |
| inputs = torch.flip(inputs, [1]) | |
| inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True) | |
| return inputs | |
| def apply_weight_norm(self): | |
| for flow in self.flows: | |
| flow.apply_weight_norm() | |
| def remove_weight_norm(self): | |
| for flow in self.flows: | |
| flow.remove_weight_norm() | |
| def resize_speaker_embeddings(self, speaker_embedding_size: Optional[int] = None): | |
| for flow in self.flows: | |
| flow.wavenet.speaker_embedding_size = speaker_embedding_size | |
| hidden_size = flow.wavenet.hidden_size | |
| num_layers = flow.wavenet.num_layers | |
| cond_layer = torch.nn.Conv1d(speaker_embedding_size, 2 * hidden_size * num_layers, 1) | |
| flow.wavenet.cond_layer = nn.utils.weight_norm(cond_layer, name="weight") | |