Create train_single_gpu.py
Browse files- RingFormer/train_single_gpu.py +295 -0
RingFormer/train_single_gpu.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
3 |
+
import itertools
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.utils.tensorboard import SummaryWriter
|
11 |
+
from torch.utils.data import DistributedSampler, DataLoader
|
12 |
+
import torch.multiprocessing as mp
|
13 |
+
from torch.distributed import init_process_group
|
14 |
+
from torch.nn.parallel import DistributedDataParallel
|
15 |
+
from env import AttrDict, build_env
|
16 |
+
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
17 |
+
from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\
|
18 |
+
discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator
|
19 |
+
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
20 |
+
from stft import TorchSTFT
|
21 |
+
from Utils.JDC.model import JDCNet
|
22 |
+
|
23 |
+
torch.backends.cudnn.benchmark = True
|
24 |
+
|
25 |
+
|
26 |
+
def train(rank, a, h):
|
27 |
+
if h.num_gpus > 1:
|
28 |
+
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
29 |
+
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
30 |
+
|
31 |
+
torch.cuda.manual_seed(h.seed)
|
32 |
+
device = torch.device('cuda:{:d}'.format(rank))
|
33 |
+
|
34 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
35 |
+
params = torch.load(h.F0_path)['net']
|
36 |
+
F0_model.load_state_dict(params)
|
37 |
+
|
38 |
+
generator = Generator(h, F0_model).to(device)
|
39 |
+
mpd = MultiPeriodDiscriminator().to(device)
|
40 |
+
msd = MultiScaleSubbandCQTDiscriminator().to(device)
|
41 |
+
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
|
42 |
+
|
43 |
+
if rank == 0:
|
44 |
+
print(generator)
|
45 |
+
os.makedirs(a.checkpoint_path, exist_ok=True)
|
46 |
+
print("checkpoints directory : ", a.checkpoint_path)
|
47 |
+
|
48 |
+
if os.path.isdir(a.checkpoint_path):
|
49 |
+
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
50 |
+
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
51 |
+
|
52 |
+
steps = 0
|
53 |
+
if cp_g is None or cp_do is None:
|
54 |
+
state_dict_do = None
|
55 |
+
last_epoch = -1
|
56 |
+
else:
|
57 |
+
state_dict_g = load_checkpoint(cp_g, device)
|
58 |
+
state_dict_do = load_checkpoint(cp_do, device)
|
59 |
+
generator.load_state_dict(state_dict_g['generator'])
|
60 |
+
mpd.load_state_dict(state_dict_do['mpd'])
|
61 |
+
msd.load_state_dict(state_dict_do['msd'])
|
62 |
+
steps = state_dict_do['steps'] + 1
|
63 |
+
last_epoch = state_dict_do['epoch']
|
64 |
+
|
65 |
+
if h.num_gpus > 1:
|
66 |
+
generator = DistributedDataParallel(generator, device_ids=[rank], find_unused_parameters=True).to(device)
|
67 |
+
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
68 |
+
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
69 |
+
|
70 |
+
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
71 |
+
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
72 |
+
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
73 |
+
|
74 |
+
if state_dict_do is not None:
|
75 |
+
optim_g.load_state_dict(state_dict_do['optim_g'])
|
76 |
+
optim_d.load_state_dict(state_dict_do['optim_d'])
|
77 |
+
|
78 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
79 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
80 |
+
|
81 |
+
training_filelist, validation_filelist = get_dataset_filelist(a)
|
82 |
+
|
83 |
+
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
84 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
85 |
+
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
86 |
+
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
87 |
+
|
88 |
+
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
89 |
+
|
90 |
+
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
91 |
+
sampler=train_sampler,
|
92 |
+
batch_size=h.batch_size,
|
93 |
+
pin_memory=True,
|
94 |
+
drop_last=True)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
if rank == 0:
|
100 |
+
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
101 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
102 |
+
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
103 |
+
base_mels_path=a.input_mels_dir)
|
104 |
+
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
105 |
+
sampler=None,
|
106 |
+
batch_size=1,
|
107 |
+
pin_memory=True,
|
108 |
+
drop_last=True)
|
109 |
+
|
110 |
+
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
111 |
+
|
112 |
+
generator.train()
|
113 |
+
mpd.train()
|
114 |
+
msd.train()
|
115 |
+
for epoch in range(max(0, last_epoch), a.training_epochs):
|
116 |
+
if rank == 0:
|
117 |
+
start = time.time()
|
118 |
+
print("Epoch: {}".format(epoch+1))
|
119 |
+
|
120 |
+
if h.num_gpus > 1:
|
121 |
+
train_sampler.set_epoch(epoch)
|
122 |
+
|
123 |
+
for i, batch in enumerate(train_loader):
|
124 |
+
if rank == 0:
|
125 |
+
start_b = time.time()
|
126 |
+
x, y, _, y_mel = batch
|
127 |
+
x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
128 |
+
y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
129 |
+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
130 |
+
y = y.unsqueeze(1)
|
131 |
+
# y_g_hat = generator(x)
|
132 |
+
spec, phase = generator(x)
|
133 |
+
|
134 |
+
y_g_hat = stft.inverse(spec, phase)
|
135 |
+
|
136 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
137 |
+
h.fmin, h.fmax_for_loss)
|
138 |
+
|
139 |
+
optim_d.zero_grad()
|
140 |
+
|
141 |
+
# MPD
|
142 |
+
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
143 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
144 |
+
loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
145 |
+
|
146 |
+
# MSD
|
147 |
+
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
148 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
149 |
+
loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
150 |
+
|
151 |
+
loss_disc_all = loss_disc_s + loss_disc_f
|
152 |
+
|
153 |
+
loss_disc_all.backward()
|
154 |
+
optim_d.step()
|
155 |
+
|
156 |
+
# Generator
|
157 |
+
optim_g.zero_grad()
|
158 |
+
|
159 |
+
# L1 Mel-Spectrogram Loss
|
160 |
+
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
161 |
+
|
162 |
+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
163 |
+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
164 |
+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
165 |
+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
166 |
+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
167 |
+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
168 |
+
|
169 |
+
loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
|
170 |
+
loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
|
171 |
+
|
172 |
+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
173 |
+
|
174 |
+
loss_gen_all.backward()
|
175 |
+
optim_g.step()
|
176 |
+
|
177 |
+
if rank == 0:
|
178 |
+
# STDOUT logging
|
179 |
+
if steps % a.stdout_interval == 0:
|
180 |
+
with torch.no_grad():
|
181 |
+
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
182 |
+
|
183 |
+
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
184 |
+
format(steps, loss_gen_all, mel_error, time.time() - start_b))
|
185 |
+
|
186 |
+
# checkpointing
|
187 |
+
if steps % a.checkpoint_interval == 0 and steps != 0:
|
188 |
+
checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
|
189 |
+
save_checkpoint(checkpoint_path,
|
190 |
+
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
191 |
+
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
192 |
+
save_checkpoint(checkpoint_path,
|
193 |
+
{'mpd': (mpd.module if h.num_gpus > 1
|
194 |
+
else mpd).state_dict(),
|
195 |
+
'msd': (msd.module if h.num_gpus > 1
|
196 |
+
else msd).state_dict(),
|
197 |
+
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
198 |
+
'epoch': epoch})
|
199 |
+
|
200 |
+
# Tensorboard summary logging
|
201 |
+
if steps % a.summary_interval == 0:
|
202 |
+
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
203 |
+
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
204 |
+
|
205 |
+
# Validation
|
206 |
+
if steps % a.validation_interval == 0: # and steps != 0:
|
207 |
+
generator.eval()
|
208 |
+
torch.cuda.empty_cache()
|
209 |
+
val_err_tot = 0
|
210 |
+
with torch.no_grad():
|
211 |
+
for j, batch in enumerate(validation_loader):
|
212 |
+
x, y, _, y_mel = batch
|
213 |
+
# y_g_hat = generator(x.to(device))
|
214 |
+
spec, phase = generator(x.to(device))
|
215 |
+
|
216 |
+
y_g_hat = stft.inverse(spec, phase)
|
217 |
+
|
218 |
+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
219 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
220 |
+
h.hop_size, h.win_size,
|
221 |
+
h.fmin, h.fmax_for_loss)
|
222 |
+
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
223 |
+
|
224 |
+
if j <= 4:
|
225 |
+
if steps == 0:
|
226 |
+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
227 |
+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
228 |
+
|
229 |
+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
230 |
+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
231 |
+
h.sampling_rate, h.hop_size, h.win_size,
|
232 |
+
h.fmin, h.fmax)
|
233 |
+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
234 |
+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
235 |
+
|
236 |
+
val_err = val_err_tot / (j+1)
|
237 |
+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
238 |
+
|
239 |
+
generator.train()
|
240 |
+
|
241 |
+
steps += 1
|
242 |
+
|
243 |
+
scheduler_g.step()
|
244 |
+
scheduler_d.step()
|
245 |
+
|
246 |
+
if rank == 0:
|
247 |
+
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
248 |
+
|
249 |
+
|
250 |
+
def main():
|
251 |
+
print('Initializing Training Process..')
|
252 |
+
|
253 |
+
parser = argparse.ArgumentParser()
|
254 |
+
|
255 |
+
parser.add_argument('--group_name', default=None)
|
256 |
+
parser.add_argument('--input_wavs_dir', default='')
|
257 |
+
parser.add_argument('--input_mels_dir', default='ft_dataset')
|
258 |
+
# parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt')
|
259 |
+
parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/eng_norm.txt')
|
260 |
+
parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/valid_eng.txt')
|
261 |
+
parser.add_argument('--checkpoint_path', default='cp_ringformer_LIBRI')
|
262 |
+
parser.add_argument('--config', default='config_v1.json')
|
263 |
+
parser.add_argument('--training_epochs', default=3100, type=int)
|
264 |
+
parser.add_argument('--stdout_interval', default=10, type=int)
|
265 |
+
parser.add_argument('--checkpoint_interval', default=2500, type=int)
|
266 |
+
parser.add_argument('--summary_interval', default=100, type=int)
|
267 |
+
parser.add_argument('--validation_interval', default=1000, type=int)
|
268 |
+
parser.add_argument('--fine_tuning', default=False, type=bool)
|
269 |
+
|
270 |
+
a = parser.parse_args()
|
271 |
+
|
272 |
+
with open(a.config) as f:
|
273 |
+
data = f.read()
|
274 |
+
|
275 |
+
json_config = json.loads(data)
|
276 |
+
h = AttrDict(json_config)
|
277 |
+
build_env(a.config, 'config.json', a.checkpoint_path)
|
278 |
+
|
279 |
+
torch.manual_seed(h.seed)
|
280 |
+
if torch.cuda.is_available():
|
281 |
+
torch.cuda.manual_seed(h.seed)
|
282 |
+
h.num_gpus = torch.cuda.device_count()
|
283 |
+
h.batch_size = int(h.batch_size / h.num_gpus)
|
284 |
+
print('Batch size per GPU :', h.batch_size)
|
285 |
+
else:
|
286 |
+
pass
|
287 |
+
|
288 |
+
if h.num_gpus > 1:
|
289 |
+
mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
|
290 |
+
else:
|
291 |
+
train(0, a, h)
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == '__main__':
|
295 |
+
main()
|