Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) ByteDance, Inc. and its affiliates. | |
# Copyright (c) Chutong Meng | |
# | |
# This source code is licensed under the CC BY-NC license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
import logging | |
import os | |
from collections import defaultdict | |
import torch | |
from tensorboardX import SummaryWriter | |
from tqdm import tqdm | |
logger = logging.getLogger("repcodec_train") | |
class Trainer: | |
def __init__( | |
self, | |
steps: int, | |
epochs: int, | |
data_loader: dict, | |
model: dict, | |
criterion: dict, | |
optimizer: dict, | |
scheduler: dict, | |
config: dict, | |
device=torch.device("cpu"), | |
): | |
self.steps = steps | |
self.epochs = epochs | |
self.data_loader = data_loader | |
self.model = model | |
self.criterion = criterion | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.config = config | |
self.device = device | |
self.writer = SummaryWriter(config["outdir"]) | |
self.total_train_loss = defaultdict(float) | |
self.total_eval_loss = defaultdict(float) | |
self.train_max_steps = config.get("train_max_steps", 0) | |
def _train_step(self, batch): | |
"""Single step of training.""" | |
mode = "train" | |
x = batch | |
x = x.to(self.device) | |
codec_loss = 0.0 | |
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x) | |
self._perplexity(perplexity, mode=mode) | |
codec_loss += self._vq_loss(vqloss, mode=mode) | |
codec_loss += self._metric_loss(y_, x, mode=mode) | |
self._record_loss("codec_loss", codec_loss, mode=mode) | |
self._update_repcodec(codec_loss) | |
self.steps += 1 | |
self.tqdm.update(1) | |
self._check_train_finish() | |
def _eval_step(self, batch): | |
"""Single step of evaluation.""" | |
mode = "eval" | |
x = batch | |
x = x.to(self.device) | |
codec_loss = 0.0 | |
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x) | |
self._perplexity(perplexity, mode=mode) | |
codec_loss += self._vq_loss(vqloss, mode=mode) | |
codec_loss += self._metric_loss(y_, x, mode=mode) | |
self._record_loss("codec_loss", codec_loss, mode=mode) | |
def run(self): | |
"""Run training.""" | |
self.finish_train = False | |
self.tqdm = tqdm( | |
initial=self.steps, total=self.train_max_steps, desc="[train]" | |
) | |
while True: | |
self._train_epoch() | |
# check whether training is finished | |
if self.finish_train: | |
break | |
self.tqdm.close() | |
logger.info("Finished training.") | |
def save_checkpoint(self, checkpoint_path: str): | |
state_dict = { | |
"model": { | |
"repcodec": self.model["repcodec"].state_dict() | |
}, | |
"optimizer": { | |
"repcodec": self.optimizer["repcodec"].state_dict(), | |
}, | |
"scheduler": { | |
"repcodec": self.scheduler["repcodec"].state_dict(), | |
}, | |
"steps": self.steps, | |
"epochs": self.epochs, | |
} | |
if not os.path.exists(os.path.dirname(checkpoint_path)): | |
os.makedirs(os.path.dirname(checkpoint_path)) | |
torch.save(state_dict, checkpoint_path) | |
def load_checkpoint( | |
self, | |
checkpoint_path: str, | |
strict: bool = True, | |
load_only_params: bool = False | |
): | |
state_dict = torch.load(checkpoint_path, map_location="cpu") | |
self.model["repcodec"].load_state_dict( | |
state_dict["model"]["repcodec"], strict=strict | |
) | |
if not load_only_params: | |
self.steps = state_dict["steps"] | |
self.epochs = state_dict["epochs"] | |
self.optimizer["repcodec"].load_state_dict( | |
state_dict["optimizer"]["repcodec"] | |
) | |
self.scheduler["repcodec"].load_state_dict( | |
state_dict["scheduler"]["repcodec"] | |
) | |
def _train_epoch(self): | |
"""One epoch of training.""" | |
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): | |
# train one step | |
self._train_step(batch) | |
# check interval | |
self._check_log_interval() | |
self._check_eval_interval() | |
self._check_save_interval() | |
# check whether training is finished | |
if self.finish_train: | |
return | |
# update | |
self.epochs += 1 | |
self.train_steps_per_epoch = train_steps_per_epoch | |
if train_steps_per_epoch > 200: | |
logger.info( | |
f"(Steps: {self.steps}) Finished {self.epochs} epoch training " | |
f"({self.train_steps_per_epoch} steps per epoch)." | |
) | |
def _eval_epoch(self): | |
"""One epoch of evaluation.""" | |
logger.info(f"(Steps: {self.steps}) Start evaluation.") | |
# change mode | |
for key in self.model.keys(): | |
self.model[key].eval() | |
# calculate loss for each batch | |
for eval_steps_per_epoch, batch in enumerate( | |
tqdm(self.data_loader["dev"], desc="[eval]"), 1 | |
): | |
# eval one step | |
self._eval_step(batch) | |
logger.info( | |
f"(Steps: {self.steps}) Finished evaluation " | |
f"({eval_steps_per_epoch} steps per epoch)." | |
) | |
# average loss | |
for key in self.total_eval_loss.keys(): | |
self.total_eval_loss[key] /= eval_steps_per_epoch | |
logger.info( | |
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." | |
) | |
# record | |
self._write_to_tensorboard(self.total_eval_loss) | |
# reset | |
self.total_eval_loss = defaultdict(float) | |
# restore mode | |
for key in self.model.keys(): | |
self.model[key].train() | |
def _metric_loss(self, predict_y, natural_y, mode='train'): | |
"""Metric losses.""" | |
metric_loss = 0.0 | |
repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y) | |
repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"] | |
self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode) | |
metric_loss += repr_reconstruct_loss | |
return metric_loss | |
def _update_repcodec(self, repr_loss): | |
"""Update generator.""" | |
self.optimizer["repcodec"].zero_grad() | |
repr_loss.backward() | |
if self.config["grad_norm"] > 0: | |
torch.nn.utils.clip_grad_norm_( | |
self.model["repcodec"].parameters(), | |
self.config["grad_norm"], | |
) | |
self.optimizer["repcodec"].step() | |
self.scheduler["repcodec"].step() | |
def _record_loss(self, name: str, loss, mode='train'): | |
"""Record loss.""" | |
if torch.is_tensor(loss): | |
loss = loss.item() | |
if mode == 'train': | |
self.total_train_loss[f"train/{name}"] += loss | |
elif mode == 'eval': | |
self.total_eval_loss[f"eval/{name}"] += loss | |
else: | |
raise NotImplementedError(f"Mode ({mode}) is not supported!") | |
def _write_to_tensorboard(self, loss): | |
"""Write to tensorboard.""" | |
for key, value in loss.items(): | |
self.writer.add_scalar(key, value, self.steps) | |
def _check_save_interval(self): | |
if self.steps and (self.steps % self.config["save_interval_steps"] == 0): | |
self.save_checkpoint( | |
os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl") | |
) | |
logger.info(f"Successfully saved checkpoint @ {self.steps} steps.") | |
def _check_eval_interval(self): | |
if self.steps % self.config["eval_interval_steps"] == 0: | |
self._eval_epoch() | |
def _check_log_interval(self): | |
if self.steps % self.config["log_interval_steps"] == 0: | |
for key in self.total_train_loss.keys(): | |
self.total_train_loss[key] /= self.config["log_interval_steps"] | |
logger.info( | |
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." | |
) | |
self._write_to_tensorboard(self.total_train_loss) | |
# reset | |
self.total_train_loss = defaultdict(float) | |
def _check_train_finish(self): | |
if self.steps >= self.train_max_steps: | |
self.finish_train = True | |
else: | |
self.finish_train = False | |
return self.finish_train | |
def _perplexity(self, perplexity, label=None, mode='train'): | |
if label: | |
name = f"{mode}/ppl_{label}" | |
else: | |
name = f"{mode}/ppl" | |
if torch.numel(perplexity) > 1: | |
perplexity = perplexity.tolist() | |
for idx, ppl in enumerate(perplexity): | |
self._record_loss(f"{name}_{idx}", ppl, mode=mode) | |
else: | |
self._record_loss(name, perplexity, mode=mode) | |
def _vq_loss(self, vqloss, label=None, mode='train'): | |
if label: | |
name = f"{mode}/vqloss_{label}" | |
else: | |
name = f"{mode}/vqloss" | |
vqloss = torch.sum(vqloss) | |
vqloss *= self.config["lambda_vq_loss"] | |
self._record_loss(name, vqloss, mode=mode) | |
return vqloss | |