TTTS / ttts /hifigan /train.py
mrfakename's picture
Add source code
4ee33aa
raw
history blame
8.72 kB
import copy
from datetime import datetime
from inspect import signature
import json
from pathlib import Path
from accelerate import Accelerator
from tqdm import tqdm
from ttts.diffusion.diffusion_util import cycle, get_grad_norm, normalize_tacotron_mel
from ttts.diffusion.train import set_requires_grad
from ttts.hifigan.dataset import HiFiGANCollater, HifiGANDataset
from torch.utils.tensorboard import SummaryWriter
from ttts.hifigan.hifigan_discriminator import HifiganDiscriminator
from ttts.hifigan.hifigan_vocoder import HifiDecoder
from ttts.hifigan.losses import DiscriminatorLoss, GeneratorLoss
from ttts.utils.infer_utils import load_model
from ttts.utils.utils import EMA, clean_checkpoints, plot_spectrogram_to_numpy, summarize
import torch
from typing import Any, Callable, Dict, Union, Tuple
import torchaudio
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import os
from ttts.vocoder.feature_extractors import MelSpectrogramFeatures
def warmup(step):
if step<1000:
return float(step/1000)
else:
return 1
class Trainer(object):
def __init__(self, cfg_path='ttts/hifigan/config.json'):
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
self.accelerator = Accelerator()
self.cfg = json.load(open(cfg_path))
self.gpt = load_model('gpt',self.cfg['dataset']['gpt_path'],'ttts/gpt/config.json','cuda')
self.mel_length_compression = self.gpt.mel_length_compression
self.hifigan_decoder = HifiDecoder(
**self.cfg['hifigan']
)
self.hifigan_discriminator = HifiganDiscriminator()
self.dataset = HifiGANDataset(self.cfg)
self.dataloader = DataLoader(self.dataset, **self.cfg['dataloader'], collate_fn=HiFiGANCollater())
self.train_steps = self.cfg['train']['train_steps']
self.val_freq = self.cfg['train']['val_freq']
if self.accelerator.is_main_process:
now = datetime.now()
self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S"))
self.logs_folder.mkdir(exist_ok = True, parents=True)
self.G_optimizer = AdamW(self.hifigan_decoder.parameters(),lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01)
self.D_optimizer = AdamW(self.hifigan_discriminator.parameters(), lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01)
self.G_scheduler = torch.optim.lr_scheduler.LambdaLR(self.G_optimizer, lr_lambda=warmup)
self.D_scheduler = torch.optim.lr_scheduler.LambdaLR(self.D_optimizer, lr_lambda=warmup)
self.hifigan_decoder, self.hifigan_discriminator, self.dataloader, self.G_optimizer, self.D_optimizer, self.G_scheduler, self.D_scheduler, self.gpt = self.accelerator.prepare(self.hifigan_decoder, self.hifigan_discriminator, self.dataloader, self.G_optimizer, self.D_optimizer, self.G_scheduler, self.D_scheduler, self.gpt)
self.dataloader = cycle(self.dataloader)
self.step=0
self.mel_extractor = MelSpectrogramFeatures().to(self.accelerator.device)
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return (
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.accelerator.device), l2_norm=True)
.unsqueeze(-1)
.to(self.accelerator.device)
)
def _get_target_encoder(self, model):
target_encoder = copy.deepcopy(model)
set_requires_grad(target_encoder, False)
for p in target_encoder.parameters():
p.DO_NOT_TRAIN = True
return target_encoder
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.hifigan_decoder),
}
torch.save(data, str(self.logs_folder / f'model-{milestone}.pt'))
def load(self, model_path):
accelerator = self.accelerator
device = accelerator.device
data = torch.load(model_path, map_location=device)
state_dict = data['model']
self.step = data['step']
model = self.accelerator.unwrap_model(self.hifigan_decoder)
model.load_state_dict(state_dict)
def train(self):
accelerator = self.accelerator
device = accelerator.device
if accelerator.is_main_process:
writer = SummaryWriter(log_dir=self.logs_folder)
writer_eval = SummaryWriter(log_dir=os.path.join(self.logs_folder, 'eval'))
with tqdm(initial = self.step, total = self.train_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_steps:
# 'padded_text': padded_text,
# 'padded_mel_code': padded_mel_code,
# 'padded_wav': padded_wav,
# 'padded_wav_refer':padded_wav_refer,
data = next(self.dataloader)
data = {k: v.to(device) for k, v in data.items()}
y = data['padded_wav']
if data==None:
continue
mel_refer = self.mel_extractor(data['padded_wav_refer']).squeeze(1)
with torch.no_grad():
latent = self.gpt(mel_refer, data['padded_text'],
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'],
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device),
return_latent=True, clip_inputs=False).transpose(1,2)
x = latent
with self.accelerator.autocast():
g = self.get_speaker_embedding(data['padded_wav'], 24000)
y_hat = self.hifigan_decoder(x, g)
score_fake, feat_fake = self.hifigan_discriminator(y_hat.detach())
score_real, feat_real = self.hifigan_discriminator(y.clone())
loss_d = self.disc_loss(score_fake, score_real)['loss']
self.accelerator.backward(loss_d)
grad_norm_d = get_grad_norm(self.hifigan_discriminator)
accelerator.clip_grad_norm_(self.hifigan_discriminator.parameters(), 1.0)
accelerator.wait_for_everyone()
self.D_optimizer.step()
self.D_optimizer.zero_grad()
self.D_scheduler.step()
accelerator.wait_for_everyone()
score_fake, feat_fake = self.hifigan_discriminator(y_hat)
loss_g = self.gen_loss(y_hat, y, score_fake, feat_fake, feat_real)['loss']
self.accelerator.backward(loss_g)
grad_norm_g = get_grad_norm(self.hifigan_decoder)
accelerator.clip_grad_norm_(self.hifigan_decoder.parameters(), 1.0)
accelerator.wait_for_everyone()
self.G_optimizer.step()
self.G_optimizer.zero_grad()
self.G_scheduler.step()
accelerator.wait_for_everyone()
pbar.set_description(f'loss_d: {loss_d:.4f} loss_g: {loss_g:.4f}')
# if accelerator.is_main_process:
# update_moving_average(self.ema_updater,self.ema_model,self.diffusion)
if accelerator.is_main_process and self.step % self.val_freq == 0:
scalar_dict = {"loss_d": loss_d, "loss/grad_d": grad_norm_d, "lr_d":self.D_scheduler.get_last_lr()[0],
"loss_g": loss_g, "loss/grad_g": grad_norm_g, "lr_g":self.G_scheduler.get_last_lr()[0],}
summarize(
writer=writer,
global_step=self.step,
scalars=scalar_dict
)
if accelerator.is_main_process and self.step % self.cfg['train']['save_freq'] == 0:
keep_ckpts = self.cfg['train']['keep_ckpts']
if keep_ckpts > 0:
clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
self.save(self.step//1000)
self.ema_model.train()
self.step += 1
pbar.update(1)
accelerator.print('training complete')
if __name__ == '__main__':
trainer = Trainer()
# trainer.load('')
trainer.train()