Flux9665 commited on
Commit
e9f478c
·
1 Parent(s): 6327d28

Delete TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py

Browse files
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py DELETED
@@ -1,211 +0,0 @@
1
- import librosa.display as lbd
2
- import matplotlib.pyplot as plt
3
- import torch
4
- import torch.multiprocessing
5
- from torch.cuda.amp import GradScaler
6
- from torch.cuda.amp import autocast
7
- from torch.nn.utils.rnn import pad_sequence
8
- from torch.utils.data.dataloader import DataLoader
9
- from tqdm import tqdm
10
-
11
- from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
12
- from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
13
- from Utility.WarmupScheduler import WarmupScheduler
14
- from Utility.path_to_transcript_dicts import *
15
- from Utility.utils import cumsum_durations
16
- from Utility.utils import delete_old_checkpoints
17
- from Utility.utils import get_most_recent_checkpoint
18
-
19
-
20
- def train_loop(net,
21
- datasets,
22
- device,
23
- save_directory,
24
- batch_size,
25
- steps,
26
- steps_per_checkpoint,
27
- lr,
28
- path_to_checkpoint,
29
- resume=False,
30
- warmup_steps=4000):
31
- # ============
32
- # Preparations
33
- # ============
34
- net = net.to(device)
35
- torch.multiprocessing.set_sharing_strategy('file_system')
36
- train_loaders = list()
37
- train_iters = list()
38
- for dataset in datasets:
39
- train_loaders.append(DataLoader(batch_size=batch_size,
40
- dataset=dataset,
41
- drop_last=True,
42
- num_workers=2,
43
- pin_memory=True,
44
- shuffle=True,
45
- prefetch_factor=5,
46
- collate_fn=collate_and_pad,
47
- persistent_workers=True))
48
- train_iters.append(iter(train_loaders[-1]))
49
- default_embeddings = {"en": None, "de": None, "el": None, "es": None, "fi": None, "ru": None, "hu": None, "nl": None, "fr": None}
50
- for index, lang in enumerate(["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]):
51
- default_embedding = None
52
- for datapoint in datasets[index]:
53
- if default_embedding is None:
54
- default_embedding = datapoint[7].squeeze()
55
- else:
56
- default_embedding = default_embedding + datapoint[7].squeeze()
57
- default_embeddings[lang] = (default_embedding / len(datasets[index])).to(device)
58
- optimizer = torch.optim.RAdam(net.parameters(), lr=lr, eps=1.0e-06, weight_decay=0.0)
59
- grad_scaler = GradScaler()
60
- scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
61
- if resume:
62
- previous_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
63
- if previous_checkpoint is not None:
64
- path_to_checkpoint = previous_checkpoint
65
- else:
66
- raise RuntimeError(f"No checkpoint found that can be resumed from in {save_directory}")
67
- step_counter = 0
68
- train_losses_total = list()
69
- if path_to_checkpoint is not None:
70
- check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
71
- net.load_state_dict(check_dict["model"])
72
- if resume:
73
- optimizer.load_state_dict(check_dict["optimizer"])
74
- step_counter = check_dict["step_counter"]
75
- grad_scaler.load_state_dict(check_dict["scaler"])
76
- scheduler.load_state_dict(check_dict["scheduler"])
77
- if step_counter > steps:
78
- print("Desired steps already reached in loaded checkpoint.")
79
- return
80
-
81
- net.train()
82
- # =============================
83
- # Actual train loop starts here
84
- # =============================
85
- for step in tqdm(range(step_counter, steps)):
86
- batches = []
87
- for index in range(len(datasets)):
88
- # we get one batch for each task (i.e. language in this case)
89
- try:
90
- batch = next(train_iters[index])
91
- batches.append(batch)
92
- except StopIteration:
93
- train_iters[index] = iter(train_loaders[index])
94
- batch = next(train_iters[index])
95
- batches.append(batch)
96
- train_loss = 0.0
97
- for batch in batches:
98
- with autocast():
99
- # we sum the loss for each task, as we would do for the
100
- # second order regular MAML, but we do it only over one
101
- # step (i.e. iterations of inner loop = 1)
102
- train_loss = train_loss + net(text_tensors=batch[0].to(device),
103
- text_lengths=batch[1].to(device),
104
- gold_speech=batch[2].to(device),
105
- speech_lengths=batch[3].to(device),
106
- gold_durations=batch[4].to(device),
107
- gold_pitch=batch[6].to(device), # mind the switched order
108
- gold_energy=batch[5].to(device), # mind the switched order
109
- utterance_embedding=batch[7].to(device),
110
- lang_ids=batch[8].to(device),
111
- return_mels=False)
112
- # then we directly update our meta-parameters without
113
- # the need for any task specific parameters
114
- train_losses_total.append(train_loss.item())
115
- optimizer.zero_grad()
116
- grad_scaler.scale(train_loss).backward()
117
- grad_scaler.unscale_(optimizer)
118
- torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
119
- grad_scaler.step(optimizer)
120
- grad_scaler.update()
121
- scheduler.step()
122
-
123
- if step % steps_per_checkpoint == 0:
124
- # ==============================
125
- # Enough steps for some insights
126
- # ==============================
127
- net.eval()
128
- print(f"Total Loss: {round(sum(train_losses_total) / len(train_losses_total), 3)}")
129
- train_losses_total = list()
130
- torch.save({
131
- "model" : net.state_dict(),
132
- "optimizer" : optimizer.state_dict(),
133
- "scaler" : grad_scaler.state_dict(),
134
- "scheduler" : scheduler.state_dict(),
135
- "step_counter": step,
136
- "default_emb" : default_embeddings["en"]
137
- },
138
- os.path.join(save_directory, "checkpoint_{}.pt".format(step)))
139
- delete_old_checkpoints(save_directory, keep=5)
140
- for lang in ["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]:
141
- plot_progress_spec(net=net,
142
- device=device,
143
- lang=lang,
144
- save_dir=save_directory,
145
- step=step,
146
- utt_embeds=default_embeddings)
147
- net.train()
148
-
149
-
150
- @torch.inference_mode()
151
- def plot_progress_spec(net, device, save_dir, step, lang, utt_embeds):
152
- tf = ArticulatoryCombinedTextFrontend(language=lang)
153
- sentence = ""
154
- default_embed = utt_embeds[lang]
155
- if lang == "en":
156
- sentence = "This is a complex sentence, it even has a pause!"
157
- elif lang == "de":
158
- sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
159
- elif lang == "el":
160
- sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
161
- elif lang == "es":
162
- sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
163
- elif lang == "fi":
164
- sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
165
- elif lang == "ru":
166
- sentence = "Это сложное предложение, в нем даже есть пауза!"
167
- elif lang == "hu":
168
- sentence = "Ez egy összetett mondat, még szünet is van benne!"
169
- elif lang == "nl":
170
- sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
171
- elif lang == "fr":
172
- sentence = "C'est une phrase complexe, elle a même une pause !"
173
- phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
174
- spec, durations, *_ = net.inference(text=phoneme_vector,
175
- return_duration_pitch_energy=True,
176
- utterance_embedding=default_embed,
177
- lang_id=get_language_id(lang).to(device))
178
- spec = spec.transpose(0, 1).to("cpu").numpy()
179
- duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
180
- if not os.path.exists(os.path.join(save_dir, "spec")):
181
- os.makedirs(os.path.join(save_dir, "spec"))
182
- fig, ax = plt.subplots(nrows=1, ncols=1)
183
- lbd.specshow(spec,
184
- ax=ax,
185
- sr=16000,
186
- cmap='GnBu',
187
- y_axis='mel',
188
- x_axis=None,
189
- hop_length=256)
190
- ax.yaxis.set_visible(False)
191
- ax.set_xticks(duration_splits, minor=True)
192
- ax.xaxis.grid(True, which='minor')
193
- ax.set_xticks(label_positions, minor=False)
194
- ax.set_xticklabels(tf.get_phone_string(sentence))
195
- ax.set_title(sentence)
196
- plt.savefig(os.path.join(os.path.join(save_dir, "spec"), f"{step}_{lang}.png"))
197
- plt.clf()
198
- plt.close()
199
-
200
-
201
- def collate_and_pad(batch):
202
- # text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id
203
- return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
204
- torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
205
- pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
206
- torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
207
- pad_sequence([datapoint[4] for datapoint in batch], batch_first=True),
208
- pad_sequence([datapoint[5] for datapoint in batch], batch_first=True),
209
- pad_sequence([datapoint[6] for datapoint in batch], batch_first=True),
210
- torch.stack([datapoint[7] for datapoint in batch]).squeeze(),
211
- torch.stack([datapoint[8] for datapoint in batch]))