Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
13.7 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import time
import torch
import shutil
import logging
import numpy as np
import os.path as osp
from progress.bar import Bar
from configs import constants as _C
from lib.utils import transforms
from lib.utils.utils import AverageMeter, prepare_batch
from lib.eval.eval_utils import (
compute_accel,
compute_error_accel,
batch_align_by_pelvis,
batch_compute_similarity_transform_torch,
)
from lib.models import build_body_model
logger = logging.getLogger(__name__)
class Trainer():
def __init__(self,
data_loaders,
network,
optimizer,
criterion=None,
train_stage='syn',
start_epoch=0,
checkpoint=None,
end_epoch=999,
lr_scheduler=None,
device=None,
writer=None,
debug=False,
resume=False,
logdir='output',
performance_type='min',
summary_iter=1,
):
self.train_loader, self.valid_loader = data_loaders
# Model and optimizer
self.network = network
self.optimizer = optimizer
# Training parameters
self.train_stage = train_stage
self.start_epoch = start_epoch
self.end_epoch = end_epoch
self.criterion = criterion
self.lr_scheduler = lr_scheduler
self.device = device
self.writer = writer
self.debug = debug
self.resume = resume
self.logdir = logdir
self.summary_iter = summary_iter
self.performance_type = performance_type
self.train_global_step = 0
self.valid_global_step = 0
self.epoch = 0
self.best_performance = float('inf') if performance_type == 'min' else -float('inf')
self.summary_loss_keys = ['pose']
self.evaluation_accumulators = dict.fromkeys(
['pred_j3d', 'target_j3d', 'pve'])# 'pred_verts', 'target_verts'])
self.J_regressor_eval = torch.from_numpy(
np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M)
)[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(device)
if self.writer is None:
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(log_dir=self.logdir)
if self.device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if checkpoint is not None:
self.load_pretrained(checkpoint)
def train(self, ):
# Single epoch training routine
losses = AverageMeter()
kp_2d_loss = AverageMeter()
kp_3d_loss = AverageMeter()
timer = {
'data': 0,
'forward': 0,
'loss': 0,
'backward': 0,
'batch': 0,
}
self.network.train()
start = time.time()
summary_string = ''
bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=len(self.train_loader))
for i, batch in enumerate(self.train_loader):
# <======= Feedforward
x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2')
timer['data'] = time.time() - start
start = time.time()
pred = self.network(x, inits, features, **kwargs)
timer['forward'] = time.time() - start
start = time.time()
# =======>
# <======= Backprop
loss, loss_dict = self.criterion(pred, gt)
timer['loss'] = time.time() - start
start = time.time()
# Clip gradients
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.optimizer.step()
# =======>
# <======= Log training info
total_loss = loss
losses.update(total_loss.item(), x.size(0))
kp_2d_loss.update(loss_dict['2d'].item(), x.size(0))
kp_3d_loss.update(loss_dict['3d'].item(), x.size(0))
timer['backward'] = time.time() - start
timer['batch'] = timer['data'] + timer['forward'] + timer['loss'] + timer['backward']
start = time.time()
summary_string = f'({i + 1}/{len(self.train_loader)}) | Total: {bar.elapsed_td} ' \
f'| loss: {losses.avg:.2f} | 2d: {kp_2d_loss.avg:.2f} ' \
f'| 3d: {kp_3d_loss.avg:.2f} '
for k, v in loss_dict.items():
if k in self.summary_loss_keys:
summary_string += f' | {k}: {v:.2f}'
if (i + 1) % self.summary_iter == 0:
self.writer.add_scalar('train_loss/'+k, v, global_step=self.train_global_step)
if (i + 1) % self.summary_iter == 0:
self.writer.add_scalar('train_loss/loss', total_loss.item(), global_step=self.train_global_step)
self.train_global_step += 1
bar.suffix = summary_string
bar.next(1)
if torch.isnan(total_loss):
exit('Nan value in loss, exiting!...')
# =======>
logger.info(summary_string)
bar.finish()
def validate(self, ):
self.network.eval()
start = time.time()
summary_string = ''
bar = Bar('Validation', fill='#', max=len(self.valid_loader))
if self.evaluation_accumulators is not None:
for k,v in self.evaluation_accumulators.items():
self.evaluation_accumulators[k] = []
with torch.no_grad():
for i, batch in enumerate(self.valid_loader):
x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2')
# <======= Feedforward
pred = self.network(x, inits, features, **kwargs)
# 3DPW dataset has groundtruth vertices
# NOTE: Following SPIN, we compute PVE against ground truth from Gendered SMPL mesh
smpl = build_body_model(self.device, batch_size=len(pred['verts_cam']), gender=batch['gender'][0])
gt_output = smpl.get_output(
body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]),
global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]),
betas=gt['betas'][0],
pose2rot=False
)
pred_j3d = torch.matmul(self.J_regressor_eval, pred['verts_cam']).cpu()
target_j3d = torch.matmul(self.J_regressor_eval, gt_output.vertices).cpu()
pred_verts = pred['verts_cam'].cpu()
target_verts = gt_output.vertices.cpu()
pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis(
[pred_j3d, target_j3d, pred_verts, target_verts], [2, 3]
)
self.evaluation_accumulators['pred_j3d'].append(pred_j3d.numpy())
self.evaluation_accumulators['target_j3d'].append(target_j3d.numpy())
pve = np.sqrt(np.sum((target_verts.numpy() - pred_verts.numpy()) ** 2, axis=-1)).mean(-1) * 1e3
self.evaluation_accumulators['pve'].append(pve[:, None])
# =======>
batch_time = time.time() - start
summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \
f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}'
self.valid_global_step += 1
bar.suffix = summary_string
bar.next()
logger.info(summary_string)
bar.finish()
def evaluate(self, ):
for k, v in self.evaluation_accumulators.items():
self.evaluation_accumulators[k] = np.vstack(v)
pred_j3ds = self.evaluation_accumulators['pred_j3d']
target_j3ds = self.evaluation_accumulators['target_j3d']
pred_j3ds = torch.from_numpy(pred_j3ds).float()
target_j3ds = torch.from_numpy(target_j3ds).float()
print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...')
errors = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
S1_hat = batch_compute_similarity_transform_torch(pred_j3ds, target_j3ds)
errors_pa = torch.sqrt(((S1_hat - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
m2mm = 1000
accel = np.mean(compute_accel(pred_j3ds)) * m2mm
accel_err = np.mean(compute_error_accel(joints_pred=pred_j3ds, joints_gt=target_j3ds)) * m2mm
mpjpe = np.mean(errors) * m2mm
pa_mpjpe = np.mean(errors_pa) * m2mm
eval_dict = {
'mpjpe': mpjpe,
'pa-mpjpe': pa_mpjpe,
'accel': accel,
'accel_err': accel_err
}
if 'pred_verts' in self.evaluation_accumulators.keys():
eval_dict.update({'pve': self.evaluation_accumulators['pve'].mean()})
log_str = f'Epoch {self.epoch}, '
log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in eval_dict.items()])
logger.info(log_str)
for k,v in eval_dict.items():
self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch)
# return (mpjpe + pa_mpjpe) / 2.
return pa_mpjpe
def save_model(self, performance, epoch):
save_dict = {
'epoch': epoch,
'model': self.network.state_dict(),
'performance': performance,
'optimizer': self.optimizer.state_dict(),
}
filename = osp.join(self.logdir, 'checkpoint.pth.tar')
torch.save(save_dict, filename)
if self.performance_type == 'min':
is_best = performance < self.best_performance
else:
is_best = performance > self.best_performance
if is_best:
logger.info('Best performance achived, saving it!')
self.best_performance = performance
shutil.copyfile(filename, osp.join(self.logdir, 'model_best.pth.tar'))
with open(osp.join(self.logdir, 'best.txt'), 'w') as f:
f.write(str(float(performance)))
def fit(self):
for epoch in range(self.start_epoch, self.end_epoch):
self.epoch = epoch
self.train()
self.validate()
performance = self.evaluate()
self.criterion.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
# log the learning rate
for param_group in self.optimizer.param_groups[:2]:
print(f'Learning rate {param_group["lr"]}')
self.writer.add_scalar('lr', param_group['lr'], global_step=self.epoch)
logger.info(f'Epoch {epoch+1} performance: {performance:.4f}')
self.save_model(performance, epoch)
self.train_loader.dataset.prepare_video_batch()
self.writer.close()
def load_pretrained(self, model_path):
if osp.isfile(model_path):
checkpoint = torch.load(model_path)
# network
ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval']
ignore_keys2 = [k for k in checkpoint['model'].keys() if 'integrator' in k]
ignore_keys.extend(ignore_keys2)
model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys}
model_state_dict = {k: v for k, v in model_state_dict.items() if k in self.network.state_dict().keys()}
self.network.load_state_dict(model_state_dict, strict=False)
if self.resume:
self.start_epoch = checkpoint['epoch']
self.best_performance = checkpoint['performance']
self.optimizer.load_state_dict(checkpoint['optimizer'])
logger.info(f"=> loaded checkpoint '{model_path}' "
f"(epoch {self.start_epoch}, performance {self.best_performance})")
else:
logger.info(f"=> no checkpoint found at '{model_path}'")