Spaces:
Runtime error
Runtime error
Delete TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py
Browse files
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py
DELETED
@@ -1,211 +0,0 @@
|
|
1 |
-
import librosa.display as lbd
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
import torch
|
4 |
-
import torch.multiprocessing
|
5 |
-
from torch.cuda.amp import GradScaler
|
6 |
-
from torch.cuda.amp import autocast
|
7 |
-
from torch.nn.utils.rnn import pad_sequence
|
8 |
-
from torch.utils.data.dataloader import DataLoader
|
9 |
-
from tqdm import tqdm
|
10 |
-
|
11 |
-
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
12 |
-
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
|
13 |
-
from Utility.WarmupScheduler import WarmupScheduler
|
14 |
-
from Utility.path_to_transcript_dicts import *
|
15 |
-
from Utility.utils import cumsum_durations
|
16 |
-
from Utility.utils import delete_old_checkpoints
|
17 |
-
from Utility.utils import get_most_recent_checkpoint
|
18 |
-
|
19 |
-
|
20 |
-
def train_loop(net,
|
21 |
-
datasets,
|
22 |
-
device,
|
23 |
-
save_directory,
|
24 |
-
batch_size,
|
25 |
-
steps,
|
26 |
-
steps_per_checkpoint,
|
27 |
-
lr,
|
28 |
-
path_to_checkpoint,
|
29 |
-
resume=False,
|
30 |
-
warmup_steps=4000):
|
31 |
-
# ============
|
32 |
-
# Preparations
|
33 |
-
# ============
|
34 |
-
net = net.to(device)
|
35 |
-
torch.multiprocessing.set_sharing_strategy('file_system')
|
36 |
-
train_loaders = list()
|
37 |
-
train_iters = list()
|
38 |
-
for dataset in datasets:
|
39 |
-
train_loaders.append(DataLoader(batch_size=batch_size,
|
40 |
-
dataset=dataset,
|
41 |
-
drop_last=True,
|
42 |
-
num_workers=2,
|
43 |
-
pin_memory=True,
|
44 |
-
shuffle=True,
|
45 |
-
prefetch_factor=5,
|
46 |
-
collate_fn=collate_and_pad,
|
47 |
-
persistent_workers=True))
|
48 |
-
train_iters.append(iter(train_loaders[-1]))
|
49 |
-
default_embeddings = {"en": None, "de": None, "el": None, "es": None, "fi": None, "ru": None, "hu": None, "nl": None, "fr": None}
|
50 |
-
for index, lang in enumerate(["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]):
|
51 |
-
default_embedding = None
|
52 |
-
for datapoint in datasets[index]:
|
53 |
-
if default_embedding is None:
|
54 |
-
default_embedding = datapoint[7].squeeze()
|
55 |
-
else:
|
56 |
-
default_embedding = default_embedding + datapoint[7].squeeze()
|
57 |
-
default_embeddings[lang] = (default_embedding / len(datasets[index])).to(device)
|
58 |
-
optimizer = torch.optim.RAdam(net.parameters(), lr=lr, eps=1.0e-06, weight_decay=0.0)
|
59 |
-
grad_scaler = GradScaler()
|
60 |
-
scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
|
61 |
-
if resume:
|
62 |
-
previous_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
|
63 |
-
if previous_checkpoint is not None:
|
64 |
-
path_to_checkpoint = previous_checkpoint
|
65 |
-
else:
|
66 |
-
raise RuntimeError(f"No checkpoint found that can be resumed from in {save_directory}")
|
67 |
-
step_counter = 0
|
68 |
-
train_losses_total = list()
|
69 |
-
if path_to_checkpoint is not None:
|
70 |
-
check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
|
71 |
-
net.load_state_dict(check_dict["model"])
|
72 |
-
if resume:
|
73 |
-
optimizer.load_state_dict(check_dict["optimizer"])
|
74 |
-
step_counter = check_dict["step_counter"]
|
75 |
-
grad_scaler.load_state_dict(check_dict["scaler"])
|
76 |
-
scheduler.load_state_dict(check_dict["scheduler"])
|
77 |
-
if step_counter > steps:
|
78 |
-
print("Desired steps already reached in loaded checkpoint.")
|
79 |
-
return
|
80 |
-
|
81 |
-
net.train()
|
82 |
-
# =============================
|
83 |
-
# Actual train loop starts here
|
84 |
-
# =============================
|
85 |
-
for step in tqdm(range(step_counter, steps)):
|
86 |
-
batches = []
|
87 |
-
for index in range(len(datasets)):
|
88 |
-
# we get one batch for each task (i.e. language in this case)
|
89 |
-
try:
|
90 |
-
batch = next(train_iters[index])
|
91 |
-
batches.append(batch)
|
92 |
-
except StopIteration:
|
93 |
-
train_iters[index] = iter(train_loaders[index])
|
94 |
-
batch = next(train_iters[index])
|
95 |
-
batches.append(batch)
|
96 |
-
train_loss = 0.0
|
97 |
-
for batch in batches:
|
98 |
-
with autocast():
|
99 |
-
# we sum the loss for each task, as we would do for the
|
100 |
-
# second order regular MAML, but we do it only over one
|
101 |
-
# step (i.e. iterations of inner loop = 1)
|
102 |
-
train_loss = train_loss + net(text_tensors=batch[0].to(device),
|
103 |
-
text_lengths=batch[1].to(device),
|
104 |
-
gold_speech=batch[2].to(device),
|
105 |
-
speech_lengths=batch[3].to(device),
|
106 |
-
gold_durations=batch[4].to(device),
|
107 |
-
gold_pitch=batch[6].to(device), # mind the switched order
|
108 |
-
gold_energy=batch[5].to(device), # mind the switched order
|
109 |
-
utterance_embedding=batch[7].to(device),
|
110 |
-
lang_ids=batch[8].to(device),
|
111 |
-
return_mels=False)
|
112 |
-
# then we directly update our meta-parameters without
|
113 |
-
# the need for any task specific parameters
|
114 |
-
train_losses_total.append(train_loss.item())
|
115 |
-
optimizer.zero_grad()
|
116 |
-
grad_scaler.scale(train_loss).backward()
|
117 |
-
grad_scaler.unscale_(optimizer)
|
118 |
-
torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
|
119 |
-
grad_scaler.step(optimizer)
|
120 |
-
grad_scaler.update()
|
121 |
-
scheduler.step()
|
122 |
-
|
123 |
-
if step % steps_per_checkpoint == 0:
|
124 |
-
# ==============================
|
125 |
-
# Enough steps for some insights
|
126 |
-
# ==============================
|
127 |
-
net.eval()
|
128 |
-
print(f"Total Loss: {round(sum(train_losses_total) / len(train_losses_total), 3)}")
|
129 |
-
train_losses_total = list()
|
130 |
-
torch.save({
|
131 |
-
"model" : net.state_dict(),
|
132 |
-
"optimizer" : optimizer.state_dict(),
|
133 |
-
"scaler" : grad_scaler.state_dict(),
|
134 |
-
"scheduler" : scheduler.state_dict(),
|
135 |
-
"step_counter": step,
|
136 |
-
"default_emb" : default_embeddings["en"]
|
137 |
-
},
|
138 |
-
os.path.join(save_directory, "checkpoint_{}.pt".format(step)))
|
139 |
-
delete_old_checkpoints(save_directory, keep=5)
|
140 |
-
for lang in ["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]:
|
141 |
-
plot_progress_spec(net=net,
|
142 |
-
device=device,
|
143 |
-
lang=lang,
|
144 |
-
save_dir=save_directory,
|
145 |
-
step=step,
|
146 |
-
utt_embeds=default_embeddings)
|
147 |
-
net.train()
|
148 |
-
|
149 |
-
|
150 |
-
@torch.inference_mode()
|
151 |
-
def plot_progress_spec(net, device, save_dir, step, lang, utt_embeds):
|
152 |
-
tf = ArticulatoryCombinedTextFrontend(language=lang)
|
153 |
-
sentence = ""
|
154 |
-
default_embed = utt_embeds[lang]
|
155 |
-
if lang == "en":
|
156 |
-
sentence = "This is a complex sentence, it even has a pause!"
|
157 |
-
elif lang == "de":
|
158 |
-
sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
|
159 |
-
elif lang == "el":
|
160 |
-
sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
|
161 |
-
elif lang == "es":
|
162 |
-
sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
|
163 |
-
elif lang == "fi":
|
164 |
-
sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
|
165 |
-
elif lang == "ru":
|
166 |
-
sentence = "Это сложное предложение, в нем даже есть пауза!"
|
167 |
-
elif lang == "hu":
|
168 |
-
sentence = "Ez egy összetett mondat, még szünet is van benne!"
|
169 |
-
elif lang == "nl":
|
170 |
-
sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
|
171 |
-
elif lang == "fr":
|
172 |
-
sentence = "C'est une phrase complexe, elle a même une pause !"
|
173 |
-
phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
|
174 |
-
spec, durations, *_ = net.inference(text=phoneme_vector,
|
175 |
-
return_duration_pitch_energy=True,
|
176 |
-
utterance_embedding=default_embed,
|
177 |
-
lang_id=get_language_id(lang).to(device))
|
178 |
-
spec = spec.transpose(0, 1).to("cpu").numpy()
|
179 |
-
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
180 |
-
if not os.path.exists(os.path.join(save_dir, "spec")):
|
181 |
-
os.makedirs(os.path.join(save_dir, "spec"))
|
182 |
-
fig, ax = plt.subplots(nrows=1, ncols=1)
|
183 |
-
lbd.specshow(spec,
|
184 |
-
ax=ax,
|
185 |
-
sr=16000,
|
186 |
-
cmap='GnBu',
|
187 |
-
y_axis='mel',
|
188 |
-
x_axis=None,
|
189 |
-
hop_length=256)
|
190 |
-
ax.yaxis.set_visible(False)
|
191 |
-
ax.set_xticks(duration_splits, minor=True)
|
192 |
-
ax.xaxis.grid(True, which='minor')
|
193 |
-
ax.set_xticks(label_positions, minor=False)
|
194 |
-
ax.set_xticklabels(tf.get_phone_string(sentence))
|
195 |
-
ax.set_title(sentence)
|
196 |
-
plt.savefig(os.path.join(os.path.join(save_dir, "spec"), f"{step}_{lang}.png"))
|
197 |
-
plt.clf()
|
198 |
-
plt.close()
|
199 |
-
|
200 |
-
|
201 |
-
def collate_and_pad(batch):
|
202 |
-
# text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id
|
203 |
-
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
|
204 |
-
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
|
205 |
-
pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
|
206 |
-
torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
|
207 |
-
pad_sequence([datapoint[4] for datapoint in batch], batch_first=True),
|
208 |
-
pad_sequence([datapoint[5] for datapoint in batch], batch_first=True),
|
209 |
-
pad_sequence([datapoint[6] for datapoint in batch], batch_first=True),
|
210 |
-
torch.stack([datapoint[7] for datapoint in batch]).squeeze(),
|
211 |
-
torch.stack([datapoint[8] for datapoint in batch]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|