|
import math |
|
from dataclasses import dataclass |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from transformers.modeling_outputs import BaseModelOutput, ModelOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.configuration_utils import PretrainedConfig |
|
import json |
|
import os |
|
import re |
|
from typing import Any, Dict, List, Optional, Tuple |
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
import phonemizer |
|
import uroman as ur |
|
import torch.nn.functional as F |
|
|
|
|
|
def has_non_roman_characters(input_string): |
|
|
|
non_roman_pattern = re.compile(r"[^\x00-\x7F]") |
|
|
|
|
|
match = non_roman_pattern.search(input_string) |
|
has_non_roman = match is not None |
|
return has_non_roman |
|
|
|
|
|
class VitsConfig(PretrainedConfig): |
|
|
|
model_type = "vits" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=38, |
|
hidden_size=192, |
|
num_hidden_layers=6, |
|
num_attention_heads=2, |
|
window_size=4, |
|
use_bias=True, |
|
ffn_dim=768, |
|
layerdrop=0.1, |
|
ffn_kernel_size=3, |
|
flow_size=192, |
|
spectrogram_bins=513, |
|
|
|
hidden_dropout=0.1, |
|
attention_dropout=0.1, |
|
activation_dropout=0.1, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-5, |
|
use_stochastic_duration_prediction=True, |
|
num_speakers=1, |
|
speaker_embedding_size=0, |
|
upsample_initial_channel=512, |
|
upsample_rates=[8, 8, 2, 2], |
|
upsample_kernel_sizes=[16, 16, 4, 4], |
|
resblock_kernel_sizes=[3, 7, 11], |
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
|
leaky_relu_slope=0.1, |
|
depth_separable_channels=2, |
|
depth_separable_num_layers=3, |
|
duration_predictor_flow_bins=10, |
|
duration_predictor_tail_bound=5.0, |
|
duration_predictor_kernel_size=3, |
|
duration_predictor_dropout=0.5, |
|
duration_predictor_num_flows=4, |
|
duration_predictor_filter_channels=256, |
|
prior_encoder_num_flows=4, |
|
prior_encoder_num_wavenet_layers=4, |
|
posterior_encoder_num_wavenet_layers=16, |
|
wavenet_kernel_size=5, |
|
wavenet_dilation_rate=1, |
|
wavenet_dropout=0.0, |
|
speaking_rate=1.0, |
|
noise_scale=0.667, |
|
noise_scale_duration=0.8, |
|
sampling_rate=16_000, |
|
**kwargs, |
|
): |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.window_size = window_size |
|
self.use_bias = use_bias |
|
self.ffn_dim = ffn_dim |
|
self.layerdrop = layerdrop |
|
self.ffn_kernel_size = ffn_kernel_size |
|
self.flow_size = flow_size |
|
self.spectrogram_bins = spectrogram_bins |
|
self.initializer_range = initializer_range |
|
self.layer_norm_eps = layer_norm_eps |
|
|
|
self.num_speakers = num_speakers |
|
self.speaker_embedding_size = speaker_embedding_size |
|
self.upsample_initial_channel = upsample_initial_channel |
|
self.upsample_rates = upsample_rates |
|
self.upsample_kernel_sizes = upsample_kernel_sizes |
|
self.resblock_kernel_sizes = resblock_kernel_sizes |
|
self.resblock_dilation_sizes = resblock_dilation_sizes |
|
self.leaky_relu_slope = leaky_relu_slope |
|
self.depth_separable_channels = depth_separable_channels |
|
self.depth_separable_num_layers = depth_separable_num_layers |
|
self.duration_predictor_flow_bins = duration_predictor_flow_bins |
|
self.duration_predictor_tail_bound = duration_predictor_tail_bound |
|
self.duration_predictor_kernel_size = duration_predictor_kernel_size |
|
self.duration_predictor_num_flows = duration_predictor_num_flows |
|
self.duration_predictor_filter_channels = duration_predictor_filter_channels |
|
self.prior_encoder_num_flows = prior_encoder_num_flows |
|
self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers |
|
self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers |
|
self.wavenet_kernel_size = wavenet_kernel_size |
|
self.wavenet_dilation_rate = wavenet_dilation_rate |
|
self.noise_scale = noise_scale |
|
self.noise_scale_duration = noise_scale_duration |
|
self.sampling_rate = sampling_rate |
|
|
|
if len(upsample_kernel_sizes) != len(upsample_rates): |
|
raise ValueError( |
|
f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of " |
|
f"`upsample_rates` ({len(upsample_rates)})" |
|
) |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
@dataclass |
|
class VitsTextEncoderOutput(ModelOutput): |
|
last_hidden_state: torch.FloatTensor = None |
|
prior_means: torch.FloatTensor = None |
|
prior_log_variances: torch.FloatTensor = None |
|
hidden_states: torch.FloatTensor = None |
|
attentions: torch.FloatTensor = None |
|
|
|
|
|
|
|
class VitsWaveNet(torch.nn.Module): |
|
def __init__(self, config, num_layers): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.num_layers = num_layers |
|
self.in_layers = torch.nn.ModuleList() |
|
self.res_skip_layers = torch.nn.ModuleList() |
|
|
|
|
|
weight_norm = nn.utils.parametrizations.weight_norm |
|
|
|
|
|
|
|
for i in range(num_layers): |
|
dilation = config.wavenet_dilation_rate**i |
|
padding = (config.wavenet_kernel_size * dilation - dilation) // 2 |
|
in_layer = torch.nn.Conv1d( |
|
in_channels=config.hidden_size, |
|
out_channels=2 * config.hidden_size, |
|
kernel_size=config.wavenet_kernel_size, |
|
dilation=dilation, |
|
padding=padding, |
|
) |
|
in_layer = weight_norm(in_layer, name="weight") |
|
self.in_layers.append(in_layer) |
|
|
|
|
|
if i < num_layers - 1: |
|
res_skip_channels = 2 * config.hidden_size |
|
else: |
|
res_skip_channels = config.hidden_size |
|
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) |
|
res_skip_layer = weight_norm(res_skip_layer, name="weight") |
|
self.res_skip_layers.append(res_skip_layer) |
|
|
|
def forward(self, |
|
inputs): |
|
outputs = torch.zeros_like(inputs) |
|
num_channels = torch.IntTensor([self.hidden_size])[0] |
|
for i in range(self.num_layers): |
|
in_act = self.in_layers[i](inputs) |
|
|
|
|
|
|
|
|
|
|
|
t_act = torch.tanh(in_act[:, :num_channels, :]) |
|
s_act = torch.sigmoid(in_act[:, num_channels:, :]) |
|
acts = t_act * s_act |
|
res_skip_acts = self.res_skip_layers[i](acts) |
|
if i < self.num_layers - 1: |
|
res_acts = res_skip_acts[:, : self.hidden_size, :] |
|
inputs = inputs + res_acts |
|
outputs = outputs + res_skip_acts[:, self.hidden_size :, :] |
|
else: |
|
outputs = outputs + res_skip_acts |
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HifiGanResidualBlock(nn.Module): |
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): |
|
super().__init__() |
|
self.leaky_relu_slope = leaky_relu_slope |
|
|
|
self.convs1 = nn.ModuleList( |
|
[ |
|
nn.Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
stride=1, |
|
dilation=dilation[i], |
|
padding=self.get_padding(kernel_size, dilation[i]), |
|
) |
|
for i in range(len(dilation)) |
|
] |
|
) |
|
self.convs2 = nn.ModuleList( |
|
[ |
|
nn.Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
stride=1, |
|
dilation=1, |
|
padding=self.get_padding(kernel_size, 1), |
|
) |
|
for _ in range(len(dilation)) |
|
] |
|
) |
|
|
|
def get_padding(self, kernel_size, dilation=1): |
|
return (kernel_size * dilation - dilation) // 2 |
|
|
|
def forward(self, hidden_states): |
|
for conv1, conv2 in zip(self.convs1, self.convs2): |
|
residual = hidden_states |
|
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) |
|
hidden_states = conv1(hidden_states) |
|
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) |
|
hidden_states = conv2(hidden_states) |
|
hidden_states = hidden_states + residual |
|
return hidden_states |
|
|
|
|
|
class VitsHifiGan(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.num_kernels = len(config.resblock_kernel_sizes) |
|
self.num_upsamples = len(config.upsample_rates) |
|
self.conv_pre = nn.Conv1d( |
|
config.flow_size, |
|
config.upsample_initial_channel, |
|
kernel_size=7, |
|
stride=1, |
|
padding=3, |
|
) |
|
|
|
self.upsampler = nn.ModuleList() |
|
for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): |
|
self.upsampler.append( |
|
nn.ConvTranspose1d( |
|
config.upsample_initial_channel // (2**i), |
|
config.upsample_initial_channel // (2 ** (i + 1)), |
|
kernel_size=kernel_size, |
|
stride=upsample_rate, |
|
padding=(kernel_size - upsample_rate) // 2, |
|
) |
|
) |
|
|
|
self.resblocks = nn.ModuleList() |
|
for i in range(len(self.upsampler)): |
|
channels = config.upsample_initial_channel // (2 ** (i + 1)) |
|
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): |
|
self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) |
|
self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) |
|
|
|
def forward(self, |
|
spectrogram): |
|
hidden_states = self.conv_pre(spectrogram) |
|
for i in range(self.num_upsamples): |
|
hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) |
|
hidden_states = self.upsampler[i](hidden_states) |
|
res_state = self.resblocks[i * self.num_kernels](hidden_states) |
|
for j in range(1, self.num_kernels): |
|
res_state += self.resblocks[i * self.num_kernels + j](hidden_states) |
|
hidden_states = res_state / self.num_kernels |
|
hidden_states = nn.functional.leaky_relu(hidden_states) |
|
hidden_states = self.conv_post(hidden_states) |
|
waveform = torch.tanh(hidden_states) |
|
return waveform |
|
|
|
|
|
class VitsResidualCouplingLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.half_channels = config.flow_size // 2 |
|
self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) |
|
self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) |
|
self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) |
|
|
|
def forward(self, |
|
x, |
|
reverse=False): |
|
first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1) |
|
hidden_states = self.conv_pre(first_half) |
|
hidden_states = self.wavenet(hidden_states) |
|
mean = self.conv_post(hidden_states) |
|
second_half = (second_half - mean) |
|
outputs = torch.cat([first_half, second_half], dim=1) |
|
return outputs |
|
|
|
|
|
class VitsResidualCouplingBlock(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.flows = nn.ModuleList() |
|
for _ in range(config.prior_encoder_num_flows): |
|
self.flows.append(VitsResidualCouplingLayer(config)) |
|
|
|
def forward(self, x, reverse=False): |
|
|
|
for flow in reversed(self.flows): |
|
x = torch.flip(x, [1]) |
|
x = flow(x, reverse=True) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
class VitsAttention(nn.Module): |
|
"""has no positional info""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
|
|
self.window_size = config.window_size |
|
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
self.scaling = self.head_dim**-0.5 |
|
|
|
if (self.head_dim * self.num_heads) != self.embed_dim: |
|
raise ValueError |
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) |
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) |
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) |
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) |
|
|
|
def _shape(self, tensor, seq_len, bsz): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
layer_head_mask = None, |
|
output_attentions = False, |
|
): |
|
|
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) * self.scaling |
|
|
|
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
|
key_states = key_states.view(*proj_shape) |
|
value_states = value_states.view(*proj_shape) |
|
|
|
src_len = key_states.size(1) |
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
|
attn_output = torch.bmm(attn_weights, |
|
value_states) |
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
|
|
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, None |
|
|
|
|
|
class VitsFeedForward(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size) |
|
self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size) |
|
self.act_fn = nn.ReLU() |
|
|
|
|
|
if config.ffn_kernel_size > 1: |
|
pad_left = (config.ffn_kernel_size - 1) // 2 |
|
pad_right = config.ffn_kernel_size // 2 |
|
self.padding = [pad_left, pad_right, 0, 0, 0, 0] |
|
else: |
|
self.padding = None |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = hidden_states.permute(0, 2, 1) |
|
if self.padding is not None: |
|
hidden_states = nn.functional.pad(hidden_states, self.padding) |
|
hidden_states = self.conv_1(hidden_states) |
|
hidden_states = self.act_fn(hidden_states) |
|
if self.padding is not None: |
|
hidden_states = nn.functional.pad(hidden_states, self.padding) |
|
hidden_states = self.conv_2(hidden_states) |
|
hidden_states = hidden_states.permute(0, 2, 1) |
|
return hidden_states |
|
|
|
|
|
class VitsEncoderLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.attention = VitsAttention(config) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.feed_forward = VitsFeedForward(config) |
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
output_attentions = False, |
|
): |
|
residual = hidden_states |
|
hidden_states, attn_weights = self.attention( |
|
hidden_states=hidden_states, |
|
|
|
output_attentions=output_attentions, |
|
) |
|
|
|
|
|
hidden_states = self.layer_norm(residual + hidden_states) |
|
|
|
residual = hidden_states |
|
hidden_states = self.feed_forward(hidden_states) |
|
|
|
hidden_states = self.final_layer_norm(residual + hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
return outputs |
|
|
|
|
|
class VitsEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.gradient_checkpointing = False |
|
self.layerdrop = config.layerdrop |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
output_attentions = None, |
|
output_hidden_states = None, |
|
return_dict = None, |
|
): |
|
for _layer in self.layers: |
|
layer_outputs = _layer(hidden_states) |
|
hidden_states = layer_outputs[0] |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
|
|
|
|
) |
|
|
|
|
|
class VitsTextEncoder(nn.Module): |
|
""" |
|
Has VitsEncoder |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) |
|
self.encoder = VitsEncoder(config) |
|
self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1) |
|
|
|
def forward(self, |
|
input_ids |
|
): |
|
hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size) |
|
last_hidden_state = self.encoder(hidden_states=hidden_states).last_hidden_state |
|
|
|
stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) |
|
prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2) |
|
|
|
return VitsTextEncoderOutput( |
|
last_hidden_state=last_hidden_state, |
|
prior_means=prior_means, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
class VitsPreTrainedModel(PreTrainedModel): |
|
config_class = VitsConfig |
|
base_model_prefix = "vits" |
|
main_input_name = "input_ids" |
|
supports_gradient_checkpointing = True |
|
|
|
|
|
|
|
class VitsModel(VitsPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.text_encoder = VitsTextEncoder(config) |
|
self.flow = VitsResidualCouplingBlock(config) |
|
self.decoder = VitsHifiGan(config) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids = None, |
|
attention_mask = None, |
|
speaker_id = None, |
|
output_attentions = None, |
|
output_hidden_states = None, |
|
return_dict = None, |
|
labels = None, |
|
speed = None, |
|
lang_code = 'deu', |
|
): |
|
mask_dtype = self.text_encoder.embed_tokens.weight.dtype |
|
if attention_mask is not None: |
|
input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype) |
|
else: |
|
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype) |
|
out = self.text_encoder(input_ids=input_ids) |
|
hidden_states = out.last_hidden_state.transpose(1, 2) |
|
input_padding_mask = input_padding_mask.transpose(1, 2) |
|
prior_means = out.prior_means |
|
bs, _, in_len = hidden_states.shape |
|
|
|
if lang_code == 'deu': |
|
pattern = [1, 2, 1] |
|
elif lang_code == 'rmc-script_latin': |
|
pattern = [2, 2, 1, 2, 2] |
|
elif lang_code == 'hun': |
|
|
|
pattern = [1, 2, 1, 1, 1] |
|
else: |
|
pattern = [1, 2, 1] |
|
duration = torch.tensor(pattern, device=hidden_states.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] |
|
duration[:, :, 0] = 4 |
|
duration[:, :, -1] = 3 |
|
|
|
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() |
|
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) |
|
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) |
|
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) |
|
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) |
|
batch_size, _, output_length, input_length = attn_mask.shape |
|
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) |
|
indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) |
|
valid_indices = indices.unsqueeze(0) < cum_duration |
|
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) |
|
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] |
|
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask |
|
attn = attn[:, 0, :, :] |
|
|
|
|
|
attn = attn + 1e-4 * torch.rand_like(attn) |
|
attn /= attn.sum(2, keepdims=True) |
|
|
|
prior_means = torch.matmul(attn, prior_means) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latents = self.flow(prior_means.transpose(1, 2), |
|
reverse=True) |
|
|
|
waveform = self.decoder(latents) |
|
|
|
return waveform[:, 0, :] |
|
|
|
|
|
class VitsTokenizer(PreTrainedTokenizer): |
|
vocab_files_names = {"vocab_file": "vocab.json"} |
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
def __init__( |
|
self, |
|
vocab_file, |
|
pad_token="<pad>", |
|
unk_token="<unk>", |
|
language=None, |
|
add_blank=True, |
|
normalize=True, |
|
phonemize=True, |
|
is_uroman=False, |
|
**kwargs, |
|
) -> None: |
|
with open(vocab_file, encoding="utf-8") as vocab_handle: |
|
self.encoder = json.load(vocab_handle) |
|
|
|
self.decoder = {v: k for k, v in self.encoder.items()} |
|
self.language = language |
|
self.add_blank = add_blank |
|
self.normalize = normalize |
|
self.phonemize = phonemize |
|
|
|
self.is_uroman = is_uroman |
|
|
|
super().__init__( |
|
pad_token=pad_token, |
|
unk_token=unk_token, |
|
language=language, |
|
add_blank=add_blank, |
|
normalize=normalize, |
|
phonemize=phonemize, |
|
is_uroman=is_uroman, |
|
**kwargs, |
|
) |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.encoder) |
|
|
|
def get_vocab(self): |
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} |
|
vocab.update(self.added_tokens_encoder) |
|
return vocab |
|
|
|
def normalize_text(self, input_string): |
|
"""Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased.""" |
|
all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) |
|
filtered_text = "" |
|
|
|
i = 0 |
|
while i < len(input_string): |
|
found_match = False |
|
for word in all_vocabulary: |
|
if input_string[i : i + len(word)] == word: |
|
filtered_text += word |
|
i += len(word) |
|
found_match = True |
|
break |
|
|
|
if not found_match: |
|
filtered_text += input_string[i].lower() |
|
i += 1 |
|
|
|
return filtered_text |
|
|
|
def _preprocess_char(self, text): |
|
"""Special treatment of characters in certain languages""" |
|
if self.language == "ron": |
|
text = text.replace("ț", "ţ") |
|
return text |
|
|
|
def prepare_for_tokenization( |
|
self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs): |
|
|
|
normalize = normalize if normalize is not None else self.normalize |
|
|
|
if normalize: |
|
|
|
text = self.normalize_text(text) |
|
|
|
filtered_text = self._preprocess_char(text) |
|
|
|
if has_non_roman_characters(filtered_text) and self.is_uroman: |
|
if not is_uroman_available(): |
|
print( |
|
"Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing " |
|
"step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` " |
|
"Note `uroman` requires python version >= 3.10" |
|
"Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman" |
|
) |
|
else: |
|
uroman = ur.Uroman() |
|
filtered_text = uroman.romanize_string(filtered_text) |
|
|
|
if self.phonemize: |
|
if not is_phonemizer_available(): |
|
raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.") |
|
|
|
filtered_text = phonemizer.phonemize( |
|
filtered_text, |
|
language="en-us", |
|
backend="espeak", |
|
strip=True, |
|
preserve_punctuation=True, |
|
with_stress=True, |
|
) |
|
filtered_text = re.sub(r"\s+", " ", filtered_text) |
|
elif normalize: |
|
|
|
filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip() |
|
|
|
return filtered_text, kwargs |
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
"""Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters.""" |
|
tokens = list(text) |
|
|
|
if self.add_blank: |
|
|
|
|
|
interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) |
|
interspersed[::2] = tokens |
|
tokens = interspersed + [self._convert_id_to_token(0)] |
|
|
|
return tokens |
|
|
|
def _convert_token_to_id(self, token): |
|
"""Converts a token (str) in an id using the vocab.""" |
|
return self.encoder.get(token, self.encoder.get(self.unk_token)) |
|
|
|
def _convert_id_to_token(self, index): |
|
"""Converts an index (integer) in a token (str) using the vocab.""" |
|
return self.decoder.get(index) |
|
|