Dionyssos's picture
oscillate vits duration
c7362aa
raw
history blame
28.6 kB
import math
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple
from transformers.tokenization_utils import PreTrainedTokenizer
import phonemizer
import uroman as ur
import torch.nn.functional as F
def has_non_roman_characters(input_string):
# Find any character outside the ASCII range
non_roman_pattern = re.compile(r"[^\x00-\x7F]")
# Search the input string for non-Roman characters
match = non_roman_pattern.search(input_string)
has_non_roman = match is not None
return has_non_roman
class VitsConfig(PretrainedConfig):
model_type = "vits"
def __init__(
self,
vocab_size=38,
hidden_size=192,
num_hidden_layers=6,
num_attention_heads=2,
window_size=4,
use_bias=True,
ffn_dim=768,
layerdrop=0.1,
ffn_kernel_size=3,
flow_size=192,
spectrogram_bins=513,
# hidden_act="relu",
hidden_dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_stochastic_duration_prediction=True,
num_speakers=1,
speaker_embedding_size=0,
upsample_initial_channel=512,
upsample_rates=[8, 8, 2, 2],
upsample_kernel_sizes=[16, 16, 4, 4],
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_slope=0.1,
depth_separable_channels=2,
depth_separable_num_layers=3,
duration_predictor_flow_bins=10,
duration_predictor_tail_bound=5.0,
duration_predictor_kernel_size=3,
duration_predictor_dropout=0.5,
duration_predictor_num_flows=4,
duration_predictor_filter_channels=256,
prior_encoder_num_flows=4,
prior_encoder_num_wavenet_layers=4,
posterior_encoder_num_wavenet_layers=16,
wavenet_kernel_size=5,
wavenet_dilation_rate=1,
wavenet_dropout=0.0,
speaking_rate=1.0, # unused
noise_scale=0.667,
noise_scale_duration=0.8,
sampling_rate=16_000,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.window_size = window_size
self.use_bias = use_bias
self.ffn_dim = ffn_dim
self.layerdrop = layerdrop
self.ffn_kernel_size = ffn_kernel_size
self.flow_size = flow_size
self.spectrogram_bins = spectrogram_bins
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
# self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
self.num_speakers = num_speakers
self.speaker_embedding_size = speaker_embedding_size
self.upsample_initial_channel = upsample_initial_channel
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.leaky_relu_slope = leaky_relu_slope
self.depth_separable_channels = depth_separable_channels
self.depth_separable_num_layers = depth_separable_num_layers
self.duration_predictor_flow_bins = duration_predictor_flow_bins
self.duration_predictor_tail_bound = duration_predictor_tail_bound
self.duration_predictor_kernel_size = duration_predictor_kernel_size
self.duration_predictor_num_flows = duration_predictor_num_flows
self.duration_predictor_filter_channels = duration_predictor_filter_channels
self.prior_encoder_num_flows = prior_encoder_num_flows
self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
self.wavenet_kernel_size = wavenet_kernel_size
self.wavenet_dilation_rate = wavenet_dilation_rate
self.noise_scale = noise_scale
self.noise_scale_duration = noise_scale_duration
self.sampling_rate = sampling_rate
if len(upsample_kernel_sizes) != len(upsample_rates):
raise ValueError(
f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
f"`upsample_rates` ({len(upsample_rates)})"
)
super().__init__(**kwargs)
@dataclass
class VitsTextEncoderOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
prior_means: torch.FloatTensor = None
prior_log_variances: torch.FloatTensor = None
hidden_states: torch.FloatTensor = None
attentions: torch.FloatTensor = None
class VitsWaveNet(torch.nn.Module):
def __init__(self, config, num_layers):
super().__init__()
self.hidden_size = config.hidden_size
self.num_layers = num_layers
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
# if hasattr(nn.utils.parametrizations, "weight_norm"):
# # raise ValueError
weight_norm = nn.utils.parametrizations.weight_norm
# else:
# raise ValueError
# # weight_norm = nn.utils.weight_norm
for i in range(num_layers):
dilation = config.wavenet_dilation_rate**i
padding = (config.wavenet_kernel_size * dilation - dilation) // 2
in_layer = torch.nn.Conv1d(
in_channels=config.hidden_size,
out_channels=2 * config.hidden_size,
kernel_size=config.wavenet_kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < num_layers - 1:
res_skip_channels = 2 * config.hidden_size
else:
res_skip_channels = config.hidden_size
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self,
inputs):
outputs = torch.zeros_like(inputs)
num_channels = torch.IntTensor([self.hidden_size])[0]
for i in range(self.num_layers):
in_act = self.in_layers[i](inputs)
# global_states = torch.zeros_like(hidden_states) # style ?
# acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
# --
# def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
# in_act = input_a # + input_b
t_act = torch.tanh(in_act[:, :num_channels, :])
s_act = torch.sigmoid(in_act[:, num_channels:, :])
acts = t_act * s_act
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_size, :]
inputs = inputs + res_acts
outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
else:
outputs = outputs + res_skip_acts
return outputs
# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
class HifiGanResidualBlock(nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
super().__init__()
self.leaky_relu_slope = leaky_relu_slope
self.convs1 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=dilation[i],
padding=self.get_padding(kernel_size, dilation[i]),
)
for i in range(len(dilation))
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=self.get_padding(kernel_size, 1),
)
for _ in range(len(dilation))
]
)
def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
def forward(self, hidden_states):
for conv1, conv2 in zip(self.convs1, self.convs2):
residual = hidden_states
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
hidden_states = conv1(hidden_states)
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
hidden_states = conv2(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class VitsHifiGan(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates)
self.conv_pre = nn.Conv1d(
config.flow_size,
config.upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
)
self.upsampler = nn.ModuleList()
for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
self.upsampler.append(
nn.ConvTranspose1d(
config.upsample_initial_channel // (2**i),
config.upsample_initial_channel // (2 ** (i + 1)),
kernel_size=kernel_size,
stride=upsample_rate,
padding=(kernel_size - upsample_rate) // 2,
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.upsampler)):
channels = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
def forward(self,
spectrogram):
hidden_states = self.conv_pre(spectrogram)
for i in range(self.num_upsamples):
hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
hidden_states = self.upsampler[i](hidden_states)
res_state = self.resblocks[i * self.num_kernels](hidden_states)
for j in range(1, self.num_kernels):
res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
hidden_states = res_state / self.num_kernels
hidden_states = nn.functional.leaky_relu(hidden_states)
hidden_states = self.conv_post(hidden_states)
waveform = torch.tanh(hidden_states)
return waveform
class VitsResidualCouplingLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.half_channels = config.flow_size // 2
self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
def forward(self,
x,
reverse=False):
first_half, second_half = torch.split(x, [self.half_channels] * 2, dim=1)
hidden_states = self.conv_pre(first_half)
hidden_states = self.wavenet(hidden_states)
mean = self.conv_post(hidden_states)
second_half = (second_half - mean)
outputs = torch.cat([first_half, second_half], dim=1)
return outputs
class VitsResidualCouplingBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.flows = nn.ModuleList()
for _ in range(config.prior_encoder_num_flows):
self.flows.append(VitsResidualCouplingLayer(config))
def forward(self, x, reverse=False):
# x L [1, 192, 481]
for flow in reversed(self.flows):
x = torch.flip(x, [1]) # flipud CHANNELs
x = flow(x, reverse=True)
return x
class VitsAttention(nn.Module):
"""has no positional info"""
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.window_size = config.window_size
self.head_dim = self.embed_dim // self.num_heads
self.scaling = self.head_dim**-0.5
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
def _shape(self, tensor, seq_len, bsz):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states,
layer_head_mask = None,
output_attentions = False,
):
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.bmm(attn_weights,
value_states)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None #attn_weights_reshaped
class VitsFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
self.act_fn = nn.ReLU()
if config.ffn_kernel_size > 1:
pad_left = (config.ffn_kernel_size - 1) // 2
pad_right = config.ffn_kernel_size // 2
self.padding = [pad_left, pad_right, 0, 0, 0, 0]
else:
self.padding = None
def forward(self, hidden_states):
hidden_states = hidden_states.permute(0, 2, 1)
if self.padding is not None:
hidden_states = nn.functional.pad(hidden_states, self.padding)
hidden_states = self.conv_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
if self.padding is not None:
hidden_states = nn.functional.pad(hidden_states, self.padding)
hidden_states = self.conv_2(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
return hidden_states
class VitsEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = VitsAttention(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = VitsFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
output_attentions = False,
):
residual = hidden_states
hidden_states, attn_weights = self.attention(
hidden_states=hidden_states,
# attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = self.layer_norm(residual + hidden_states)
residual = hidden_states
hidden_states = self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(residual + hidden_states)
outputs = (hidden_states,)
return outputs
class VitsEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.layerdrop = config.layerdrop
def forward(
self,
hidden_states,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
):
for _layer in self.layers:
layer_outputs = _layer(hidden_states)
hidden_states = layer_outputs[0]
return BaseModelOutput(
last_hidden_state=hidden_states,
# hidden_states=all_hidden_states,
# attentions=all_self_attentions,
)
class VitsTextEncoder(nn.Module):
"""
Has VitsEncoder
"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.encoder = VitsEncoder(config) # 6 Layers of VitsAttention
self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
def forward(self,
input_ids
):
hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
last_hidden_state = self.encoder(hidden_states=hidden_states).last_hidden_state
stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2)
prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
return VitsTextEncoderOutput(
last_hidden_state=last_hidden_state,
prior_means=prior_means,
# prior_log_variances=prior_log_variances,
# hidden_states=encoder_outputs.hidden_states,
# attentions=encoder_outputs.attentions,
)
class VitsPreTrainedModel(PreTrainedModel):
config_class = VitsConfig
base_model_prefix = "vits"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
class VitsModel(VitsPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.text_encoder = VitsTextEncoder(config) # has VitsEncoder that includes 6L of VitsAttention
self.flow = VitsResidualCouplingBlock(config)
self.decoder = VitsHifiGan(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids = None,
attention_mask = None,
speaker_id = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
labels = None,
speed = None,
lang_code = 'deu', # speed oscillation pattern per voice/lang
):
mask_dtype = self.text_encoder.embed_tokens.weight.dtype
if attention_mask is not None:
input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
else:
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
out = self.text_encoder(input_ids=input_ids)
hidden_states = out.last_hidden_state.transpose(1, 2)
input_padding_mask = input_padding_mask.transpose(1, 2)
prior_means = out.prior_means
bs, _, in_len = hidden_states.shape
# VITS Duration Oscillation
if lang_code == 'deu':
pattern = [1, 2, 1] # each voice (lang_code) sounds cooler with different pattern
elif lang_code == 'rmc-script_latin':
pattern = [2, 2, 1, 2, 2] # [2, 2, 2, 1, 2]
elif lang_code == 'hun':
# pattern = [1, 2, 2, 1, 1, 1] #sounds cool / has valley-pause
pattern = [1, 2, 1, 1, 1]
else:
pattern = [1, 2, 1]
duration = torch.tensor(pattern, device=hidden_states.device).repeat(int(in_len / len(pattern)) + 2)[None, None, :in_len] # perhaps define [1, 2, 1] per voice or language
duration[:, :, 0] = 4
duration[:, :, -1] = 3
# ATTN
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
batch_size, _, output_length, input_length = attn_mask.shape
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
valid_indices = indices.unsqueeze(0) < cum_duration
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
attn = attn[:, 0, :, :]
attn = attn + 1e-4 * torch.rand_like(attn)
attn /= attn.sum(2, keepdims=True)
#print(attn)
prior_means = torch.matmul(attn, prior_means) # try attn to contain .5/.5 instead of 1/0 so it smoothly interpolates repeated prior_means
#prior_means = F.interpolate(prior_means.transpose(1,2), int(1.74 * prior_means.shape[1]), mode='linear').transpose(1,2) # extend for slow speed
# prior means have now been replicated x duration of each prior mean
latents = self.flow(prior_means.transpose(1, 2), # + torch.randn_like(prior_means) * .94,
reverse=True)
waveform = self.decoder(latents) # [bs, 1, 16000]
return waveform[:, 0, :]
class VitsTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
pad_token="<pad>",
unk_token="<unk>",
language=None,
add_blank=True,
normalize=True,
phonemize=True,
is_uroman=False,
**kwargs,
) -> None:
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.language = language
self.add_blank = add_blank
self.normalize = normalize
self.phonemize = phonemize
self.is_uroman = is_uroman
super().__init__(
pad_token=pad_token,
unk_token=unk_token,
language=language,
add_blank=add_blank,
normalize=normalize,
phonemize=phonemize,
is_uroman=is_uroman,
**kwargs,
)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def normalize_text(self, input_string):
"""Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
filtered_text = ""
i = 0
while i < len(input_string):
found_match = False
for word in all_vocabulary:
if input_string[i : i + len(word)] == word:
filtered_text += word
i += len(word)
found_match = True
break
if not found_match:
filtered_text += input_string[i].lower()
i += 1
return filtered_text
def _preprocess_char(self, text):
"""Special treatment of characters in certain languages"""
if self.language == "ron":
text = text.replace("ț", "ţ")
return text
def prepare_for_tokenization(
self, text: str, is_split_into_words: bool = False, normalize = None, **kwargs):
normalize = normalize if normalize is not None else self.normalize
if normalize:
# normalise for casing
text = self.normalize_text(text)
filtered_text = self._preprocess_char(text)
if has_non_roman_characters(filtered_text) and self.is_uroman:
if not is_uroman_available():
print(
"Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing "
"step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` "
"Note `uroman` requires python version >= 3.10"
"Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman"
)
else:
uroman = ur.Uroman()
filtered_text = uroman.romanize_string(filtered_text)
if self.phonemize:
if not is_phonemizer_available():
raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")
filtered_text = phonemizer.phonemize(
filtered_text,
language="en-us",
backend="espeak",
strip=True,
preserve_punctuation=True,
with_stress=True,
)
filtered_text = re.sub(r"\s+", " ", filtered_text)
elif normalize:
# strip any chars outside of the vocab (punctuation)
filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()
return filtered_text, kwargs
def _tokenize(self, text: str) -> List[str]:
"""Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
tokens = list(text)
if self.add_blank:
# sounds dyslexi if no space between letters
# sounds disconnected if >2 spaces between letters
interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2) # + 1) # +1 rises slice index error if tokens odd
interspersed[::2] = tokens
tokens = interspersed + [self._convert_id_to_token(0)] # append one last space (it has indexing error ::2 mismatch if tokens is odd)
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)