|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from collections import defaultdict |
|
|
|
import torch |
|
from tensorboardX import SummaryWriter |
|
from tqdm import tqdm |
|
|
|
logger = logging.getLogger("repcodec_train") |
|
|
|
|
|
class Trainer: |
|
def __init__( |
|
self, |
|
steps: int, |
|
epochs: int, |
|
data_loader: dict, |
|
model: dict, |
|
criterion: dict, |
|
optimizer: dict, |
|
scheduler: dict, |
|
config: dict, |
|
device=torch.device("cpu"), |
|
): |
|
self.steps = steps |
|
self.epochs = epochs |
|
self.data_loader = data_loader |
|
self.model = model |
|
self.criterion = criterion |
|
self.optimizer = optimizer |
|
self.scheduler = scheduler |
|
self.config = config |
|
self.device = device |
|
self.writer = SummaryWriter(config["outdir"]) |
|
self.total_train_loss = defaultdict(float) |
|
self.total_eval_loss = defaultdict(float) |
|
self.train_max_steps = config.get("train_max_steps", 0) |
|
|
|
def _train_step(self, batch): |
|
"""Single step of training.""" |
|
mode = "train" |
|
x = batch |
|
x = x.to(self.device) |
|
|
|
codec_loss = 0.0 |
|
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x) |
|
self._perplexity(perplexity, mode=mode) |
|
codec_loss += self._vq_loss(vqloss, mode=mode) |
|
codec_loss += self._metric_loss(y_, x, mode=mode) |
|
|
|
self._record_loss("codec_loss", codec_loss, mode=mode) |
|
self._update_repcodec(codec_loss) |
|
|
|
self.steps += 1 |
|
self.tqdm.update(1) |
|
self._check_train_finish() |
|
|
|
@torch.no_grad() |
|
def _eval_step(self, batch): |
|
"""Single step of evaluation.""" |
|
mode = "eval" |
|
x = batch |
|
x = x.to(self.device) |
|
|
|
codec_loss = 0.0 |
|
y_, zq, z, vqloss, perplexity = self.model["repcodec"](x) |
|
self._perplexity(perplexity, mode=mode) |
|
codec_loss += self._vq_loss(vqloss, mode=mode) |
|
codec_loss += self._metric_loss(y_, x, mode=mode) |
|
|
|
self._record_loss("codec_loss", codec_loss, mode=mode) |
|
|
|
def run(self): |
|
"""Run training.""" |
|
self.finish_train = False |
|
self.tqdm = tqdm( |
|
initial=self.steps, total=self.train_max_steps, desc="[train]" |
|
) |
|
while True: |
|
self._train_epoch() |
|
|
|
|
|
if self.finish_train: |
|
break |
|
|
|
self.tqdm.close() |
|
logger.info("Finished training.") |
|
|
|
def save_checkpoint(self, checkpoint_path: str): |
|
state_dict = { |
|
"model": { |
|
"repcodec": self.model["repcodec"].state_dict() |
|
}, |
|
"optimizer": { |
|
"repcodec": self.optimizer["repcodec"].state_dict(), |
|
}, |
|
"scheduler": { |
|
"repcodec": self.scheduler["repcodec"].state_dict(), |
|
}, |
|
"steps": self.steps, |
|
"epochs": self.epochs, |
|
} |
|
|
|
if not os.path.exists(os.path.dirname(checkpoint_path)): |
|
os.makedirs(os.path.dirname(checkpoint_path)) |
|
torch.save(state_dict, checkpoint_path) |
|
|
|
def load_checkpoint( |
|
self, |
|
checkpoint_path: str, |
|
strict: bool = True, |
|
load_only_params: bool = False |
|
): |
|
state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
self.model["repcodec"].load_state_dict( |
|
state_dict["model"]["repcodec"], strict=strict |
|
) |
|
|
|
if not load_only_params: |
|
self.steps = state_dict["steps"] |
|
self.epochs = state_dict["epochs"] |
|
self.optimizer["repcodec"].load_state_dict( |
|
state_dict["optimizer"]["repcodec"] |
|
) |
|
self.scheduler["repcodec"].load_state_dict( |
|
state_dict["scheduler"]["repcodec"] |
|
) |
|
|
|
def _train_epoch(self): |
|
"""One epoch of training.""" |
|
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): |
|
|
|
self._train_step(batch) |
|
|
|
|
|
self._check_log_interval() |
|
self._check_eval_interval() |
|
self._check_save_interval() |
|
|
|
|
|
if self.finish_train: |
|
return |
|
|
|
|
|
self.epochs += 1 |
|
self.train_steps_per_epoch = train_steps_per_epoch |
|
if train_steps_per_epoch > 200: |
|
logger.info( |
|
f"(Steps: {self.steps}) Finished {self.epochs} epoch training " |
|
f"({self.train_steps_per_epoch} steps per epoch)." |
|
) |
|
|
|
def _eval_epoch(self): |
|
"""One epoch of evaluation.""" |
|
logger.info(f"(Steps: {self.steps}) Start evaluation.") |
|
|
|
for key in self.model.keys(): |
|
self.model[key].eval() |
|
|
|
|
|
for eval_steps_per_epoch, batch in enumerate( |
|
tqdm(self.data_loader["dev"], desc="[eval]"), 1 |
|
): |
|
|
|
self._eval_step(batch) |
|
|
|
logger.info( |
|
f"(Steps: {self.steps}) Finished evaluation " |
|
f"({eval_steps_per_epoch} steps per epoch)." |
|
) |
|
|
|
|
|
for key in self.total_eval_loss.keys(): |
|
self.total_eval_loss[key] /= eval_steps_per_epoch |
|
logger.info( |
|
f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." |
|
) |
|
|
|
|
|
self._write_to_tensorboard(self.total_eval_loss) |
|
|
|
|
|
self.total_eval_loss = defaultdict(float) |
|
|
|
|
|
for key in self.model.keys(): |
|
self.model[key].train() |
|
|
|
def _metric_loss(self, predict_y, natural_y, mode='train'): |
|
"""Metric losses.""" |
|
metric_loss = 0.0 |
|
|
|
repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y) |
|
repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"] |
|
self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode) |
|
metric_loss += repr_reconstruct_loss |
|
|
|
return metric_loss |
|
|
|
def _update_repcodec(self, repr_loss): |
|
"""Update generator.""" |
|
self.optimizer["repcodec"].zero_grad() |
|
repr_loss.backward() |
|
if self.config["grad_norm"] > 0: |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model["repcodec"].parameters(), |
|
self.config["grad_norm"], |
|
) |
|
self.optimizer["repcodec"].step() |
|
self.scheduler["repcodec"].step() |
|
|
|
def _record_loss(self, name: str, loss, mode='train'): |
|
"""Record loss.""" |
|
if torch.is_tensor(loss): |
|
loss = loss.item() |
|
|
|
if mode == 'train': |
|
self.total_train_loss[f"train/{name}"] += loss |
|
elif mode == 'eval': |
|
self.total_eval_loss[f"eval/{name}"] += loss |
|
else: |
|
raise NotImplementedError(f"Mode ({mode}) is not supported!") |
|
|
|
def _write_to_tensorboard(self, loss): |
|
"""Write to tensorboard.""" |
|
for key, value in loss.items(): |
|
self.writer.add_scalar(key, value, self.steps) |
|
|
|
def _check_save_interval(self): |
|
if self.steps and (self.steps % self.config["save_interval_steps"] == 0): |
|
self.save_checkpoint( |
|
os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl") |
|
) |
|
logger.info(f"Successfully saved checkpoint @ {self.steps} steps.") |
|
|
|
def _check_eval_interval(self): |
|
if self.steps % self.config["eval_interval_steps"] == 0: |
|
self._eval_epoch() |
|
|
|
def _check_log_interval(self): |
|
if self.steps % self.config["log_interval_steps"] == 0: |
|
for key in self.total_train_loss.keys(): |
|
self.total_train_loss[key] /= self.config["log_interval_steps"] |
|
logger.info( |
|
f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." |
|
) |
|
self._write_to_tensorboard(self.total_train_loss) |
|
|
|
|
|
self.total_train_loss = defaultdict(float) |
|
|
|
def _check_train_finish(self): |
|
if self.steps >= self.train_max_steps: |
|
self.finish_train = True |
|
else: |
|
self.finish_train = False |
|
return self.finish_train |
|
|
|
def _perplexity(self, perplexity, label=None, mode='train'): |
|
if label: |
|
name = f"{mode}/ppl_{label}" |
|
else: |
|
name = f"{mode}/ppl" |
|
if torch.numel(perplexity) > 1: |
|
perplexity = perplexity.tolist() |
|
for idx, ppl in enumerate(perplexity): |
|
self._record_loss(f"{name}_{idx}", ppl, mode=mode) |
|
else: |
|
self._record_loss(name, perplexity, mode=mode) |
|
|
|
def _vq_loss(self, vqloss, label=None, mode='train'): |
|
if label: |
|
name = f"{mode}/vqloss_{label}" |
|
else: |
|
name = f"{mode}/vqloss" |
|
vqloss = torch.sum(vqloss) |
|
vqloss *= self.config["lambda_vq_loss"] |
|
self._record_loss(name, vqloss, mode=mode) |
|
|
|
return vqloss |
|
|