# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de import time import torch import shutil import logging import numpy as np import os.path as osp from progress.bar import Bar from configs import constants as _C from lib.utils import transforms from lib.utils.utils import AverageMeter, prepare_batch from lib.eval.eval_utils import ( compute_accel, compute_error_accel, batch_align_by_pelvis, batch_compute_similarity_transform_torch, ) from lib.models import build_body_model logger = logging.getLogger(__name__) class Trainer(): def __init__(self, data_loaders, network, optimizer, criterion=None, train_stage='syn', start_epoch=0, checkpoint=None, end_epoch=999, lr_scheduler=None, device=None, writer=None, debug=False, resume=False, logdir='output', performance_type='min', summary_iter=1, ): self.train_loader, self.valid_loader = data_loaders # Model and optimizer self.network = network self.optimizer = optimizer # Training parameters self.train_stage = train_stage self.start_epoch = start_epoch self.end_epoch = end_epoch self.criterion = criterion self.lr_scheduler = lr_scheduler self.device = device self.writer = writer self.debug = debug self.resume = resume self.logdir = logdir self.summary_iter = summary_iter self.performance_type = performance_type self.train_global_step = 0 self.valid_global_step = 0 self.epoch = 0 self.best_performance = float('inf') if performance_type == 'min' else -float('inf') self.summary_loss_keys = ['pose'] self.evaluation_accumulators = dict.fromkeys( ['pred_j3d', 'target_j3d', 'pve'])# 'pred_verts', 'target_verts']) self.J_regressor_eval = torch.from_numpy( np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M) )[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(device) if self.writer is None: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.logdir) if self.device is None: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if checkpoint is not None: self.load_pretrained(checkpoint) def train(self, ): # Single epoch training routine losses = AverageMeter() kp_2d_loss = AverageMeter() kp_3d_loss = AverageMeter() timer = { 'data': 0, 'forward': 0, 'loss': 0, 'backward': 0, 'batch': 0, } self.network.train() start = time.time() summary_string = '' bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=len(self.train_loader)) for i, batch in enumerate(self.train_loader): # <======= Feedforward x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2') timer['data'] = time.time() - start start = time.time() pred = self.network(x, inits, features, **kwargs) timer['forward'] = time.time() - start start = time.time() # =======> # <======= Backprop loss, loss_dict = self.criterion(pred, gt) timer['loss'] = time.time() - start start = time.time() # Clip gradients self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) self.optimizer.step() # =======> # <======= Log training info total_loss = loss losses.update(total_loss.item(), x.size(0)) kp_2d_loss.update(loss_dict['2d'].item(), x.size(0)) kp_3d_loss.update(loss_dict['3d'].item(), x.size(0)) timer['backward'] = time.time() - start timer['batch'] = timer['data'] + timer['forward'] + timer['loss'] + timer['backward'] start = time.time() summary_string = f'({i + 1}/{len(self.train_loader)}) | Total: {bar.elapsed_td} ' \ f'| loss: {losses.avg:.2f} | 2d: {kp_2d_loss.avg:.2f} ' \ f'| 3d: {kp_3d_loss.avg:.2f} ' for k, v in loss_dict.items(): if k in self.summary_loss_keys: summary_string += f' | {k}: {v:.2f}' if (i + 1) % self.summary_iter == 0: self.writer.add_scalar('train_loss/'+k, v, global_step=self.train_global_step) if (i + 1) % self.summary_iter == 0: self.writer.add_scalar('train_loss/loss', total_loss.item(), global_step=self.train_global_step) self.train_global_step += 1 bar.suffix = summary_string bar.next(1) if torch.isnan(total_loss): exit('Nan value in loss, exiting!...') # =======> logger.info(summary_string) bar.finish() def validate(self, ): self.network.eval() start = time.time() summary_string = '' bar = Bar('Validation', fill='#', max=len(self.valid_loader)) if self.evaluation_accumulators is not None: for k,v in self.evaluation_accumulators.items(): self.evaluation_accumulators[k] = [] with torch.no_grad(): for i, batch in enumerate(self.valid_loader): x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2') # <======= Feedforward pred = self.network(x, inits, features, **kwargs) # 3DPW dataset has groundtruth vertices # NOTE: Following SPIN, we compute PVE against ground truth from Gendered SMPL mesh smpl = build_body_model(self.device, batch_size=len(pred['verts_cam']), gender=batch['gender'][0]) gt_output = smpl.get_output( body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]), global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]), betas=gt['betas'][0], pose2rot=False ) pred_j3d = torch.matmul(self.J_regressor_eval, pred['verts_cam']).cpu() target_j3d = torch.matmul(self.J_regressor_eval, gt_output.vertices).cpu() pred_verts = pred['verts_cam'].cpu() target_verts = gt_output.vertices.cpu() pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( [pred_j3d, target_j3d, pred_verts, target_verts], [2, 3] ) self.evaluation_accumulators['pred_j3d'].append(pred_j3d.numpy()) self.evaluation_accumulators['target_j3d'].append(target_j3d.numpy()) pve = np.sqrt(np.sum((target_verts.numpy() - pred_verts.numpy()) ** 2, axis=-1)).mean(-1) * 1e3 self.evaluation_accumulators['pve'].append(pve[:, None]) # =======> batch_time = time.time() - start summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \ f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}' self.valid_global_step += 1 bar.suffix = summary_string bar.next() logger.info(summary_string) bar.finish() def evaluate(self, ): for k, v in self.evaluation_accumulators.items(): self.evaluation_accumulators[k] = np.vstack(v) pred_j3ds = self.evaluation_accumulators['pred_j3d'] target_j3ds = self.evaluation_accumulators['target_j3d'] pred_j3ds = torch.from_numpy(pred_j3ds).float() target_j3ds = torch.from_numpy(target_j3ds).float() print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...') errors = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() S1_hat = batch_compute_similarity_transform_torch(pred_j3ds, target_j3ds) errors_pa = torch.sqrt(((S1_hat - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() m2mm = 1000 accel = np.mean(compute_accel(pred_j3ds)) * m2mm accel_err = np.mean(compute_error_accel(joints_pred=pred_j3ds, joints_gt=target_j3ds)) * m2mm mpjpe = np.mean(errors) * m2mm pa_mpjpe = np.mean(errors_pa) * m2mm eval_dict = { 'mpjpe': mpjpe, 'pa-mpjpe': pa_mpjpe, 'accel': accel, 'accel_err': accel_err } if 'pred_verts' in self.evaluation_accumulators.keys(): eval_dict.update({'pve': self.evaluation_accumulators['pve'].mean()}) log_str = f'Epoch {self.epoch}, ' log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in eval_dict.items()]) logger.info(log_str) for k,v in eval_dict.items(): self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch) # return (mpjpe + pa_mpjpe) / 2. return pa_mpjpe def save_model(self, performance, epoch): save_dict = { 'epoch': epoch, 'model': self.network.state_dict(), 'performance': performance, 'optimizer': self.optimizer.state_dict(), } filename = osp.join(self.logdir, 'checkpoint.pth.tar') torch.save(save_dict, filename) if self.performance_type == 'min': is_best = performance < self.best_performance else: is_best = performance > self.best_performance if is_best: logger.info('Best performance achived, saving it!') self.best_performance = performance shutil.copyfile(filename, osp.join(self.logdir, 'model_best.pth.tar')) with open(osp.join(self.logdir, 'best.txt'), 'w') as f: f.write(str(float(performance))) def fit(self): for epoch in range(self.start_epoch, self.end_epoch): self.epoch = epoch self.train() self.validate() performance = self.evaluate() self.criterion.step() if self.lr_scheduler is not None: self.lr_scheduler.step() # log the learning rate for param_group in self.optimizer.param_groups[:2]: print(f'Learning rate {param_group["lr"]}') self.writer.add_scalar('lr', param_group['lr'], global_step=self.epoch) logger.info(f'Epoch {epoch+1} performance: {performance:.4f}') self.save_model(performance, epoch) self.train_loader.dataset.prepare_video_batch() self.writer.close() def load_pretrained(self, model_path): if osp.isfile(model_path): checkpoint = torch.load(model_path) # network ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval'] ignore_keys2 = [k for k in checkpoint['model'].keys() if 'integrator' in k] ignore_keys.extend(ignore_keys2) model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys} model_state_dict = {k: v for k, v in model_state_dict.items() if k in self.network.state_dict().keys()} self.network.load_state_dict(model_state_dict, strict=False) if self.resume: self.start_epoch = checkpoint['epoch'] self.best_performance = checkpoint['performance'] self.optimizer.load_state_dict(checkpoint['optimizer']) logger.info(f"=> loaded checkpoint '{model_path}' " f"(epoch {self.start_epoch}, performance {self.best_performance})") else: logger.info(f"=> no checkpoint found at '{model_path}'")