KingNish's picture
Upload ./RepCodec/trainer/autoencoder.py with huggingface_hub
b49f9aa verified
raw
history blame
9.47 kB
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import logging
import os
from collections import defaultdict
import torch
from tensorboardX import SummaryWriter
from tqdm import tqdm
logger = logging.getLogger("repcodec_train")
class Trainer:
def __init__(
self,
steps: int,
epochs: int,
data_loader: dict,
model: dict,
criterion: dict,
optimizer: dict,
scheduler: dict,
config: dict,
device=torch.device("cpu"),
):
self.steps = steps
self.epochs = epochs
self.data_loader = data_loader
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = scheduler
self.config = config
self.device = device
self.writer = SummaryWriter(config["outdir"])
self.total_train_loss = defaultdict(float)
self.total_eval_loss = defaultdict(float)
self.train_max_steps = config.get("train_max_steps", 0)
def _train_step(self, batch):
"""Single step of training."""
mode = "train"
x = batch
x = x.to(self.device)
codec_loss = 0.0
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
self._perplexity(perplexity, mode=mode)
codec_loss += self._vq_loss(vqloss, mode=mode)
codec_loss += self._metric_loss(y_, x, mode=mode)
self._record_loss("codec_loss", codec_loss, mode=mode)
self._update_repcodec(codec_loss)
self.steps += 1
self.tqdm.update(1)
self._check_train_finish()
@torch.no_grad()
def _eval_step(self, batch):
"""Single step of evaluation."""
mode = "eval"
x = batch
x = x.to(self.device)
codec_loss = 0.0
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
self._perplexity(perplexity, mode=mode)
codec_loss += self._vq_loss(vqloss, mode=mode)
codec_loss += self._metric_loss(y_, x, mode=mode)
self._record_loss("codec_loss", codec_loss, mode=mode)
def run(self):
"""Run training."""
self.finish_train = False
self.tqdm = tqdm(
initial=self.steps, total=self.train_max_steps, desc="[train]"
)
while True:
self._train_epoch()
# check whether training is finished
if self.finish_train:
break
self.tqdm.close()
logger.info("Finished training.")
def save_checkpoint(self, checkpoint_path: str):
state_dict = {
"model": {
"repcodec": self.model["repcodec"].state_dict()
},
"optimizer": {
"repcodec": self.optimizer["repcodec"].state_dict(),
},
"scheduler": {
"repcodec": self.scheduler["repcodec"].state_dict(),
},
"steps": self.steps,
"epochs": self.epochs,
}
if not os.path.exists(os.path.dirname(checkpoint_path)):
os.makedirs(os.path.dirname(checkpoint_path))
torch.save(state_dict, checkpoint_path)
def load_checkpoint(
self,
checkpoint_path: str,
strict: bool = True,
load_only_params: bool = False
):
state_dict = torch.load(checkpoint_path, map_location="cpu")
self.model["repcodec"].load_state_dict(
state_dict["model"]["repcodec"], strict=strict
)
if not load_only_params:
self.steps = state_dict["steps"]
self.epochs = state_dict["epochs"]
self.optimizer["repcodec"].load_state_dict(
state_dict["optimizer"]["repcodec"]
)
self.scheduler["repcodec"].load_state_dict(
state_dict["scheduler"]["repcodec"]
)
def _train_epoch(self):
"""One epoch of training."""
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
# train one step
self._train_step(batch)
# check interval
self._check_log_interval()
self._check_eval_interval()
self._check_save_interval()
# check whether training is finished
if self.finish_train:
return
# update
self.epochs += 1
self.train_steps_per_epoch = train_steps_per_epoch
if train_steps_per_epoch > 200:
logger.info(
f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
f"({self.train_steps_per_epoch} steps per epoch)."
)
def _eval_epoch(self):
"""One epoch of evaluation."""
logger.info(f"(Steps: {self.steps}) Start evaluation.")
# change mode
for key in self.model.keys():
self.model[key].eval()
# calculate loss for each batch
for eval_steps_per_epoch, batch in enumerate(
tqdm(self.data_loader["dev"], desc="[eval]"), 1
):
# eval one step
self._eval_step(batch)
logger.info(
f"(Steps: {self.steps}) Finished evaluation "
f"({eval_steps_per_epoch} steps per epoch)."
)
# average loss
for key in self.total_eval_loss.keys():
self.total_eval_loss[key] /= eval_steps_per_epoch
logger.info(
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
)
# record
self._write_to_tensorboard(self.total_eval_loss)
# reset
self.total_eval_loss = defaultdict(float)
# restore mode
for key in self.model.keys():
self.model[key].train()
def _metric_loss(self, predict_y, natural_y, mode='train'):
"""Metric losses."""
metric_loss = 0.0
repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y)
repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"]
self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode)
metric_loss += repr_reconstruct_loss
return metric_loss
def _update_repcodec(self, repr_loss):
"""Update generator."""
self.optimizer["repcodec"].zero_grad()
repr_loss.backward()
if self.config["grad_norm"] > 0:
torch.nn.utils.clip_grad_norm_(
self.model["repcodec"].parameters(),
self.config["grad_norm"],
)
self.optimizer["repcodec"].step()
self.scheduler["repcodec"].step()
def _record_loss(self, name: str, loss, mode='train'):
"""Record loss."""
if torch.is_tensor(loss):
loss = loss.item()
if mode == 'train':
self.total_train_loss[f"train/{name}"] += loss
elif mode == 'eval':
self.total_eval_loss[f"eval/{name}"] += loss
else:
raise NotImplementedError(f"Mode ({mode}) is not supported!")
def _write_to_tensorboard(self, loss):
"""Write to tensorboard."""
for key, value in loss.items():
self.writer.add_scalar(key, value, self.steps)
def _check_save_interval(self):
if self.steps and (self.steps % self.config["save_interval_steps"] == 0):
self.save_checkpoint(
os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl")
)
logger.info(f"Successfully saved checkpoint @ {self.steps} steps.")
def _check_eval_interval(self):
if self.steps % self.config["eval_interval_steps"] == 0:
self._eval_epoch()
def _check_log_interval(self):
if self.steps % self.config["log_interval_steps"] == 0:
for key in self.total_train_loss.keys():
self.total_train_loss[key] /= self.config["log_interval_steps"]
logger.info(
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
)
self._write_to_tensorboard(self.total_train_loss)
# reset
self.total_train_loss = defaultdict(float)
def _check_train_finish(self):
if self.steps >= self.train_max_steps:
self.finish_train = True
else:
self.finish_train = False
return self.finish_train
def _perplexity(self, perplexity, label=None, mode='train'):
if label:
name = f"{mode}/ppl_{label}"
else:
name = f"{mode}/ppl"
if torch.numel(perplexity) > 1:
perplexity = perplexity.tolist()
for idx, ppl in enumerate(perplexity):
self._record_loss(f"{name}_{idx}", ppl, mode=mode)
else:
self._record_loss(name, perplexity, mode=mode)
def _vq_loss(self, vqloss, label=None, mode='train'):
if label:
name = f"{mode}/vqloss_{label}"
else:
name = f"{mode}/vqloss"
vqloss = torch.sum(vqloss)
vqloss *= self.config["lambda_vq_loss"]
self._record_loss(name, vqloss, mode=mode)
return vqloss