Spaces:
Runtime error
Runtime error
File size: 11,520 Bytes
2cb106d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import os
import statistics
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
class FastSpeechDataset(Dataset):
def __init__(self,
path_to_transcript_dict,
acoustic_checkpoint_path,
cache_dir,
lang,
loading_processes=40,
min_len_in_seconds=1,
max_len_in_seconds=20,
cut_silence=False,
reduction_factor=1,
device=torch.device("cpu"),
rebuild_cache=False,
ctc_selection=True,
save_imgs=False):
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache:
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
cut_silences=cut_silence,
rebuild_cache=rebuild_cache,
device=device)
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
# we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech.
if not isinstance(datapoints, tuple): # check for backwards compatibility
print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.")
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
cache_dir=cache_dir,
lang=lang,
loading_processes=loading_processes,
min_len_in_seconds=min_len_in_seconds,
max_len_in_seconds=max_len_in_seconds,
cut_silences=cut_silence,
rebuild_cache=True)
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
dataset = datapoints[0]
norm_waves = datapoints[1]
# build cache
print("... building dataset cache ...")
self.datapoints = list()
self.ctc_losses = list()
acoustic_model = Aligner()
acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
# ==========================================
# actual creation of datapoints starts here
# ==========================================
acoustic_model = acoustic_model.to(device)
dio = Dio(reduction_factor=reduction_factor, fs=16000)
energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000)
dc = DurationCalculator(reduction_factor=reduction_factor)
vis_dir = os.path.join(cache_dir, "duration_vis")
os.makedirs(vis_dir, exist_ok=True)
pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device)
for index in tqdm(range(len(dataset))):
norm_wave = norm_waves[index]
norm_wave_length = torch.LongTensor([len(norm_wave)])
if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection:
continue
text = dataset[index][0]
melspec = dataset[index][2]
melspec_length = dataset[index][3]
alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device),
tokens=text.to(device),
save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None,
return_ctc=True)
cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
last_vec = None
for phoneme_index, vec in enumerate(text):
if last_vec is not None:
if last_vec.numpy().tolist() == vec.numpy().tolist():
# we found a case of repeating phonemes!
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
dur_1 = cached_duration[phoneme_index - 1]
dur_2 = cached_duration[phoneme_index]
total_dur = dur_1 + dur_2
new_dur_1 = int((total_dur / 5) * 3)
new_dur_2 = total_dur - new_dur_1
cached_duration[phoneme_index - 1] = new_dur_1
cached_duration[phoneme_index] = new_dur_2
last_vec = vec
cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=cached_duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
cached_pitch = dio(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=melspec_length,
durations=cached_duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
try:
prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu()
except RuntimeError:
# if there is an audio without any voiced segments whatsoever we have to skip it.
continue
self.datapoints.append([dataset[index][0],
dataset[index][1],
dataset[index][2],
dataset[index][3],
cached_duration.cpu(),
cached_energy,
cached_pitch,
prosodic_condition])
self.ctc_losses.append(ctc_loss)
# =============================
# done with datapoint creation
# =============================
if ctc_selection:
# now we can filter out some bad datapoints based on the CTC scores we collected
mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses)
std_dev = statistics.stdev(self.ctc_losses)
threshold = mean_ctc + std_dev
for index in range(len(self.ctc_losses), 0, -1):
if self.ctc_losses[index - 1] > threshold:
self.datapoints.pop(index - 1)
print(
f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}")
# save to cache
if len(self.datapoints) > 0:
torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt"))
else:
import sys
print("No datapoints were prepared! Exiting...")
sys.exit()
else:
# just load the datapoints from cache
self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu')
self.cache_dir = cache_dir
self.language_id = get_language_id(lang)
print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
def __getitem__(self, index):
return self.datapoints[index][0], \
self.datapoints[index][1], \
self.datapoints[index][2], \
self.datapoints[index][3], \
self.datapoints[index][4], \
self.datapoints[index][5], \
self.datapoints[index][6], \
self.datapoints[index][7], \
self.language_id
def __len__(self):
return len(self.datapoints)
def remove_samples(self, list_of_samples_to_remove):
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
self.datapoints.pop(remove_id)
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
print("Dataset updated!")
def fix_repeating_phones(self):
"""
The viterbi decoding of the durations cannot
handle repetitions. This is now solved heuristically,
but if you have a cache from before March 2022,
use this method to postprocess those cases.
"""
for datapoint_index in tqdm(list(range(len(self.datapoints)))):
last_vec = None
for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]):
if last_vec is not None:
if last_vec.numpy().tolist() == vec.numpy().tolist():
# we found a case of repeating phonemes!
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1]
dur_2 = self.datapoints[datapoint_index][4][phoneme_index]
total_dur = dur_1 + dur_2
new_dur_1 = int((total_dur / 5) * 3)
new_dur_2 = total_dur - new_dur_1
self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1
self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2
print("fix applied")
last_vec = vec
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
print("Dataset updated!")
|