Spaces:
Build error
Build error
import copy | |
from datetime import datetime | |
from inspect import signature | |
import json | |
from pathlib import Path | |
from accelerate import Accelerator | |
from tqdm import tqdm | |
from ttts.diffusion.diffusion_util import cycle, get_grad_norm, normalize_tacotron_mel | |
from ttts.diffusion.train import set_requires_grad | |
from ttts.hifigan.dataset import HiFiGANCollater, HifiGANDataset | |
from torch.utils.tensorboard import SummaryWriter | |
from ttts.hifigan.hifigan_discriminator import HifiganDiscriminator | |
from ttts.hifigan.hifigan_vocoder import HifiDecoder | |
from ttts.hifigan.losses import DiscriminatorLoss, GeneratorLoss | |
from ttts.utils.infer_utils import load_model | |
from ttts.utils.utils import EMA, clean_checkpoints, plot_spectrogram_to_numpy, summarize | |
import torch | |
from typing import Any, Callable, Dict, Union, Tuple | |
import torchaudio | |
from torch.utils.data import DataLoader, Dataset | |
from torch.optim import AdamW | |
import os | |
from ttts.vocoder.feature_extractors import MelSpectrogramFeatures | |
def warmup(step): | |
if step<1000: | |
return float(step/1000) | |
else: | |
return 1 | |
class Trainer(object): | |
def __init__(self, cfg_path='ttts/hifigan/config.json'): | |
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
# self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) | |
self.accelerator = Accelerator() | |
self.cfg = json.load(open(cfg_path)) | |
self.gpt = load_model('gpt',self.cfg['dataset']['gpt_path'],'ttts/gpt/config.json','cuda') | |
self.mel_length_compression = self.gpt.mel_length_compression | |
self.hifigan_decoder = HifiDecoder( | |
**self.cfg['hifigan'] | |
) | |
self.hifigan_discriminator = HifiganDiscriminator() | |
self.dataset = HifiGANDataset(self.cfg) | |
self.dataloader = DataLoader(self.dataset, **self.cfg['dataloader'], collate_fn=HiFiGANCollater()) | |
self.train_steps = self.cfg['train']['train_steps'] | |
self.val_freq = self.cfg['train']['val_freq'] | |
if self.accelerator.is_main_process: | |
now = datetime.now() | |
self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S")) | |
self.logs_folder.mkdir(exist_ok = True, parents=True) | |
self.G_optimizer = AdamW(self.hifigan_decoder.parameters(),lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01) | |
self.D_optimizer = AdamW(self.hifigan_discriminator.parameters(), lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01) | |
self.G_scheduler = torch.optim.lr_scheduler.LambdaLR(self.G_optimizer, lr_lambda=warmup) | |
self.D_scheduler = torch.optim.lr_scheduler.LambdaLR(self.D_optimizer, lr_lambda=warmup) | |
self.hifigan_decoder, self.hifigan_discriminator, self.dataloader, self.G_optimizer, self.D_optimizer, self.G_scheduler, self.D_scheduler, self.gpt = self.accelerator.prepare(self.hifigan_decoder, self.hifigan_discriminator, self.dataloader, self.G_optimizer, self.D_optimizer, self.G_scheduler, self.D_scheduler, self.gpt) | |
self.dataloader = cycle(self.dataloader) | |
self.step=0 | |
self.mel_extractor = MelSpectrogramFeatures().to(self.accelerator.device) | |
self.disc_loss = DiscriminatorLoss() | |
self.gen_loss = GeneratorLoss() | |
def get_speaker_embedding(self, audio, sr): | |
audio_16k = torchaudio.functional.resample(audio, sr, 16000) | |
return ( | |
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.accelerator.device), l2_norm=True) | |
.unsqueeze(-1) | |
.to(self.accelerator.device) | |
) | |
def _get_target_encoder(self, model): | |
target_encoder = copy.deepcopy(model) | |
set_requires_grad(target_encoder, False) | |
for p in target_encoder.parameters(): | |
p.DO_NOT_TRAIN = True | |
return target_encoder | |
def save(self, milestone): | |
if not self.accelerator.is_local_main_process: | |
return | |
data = { | |
'step': self.step, | |
'model': self.accelerator.get_state_dict(self.hifigan_decoder), | |
} | |
torch.save(data, str(self.logs_folder / f'model-{milestone}.pt')) | |
def load(self, model_path): | |
accelerator = self.accelerator | |
device = accelerator.device | |
data = torch.load(model_path, map_location=device) | |
state_dict = data['model'] | |
self.step = data['step'] | |
model = self.accelerator.unwrap_model(self.hifigan_decoder) | |
model.load_state_dict(state_dict) | |
def train(self): | |
accelerator = self.accelerator | |
device = accelerator.device | |
if accelerator.is_main_process: | |
writer = SummaryWriter(log_dir=self.logs_folder) | |
writer_eval = SummaryWriter(log_dir=os.path.join(self.logs_folder, 'eval')) | |
with tqdm(initial = self.step, total = self.train_steps, disable = not accelerator.is_main_process) as pbar: | |
while self.step < self.train_steps: | |
# 'padded_text': padded_text, | |
# 'padded_mel_code': padded_mel_code, | |
# 'padded_wav': padded_wav, | |
# 'padded_wav_refer':padded_wav_refer, | |
data = next(self.dataloader) | |
data = {k: v.to(device) for k, v in data.items()} | |
y = data['padded_wav'] | |
if data==None: | |
continue | |
mel_refer = self.mel_extractor(data['padded_wav_refer']).squeeze(1) | |
with torch.no_grad(): | |
latent = self.gpt(mel_refer, data['padded_text'], | |
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'], | |
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device), | |
return_latent=True, clip_inputs=False).transpose(1,2) | |
x = latent | |
with self.accelerator.autocast(): | |
g = self.get_speaker_embedding(data['padded_wav'], 24000) | |
y_hat = self.hifigan_decoder(x, g) | |
score_fake, feat_fake = self.hifigan_discriminator(y_hat.detach()) | |
score_real, feat_real = self.hifigan_discriminator(y.clone()) | |
loss_d = self.disc_loss(score_fake, score_real)['loss'] | |
self.accelerator.backward(loss_d) | |
grad_norm_d = get_grad_norm(self.hifigan_discriminator) | |
accelerator.clip_grad_norm_(self.hifigan_discriminator.parameters(), 1.0) | |
accelerator.wait_for_everyone() | |
self.D_optimizer.step() | |
self.D_optimizer.zero_grad() | |
self.D_scheduler.step() | |
accelerator.wait_for_everyone() | |
score_fake, feat_fake = self.hifigan_discriminator(y_hat) | |
loss_g = self.gen_loss(y_hat, y, score_fake, feat_fake, feat_real)['loss'] | |
self.accelerator.backward(loss_g) | |
grad_norm_g = get_grad_norm(self.hifigan_decoder) | |
accelerator.clip_grad_norm_(self.hifigan_decoder.parameters(), 1.0) | |
accelerator.wait_for_everyone() | |
self.G_optimizer.step() | |
self.G_optimizer.zero_grad() | |
self.G_scheduler.step() | |
accelerator.wait_for_everyone() | |
pbar.set_description(f'loss_d: {loss_d:.4f} loss_g: {loss_g:.4f}') | |
# if accelerator.is_main_process: | |
# update_moving_average(self.ema_updater,self.ema_model,self.diffusion) | |
if accelerator.is_main_process and self.step % self.val_freq == 0: | |
scalar_dict = {"loss_d": loss_d, "loss/grad_d": grad_norm_d, "lr_d":self.D_scheduler.get_last_lr()[0], | |
"loss_g": loss_g, "loss/grad_g": grad_norm_g, "lr_g":self.G_scheduler.get_last_lr()[0],} | |
summarize( | |
writer=writer, | |
global_step=self.step, | |
scalars=scalar_dict | |
) | |
if accelerator.is_main_process and self.step % self.cfg['train']['save_freq'] == 0: | |
keep_ckpts = self.cfg['train']['keep_ckpts'] | |
if keep_ckpts > 0: | |
clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) | |
self.save(self.step//1000) | |
self.ema_model.train() | |
self.step += 1 | |
pbar.update(1) | |
accelerator.print('training complete') | |
if __name__ == '__main__': | |
trainer = Trainer() | |
# trainer.load('') | |
trainer.train() |