""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import argparse import copy import json import logging import os import sys import warnings from typing import Any, Dict import model.vqvae as vqvae import numpy as np import torch import torch.optim as optim from data_loaders.get_data import get_dataset_loader, load_local_data from diffusion.nn import sum_flat from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from utils.vq_parser_utils import train_args warnings.filterwarnings("ignore") def cycle(iterable): while True: for x in iterable: yield x def get_logger(out_dir: str): logger = logging.getLogger("Exp") logger.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") file_path = os.path.join(out_dir, "run.log") file_hdlr = logging.FileHandler(file_path) file_hdlr.setFormatter(formatter) strm_hdlr = logging.StreamHandler(sys.stdout) strm_hdlr.setFormatter(formatter) logger.addHandler(file_hdlr) logger.addHandler(strm_hdlr) return logger class ModelTrainer: def __init__(self, args, net: vqvae.TemporalVertexCodec, logger, writer): self.net = net self.warm_up_iter = args.warm_up_iter self.lr = args.lr self.optimizer = optim.AdamW( self.net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay, ) self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=args.lr_scheduler, gamma=args.gamma ) self.data_format = args.data_format self.loss = torch.nn.SmoothL1Loss() self.loss_vel = args.loss_vel self.commit = args.commit self.logger = logger self.writer = writer self.best_commit = float("inf") self.best_recons = float("inf") self.best_perplexity = float("inf") self.best_iter = 0 self.out_dir = args.out_dir def _masked_l2(self, a, b, mask): loss = self._l2_loss(a, b) loss = sum_flat(loss * mask.float()) n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries mse_loss_val = loss / non_zero_elements return mse_loss_val def _l2_loss(self, motion_pred, motion_gt, mask=None): if mask is not None: return self._masked_l2(motion_pred, motion_gt, mask) else: return self.loss(motion_pred, motion_gt) def _vel_loss(self, motion_pred, motion_gt): model_results_vel = motion_pred[..., :-1] - motion_pred[..., 1:] model_targets_vel = motion_gt[..., :-1] - motion_gt[..., 1:] return self.loss(model_results_vel, model_targets_vel) def _update_lr_warm_up(self, nb_iter): current_lr = self.lr * (nb_iter + 1) / (self.warm_up_iter + 1) for param_group in self.optimizer.param_groups: param_group["lr"] = current_lr return current_lr def run_warmup_steps(self, train_loader_iter, skip_step, logger): avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 for nb_iter in tqdm(range(1, args.warm_up_iter)): current_lr = self._update_lr_warm_up(nb_iter) gt_motion, cond = next(train_loader_iter) loss_dict = self.run_train_step(gt_motion, cond, skip_step) avg_recons += loss_dict["loss_motion"] avg_perplexity += loss_dict["perplexity"] avg_commit += loss_dict["loss_commit"] if nb_iter % args.print_iter == 0: avg_recons /= args.print_iter avg_perplexity /= args.print_iter avg_commit /= args.print_iter logger.info( f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" ) avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 def run_train_step( self, gt_motion: torch.Tensor, cond: torch.Tensor, skip_step: int ) -> Dict[str, Any]: self.net.train() loss_dict = {} # run model gt_motion = gt_motion.permute(0, 3, 1, 2).squeeze(-1).cuda().float() cond["y"] = { key: val.to(gt_motion.device) if torch.is_tensor(val) else val for key, val in cond["y"].items() } gt_motion = gt_motion[:, ::skip_step, :] pred_motion, loss_commit, perplexity = self.net(gt_motion, mask=None) loss_motion = self._l2_loss(pred_motion, gt_motion).mean() loss_vel = 0.0 if self.loss_vel > 0: loss_vel = self._vel_loss(pred_motion, gt_motion) loss = loss_motion + self.commit * loss_commit + self.loss_vel * loss_vel self.optimizer.zero_grad() loss.backward() self.optimizer.step() # record losses if self.loss_vel > 0: loss_dict["vel"] = loss_vel.item() loss_dict["loss"] = loss.item() loss_dict["loss_motion"] = loss_motion.item() loss_dict["loss_commit"] = loss_commit.item() loss_dict["perplexity"] = perplexity.item() return loss_dict def save_model(self, save_path): torch.save( { "net": self.net.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler, }, save_path, ) def _save_predictions(self, name, unstd_pose, unstd_pred): curr_name = os.path.basename(name) path = os.path.join(self.out_dir, curr_name) for j in range(len(path.split("/")) - 1): if not os.path.exists("/".join(path.split("/")[: j + 1])): os.system("mkdir " + "/".join(path.split("/")[: j + 1])) np.save(os.path.join(self.out_dir, curr_name + "_gt.npy"), unstd_pose) np.save(os.path.join(self.out_dir, curr_name + "_pred.npy"), unstd_pred) def _log_losses( self, commit_loss: float, recons_loss: float, total_perplexity: float, nb_iter: int, nb_sample: int, draw: bool, save: bool, ) -> None: avg_commit = commit_loss / nb_sample avg_recons = recons_loss / nb_sample avg_perplexity = total_perplexity / nb_sample self.logger.info( f"Eval. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" ) if draw: self.writer.add_scalar("./Val/Perplexity", avg_perplexity, nb_iter) self.writer.add_scalar("./Val/Commit", avg_commit, nb_iter) self.writer.add_scalar("./Val/Recons", avg_recons, nb_iter) if avg_perplexity < self.best_perplexity: msg = f"--> --> \t Perplexity Improved from {self.best_perplexity:.5f} to {avg_perplexity:.5f} !!!" self.logger.info(msg) self.best_perplexity = avg_perplexity if save: print(f"saving checkpoint net_best.pth") self.save_model(os.path.join(self.out_dir, "net_best.pth")) if avg_commit < self.best_commit: msg = f"--> --> \t Commit Improved from {self.best_commit:.5f} to {avg_commit:.5f} !!!" self.logger.info(msg) self.best_commit = avg_commit if avg_recons < self.best_recons: msg = f"--> --> \t Recons Improved from {self.best_recons:.5f} to {avg_recons:.5f} !!!" self.logger.info(msg) self.best_recons = avg_recons @torch.no_grad() def evaluation_vqvae( self, val_loader, nb_iter: int, draw: bool = True, save: bool = True, savenpy: bool = False, ) -> None: self.net.eval() nb_sample = 0 commit_loss = 0 recons_loss = 0 total_perplexity = 0 for _, batch in enumerate(val_loader): motion, cond = batch m_length = cond["y"]["lengths"] motion = motion.permute(0, 3, 1, 2).squeeze(-1).cuda().float() cond["y"] = { key: val.to(motion.device) if torch.is_tensor(val) else val for key, val in cond["y"].items() } motion = motion[:, :: val_loader.dataset.step, :].cuda().float() bs, seq = motion.shape[0], motion.shape[1] pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() for i in range(bs): curr_gt = motion[i : i + 1, : m_length[i]] pred, loss_commit, perplexity = self.net(curr_gt) l2_loss = self._l2_loss(pred, curr_gt) recons_loss += l2_loss.mean().item() commit_loss += loss_commit total_perplexity += perplexity unstd_pred = val_loader.dataset.inv_transform( pred.detach().cpu().numpy(), self.data_format ) unstd_pose = val_loader.dataset.inv_transform( curr_gt.detach().cpu().numpy(), self.data_format ) if savenpy: self._save_predictions( "b{i:04d}", unstd_pose[:, : m_length[i]], unstd_pred ) pred_pose_eval[i : i + 1, : m_length[i], :] = pred nb_sample += bs self._log_losses( commit_loss, recons_loss, total_perplexity, nb_iter, nb_sample, draw, save ) if save: print(f"saving checkpoint net_last.pth") self.save_model(os.path.join(self.out_dir, "net_last.pth")) if nb_iter % 100000 == 0: print(f"saving checkpoint net_iter_x.pth") self.save_model( os.path.join(self.out_dir, "net_iter" + str(nb_iter) + ".pth") ) def _load_data_info(args, logger): data_dict = load_local_data(args.data_root, audio_per_frame=1600) train_loader = get_dataset_loader( args=args, data_dict=data_dict, split="train", add_padding=False ) val_loader = get_dataset_loader( args=args, data_dict=data_dict, split="val", add_padding=False ) logger.info( f"Training on {args.dataname}, motions are with {args.nb_joints} joints" ) train_loader_iter = cycle(train_loader) skip_step = train_loader.dataset.step return train_loader_iter, val_loader, skip_step def _load_checkpoint(args, net, logger): cp_dir = os.path.dirname(args.resume_pth) with open(f"{cp_dir}/args.json") as f: trans_args = json.load(f) assert trans_args["data_root"] == args.data_root, "data_root doesnt match" logger.info("loading checkpoint from {}".format(args.resume_pth)) ckpt = torch.load(args.resume_pth, map_location="cpu") net.load_state_dict(ckpt["net"], strict=True) return net def main(args): torch.manual_seed(args.seed) os.makedirs(args.out_dir, exist_ok=True) logger = get_logger(args.out_dir) writer = SummaryWriter(args.out_dir) logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) if args.data_format == "pose": args.nb_joints = 104 elif args.data_format == "face": args.nb_joints = 256 args_path = os.path.join(args.out_dir, "args.json") with open(args_path, "w") as fw: json.dump(vars(args), fw, indent=4, sort_keys=True) if not os.path.exists(args.data_root): args.data_root = args.data_root.replace("/home/", "/derived/") train_loader_iter, val_loader, skip_step = _load_data_info(args, logger) net = vqvae.TemporalVertexCodec( n_vertices=args.nb_joints, latent_dim=args.output_emb_width, categories=args.code_dim, residual_depth=args.depth, ) if args.resume_pth: net = _load_checkpoint(args, net, logger) net.train() net.cuda() trainer = ModelTrainer(args, net, logger, writer) trainer.run_warmup_steps(train_loader_iter, skip_step, logger) avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 with torch.no_grad(): trainer.evaluation_vqvae( val_loader, 0, save=(args.total_iter > 0), savenpy=True ) for nb_iter in range(1, args.total_iter + 1): gt_motion, cond = next(train_loader_iter) loss_dict = trainer.run_train_step(gt_motion, cond, skip_step) trainer.scheduler.step() avg_recons += loss_dict["loss_motion"] avg_perplexity += loss_dict["perplexity"] avg_commit += loss_dict["loss_commit"] if nb_iter % args.print_iter == 0: avg_recons /= args.print_iter avg_perplexity /= args.print_iter avg_commit /= args.print_iter writer.add_scalar("./Train/L1", avg_recons, nb_iter) writer.add_scalar("./Train/PPL", avg_perplexity, nb_iter) writer.add_scalar("./Train/Commit", avg_commit, nb_iter) logger.info( f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" ) avg_recons, avg_perplexity, avg_commit = (0.0, 0.0, 0.0) if nb_iter % args.eval_iter == 0: trainer.evaluation_vqvae( val_loader, nb_iter, save=(args.total_iter > 0), savenpy=True ) if __name__ == "__main__": args = train_args() main(args)