TTTS / ttts /vqvae /train.py
mrfakename's picture
Add source code
4ee33aa
import copy
from datetime import datetime
import json
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from ttts.utils.utils import EMA, clean_checkpoints, plot_spectrogram_to_numpy, summarize, update_moving_average
from ttts.vqvae.dataset import PreprocessedMelDataset
import torch
import os
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import AdamW
from accelerate import Accelerator
from ttts.vqvae.xtts_dvae import DiscreteVAE
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
def get_grad_norm(model):
total_norm = 0
for name,p in model.named_parameters():
try:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
except:
print(name)
total_norm = total_norm ** (1. / 2)
return total_norm
def cycle(dl):
while True:
for data in dl:
yield data
class Trainer(object):
def __init__(self, cfg_path='ttts/vqvae/config.json'):
self.accelerator = Accelerator()
self.cfg = json.load(open(cfg_path))
self.vqvae = DiscreteVAE(**self.cfg['vqvae'])
self.dataset = PreprocessedMelDataset(self.cfg)
self.dataloader = DataLoader(self.dataset, **self.cfg['dataloader'])
self.train_steps = self.cfg['train']['train_steps']
self.val_freq = self.cfg['train']['val_freq']
if self.accelerator.is_main_process:
# self.ema_model = self._get_target_encoder(self.vqvae).to(self.accelerator.device)
now = datetime.now()
self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S"))
self.logs_folder.mkdir(exist_ok = True, parents=True)
self.ema_updater = EMA(0.999)
self.optimizer = AdamW(self.vqvae.parameters(),lr=3e-4, betas=(0.9, 0.9999), weight_decay=0.01)
self.vqvae, self.dataloader, self.optimizer = self.accelerator.prepare(self.vqvae, self.dataloader, self.optimizer)
self.dataloader = cycle(self.dataloader)
self.step=0
self.gradient_accumulate_every=1
def _get_target_encoder(self, model):
target_encoder = copy.deepcopy(model)
set_requires_grad(target_encoder, False)
for p in target_encoder.parameters():
p.DO_NOT_TRAIN = True
return target_encoder
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.vqvae),
}
torch.save(data, str(self.logs_folder / f'model-{milestone}.pt'))
def load(self, model_path):
accelerator = self.accelerator
device = accelerator.device
data = torch.load(model_path, map_location=device)
state_dict = data['model']
self.step = data['step']
vqvae = accelerator.unwrap_model(self.vqvae)
vqvae.load_state_dict(state_dict)
# if self.accelerator.is_local_main_process:
# self.ema_model.load_state_dict(state_dict)
def train(self):
accelerator = self.accelerator
device = accelerator.device
if accelerator.is_main_process:
writer = SummaryWriter(log_dir=self.logs_folder)
with tqdm(initial = self.step, total = self.train_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_steps:
total_loss = 0.
for _ in range(self.gradient_accumulate_every):
mel = next(self.dataloader)
mel = mel.to(device).squeeze(1)
with self.accelerator.autocast():
recon_loss, commitment_loss, mel_recon = self.vqvae(mel)
recon_loss = torch.mean(recon_loss)
loss = recon_loss+0.25*commitment_loss
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
self.accelerator.backward(loss)
grad_norm = get_grad_norm(self.vqvae)
accelerator.clip_grad_norm_(self.vqvae.parameters(), 1.0)
pbar.set_description(f'loss: {total_loss:.4f}')
accelerator.wait_for_everyone()
self.optimizer.step()
self.optimizer.zero_grad()
accelerator.wait_for_everyone()
# if accelerator.is_main_process:
# update_moving_average(self.ema_updater,self.ema_model,self.vqvae)
if accelerator.is_main_process and self.step % self.val_freq == 0:
with torch.no_grad():
# self.ema_model.eval()
eval_model = self.accelerator.unwrap_model(self.vqvae)
eval_model.eval()
# mel_recon_ema = self.ema_model.infer(mel)[0]
mel_recon_ema = eval_model.infer(mel)[0]
eval_model.train()
scalar_dict = {"loss": total_loss, "loss_mel":recon_loss, "loss_commitment":commitment_loss, "loss/grad": grad_norm}
image_dict = {
"all/spec": plot_spectrogram_to_numpy(mel[0, :, :].detach().unsqueeze(-1).cpu()),
"all/spec_pred": plot_spectrogram_to_numpy(mel_recon[0, :, :].detach().unsqueeze(-1).cpu()),
"all/spec_pred_ema": plot_spectrogram_to_numpy(mel_recon_ema[0, :, :].detach().unsqueeze(-1).cpu()),
}
summarize(
writer=writer,
global_step=self.step,
images=image_dict,
scalars=scalar_dict
)
if accelerator.is_main_process and self.step % self.cfg['train']['save_freq']==0:
keep_ckpts = self.cfg['train']['keep_ckpts']
if keep_ckpts > 0:
clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
self.save(self.step//1000)
self.step += 1
pbar.update(1)
accelerator.print('training complete')
if __name__ == '__main__':
trainer = Trainer()
# trainer.load('~/tortoise_plus_zh/ttts/vqvae/logs/2023-11-04-00-25-39/model-14.pt')
trainer.train()