Spaces:
Build error
Build error
File size: 8,723 Bytes
4ee33aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import copy
from datetime import datetime
from inspect import signature
import json
from pathlib import Path
from accelerate import Accelerator
from tqdm import tqdm
from ttts.diffusion.diffusion_util import cycle, get_grad_norm, normalize_tacotron_mel
from ttts.diffusion.train import set_requires_grad
from ttts.hifigan.dataset import HiFiGANCollater, HifiGANDataset
from torch.utils.tensorboard import SummaryWriter
from ttts.hifigan.hifigan_discriminator import HifiganDiscriminator
from ttts.hifigan.hifigan_vocoder import HifiDecoder
from ttts.hifigan.losses import DiscriminatorLoss, GeneratorLoss
from ttts.utils.infer_utils import load_model
from ttts.utils.utils import EMA, clean_checkpoints, plot_spectrogram_to_numpy, summarize
import torch
from typing import Any, Callable, Dict, Union, Tuple
import torchaudio
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import os
from ttts.vocoder.feature_extractors import MelSpectrogramFeatures
def warmup(step):
if step<1000:
return float(step/1000)
else:
return 1
class Trainer(object):
def __init__(self, cfg_path='ttts/hifigan/config.json'):
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
self.accelerator = Accelerator()
self.cfg = json.load(open(cfg_path))
self.gpt = load_model('gpt',self.cfg['dataset']['gpt_path'],'ttts/gpt/config.json','cuda')
self.mel_length_compression = self.gpt.mel_length_compression
self.hifigan_decoder = HifiDecoder(
**self.cfg['hifigan']
)
self.hifigan_discriminator = HifiganDiscriminator()
self.dataset = HifiGANDataset(self.cfg)
self.dataloader = DataLoader(self.dataset, **self.cfg['dataloader'], collate_fn=HiFiGANCollater())
self.train_steps = self.cfg['train']['train_steps']
self.val_freq = self.cfg['train']['val_freq']
if self.accelerator.is_main_process:
now = datetime.now()
self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S"))
self.logs_folder.mkdir(exist_ok = True, parents=True)
self.G_optimizer = AdamW(self.hifigan_decoder.parameters(),lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01)
self.D_optimizer = AdamW(self.hifigan_discriminator.parameters(), lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01)
self.G_scheduler = torch.optim.lr_scheduler.LambdaLR(self.G_optimizer, lr_lambda=warmup)
self.D_scheduler = torch.optim.lr_scheduler.LambdaLR(self.D_optimizer, lr_lambda=warmup)
self.hifigan_decoder, self.hifigan_discriminator, self.dataloader, self.G_optimizer, self.D_optimizer, self.G_scheduler, self.D_scheduler, self.gpt = self.accelerator.prepare(self.hifigan_decoder, self.hifigan_discriminator, self.dataloader, self.G_optimizer, self.D_optimizer, self.G_scheduler, self.D_scheduler, self.gpt)
self.dataloader = cycle(self.dataloader)
self.step=0
self.mel_extractor = MelSpectrogramFeatures().to(self.accelerator.device)
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return (
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.accelerator.device), l2_norm=True)
.unsqueeze(-1)
.to(self.accelerator.device)
)
def _get_target_encoder(self, model):
target_encoder = copy.deepcopy(model)
set_requires_grad(target_encoder, False)
for p in target_encoder.parameters():
p.DO_NOT_TRAIN = True
return target_encoder
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.hifigan_decoder),
}
torch.save(data, str(self.logs_folder / f'model-{milestone}.pt'))
def load(self, model_path):
accelerator = self.accelerator
device = accelerator.device
data = torch.load(model_path, map_location=device)
state_dict = data['model']
self.step = data['step']
model = self.accelerator.unwrap_model(self.hifigan_decoder)
model.load_state_dict(state_dict)
def train(self):
accelerator = self.accelerator
device = accelerator.device
if accelerator.is_main_process:
writer = SummaryWriter(log_dir=self.logs_folder)
writer_eval = SummaryWriter(log_dir=os.path.join(self.logs_folder, 'eval'))
with tqdm(initial = self.step, total = self.train_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_steps:
# 'padded_text': padded_text,
# 'padded_mel_code': padded_mel_code,
# 'padded_wav': padded_wav,
# 'padded_wav_refer':padded_wav_refer,
data = next(self.dataloader)
data = {k: v.to(device) for k, v in data.items()}
y = data['padded_wav']
if data==None:
continue
mel_refer = self.mel_extractor(data['padded_wav_refer']).squeeze(1)
with torch.no_grad():
latent = self.gpt(mel_refer, data['padded_text'],
torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'],
torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device),
return_latent=True, clip_inputs=False).transpose(1,2)
x = latent
with self.accelerator.autocast():
g = self.get_speaker_embedding(data['padded_wav'], 24000)
y_hat = self.hifigan_decoder(x, g)
score_fake, feat_fake = self.hifigan_discriminator(y_hat.detach())
score_real, feat_real = self.hifigan_discriminator(y.clone())
loss_d = self.disc_loss(score_fake, score_real)['loss']
self.accelerator.backward(loss_d)
grad_norm_d = get_grad_norm(self.hifigan_discriminator)
accelerator.clip_grad_norm_(self.hifigan_discriminator.parameters(), 1.0)
accelerator.wait_for_everyone()
self.D_optimizer.step()
self.D_optimizer.zero_grad()
self.D_scheduler.step()
accelerator.wait_for_everyone()
score_fake, feat_fake = self.hifigan_discriminator(y_hat)
loss_g = self.gen_loss(y_hat, y, score_fake, feat_fake, feat_real)['loss']
self.accelerator.backward(loss_g)
grad_norm_g = get_grad_norm(self.hifigan_decoder)
accelerator.clip_grad_norm_(self.hifigan_decoder.parameters(), 1.0)
accelerator.wait_for_everyone()
self.G_optimizer.step()
self.G_optimizer.zero_grad()
self.G_scheduler.step()
accelerator.wait_for_everyone()
pbar.set_description(f'loss_d: {loss_d:.4f} loss_g: {loss_g:.4f}')
# if accelerator.is_main_process:
# update_moving_average(self.ema_updater,self.ema_model,self.diffusion)
if accelerator.is_main_process and self.step % self.val_freq == 0:
scalar_dict = {"loss_d": loss_d, "loss/grad_d": grad_norm_d, "lr_d":self.D_scheduler.get_last_lr()[0],
"loss_g": loss_g, "loss/grad_g": grad_norm_g, "lr_g":self.G_scheduler.get_last_lr()[0],}
summarize(
writer=writer,
global_step=self.step,
scalars=scalar_dict
)
if accelerator.is_main_process and self.step % self.cfg['train']['save_freq'] == 0:
keep_ckpts = self.cfg['train']['keep_ckpts']
if keep_ckpts > 0:
clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
self.save(self.step//1000)
self.ema_model.train()
self.step += 1
pbar.update(1)
accelerator.print('training complete')
if __name__ == '__main__':
trainer = Trainer()
# trainer.load('')
trainer.train() |