import math from dataclasses import dataclass import numpy as np import torch from torch import nn from transformers.modeling_outputs import BaseModelOutput, ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig import json import os import re from typing import Any, Dict, List, Optional, Tuple from transformers.tokenization_utils import PreTrainedTokenizer import phonemizer import uroman as ur import torch.nn.functional as F def has_non_roman_characters(input_string): # Find any character outside the ASCII range non_roman_pattern = re.compile(r"[^\x00-\x7F]") # Search the input string for non-Roman characters match = non_roman_pattern.search(input_string) has_non_roman = match is not None return has_non_roman class VitsConfig(PretrainedConfig): model_type = "vits" def __init__( self, vocab_size=38, hidden_size=192, num_hidden_layers=6, num_attention_heads=2, window_size=4, use_bias=True, ffn_dim=768, layerdrop=0.1, ffn_kernel_size=3, flow_size=192, spectrogram_bins=513, # hidden_act="relu", hidden_dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, initializer_range=0.02, layer_norm_eps=1e-5, use_stochastic_duration_prediction=True, num_speakers=1, speaker_embedding_size=0, upsample_initial_channel=512, upsample_rates=[8, 8, 2, 2], upsample_kernel_sizes=[16, 16, 4, 4], resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], leaky_relu_slope=0.1, depth_separable_channels=2, depth_separable_num_layers=3, duration_predictor_flow_bins=10, duration_predictor_tail_bound=5.0, duration_predictor_kernel_size=3, duration_predictor_dropout=0.5, duration_predictor_num_flows=4, duration_predictor_filter_channels=256, prior_encoder_num_flows=4, prior_encoder_num_wavenet_layers=4, posterior_encoder_num_wavenet_layers=16, wavenet_kernel_size=5, wavenet_dilation_rate=1, wavenet_dropout=0.0, speaking_rate=1.0, # unused noise_scale=0.667, noise_scale_duration=0.8, sampling_rate=16_000, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.window_size = window_size self.use_bias = use_bias self.ffn_dim = ffn_dim self.layerdrop = layerdrop self.ffn_kernel_size = ffn_kernel_size self.flow_size = flow_size self.spectrogram_bins = spectrogram_bins self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps # self.use_stochastic_duration_prediction = use_stochastic_duration_prediction self.num_speakers = num_speakers self.speaker_embedding_size = speaker_embedding_size self.upsample_initial_channel = upsample_initial_channel self.upsample_rates = upsample_rates self.upsample_kernel_sizes = upsample_kernel_sizes self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.leaky_relu_slope = leaky_relu_slope self.depth_separable_channels = depth_separable_channels self.depth_separable_num_layers = depth_separable_num_layers self.duration_predictor_flow_bins = duration_predictor_flow_bins self.duration_predictor_tail_bound = duration_predictor_tail_bound self.duration_predictor_kernel_size = duration_predictor_kernel_size self.duration_predictor_num_flows = duration_predictor_num_flows self.duration_predictor_filter_channels = duration_predictor_filter_channels self.prior_encoder_num_flows = prior_encoder_num_flows self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers self.wavenet_kernel_size = wavenet_kernel_size self.wavenet_dilation_rate = wavenet_dilation_rate self.noise_scale = noise_scale self.noise_scale_duration = noise_scale_duration self.sampling_rate = sampling_rate if len(upsample_kernel_sizes) != len(upsample_rates): raise ValueError( f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of " f"`upsample_rates` ({len(upsample_rates)})" ) super().__init__(**kwargs) @dataclass class VitsTextEncoderOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None prior_means: torch.FloatTensor = None prior_log_variances: torch.FloatTensor = None hidden_states: torch.FloatTensor = None attentions: torch.FloatTensor = None class VitsWaveNet(torch.nn.Module): def __init__(self, config, num_layers): super().__init__() self.hidden_size = config.hidden_size self.num_layers = num_layers self.in_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList() # if hasattr(nn.utils.parametrizations, "weight_norm"): # # raise ValueError weight_norm = nn.utils.parametrizations.weight_norm # else: # raise ValueError # # weight_norm = nn.utils.weight_norm for i in range(num_layers): dilation = config.wavenet_dilation_rate**i padding = (config.wavenet_kernel_size * dilation - dilation) // 2 in_layer = torch.nn.Conv1d( in_channels=config.hidden_size, out_channels=2 * config.hidden_size, kernel_size=config.wavenet_kernel_size, dilation=dilation, padding=padding, ) in_layer = weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary if i < num_layers - 1: res_skip_channels = 2 * config.hidden_size else: res_skip_channels = config.hidden_size res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) res_skip_layer = weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def forward(self, inputs): outputs = torch.zeros_like(inputs) num_channels = torch.IntTensor([self.hidden_size])[0] for i in range(self.num_layers): in_act = self.in_layers[i](inputs) # global_states = torch.zeros_like(hidden_states) # style ? # acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) # -- # def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): # in_act = input_a # + input_b t_act = torch.tanh(in_act[:, :num_channels, :]) s_act = torch.sigmoid(in_act[:, num_channels:, :]) acts = t_act * s_act res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: res_acts = res_skip_acts[:, : self.hidden_size, :] inputs = inputs + res_acts outputs = outputs + res_skip_acts[:, self.hidden_size :, :] else: outputs = outputs + res_skip_acts return outputs # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock class HifiGanResidualBlock(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): super().__init__() self.leaky_relu_slope = leaky_relu_slope self.convs1 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=dilation[i], padding=self.get_padding(kernel_size, dilation[i]), ) for i in range(len(dilation)) ] ) self.convs2 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=self.get_padding(kernel_size, 1), ) for _ in range(len(dilation)) ] ) def get_padding(self, kernel_size, dilation=1): return (kernel_size * dilation - dilation) // 2 def forward(self, hidden_states): for conv1, conv2 in zip(self.convs1, self.convs2): residual = hidden_states hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv1(hidden_states) hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv2(hidden_states) hidden_states = hidden_states + residual return hidden_states class VitsHifiGan(nn.Module): def __init__(self, config): super().__init__() self.config = config self.num_kernels = len(config.resblock_kernel_sizes) self.num_upsamples = len(config.upsample_rates) self.conv_pre = nn.Conv1d( config.flow_size, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3, ) self.upsampler = nn.ModuleList() for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): self.upsampler.append( nn.ConvTranspose1d( config.upsample_initial_channel // (2**i), config.upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel_size, stride=upsample_rate, padding=(kernel_size - upsample_rate) // 2, ) ) self.resblocks = nn.ModuleList() for i in range(len(self.upsampler)): channels = config.upsample_initial_channel // (2 ** (i + 1)) for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) def forward(self, spectrogram): hidden_states = self.conv_pre(spectrogram) for i in range(self.num_upsamples): hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) hidden_states = self.upsampler[i](hidden_states) res_state = self.resblocks[i * self.num_kernels](hidden_states) for j in range(1, self.num_kernels): res_state += self.resblocks[i * self.num_kernels + j](hidden_states) hidden_states = res_state / self.num_kernels hidden_states = nn.functional.leaky_relu(hidden_states) hidden_states = self.conv_post(hidden_states) waveform = torch.tanh(hidden_states) return waveform class VitsResidualCouplingLayer(nn.Module): def __init__(self, config): super().__init__() self.half_channels = config.flow_size // 2 self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) def forward(self, x, reverse=False): first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1) hidden_states = self.conv_pre(first_half) hidden_states = self.wavenet(hidden_states) mean = self.conv_post(hidden_states) second_half = (second_half - mean) outputs = torch.cat([first_half, second_half], dim=1) return outputs class VitsResidualCouplingBlock(nn.Module): def __init__(self, config): super().__init__() self.flows = nn.ModuleList() for _ in range(config.prior_encoder_num_flows): self.flows.append(VitsResidualCouplingLayer(config)) def forward(self, x, reverse=False): # x L [1, 192, 481] for flow in reversed(self.flows): x = torch.flip(x, [1]) # flipud CHANNELs x = flow(x, reverse=True) return x class VitsAttention(nn.Module): """has no positional info""" def __init__(self, config): super().__init__() self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.window_size = config.window_size self.head_dim = self.embed_dim // self.num_heads self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) def _shape(self, tensor, seq_len, bsz): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states, layer_head_mask = None, output_attentions = False, ): bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.bmm(attn_weights, value_states) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None #attn_weights_reshaped class VitsFeedForward(nn.Module): def __init__(self, config): super().__init__() self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size) self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size) self.act_fn = nn.ReLU() if config.ffn_kernel_size > 1: pad_left = (config.ffn_kernel_size - 1) // 2 pad_right = config.ffn_kernel_size // 2 self.padding = [pad_left, pad_right, 0, 0, 0, 0] else: self.padding = None def forward(self, hidden_states): hidden_states = hidden_states.permute(0, 2, 1) if self.padding is not None: hidden_states = nn.functional.pad(hidden_states, self.padding) hidden_states = self.conv_1(hidden_states) hidden_states = self.act_fn(hidden_states) if self.padding is not None: hidden_states = nn.functional.pad(hidden_states, self.padding) hidden_states = self.conv_2(hidden_states) hidden_states = hidden_states.permute(0, 2, 1) return hidden_states class VitsEncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = VitsAttention(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = VitsFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states, output_attentions = False, ): residual = hidden_states hidden_states, attn_weights = self.attention( hidden_states=hidden_states, # attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = self.layer_norm(residual + hidden_states) residual = hidden_states hidden_states = self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(residual + hidden_states) outputs = (hidden_states,) return outputs class VitsEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.layerdrop = config.layerdrop def forward( self, hidden_states, output_attentions = None, output_hidden_states = None, return_dict = None, ): for _layer in self.layers: layer_outputs = _layer(hidden_states) hidden_states = layer_outputs[0] return BaseModelOutput( last_hidden_state=hidden_states, # hidden_states=all_hidden_states, # attentions=all_self_attentions, ) class VitsTextEncoder(nn.Module): """ Has VitsEncoder """ def __init__(self, config): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1) def forward(self, input_ids ): hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size) last_hidden_state = self.encoder(hidden_states=hidden_states).last_hidden_state stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2) return VitsTextEncoderOutput( last_hidden_state=last_hidden_state, prior_means=prior_means, # prior_log_variances=prior_log_variances, # hidden_states=encoder_outputs.hidden_states, # attentions=encoder_outputs.attentions, ) class VitsPreTrainedModel(PreTrainedModel): config_class = VitsConfig base_model_prefix = "vits" main_input_name = "input_ids" supports_gradient_checkpointing = True class VitsModel(VitsPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention self.flow = VitsResidualCouplingBlock(config) self.decoder = VitsHifiGan(config) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, speed = None, lang_code = 'deu', # speed oscillation pattern per voice/lang ): mask_dtype = self.text_encoder.embed_tokens.weight.dtype if attention_mask is not None: input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype) else: input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype) out = self.text_encoder(input_ids=input_ids) hidden_states = out.last_hidden_state.transpose(1, 2) input_padding_mask = input_padding_mask.transpose(1, 2) prior_means = out.prior_means bs, _, in_len = hidden_states.shape # VITS Duration Oscillation if lang_code == 'deu': pattern = [1, 2, 1] # each voice (lang_code) sounds cooler with different pattern elif lang_code == 'rmc-script_latin': pattern = [2, 2, 1, 2, 2] # [2, 2, 2, 1, 2] elif lang_code == 'hun': # pattern = [1, 2, 2, 1, 1, 1] #sounds cool / has valley-pause pattern = [1, 2, 1, 1, 1] else: pattern = [1, 2, 1] duration = torch.tensor(pattern, device=hidden_states.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language duration[:, :, 0] = 4 duration[:, :, -1] = 3 # ATTN predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) batch_size, _, output_length, input_length = attn_mask.shape cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) valid_indices = indices.unsqueeze(0) < cum_duration valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask attn = attn[:, 0, :, :] attn = attn + 1e-4 * torch.rand_like(attn) attn /= attn.sum(2, keepdims=True) #print(attn) prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means #prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed # prior means have now been replicated x duration of each prior mean latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94, reverse=True) waveform = self.decoder(latents) # [bs, 1, 16000] return waveform[:, 0, :] class VitsTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "vocab.json"} model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, pad_token="", unk_token="", language=None, add_blank=True, normalize=True, phonemize=True, is_uroman=False, **kwargs, ) -> None: with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} self.language = language self.add_blank = add_blank self.normalize = normalize self.phonemize = phonemize self.is_uroman = is_uroman super().__init__( pad_token=pad_token, unk_token=unk_token, language=language, add_blank=add_blank, normalize=normalize, phonemize=phonemize, is_uroman=is_uroman, **kwargs, ) @property def vocab_size(self): return len(self.encoder) def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def normalize_text(self, input_string): """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased.""" all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) filtered_text = "" i = 0 while i < len(input_string): found_match = False for word in all_vocabulary: if input_string[i : i + len(word)] == word: filtered_text += word i += len(word) found_match = True break if not found_match: filtered_text += input_string[i].lower() i += 1 return filtered_text def _preprocess_char(self, text): """Special treatment of characters in certain languages""" if self.language == "ron": text = text.replace("ț", "ţ") return text def prepare_for_tokenization( self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs): normalize = normalize if normalize is not None else self.normalize if normalize: # normalise for casing text = self.normalize_text(text) filtered_text = self._preprocess_char(text) if has_non_roman_characters(filtered_text) and self.is_uroman: if not is_uroman_available(): print( "Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing " "step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` " "Note `uroman` requires python version >= 3.10" "Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman" ) else: uroman = ur.Uroman() filtered_text = uroman.romanize_string(filtered_text) if self.phonemize: if not is_phonemizer_available(): raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.") filtered_text = phonemizer.phonemize( filtered_text, language="en-us", backend="espeak", strip=True, preserve_punctuation=True, with_stress=True, ) filtered_text = re.sub(r"\s+", " ", filtered_text) elif normalize: # strip any chars outside of the vocab (punctuation) filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip() return filtered_text, kwargs def _tokenize(self, text: str) -> List[str]: """Tokenize a string by inserting the `` token at the boundary between adjacent characters.""" tokens = list(text) if self.add_blank: # sounds dyslexi if no space between letters # sounds disconnected if >2 spaces between letters interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd interspersed[::2] = tokens tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd) return tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.decoder.get(index)