diff --git a/.gitattributes b/.gitattributes index 5045717190f487fc0ef3792d1ccee60d016d36a9..b7c35705b9fcfdb94022672a28dd5b2d9f4b8762 100644 --- a/.gitattributes +++ b/.gitattributes @@ -37,3 +37,7 @@ examples/drone_video.mp4 filter=lfs diff=lfs merge=lfs -text examples/IMG_9730.mov filter=lfs diff=lfs merge=lfs -text examples/IMG_9731.mov filter=lfs diff=lfs merge=lfs -text examples/IMG_9732.mov filter=lfs diff=lfs merge=lfs -text +examples/test16.mov filter=lfs diff=lfs merge=lfs -text +examples/test17.mov filter=lfs diff=lfs merge=lfs -text +examples/test18.mov filter=lfs diff=lfs merge=lfs -text +examples/test19.mov filter=lfs diff=lfs merge=lfs -text diff --git a/examples/test16.mov b/examples/test16.mov new file mode 100644 index 0000000000000000000000000000000000000000..b00eafc06cd037c17fd61b660030fefabda6a38a --- /dev/null +++ b/examples/test16.mov @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f068400bf962e732e5517af45397694f84fae0a6592085b9dd3781fdbacaa550 +size 1567779 diff --git a/examples/test17.mov b/examples/test17.mov new file mode 100644 index 0000000000000000000000000000000000000000..de4b0669d6867fb70d5acbfe24c0929eb96b2c79 --- /dev/null +++ b/examples/test17.mov @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce06d8885332fd0b770273010dbd4da20a0867a386dc55925f85198651651253 +size 2299497 diff --git a/examples/test18.mov b/examples/test18.mov new file mode 100644 index 0000000000000000000000000000000000000000..97d0466333a92cb4877c80e0d22eeedde01732ee --- /dev/null +++ b/examples/test18.mov @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66fc6eb20e1c8525070c8004bed621e0acc2712accace1dbf1eb72fced62bb14 +size 2033756 diff --git a/examples/test19.mov b/examples/test19.mov new file mode 100644 index 0000000000000000000000000000000000000000..5414b6061cf39260b47ebd8bcfb124980be6fe3e --- /dev/null +++ b/examples/test19.mov @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:878219571dbf0e8ff56f4ba4bf325f90f46a730b57a35a2df91f4f509af616d8 +size 1940593 diff --git a/fetch_demo_data.sh b/fetch_demo_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..1f5fcfd2a4558a3df705854404c734cbb5a82f66 --- /dev/null +++ b/fetch_demo_data.sh @@ -0,0 +1,50 @@ +#!/bin/bash +urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } + +# SMPL Neutral model +echo -e "\nYou need to register at https://smplify.is.tue.mpg.de" +read -p "Username (SMPLify):" username +read -p "Password (SMPLify):" password +username=$(urle $username) +password=$(urle $password) + +mkdir -p dataset/body_models/smpl +wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&resume=1&sfile=mpips_smplify_public_v2.zip' -O './dataset/body_models/smplify.zip' --no-check-certificate --continue +unzip dataset/body_models/smplify.zip -d dataset/body_models/smplify +mv dataset/body_models/smplify/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_NEUTRAL.pkl +rm -rf dataset/body_models/smplify +rm -rf dataset/body_models/smplify.zip + +# SMPL Male and Female model +echo -e "\nYou need to register at https://smpl.is.tue.mpg.de" +read -p "Username (SMPL):" username +read -p "Password (SMPL):" password +username=$(urle $username) +password=$(urle $password) + +wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip' -O './dataset/body_models/smpl.zip' --no-check-certificate --continue +unzip dataset/body_models/smpl.zip -d dataset/body_models/smpl +mv dataset/body_models/smpl/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_FEMALE.pkl +mv dataset/body_models/smpl/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl dataset/body_models/smpl/SMPL_MALE.pkl +rm -rf dataset/body_models/smpl/smpl +rm -rf dataset/body_models/smpl.zip + +# Auxiliary SMPL-related data +wget "https://drive.google.com/uc?id=1pbmzRbWGgae6noDIyQOnohzaVnX_csUZ&export=download&confirm=t" -O 'dataset/body_models.tar.gz' +tar -xvf dataset/body_models.tar.gz -C dataset/ +rm -rf dataset/body_models.tar.gz + +# Checkpoints +mkdir checkpoints +gdown "https://drive.google.com/uc?id=1i7kt9RlCCCNEW2aYaDWVr-G778JkLNcB&export=download&confirm=t" -O 'checkpoints/wham_vit_w_3dpw.pth.tar' +gdown "https://drive.google.com/uc?id=19qkI-a6xuwob9_RFNSPWf1yWErwVVlks&export=download&confirm=t" -O 'checkpoints/wham_vit_bedlam_w_3dpw.pth.tar' +gdown "https://drive.google.com/uc?id=1J6l8teyZrL0zFzHhzkC7efRhU0ZJ5G9Y&export=download&confirm=t" -O 'checkpoints/hmr2a.ckpt' +gdown "https://drive.google.com/uc?id=1kXTV4EYb-BI3H7J-bkR3Bc4gT9zfnHGT&export=download&confirm=t" -O 'checkpoints/dpvo.pth' +gdown "https://drive.google.com/uc?id=1zJ0KP23tXD42D47cw1Gs7zE2BA_V_ERo&export=download&confirm=t" -O 'checkpoints/yolov8x.pt' +gdown "https://drive.google.com/uc?id=1xyF7F3I7lWtdq82xmEPVQ5zl4HaasBso&export=download&confirm=t" -O 'checkpoints/vitpose-h-multi-coco.pth' + +# Demo videos +gdown "https://drive.google.com/uc?id=1KjfODCcOUm_xIMLLR54IcjJtf816Dkc7&export=download&confirm=t" -O 'examples.tar.gz' +tar -xvf examples.tar.gz +rm -rf examples.tar.gz + diff --git a/lib/core/loss.py b/lib/core/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1fc182812479a90cfab611a140e57a787dab3ce3 --- /dev/null +++ b/lib/core/loss.py @@ -0,0 +1,438 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from configs import constants as _C +from lib.utils import transforms +from lib.utils.kp_utils import root_centering + +class WHAMLoss(nn.Module): + def __init__( + self, + cfg=None, + device=None, + ): + super(WHAMLoss, self).__init__() + + self.cfg = cfg + self.n_joints = _C.KEYPOINTS.NUM_JOINTS + self.criterion = nn.MSELoss() + self.criterion_noreduce = nn.MSELoss(reduction='none') + + self.pose_loss_weight = cfg.LOSS.POSE_LOSS_WEIGHT + self.shape_loss_weight = cfg.LOSS.SHAPE_LOSS_WEIGHT + self.keypoint_2d_loss_weight = cfg.LOSS.JOINT2D_LOSS_WEIGHT + self.keypoint_3d_loss_weight = cfg.LOSS.JOINT3D_LOSS_WEIGHT + self.cascaded_loss_weight = cfg.LOSS.CASCADED_LOSS_WEIGHT + self.vertices_loss_weight = cfg.LOSS.VERTS3D_LOSS_WEIGHT + self.contact_loss_weight = cfg.LOSS.CONTACT_LOSS_WEIGHT + self.root_vel_loss_weight = cfg.LOSS.ROOT_VEL_LOSS_WEIGHT + self.root_pose_loss_weight = cfg.LOSS.ROOT_POSE_LOSS_WEIGHT + self.sliding_loss_weight = cfg.LOSS.SLIDING_LOSS_WEIGHT + self.camera_loss_weight = cfg.LOSS.CAMERA_LOSS_WEIGHT + self.loss_weight = cfg.LOSS.LOSS_WEIGHT + + kp_weights = [ + 0.5, 0.5, 0.5, 0.5, 0.5, # Face + 1.5, 1.5, 4, 4, 4, 4, # Arms + 1.5, 1.5, 4, 4, 4, 4, # Legs + 4, 4, 1.5, 1.5, 4, 4, # Legs + 4, 4, 1.5, 1.5, 4, 4, # Arms + 0.5, 0.5 # Head + ] + + theta_weights = [ + 0.1, 1.0, 1.0, 1.0, 1.0, # pelvis, lhip, rhip, spine1, lknee + 1.0, 1.0, 1.0, 1.0, 1.0, # rknn, spine2, lankle, rankle, spin3 + 0.1, 0.1, # Foot + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, # neck, lisldr, risldr, head, losldr, rosldr, + 1.0, 1.0, 1.0, 1.0, # lelbow, relbow, lwrist, rwrist + 0.1, 0.1, # Hand + ] + self.theta_weights = torch.tensor([[theta_weights]]).float().to(device) + self.theta_weights /= self.theta_weights.mean() + self.kp_weights = torch.tensor([kp_weights]).float().to(device) + + self.epoch = -1 + self.step() + + def step(self): + self.epoch += 1 + self.skip_camera_loss = self.epoch < self.cfg.LOSS.CAMERA_LOSS_SKIP_EPOCH + + def forward(self, pred, gt): + + loss = 0.0 + b, f = gt['kp3d'].shape[:2] + + # <======= Predictions and Groundtruths + pred_betas = pred['betas'] + pred_pose = pred['pose'].reshape(b, f, -1, 6) + pred_kp3d_nn = pred['kp3d_nn'] + pred_kp3d_smpl = root_centering(pred['kp3d'].reshape(b, f, -1, 3)) + pred_full_kp2d = pred['full_kp2d'] + pred_weak_kp2d = pred['weak_kp2d'] + pred_contact = pred['contact'] + pred_vel_root = pred['vel_root'] + pred_pose_root = pred['poses_root_r6d'][:, 1:] + pred_vel_root_ref = pred['vel_root_refined'] + pred_pose_root_ref = pred['poses_root_r6d_refined'][:, 1:] + pred_cam_r = transforms.matrix_to_rotation_6d(pred['R']) + + gt_betas = gt['betas'] + gt_pose = gt['pose'] + gt_kp3d = root_centering(gt['kp3d']) + gt_full_kp2d = gt['full_kp2d'] + gt_weak_kp2d = gt['weak_kp2d'] + gt_contact = gt['contact'] + gt_vel_root = gt['vel_root'] + gt_pose_root = gt['pose_root'][:, 1:] + gt_cam_angvel = gt['cam_angvel'] + gt_cam_r = transforms.matrix_to_rotation_6d(gt['R'][:, 1:]) + bbox = gt['bbox'] + # =======> + + loss_keypoints_full = full_projected_keypoint_loss( + pred_full_kp2d, + gt_full_kp2d, + bbox, + self.kp_weights, + criterion=self.criterion_noreduce, + ) + + loss_keypoints_weak = weak_projected_keypoint_loss( + pred_weak_kp2d, + gt_weak_kp2d, + self.kp_weights, + criterion=self.criterion_noreduce + ) + + # Compute 3D keypoint loss + loss_keypoints_3d_nn = keypoint_3d_loss( + pred_kp3d_nn, + gt_kp3d[:, :, :self.n_joints], + self.kp_weights[:, :self.n_joints], + criterion=self.criterion_noreduce, + ) + + loss_keypoints_3d_smpl = keypoint_3d_loss( + pred_kp3d_smpl, + gt_kp3d, + self.kp_weights, + criterion=self.criterion_noreduce, + ) + + loss_cascaded = keypoint_3d_loss( + pred_kp3d_nn, + torch.cat((pred_kp3d_smpl[:, :, :self.n_joints], gt_kp3d[:, :, :self.n_joints, -1:]), dim=-1), + self.kp_weights[:, :self.n_joints] * 0.5, + criterion=self.criterion_noreduce, + ) + + loss_vertices = vertices_loss( + pred['verts_cam'], + gt['verts'], + gt['has_verts'], + criterion=self.criterion_noreduce, + ) + + # Compute loss on SMPL parameters + smpl_mask = gt['has_smpl'] + loss_regr_pose, loss_regr_betas = smpl_losses( + pred_pose, + pred_betas, + gt_pose, + gt_betas, + self.theta_weights, + smpl_mask, + criterion=self.criterion_noreduce + ) + + # Compute loss on foot contact + loss_contact = contact_loss( + pred_contact, + gt_contact, + self.criterion_noreduce + ) + + # Compute loss on root velocity and angular velocity + loss_vel_root, loss_pose_root = root_loss( + pred_vel_root, + pred_pose_root, + gt_vel_root, + gt_pose_root, + gt_contact, + self.criterion_noreduce + ) + + # Root loss after trajectory refinement + loss_vel_root_ref, loss_pose_root_ref = root_loss( + pred_vel_root_ref, + pred_pose_root_ref, + gt_vel_root, + gt_pose_root, + gt_contact, + self.criterion_noreduce + ) + + # Camera prediction loss + loss_camera = camera_loss( + pred_cam_r, + gt_cam_r, + gt_cam_angvel[:, 1:], + gt['has_traj'], + self.criterion_noreduce, + self.skip_camera_loss + ) + + # Foot sliding loss + loss_sliding = sliding_loss( + pred['feet'], + gt_contact, + ) + + # Foot sliding loss + loss_sliding_ref = sliding_loss( + pred['feet_refined'], + gt_contact, + ) + + loss_keypoints = loss_keypoints_full + loss_keypoints_weak + loss_keypoints *= self.keypoint_2d_loss_weight + loss_keypoints_3d_smpl *= self.keypoint_3d_loss_weight + loss_keypoints_3d_nn *= self.keypoint_3d_loss_weight + loss_cascaded *= self.cascaded_loss_weight + loss_vertices *= self.vertices_loss_weight + loss_contact *= self.contact_loss_weight + loss_root = loss_vel_root * self.root_vel_loss_weight + loss_pose_root * self.root_pose_loss_weight + loss_root_ref = loss_vel_root_ref * self.root_vel_loss_weight + loss_pose_root_ref * self.root_pose_loss_weight + + loss_regr_pose *= self.pose_loss_weight + loss_regr_betas *= self.shape_loss_weight + + loss_sliding *= self.sliding_loss_weight + loss_camera *= self.camera_loss_weight + loss_sliding_ref *= self.sliding_loss_weight + + loss_dict = { + 'pose': loss_regr_pose * self.loss_weight, + 'betas': loss_regr_betas * self.loss_weight, + '2d': loss_keypoints * self.loss_weight, + '3d': loss_keypoints_3d_smpl * self.loss_weight, + '3d_nn': loss_keypoints_3d_nn * self.loss_weight, + 'casc': loss_cascaded * self.loss_weight, + 'v3d': loss_vertices * self.loss_weight, + 'contact': loss_contact * self.loss_weight, + 'root': loss_root * self.loss_weight, + 'root_ref': loss_root_ref * self.loss_weight, + 'sliding': loss_sliding * self.loss_weight, + 'camera': loss_camera * self.loss_weight, + 'sliding_ref': loss_sliding_ref * self.loss_weight, + } + + loss = sum(loss for loss in loss_dict.values()) + + return loss, loss_dict + + +def root_loss( + pred_vel_root, + pred_pose_root, + gt_vel_root, + gt_pose_root, + stationary, + criterion +): + + mask_r = (gt_pose_root != 0.0).all(dim=-1).all(dim=-1) + mask_v = (gt_vel_root != 0.0).all(dim=-1).all(dim=-1) + mask_s = (stationary != -1).any(dim=1).any(dim=1) + mask_v = mask_v * mask_s + + if mask_r.any(): + loss_r = criterion(pred_pose_root, gt_pose_root)[mask_r].mean() + else: + loss_r = torch.FloatTensor(1).fill_(0.).to(gt_pose_root.device)[0] + + if mask_v.any(): + loss_v = 0 + T = gt_vel_root.shape[0] + ws_list = [1, 3, 9, 27] + for ws in ws_list: + tmp_v = 0 + for m in range(T//ws): + cumulative_v = torch.sum(pred_vel_root[:, m:(m+1)*ws] - gt_vel_root[:, m:(m+1)*ws], dim=1) + tmp_v += torch.norm(cumulative_v, dim=-1) + loss_v += tmp_v + loss_v = loss_v[mask_v].mean() + else: + loss_v = torch.FloatTensor(1).fill_(0.).to(gt_vel_root.device)[0] + + return loss_v, loss_r + + +def contact_loss( + pred_stationary, + gt_stationary, + criterion, +): + + mask = gt_stationary != -1 + if mask.any(): + loss = criterion(pred_stationary, gt_stationary)[mask].mean() + else: + loss = torch.FloatTensor(1).fill_(0.).to(gt_stationary.device)[0] + return loss + + + +def full_projected_keypoint_loss( + pred_keypoints_2d, + gt_keypoints_2d, + bbox, + weight, + criterion, +): + + scale = bbox[..., 2:] * 200. + conf = gt_keypoints_2d[..., -1] + + if (conf > 0).any(): + loss = torch.mean( + weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1) + ) / scale, dim=1).mean() * conf.mean() + else: + loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0] + return loss + + +def weak_projected_keypoint_loss( + pred_keypoints_2d, + gt_keypoints_2d, + weight, + criterion, +): + + conf = gt_keypoints_2d[..., -1] + if (conf > 0).any(): + loss = torch.mean( + weight * (conf * torch.norm(pred_keypoints_2d - gt_keypoints_2d[..., :2], dim=-1) + ), dim=1).mean() * conf.mean() * 5 + else: + loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_2d.device)[0] + return loss + + +def keypoint_3d_loss( + pred_keypoints_3d, + gt_keypoints_3d, + weight, + criterion, +): + + conf = gt_keypoints_3d[..., -1] + if (conf > 0).any(): + if weight.shape[-2] > 17: + pred_keypoints_3d[..., -14:] = pred_keypoints_3d[..., -14:] - pred_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True) + gt_keypoints_3d[..., -14:] = gt_keypoints_3d[..., -14:] - gt_keypoints_3d[..., -14:].mean(dim=-2, keepdims=True) + + loss = torch.mean( + weight * (conf * torch.norm(pred_keypoints_3d - gt_keypoints_3d[..., :3], dim=-1) + ), dim=1).mean() * conf.mean() + else: + loss = torch.FloatTensor(1).fill_(0.).to(gt_keypoints_3d.device)[0] + return loss + + +def vertices_loss( + pred_verts, + gt_verts, + mask, + criterion, +): + + if mask.sum() > 0: + # Align + pred_verts = pred_verts.view_as(gt_verts) + pred_verts = pred_verts - pred_verts.mean(-2, True) + gt_verts = gt_verts - gt_verts.mean(-2, True) + + # loss = criterion(pred_verts, gt_verts).mean() * mask.float().mean() + # loss = torch.mean( + # `(torch.norm(pred_verts - gt_verts, dim=-1)[mask]` + # ), dim=1).mean() * mask.float().mean() + loss = torch.mean( + (torch.norm(pred_verts - gt_verts, p=1, dim=-1)[mask] + ), dim=1).mean() * mask.float().mean() + else: + loss = torch.FloatTensor(1).fill_(0.).to(gt_verts.device)[0] + return loss + + +def smpl_losses( + pred_pose, + pred_betas, + gt_pose, + gt_betas, + weight, + mask, + criterion, +): + + if mask.any().item(): + loss_regr_pose = torch.mean( + weight * torch.square(pred_pose - gt_pose)[mask].mean(-1) + ) * mask.float().mean() + loss_regr_betas = F.mse_loss(pred_betas, gt_betas, reduction='none')[mask].mean() * mask.float().mean() + else: + loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0] + loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(gt_pose.device)[0] + + return loss_regr_pose, loss_regr_betas + + +def camera_loss( + pred_cam_r, + gt_cam_r, + cam_angvel, + mask, + criterion, + skip +): + # mask = (gt_cam_r != 0.0).all(dim=-1).all(dim=-1) + + if mask.any() and not skip: + # Camera pose loss in 6D representation + loss_r = criterion(pred_cam_r, gt_cam_r)[mask].mean() + + # Reconstruct camera angular velocity and compute reconstruction loss + pred_R = transforms.rotation_6d_to_matrix(pred_cam_r) + cam_angvel_from_R = transforms.matrix_to_rotation_6d(pred_R[:, :-1] @ pred_R[:, 1:].transpose(-1, -2)) + cam_angvel_from_R = (cam_angvel_from_R - torch.tensor([[[1, 0, 0, 0, 1, 0]]]).to(cam_angvel)) * 30 + loss_a = criterion(cam_angvel, cam_angvel_from_R)[mask].mean() + + loss = loss_r + loss_a + else: + loss = torch.FloatTensor(1).fill_(0.).to(gt_cam_r.device)[0] + + return loss + + +def sliding_loss( + foot_position, + contact_prob, +): + """ Compute foot skate loss when foot is assumed to be on contact with ground + + foot_position: 3D foot (heel and toe) position, torch.Tensor (B, F, 4, 3) + contact_prob: contact probability of foot (heel and toe), torch.Tensor (B, F, 4) + """ + + contact_mask = (contact_prob > 0.5).detach().float() + foot_velocity = foot_position[:, 1:] - foot_position[:, :-1] + loss = (torch.norm(foot_velocity, dim=-1) * contact_mask[:, 1:]).mean() + return loss diff --git a/lib/core/trainer.py b/lib/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..90d10d4a4bfec7e47373b2520dd78c4eef9ee2ae --- /dev/null +++ b/lib/core/trainer.py @@ -0,0 +1,341 @@ +# -*- coding: utf-8 -*- + +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is +# holder of all proprietary rights on this computer program. +# You can only use this computer program if you have closed +# a license agreement with MPG or you get the right to use the computer +# program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and +# liable to prosecution. +# +# Copyright©2019 Max-Planck-Gesellschaft zur Förderung +# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute +# for Intelligent Systems. All rights reserved. +# +# Contact: ps-license@tuebingen.mpg.de + +import time +import torch +import shutil +import logging +import numpy as np +import os.path as osp +from progress.bar import Bar + +from configs import constants as _C +from lib.utils import transforms +from lib.utils.utils import AverageMeter, prepare_batch +from lib.eval.eval_utils import ( + compute_accel, + compute_error_accel, + batch_align_by_pelvis, + batch_compute_similarity_transform_torch, +) +from lib.models import build_body_model + +logger = logging.getLogger(__name__) + +class Trainer(): + def __init__(self, + data_loaders, + network, + optimizer, + criterion=None, + train_stage='syn', + start_epoch=0, + checkpoint=None, + end_epoch=999, + lr_scheduler=None, + device=None, + writer=None, + debug=False, + resume=False, + logdir='output', + performance_type='min', + summary_iter=1, + ): + + self.train_loader, self.valid_loader = data_loaders + + # Model and optimizer + self.network = network + self.optimizer = optimizer + + # Training parameters + self.train_stage = train_stage + self.start_epoch = start_epoch + self.end_epoch = end_epoch + self.criterion = criterion + self.lr_scheduler = lr_scheduler + self.device = device + self.writer = writer + self.debug = debug + self.resume = resume + self.logdir = logdir + self.summary_iter = summary_iter + + self.performance_type = performance_type + self.train_global_step = 0 + self.valid_global_step = 0 + self.epoch = 0 + self.best_performance = float('inf') if performance_type == 'min' else -float('inf') + self.summary_loss_keys = ['pose'] + + self.evaluation_accumulators = dict.fromkeys( + ['pred_j3d', 'target_j3d', 'pve'])# 'pred_verts', 'target_verts']) + + self.J_regressor_eval = torch.from_numpy( + np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M) + )[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(device) + + if self.writer is None: + from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(log_dir=self.logdir) + + if self.device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + if checkpoint is not None: + self.load_pretrained(checkpoint) + + def train(self, ): + # Single epoch training routine + + losses = AverageMeter() + kp_2d_loss = AverageMeter() + kp_3d_loss = AverageMeter() + + timer = { + 'data': 0, + 'forward': 0, + 'loss': 0, + 'backward': 0, + 'batch': 0, + } + self.network.train() + start = time.time() + summary_string = '' + + bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=len(self.train_loader)) + for i, batch in enumerate(self.train_loader): + + # <======= Feedforward + x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2') + timer['data'] = time.time() - start + start = time.time() + pred = self.network(x, inits, features, **kwargs) + timer['forward'] = time.time() - start + start = time.time() + # =======> + + # <======= Backprop + loss, loss_dict = self.criterion(pred, gt) + timer['loss'] = time.time() - start + start = time.time() + + # Clip gradients + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) + self.optimizer.step() + # =======> + + # <======= Log training info + total_loss = loss + losses.update(total_loss.item(), x.size(0)) + kp_2d_loss.update(loss_dict['2d'].item(), x.size(0)) + kp_3d_loss.update(loss_dict['3d'].item(), x.size(0)) + + timer['backward'] = time.time() - start + timer['batch'] = timer['data'] + timer['forward'] + timer['loss'] + timer['backward'] + start = time.time() + + summary_string = f'({i + 1}/{len(self.train_loader)}) | Total: {bar.elapsed_td} ' \ + f'| loss: {losses.avg:.2f} | 2d: {kp_2d_loss.avg:.2f} ' \ + f'| 3d: {kp_3d_loss.avg:.2f} ' + + for k, v in loss_dict.items(): + if k in self.summary_loss_keys: + summary_string += f' | {k}: {v:.2f}' + if (i + 1) % self.summary_iter == 0: + self.writer.add_scalar('train_loss/'+k, v, global_step=self.train_global_step) + + if (i + 1) % self.summary_iter == 0: + self.writer.add_scalar('train_loss/loss', total_loss.item(), global_step=self.train_global_step) + + self.train_global_step += 1 + bar.suffix = summary_string + bar.next(1) + + if torch.isnan(total_loss): + exit('Nan value in loss, exiting!...') + # =======> + + logger.info(summary_string) + bar.finish() + + def validate(self, ): + self.network.eval() + + start = time.time() + summary_string = '' + bar = Bar('Validation', fill='#', max=len(self.valid_loader)) + + if self.evaluation_accumulators is not None: + for k,v in self.evaluation_accumulators.items(): + self.evaluation_accumulators[k] = [] + + with torch.no_grad(): + for i, batch in enumerate(self.valid_loader): + x, inits, features, kwargs, gt = prepare_batch(batch, self.device, self.train_stage=='stage2') + + # <======= Feedforward + pred = self.network(x, inits, features, **kwargs) + + # 3DPW dataset has groundtruth vertices + # NOTE: Following SPIN, we compute PVE against ground truth from Gendered SMPL mesh + smpl = build_body_model(self.device, batch_size=len(pred['verts_cam']), gender=batch['gender'][0]) + gt_output = smpl.get_output( + body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]), + global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]), + betas=gt['betas'][0], + pose2rot=False + ) + + pred_j3d = torch.matmul(self.J_regressor_eval, pred['verts_cam']).cpu() + target_j3d = torch.matmul(self.J_regressor_eval, gt_output.vertices).cpu() + pred_verts = pred['verts_cam'].cpu() + target_verts = gt_output.vertices.cpu() + + pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( + [pred_j3d, target_j3d, pred_verts, target_verts], [2, 3] + ) + + self.evaluation_accumulators['pred_j3d'].append(pred_j3d.numpy()) + self.evaluation_accumulators['target_j3d'].append(target_j3d.numpy()) + pve = np.sqrt(np.sum((target_verts.numpy() - pred_verts.numpy()) ** 2, axis=-1)).mean(-1) * 1e3 + self.evaluation_accumulators['pve'].append(pve[:, None]) + # =======> + + batch_time = time.time() - start + + summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \ + f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}' + + self.valid_global_step += 1 + bar.suffix = summary_string + bar.next() + + logger.info(summary_string) + + bar.finish() + + def evaluate(self, ): + for k, v in self.evaluation_accumulators.items(): + self.evaluation_accumulators[k] = np.vstack(v) + + pred_j3ds = self.evaluation_accumulators['pred_j3d'] + target_j3ds = self.evaluation_accumulators['target_j3d'] + + pred_j3ds = torch.from_numpy(pred_j3ds).float() + target_j3ds = torch.from_numpy(target_j3ds).float() + + print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...') + errors = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + S1_hat = batch_compute_similarity_transform_torch(pred_j3ds, target_j3ds) + errors_pa = torch.sqrt(((S1_hat - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + + m2mm = 1000 + accel = np.mean(compute_accel(pred_j3ds)) * m2mm + accel_err = np.mean(compute_error_accel(joints_pred=pred_j3ds, joints_gt=target_j3ds)) * m2mm + mpjpe = np.mean(errors) * m2mm + pa_mpjpe = np.mean(errors_pa) * m2mm + + eval_dict = { + 'mpjpe': mpjpe, + 'pa-mpjpe': pa_mpjpe, + 'accel': accel, + 'accel_err': accel_err + } + + if 'pred_verts' in self.evaluation_accumulators.keys(): + eval_dict.update({'pve': self.evaluation_accumulators['pve'].mean()}) + + log_str = f'Epoch {self.epoch}, ' + log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in eval_dict.items()]) + logger.info(log_str) + + for k,v in eval_dict.items(): + self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch) + + # return (mpjpe + pa_mpjpe) / 2. + return pa_mpjpe + + def save_model(self, performance, epoch): + save_dict = { + 'epoch': epoch, + 'model': self.network.state_dict(), + 'performance': performance, + 'optimizer': self.optimizer.state_dict(), + } + + filename = osp.join(self.logdir, 'checkpoint.pth.tar') + torch.save(save_dict, filename) + + if self.performance_type == 'min': + is_best = performance < self.best_performance + else: + is_best = performance > self.best_performance + + if is_best: + logger.info('Best performance achived, saving it!') + self.best_performance = performance + shutil.copyfile(filename, osp.join(self.logdir, 'model_best.pth.tar')) + + with open(osp.join(self.logdir, 'best.txt'), 'w') as f: + f.write(str(float(performance))) + + def fit(self): + for epoch in range(self.start_epoch, self.end_epoch): + self.epoch = epoch + self.train() + self.validate() + performance = self.evaluate() + + self.criterion.step() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + # log the learning rate + for param_group in self.optimizer.param_groups[:2]: + print(f'Learning rate {param_group["lr"]}') + self.writer.add_scalar('lr', param_group['lr'], global_step=self.epoch) + + logger.info(f'Epoch {epoch+1} performance: {performance:.4f}') + + self.save_model(performance, epoch) + self.train_loader.dataset.prepare_video_batch() + + self.writer.close() + + def load_pretrained(self, model_path): + if osp.isfile(model_path): + checkpoint = torch.load(model_path) + + # network + ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval'] + ignore_keys2 = [k for k in checkpoint['model'].keys() if 'integrator' in k] + ignore_keys.extend(ignore_keys2) + model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys} + model_state_dict = {k: v for k, v in model_state_dict.items() if k in self.network.state_dict().keys()} + self.network.load_state_dict(model_state_dict, strict=False) + + if self.resume: + self.start_epoch = checkpoint['epoch'] + self.best_performance = checkpoint['performance'] + self.optimizer.load_state_dict(checkpoint['optimizer']) + + logger.info(f"=> loaded checkpoint '{model_path}' " + f"(epoch {self.start_epoch}, performance {self.best_performance})") + else: + logger.info(f"=> no checkpoint found at '{model_path}'") \ No newline at end of file diff --git a/lib/data/__init__.py b/lib/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/data/__pycache__/__init__.cpython-39.pyc b/lib/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a50ecdd3317f879777921e8717d6f0c6a46e68d Binary files /dev/null and b/lib/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/data/__pycache__/_dataset.cpython-39.pyc b/lib/data/__pycache__/_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6eec1e688837ce128f0d6eded25c07424d45025 Binary files /dev/null and b/lib/data/__pycache__/_dataset.cpython-39.pyc differ diff --git a/lib/data/_dataset.py b/lib/data/_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17889ecacbf4c5bf601bb4c058d51e8b4f603745 --- /dev/null +++ b/lib/data/_dataset.py @@ -0,0 +1,77 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import numpy as np +from skimage.util.shape import view_as_windows + +from configs import constants as _C +from .utils.normalizer import Normalizer +from ..utils.imutils import transform + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, cfg, training=True): + super(BaseDataset, self).__init__() + self.epoch = 0 + self.training = training + self.n_joints = _C.KEYPOINTS.NUM_JOINTS + self.n_frames = cfg.DATASET.SEQLEN + 1 + self.keypoints_normalizer = Normalizer(cfg) + + def prepare_video_batch(self): + r = self.epoch % 4 + + self.video_indices = [] + vid_name = self.labels['vid'] + if isinstance(vid_name, torch.Tensor): vid_name = vid_name.numpy() + video_names_unique, group = np.unique( + vid_name, return_index=True) + perm = np.argsort(group) + group_perm = group[perm] + indices = np.split( + np.arange(0, self.labels['vid'].shape[0]), group_perm[1:] + ) + for idx in range(len(video_names_unique)): + indexes = indices[idx] + if indexes.shape[0] < self.n_frames: continue + chunks = view_as_windows( + indexes, (self.n_frames), step=self.n_frames // 4 + ) + start_finish = chunks[r::4, (0, -1)].tolist() + self.video_indices += start_finish + + self.epoch += 1 + + def __len__(self): + if self.training: + return len(self.video_indices) + else: + return len(self.labels['kp2d']) + + def __getitem__(self, index): + return self.get_single_sequence(index) + + def get_single_sequence(self, index): + NotImplementedError('get_single_sequence is not implemented') + + def get_naive_intrinsics(self, res): + # Assume 45 degree FOV + img_w, img_h = res + self.focal_length = (img_w * img_w + img_h * img_h) ** 0.5 + self.cam_intrinsics = torch.eye(3).repeat(1, 1, 1).float() + self.cam_intrinsics[:, 0, 0] = self.focal_length + self.cam_intrinsics[:, 1, 1] = self.focal_length + self.cam_intrinsics[:, 0, 2] = img_w/2. + self.cam_intrinsics[:, 1, 2] = img_h/2. + + def j2d_processing(self, kp, bbox): + center = bbox[..., :2] + scale = bbox[..., -1:] + nparts = kp.shape[0] + for i in range(nparts): + kp[i, 0:2] = transform(kp[i, 0:2] + 1, center, scale, + [224, 224]) + kp[:, :2] = 2. * kp[:, :2] / 224 - 1. + kp = kp.astype('float32') + return kp \ No newline at end of file diff --git a/lib/data/dataloader.py b/lib/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff536d354d01bc2f1d5cb0f3531d7daeacd7bc1 --- /dev/null +++ b/lib/data/dataloader.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch + +from .datasets import EvalDataset, DataFactory +from ..utils.data_utils import make_collate_fn + + +def setup_eval_dataloader(cfg, data, split='test', backbone=None): + if backbone is None: + backbone = cfg.MODEL.BACKBONE + + dataset = EvalDataset(cfg, data, split, backbone) + dloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + num_workers=0, + shuffle=False, + pin_memory=True, + collate_fn=make_collate_fn() + ) + return dloader + + +def setup_train_dataloader(cfg, ): + n_workers = 0 if cfg.DEBUG else cfg.NUM_WORKERS + + train_dataset = DataFactory(cfg, cfg.TRAIN.STAGE) + dloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=cfg.TRAIN.BATCH_SIZE, + num_workers=n_workers, + shuffle=True, + pin_memory=True, + collate_fn=make_collate_fn() + ) + return dloader + + +def setup_dloaders(cfg, dset='3dpw', split='val'): + test_dloader = setup_eval_dataloader(cfg, dset, split, cfg.MODEL.BACKBONE) + train_dloader = setup_train_dataloader(cfg) + + return train_dloader, test_dloader \ No newline at end of file diff --git a/lib/data/datasets/__init__.py b/lib/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20c4b909be3802c5a99b87be26487d4ecb33a499 --- /dev/null +++ b/lib/data/datasets/__init__.py @@ -0,0 +1,3 @@ +from .dataset_eval import EvalDataset +from .dataset_custom import CustomDataset +from .mixed_dataset import DataFactory \ No newline at end of file diff --git a/lib/data/datasets/__pycache__/__init__.cpython-39.pyc b/lib/data/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ebff022afe5a20e7a69e827fe14a7a503ac7434 Binary files /dev/null and b/lib/data/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/amass.cpython-39.pyc b/lib/data/datasets/__pycache__/amass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb96f2f3ee07b0c663cd09f34b2c10b06651c514 Binary files /dev/null and b/lib/data/datasets/__pycache__/amass.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/bedlam.cpython-39.pyc b/lib/data/datasets/__pycache__/bedlam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c83e0fae3c4c94927cd2f3eb1c26cdc87720700f Binary files /dev/null and b/lib/data/datasets/__pycache__/bedlam.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc b/lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61a27240a9b9ae75d82cd58289162ab1fa308b62 Binary files /dev/null and b/lib/data/datasets/__pycache__/dataset2d.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc b/lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7900ba39389db4bb73e550b2b803d6f7a490e8d0 Binary files /dev/null and b/lib/data/datasets/__pycache__/dataset3d.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc b/lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5dad3c63cb490577cc37e517db84a51e0d4fb6a Binary files /dev/null and b/lib/data/datasets/__pycache__/dataset_custom.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc b/lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3654858aa591072f5b712193bcc7a448579881be Binary files /dev/null and b/lib/data/datasets/__pycache__/dataset_eval.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc b/lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de4a4c9cc8900aa6a027fbe161198f8c2e21fc5a Binary files /dev/null and b/lib/data/datasets/__pycache__/mixed_dataset.cpython-39.pyc differ diff --git a/lib/data/datasets/__pycache__/videos.cpython-39.pyc b/lib/data/datasets/__pycache__/videos.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9c7b3e16f22dd028aad45cf03b716e145506643 Binary files /dev/null and b/lib/data/datasets/__pycache__/videos.cpython-39.pyc differ diff --git a/lib/data/datasets/amass.py b/lib/data/datasets/amass.py new file mode 100644 index 0000000000000000000000000000000000000000..e05cef153787bc17d143272bb7ad8c077d03902b --- /dev/null +++ b/lib/data/datasets/amass.py @@ -0,0 +1,173 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import joblib +from lib.utils import transforms + +from configs import constants as _C + +from ..utils.augmentor import * +from .._dataset import BaseDataset +from ...models import build_body_model +from ...utils import data_utils as d_utils +from ...utils.kp_utils import root_centering + + + +def compute_contact_label(feet, thr=1e-2, alpha=5): + vel = torch.zeros_like(feet[..., 0]) + label = torch.zeros_like(feet[..., 0]) + + vel[1:-1] = (feet[2:] - feet[:-2]).norm(dim=-1) / 2.0 + vel[0] = vel[1].clone() + vel[-1] = vel[-2].clone() + + label = 1 / (1 + torch.exp(alpha * (thr ** -1) * (vel - thr))) + return label + + +class AMASSDataset(BaseDataset): + def __init__(self, cfg): + label_pth = _C.PATHS.AMASS_LABEL + super(AMASSDataset, self).__init__(cfg, training=True) + + self.supervise_pose = cfg.TRAIN.STAGE == 'stage1' + self.labels = joblib.load(label_pth) + self.SequenceAugmentor = SequenceAugmentor(cfg.DATASET.SEQLEN + 1) + + # Load augmentators + self.VideoAugmentor = VideoAugmentor(cfg) + self.SMPLAugmentor = SMPLAugmentor(cfg) + self.d_img_feature = _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE] + + self.n_frames = int(cfg.DATASET.SEQLEN * self.SequenceAugmentor.l_factor) + 1 + self.smpl = build_body_model('cpu', self.n_frames) + self.prepare_video_batch() + + # Naive assumption of image intrinsics + self.img_w, self.img_h = 1000, 1000 + self.get_naive_intrinsics((self.img_w, self.img_h)) + + self.CameraAugmentor = CameraAugmentor(cfg.DATASET.SEQLEN + 1, self.img_w, self.img_h, self.focal_length) + + + @property + def __name__(self, ): + return 'AMASS' + + def get_input(self, target): + gt_kp3d = target['kp3d'] + inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone()) + kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics) + mask = self.VideoAugmentor.get_mask() + kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224) + + target['bbox'] = bbox[1:] + target['kp2d'] = kp2d + target['mask'] = mask[1:] + target['features'] = torch.zeros((self.SMPLAugmentor.n_frames, self.d_img_feature)).float() + return target + + def get_groundtruth(self, target): + # GT 1. Joints + gt_kp3d = target['kp3d'] + gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics) + target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1]) * float(self.supervise_pose)), dim=-1) + target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1]) * float(self.supervise_pose)), dim=-1)[1:] + target['weak_kp2d'] = torch.zeros_like(target['full_kp2d']) + target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1) + target['verts'] = torch.zeros((self.SMPLAugmentor.n_frames, 6890, 3)).float() + + # GT 2. Root pose + vel_world = (target['transl'][1:] - target['transl'][:-1]) + pose_root = target['pose_root'].clone() + vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1) + target['vel_root'] = vel_root.clone() + target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root) + target['init_root'] = target['pose_root'][:1].clone() + + # GT 3. Foot contact + contact = compute_contact_label(target['feet']) + if 'tread' in target['vid']: + target['contact'] = torch.ones_like(contact) * (-1) + else: + target['contact'] = contact + + return target + + def forward_smpl(self, target): + output = self.smpl.get_output( + body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])), + global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])), + betas=target['betas'], + pose2rot=False) + + target['transl'] = target['transl'] - output.offset + target['transl'] = target['transl'] - target['transl'][0] + target['kp3d'] = output.joints + target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2) + + return target + + def augment_data(self, target): + # Augmentation 1. SMPL params augmentation + target = self.SMPLAugmentor(target) + + # Augmentation 2. Sequence speed augmentation + target = self.SequenceAugmentor(target) + + # Get world-coordinate SMPL + target = self.forward_smpl(target) + + # Augmentation 3. Virtual camera generation + target = self.CameraAugmentor(target) + + return target + + def load_amass(self, index, target): + start_index, end_index = self.video_indices[index] + + # Load AMASS labels + pose = torch.from_numpy(self.labels['pose'][start_index:end_index+1].copy()) + pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3)) + transl = torch.from_numpy(self.labels['transl'][start_index:end_index+1].copy()) + betas = torch.from_numpy(self.labels['betas'][start_index:end_index+1].copy()) + + # Stack GT + target.update({'vid': self.labels['vid'][start_index], + 'pose': pose, + 'transl': transl, + 'betas': betas}) + + return target + + def get_single_sequence(self, index): + target = {'res': torch.tensor([self.img_w, self.img_h]).float(), + 'cam_intrinsics': self.cam_intrinsics.clone(), + 'has_full_screen': torch.tensor(True), + 'has_smpl': torch.tensor(self.supervise_pose), + 'has_traj': torch.tensor(True), + 'has_verts': torch.tensor(False),} + + target = self.load_amass(index, target) + target = self.augment_data(target) + target = self.get_groundtruth(target) + target = self.get_input(target) + + target = d_utils.prepare_keypoints_data(target) + target = d_utils.prepare_smpl_data(target) + + return target + + +def perspective_projection(points, cam_intrinsics, rotation=None, translation=None): + K = cam_intrinsics + if rotation is not None: + points = torch.matmul(rotation, points.transpose(1, 2)).transpose(1, 2) + if translation is not None: + points = points + translation.unsqueeze(1) + projected_points = points / points[:, :, -1].unsqueeze(-1) + projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float()) + return projected_points[:, :, :-1] \ No newline at end of file diff --git a/lib/data/datasets/bedlam.py b/lib/data/datasets/bedlam.py new file mode 100644 index 0000000000000000000000000000000000000000..9aba9757388218c16b641aa2ad9fcaa00fc37900 --- /dev/null +++ b/lib/data/datasets/bedlam.py @@ -0,0 +1,165 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import joblib +from lib.utils import transforms + +from configs import constants as _C + +from .amass import compute_contact_label, perspective_projection +from ..utils.augmentor import * +from .._dataset import BaseDataset +from ...models import build_body_model +from ...utils import data_utils as d_utils +from ...utils.kp_utils import root_centering + +class BEDLAMDataset(BaseDataset): + def __init__(self, cfg): + label_pth = _C.PATHS.BEDLAM_LABEL.replace('backbone', cfg.MODEL.BACKBONE) + super(BEDLAMDataset, self).__init__(cfg, training=True) + + self.labels = joblib.load(label_pth) + + self.VideoAugmentor = VideoAugmentor(cfg) + self.SMPLAugmentor = SMPLAugmentor(cfg, False) + + self.smpl = build_body_model('cpu', self.n_frames) + self.prepare_video_batch() + + @property + def __name__(self, ): + return 'BEDLAM' + + def get_inputs(self, index, target, vis_thr=0.6): + start_index, end_index = self.video_indices[index] + + bbox = self.labels['bbox'][start_index:end_index+1].clone() + bbox[:, 2] = bbox[:, 2] / 200 + + gt_kp3d = target['kp3d'] + inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone()) + # kp2d = perspective_projection(inpt_kp3d, target['K']) + kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics) + mask = self.VideoAugmentor.get_mask() + # kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox) + kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224) + + target['bbox'] = bbox[1:] + target['kp2d'] = kp2d + target['mask'] = mask[1:] + + # Image features + target['features'] = self.labels['features'][start_index+1:end_index+1].clone() + + return target + + def get_groundtruth(self, index, target): + start_index, end_index = self.video_indices[index] + + # GT 1. Joints + gt_kp3d = target['kp3d'] + # gt_kp2d = perspective_projection(gt_kp3d, target['K']) + gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics) + target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1])), dim=-1) + # target['full_kp2d'] = torch.cat((gt_kp2d, torch.zeros_like(gt_kp2d[..., :1])), dim=-1)[1:] + target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1])), dim=-1)[1:] + target['weak_kp2d'] = torch.zeros_like(target['full_kp2d']) + target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1) + + # GT 2. Root pose + w_transl = self.labels['w_trans'][start_index:end_index+1] + pose_root = transforms.axis_angle_to_matrix(self.labels['root'][start_index:end_index+1]) + vel_world = (w_transl[1:] - w_transl[:-1]) + vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1) + target['vel_root'] = vel_root.clone() + target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root) + target['init_root'] = target['pose_root'][:1].clone() + + return target + + def forward_smpl(self, target): + output = self.smpl.get_output( + body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])), + global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])), + betas=target['betas'], + transl=target['transl'], + pose2rot=False) + + target['kp3d'] = output.joints + output.offset.unsqueeze(1) + target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2) + target['verts'] = output.vertices[1:, ].clone() + + return target + + def augment_data(self, target): + # Augmentation 1. SMPL params augmentation + target = self.SMPLAugmentor(target) + + # Get world-coordinate SMPL + target = self.forward_smpl(target) + + return target + + def load_camera(self, index, target): + start_index, end_index = self.video_indices[index] + + # Get camera info + extrinsics = self.labels['extrinsics'][start_index:end_index+1].clone() + R = extrinsics[:, :3, :3] + T = extrinsics[:, :3, -1] + K = self.labels['intrinsics'][start_index:end_index+1].clone() + width, height = K[0, 0, 2] * 2, K[0, 1, 2] * 2 + target['R'] = R + target['res'] = torch.tensor([width, height]).float() + + # Compute angular velocity + cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2)) + cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize + target['cam_angvel'] = cam_angvel * 3e1 # BEDLAM is 30-fps + + target['K'] = K # Use GT camera intrinsics for projecting keypoints + self.get_naive_intrinsics(target['res']) + target['cam_intrinsics'] = self.cam_intrinsics + + return target + + def load_params(self, index, target): + start_index, end_index = self.video_indices[index] + + # Load AMASS labels + pose = self.labels['pose'][start_index:end_index+1].clone() + pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3)) + transl = self.labels['c_trans'][start_index:end_index+1].clone() + betas = self.labels['betas'][start_index:end_index+1, :10].clone() + + # Stack GT + target.update({'vid': self.labels['vid'][start_index].clone(), + 'pose': pose, + 'transl': transl, + 'betas': betas}) + + return target + + + def get_single_sequence(self, index): + target = {'has_full_screen': torch.tensor(True), + 'has_smpl': torch.tensor(True), + 'has_traj': torch.tensor(False), + 'has_verts': torch.tensor(True), + + # Null contact label + 'contact': torch.ones((self.n_frames - 1, 4)) * (-1), + } + + target = self.load_params(index, target) + target = self.load_camera(index, target) + target = self.augment_data(target) + target = self.get_groundtruth(index, target) + target = self.get_inputs(index, target) + + target = d_utils.prepare_keypoints_data(target) + target = d_utils.prepare_smpl_data(target) + + return target \ No newline at end of file diff --git a/lib/data/datasets/dataset2d.py b/lib/data/datasets/dataset2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b06f8b5d61c951922f55ca17e4982cc3a0ba07 --- /dev/null +++ b/lib/data/datasets/dataset2d.py @@ -0,0 +1,140 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import joblib + +from .._dataset import BaseDataset +from ..utils.augmentor import * +from ...utils import data_utils as d_utils +from ...utils import transforms +from ...models import build_body_model +from ...utils.kp_utils import convert_kps, root_centering + + +class Dataset2D(BaseDataset): + def __init__(self, cfg, fname, training): + super(Dataset2D, self).__init__(cfg, training) + + self.epoch = 0 + self.n_frames = cfg.DATASET.SEQLEN + 1 + self.labels = joblib.load(fname) + + if self.training: + self.prepare_video_batch() + + self.smpl = build_body_model('cpu', self.n_frames) + self.SMPLAugmentor = SMPLAugmentor(cfg, False) + + def __getitem__(self, index): + return self.get_single_sequence(index) + + def get_inputs(self, index, target, vis_thr=0.6): + start_index, end_index = self.video_indices[index] + + # 2D keypoints detection + kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone() + kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], target['cam_intrinsics'], 224, 224, target['bbox']) + target['bbox'] = bbox[1:] + target['kp2d'] = kp2d + + # Detection mask + target['mask'] = ~self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone().bool() + + # Image features + target['features'] = self.labels['features'][start_index+1:end_index+1].clone() + + return target + + def get_labels(self, index, target): + start_index, end_index = self.video_indices[index] + + # SMPL parameters + # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input. + # We do not supervise the network on SMPL parameters. + target['pose'] = transforms.axis_angle_to_matrix( + self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3)) + target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t + + # Apply SMPL augmentor (y-axis rotation and initial frame noise) + target = self.SMPLAugmentor(target) + + # 2D keypoints + kp2d = self.labels['kp2d'][start_index:end_index+1].clone().float()[..., :2] + gt_kp2d = torch.zeros((self.n_frames - 1, 31, 2)) + gt_kp2d[:, :17] = kp2d[1:].clone() + + # Set 0 confidence to the masked keypoints + mask = torch.zeros((self.n_frames - 1, 31)) + mask[:, :17] = self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone() + mask = torch.logical_and(gt_kp2d.mean(-1) != 0, mask) + gt_kp2d = torch.cat((gt_kp2d, mask.float().unsqueeze(-1)), dim=-1) + + _gt_kp2d = gt_kp2d.clone() + for idx in range(len(_gt_kp2d)): + _gt_kp2d[idx][..., :2] = torch.from_numpy( + self.j2d_processing(gt_kp2d[idx][..., :2].numpy().copy(), + target['bbox'][idx].numpy().copy())) + + target['weak_kp2d'] = _gt_kp2d.clone() + target['full_kp2d'] = torch.zeros_like(gt_kp2d) + target['kp3d'] = torch.zeros((kp2d.shape[0], 31, 4)) + + # No SMPL vertices available + target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float() + return target + + def get_init_frame(self, target): + # Prepare initial frame + output = self.smpl.get_output( + body_pose=target['init_pose'][:, 1:], + global_orient=target['init_pose'][:, :1], + betas=target['betas'][:1], + pose2rot=False + ) + target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1) + + return target + + def get_single_sequence(self, index): + # Camera parameters + res = (224.0, 224.0) + bbox = torch.tensor([112.0, 112.0, 1.12]) + res = torch.tensor(res) + self.get_naive_intrinsics(res) + bbox = bbox.repeat(self.n_frames, 1) + + # Universal target + target = {'has_full_screen': torch.tensor(False), + 'has_smpl': torch.tensor(self.has_smpl), + 'has_traj': torch.tensor(self.has_traj), + 'has_verts': torch.tensor(False), + 'transl': torch.zeros((self.n_frames, 3)), + + # Camera parameters and bbox + 'res': res, + 'cam_intrinsics': self.cam_intrinsics, + 'bbox': bbox, + + # Null camera motion + 'R': torch.eye(3).repeat(self.n_frames, 1, 1), + 'cam_angvel': torch.zeros((self.n_frames - 1, 6)), + + # Null root orientation and velocity + 'pose_root': torch.zeros((self.n_frames, 6)), + 'vel_root': torch.zeros((self.n_frames - 1, 3)), + 'init_root': torch.zeros((1, 6)), + + # Null contact label + 'contact': torch.ones((self.n_frames - 1, 4)) * (-1) + } + + self.get_inputs(index, target) + self.get_labels(index, target) + self.get_init_frame(target) + + target = d_utils.prepare_keypoints_data(target) + target = d_utils.prepare_smpl_data(target) + + return target \ No newline at end of file diff --git a/lib/data/datasets/dataset3d.py b/lib/data/datasets/dataset3d.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8e3c122d1df5f96a9c0b3695da9927460a2ee0 --- /dev/null +++ b/lib/data/datasets/dataset3d.py @@ -0,0 +1,172 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import joblib +import numpy as np + +from .._dataset import BaseDataset +from ..utils.augmentor import * +from ...utils import data_utils as d_utils +from ...utils import transforms +from ...models import build_body_model +from ...utils.kp_utils import convert_kps, root_centering + + +class Dataset3D(BaseDataset): + def __init__(self, cfg, fname, training): + super(Dataset3D, self).__init__(cfg, training) + + self.epoch = 0 + self.labels = joblib.load(fname) + self.n_frames = cfg.DATASET.SEQLEN + 1 + + if self.training: + self.prepare_video_batch() + + self.smpl = build_body_model('cpu', self.n_frames) + self.SMPLAugmentor = SMPLAugmentor(cfg, False) + self.VideoAugmentor = VideoAugmentor(cfg) + + def __getitem__(self, index): + return self.get_single_sequence(index) + + def get_inputs(self, index, target, vis_thr=0.6): + start_index, end_index = self.video_indices[index] + + # 2D keypoints detection + kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone() + bbox = self.labels['bbox'][start_index:end_index+1][..., [0, 1, -1]].clone() + bbox[:, 2] = bbox[:, 2] / 200 + kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox) + + target['bbox'] = bbox[1:] + target['kp2d'] = kp2d + target['mask'] = self.labels['kp2d'][start_index+1:end_index+1][..., -1] < vis_thr + + # Image features + target['features'] = self.labels['features'][start_index+1:end_index+1].clone() + + return target + + def get_labels(self, index, target): + start_index, end_index = self.video_indices[index] + + # SMPL parameters + # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input. + # We do not supervise the network on SMPL parameters. + target['pose'] = transforms.axis_angle_to_matrix( + self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3)) + target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t + + # Apply SMPL augmentor (y-axis rotation and initial frame noise) + target = self.SMPLAugmentor(target) + + # 3D and 2D keypoints + if self.__name__ == 'ThreeDPW': # 3DPW has SMPL labels + gt_kp3d = self.labels['joints3D'][start_index:end_index+1].clone() + gt_kp2d = self.labels['joints2D'][start_index+1:end_index+1, ..., :2].clone() + gt_kp3d = root_centering(gt_kp3d.clone()) + + else: # Human36m and MPII do not have SMPL labels + gt_kp3d = torch.zeros((self.n_frames, self.n_joints + 14, 3)) + gt_kp3d[:, self.n_joints:] = convert_kps(self.labels['joints3D'][start_index:end_index+1], 'spin', 'common') + gt_kp2d = torch.zeros((self.n_frames - 1, self.n_joints + 14, 2)) + gt_kp2d[:, self.n_joints:] = convert_kps(self.labels['joints2D'][start_index+1:end_index+1, ..., :2], 'spin', 'common') + + conf = self.mask.repeat(self.n_frames, 1).unsqueeze(-1) + gt_kp2d = torch.cat((gt_kp2d, conf[1:]), dim=-1) + gt_kp3d = torch.cat((gt_kp3d, conf), dim=-1) + target['kp3d'] = gt_kp3d + target['full_kp2d'] = gt_kp2d + target['weak_kp2d'] = torch.zeros_like(gt_kp2d) + + if self.__name__ != 'ThreeDPW': # 3DPW does not contain world-coordinate motion + # Foot ground contact labels for Human36M and MPII3D + target['contact'] = self.labels['stationaries'][start_index+1:end_index+1].clone() + else: + # No foot ground contact label available for 3DPW + target['contact'] = torch.ones((self.n_frames - 1, 4)) * (-1) + + if self.has_verts: + # SMPL vertices available for 3DPW + with torch.no_grad(): + start_index, end_index = self.video_indices[index] + gender = self.labels['gender'][start_index].item() + output = self.smpl_gender[gender]( + body_pose=target['pose'][1:, 1:], + global_orient=target['pose'][1:, :1], + betas=target['betas'][1:], + pose2rot=False, + ) + target['verts'] = output.vertices.clone() + else: + # No SMPL vertices available + target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float() + + return target + + def get_init_frame(self, target): + # Prepare initial frame + output = self.smpl.get_output( + body_pose=target['init_pose'][:, 1:], + global_orient=target['init_pose'][:, :1], + betas=target['betas'][:1], + pose2rot=False + ) + target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1) + + return target + + def get_camera_info(self, index, target): + start_index, end_index = self.video_indices[index] + + # Intrinsics + target['res'] = self.labels['res'][start_index:end_index+1][0].clone() + self.get_naive_intrinsics(target['res']) + target['cam_intrinsics'] = self.cam_intrinsics.clone() + + # Extrinsics pose + R = self.labels['cam_poses'][start_index:end_index+1, :3, :3].clone().float() + yaw = transforms.axis_angle_to_matrix(torch.tensor([[0, 2 * np.pi * np.random.uniform(), 0]])).float() + if self.__name__ == 'Human36M': + # Map Z-up to Y-down coordinate + zup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[-np.pi/2, 0, 0]])).float() + zup2ydown = torch.matmul(yaw, zup2ydown) + R = torch.matmul(R, zup2ydown) + elif self.__name__ == 'MPII3D': + # Map Y-up to Y-down coordinate + yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float() + yup2ydown = torch.matmul(yaw, yup2ydown) + R = torch.matmul(R, yup2ydown) + + return target + + def get_single_sequence(self, index): + # Universal target + target = {'has_full_screen': torch.tensor(True), + 'has_smpl': torch.tensor(self.has_smpl), + 'has_traj': torch.tensor(self.has_traj), + 'has_verts': torch.tensor(self.has_verts), + 'transl': torch.zeros((self.n_frames, 3)), + + # Null camera motion + 'R': torch.eye(3).repeat(self.n_frames, 1, 1), + 'cam_angvel': torch.zeros((self.n_frames - 1, 6)), + + # Null root orientation and velocity + 'pose_root': torch.zeros((self.n_frames, 6)), + 'vel_root': torch.zeros((self.n_frames - 1, 3)), + 'init_root': torch.zeros((1, 6)), + } + + self.get_camera_info(index, target) + self.get_inputs(index, target) + self.get_labels(index, target) + self.get_init_frame(target) + + target = d_utils.prepare_keypoints_data(target) + target = d_utils.prepare_smpl_data(target) + + return target \ No newline at end of file diff --git a/lib/data/datasets/dataset_custom.py b/lib/data/datasets/dataset_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..dd27140050725aa715a42f411dc270c12c13723d --- /dev/null +++ b/lib/data/datasets/dataset_custom.py @@ -0,0 +1,115 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch + +from ..utils.normalizer import Normalizer +from ...models import build_body_model +from ...utils import transforms +from ...utils.kp_utils import root_centering +from ...utils.imutils import compute_cam_intrinsics + +KEYPOINTS_THR = 0.3 + +def convert_dpvo_to_cam_angvel(traj, fps): + """Function to convert DPVO trajectory output to camera angular velocity""" + + # 0 ~ 3: translation, 3 ~ 7: Quaternion + quat = traj[:, 3:] + + # Convert (x,y,z,q) to (q,x,y,z) + quat = quat[:, [3, 0, 1, 2]] + + # Quat is camera to world transformation. Convert it to world to camera + world2cam = transforms.quaternion_to_matrix(torch.from_numpy(quat)).float() + R = world2cam.mT + + # Compute the rotational changes over time. + cam_angvel = transforms.matrix_to_axis_angle(R[:-1] @ R[1:].transpose(-1, -2)) + + # Convert matrix to 6D representation + cam_angvel = transforms.matrix_to_rotation_6d(transforms.axis_angle_to_matrix(cam_angvel)) + + # Normalize 6D angular velocity + cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize + cam_angvel = cam_angvel * fps + cam_angvel = torch.cat((cam_angvel, cam_angvel[:1]), dim=0) + return cam_angvel + + +class CustomDataset(torch.utils.data.Dataset): + def __init__(self, cfg, tracking_results, slam_results, width, height, fps): + + self.tracking_results = tracking_results + self.slam_results = slam_results + self.width = width + self.height = height + self.fps = fps + self.res = torch.tensor([width, height]).float() + self.intrinsics = compute_cam_intrinsics(self.res) + + self.device = cfg.DEVICE.lower() + + self.smpl = build_body_model('cpu') + self.keypoints_normalizer = Normalizer(cfg) + + self._to = lambda x: x.unsqueeze(0).to(self.device) + + def __len__(self): + return len(self.tracking_results.keys()) + + def load_data(self, index, flip=False): + if flip: + self.prefix = 'flipped_' + else: + self.prefix = '' + + return self.__getitem__(index) + + def __getitem__(self, _index): + if _index >= len(self): return + + index = sorted(list(self.tracking_results.keys()))[_index] + + # Process 2D keypoints + kp2d = torch.from_numpy(self.tracking_results[index][self.prefix + 'keypoints']).float() + mask = kp2d[..., -1] < KEYPOINTS_THR + bbox = torch.from_numpy(self.tracking_results[index][self.prefix + 'bbox']).float() + + norm_kp2d, _ = self.keypoints_normalizer( + kp2d[..., :-1].clone(), self.res, self.intrinsics, 224, 224, bbox + ) + + # Process image features + features = self.tracking_results[index][self.prefix + 'features'] + + # Process initial pose + init_output = self.smpl.get_output( + global_orient=self.tracking_results[index][self.prefix + 'init_global_orient'], + body_pose=self.tracking_results[index][self.prefix + 'init_body_pose'], + betas=self.tracking_results[index][self.prefix + 'init_betas'], + pose2rot=False, + return_full_pose=True + ) + init_kp3d = root_centering(init_output.joints[:, :17], 'coco') + init_kp = torch.cat((init_kp3d.reshape(1, -1), norm_kp2d[0].clone().reshape(1, -1)), dim=-1) + init_smpl = transforms.matrix_to_rotation_6d(init_output.full_pose) + init_root = transforms.matrix_to_rotation_6d(init_output.global_orient) + + # Process SLAM results + cam_angvel = convert_dpvo_to_cam_angvel(self.slam_results, self.fps) + + return ( + index, # subject id + self._to(norm_kp2d), # 2d keypoints + (self._to(init_kp), self._to(init_smpl)), # initial pose + self._to(features), # image features + self._to(mask), # keypoints mask + init_root.to(self.device), # initial root orientation + self._to(cam_angvel), # camera angular velocity + self.tracking_results[index]['frame_id'], # frame indices + {'cam_intrinsics': self._to(self.intrinsics), # other keyword arguments + 'bbox': self._to(bbox), + 'res': self._to(self.res)}, + ) \ No newline at end of file diff --git a/lib/data/datasets/dataset_eval.py b/lib/data/datasets/dataset_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..85d7569027b09b08b91c6cc9c35404a800709aee --- /dev/null +++ b/lib/data/datasets/dataset_eval.py @@ -0,0 +1,113 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os +import torch +import joblib + +from configs import constants as _C +from .._dataset import BaseDataset +from ...utils import transforms +from ...utils import data_utils as d_utils +from ...utils.kp_utils import root_centering + +FPS = 30 +class EvalDataset(BaseDataset): + def __init__(self, cfg, data, split, backbone): + super(EvalDataset, self).__init__(cfg, False) + + self.prefix = '' + self.data = data + parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth') + self.labels = joblib.load(parsed_data_path) + + def load_data(self, index, flip=False): + if flip: + self.prefix = 'flipped_' + else: + self.prefix = '' + + target = self.__getitem__(index) + for key, val in target.items(): + if isinstance(val, torch.Tensor): + target[key] = val.unsqueeze(0) + return target + + def __getitem__(self, index): + target = {} + target = self.get_data(index) + target = d_utils.prepare_keypoints_data(target) + target = d_utils.prepare_smpl_data(target) + + return target + + def __len__(self): + return len(self.labels['kp2d']) + + def prepare_labels(self, index, target): + # Ground truth SMPL parameters + target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3)) + target['betas'] = self.labels['betas'][index] + target['gender'] = self.labels['gender'][index] + + # Sequence information + target['res'] = self.labels['res'][index][0] + target['vid'] = self.labels['vid'][index] + target['frame_id'] = self.labels['frame_id'][index][1:] + + # Camera information + self.get_naive_intrinsics(target['res']) + target['cam_intrinsics'] = self.cam_intrinsics + R = self.labels['cam_poses'][index][:, :3, :3].clone() + if 'emdb' in self.data.lower(): + # Use groundtruth camera angular velocity. + # Can be updated with SLAM results if you have it. + cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2)) + cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS + target['R'] = R + else: + cam_angvel = torch.zeros((len(target['pose']) - 1, 6)) + target['cam_angvel'] = cam_angvel + return target + + def prepare_inputs(self, index, target): + for key in ['features', 'bbox']: + data = self.labels[self.prefix + key][index][1:] + target[key] = data + + bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float() + bbox[:, 2] = bbox[:, 2] / 200 + + # Normalize keypoints + kp2d, bbox = self.keypoints_normalizer( + self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(), + target['res'], target['cam_intrinsics'], 224, 224, bbox) + target['kp2d'] = kp2d + target['bbox'] = bbox[1:] + + # Masking out low confident keypoints + mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3 + target['input_kp2d'] = self.labels['kp2d'][index][1:] + target['input_kp2d'][mask[1:]] *= 0 + target['mask'] = mask[1:] + + return target + + def prepare_initialization(self, index, target): + # Initial frame per-frame estimation + target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1) + target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu() + pose_root = target['pose'][:, 0].clone() + target['init_root'] = transforms.matrix_to_rotation_6d(pose_root) + + return target + + def get_data(self, index): + target = {} + + target = self.prepare_labels(index, target) + target = self.prepare_inputs(index, target) + target = self.prepare_initialization(index, target) + + return target \ No newline at end of file diff --git a/lib/data/datasets/mixed_dataset.py b/lib/data/datasets/mixed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..33168238d2745014b92c8d9809d54c8e422291f4 --- /dev/null +++ b/lib/data/datasets/mixed_dataset.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import numpy as np + +from .amass import AMASSDataset +from .videos import Human36M, ThreeDPW, MPII3D, InstaVariety +from .bedlam import BEDLAMDataset +from lib.utils.data_utils import make_collate_fn + + +class DataFactory(torch.utils.data.Dataset): + def __init__(self, cfg, train_stage='syn'): + super(DataFactory, self).__init__() + + if train_stage == 'stage1': + self.datasets = [AMASSDataset(cfg)] + self.dataset_names = ['AMASS'] + elif train_stage == 'stage2': + self.datasets = [ + AMASSDataset(cfg), ThreeDPW(cfg), + Human36M(cfg), MPII3D(cfg), InstaVariety(cfg) + ] + self.dataset_names = ['AMASS', '3DPW', 'Human36M', 'MPII3D', 'Insta'] + + if len(cfg.DATASET.RATIO) == 6: # Use BEDLAM + self.datasets.append(BEDLAMDataset(cfg)) + self.dataset_names.append('BEDLAM') + + self._set_partition(cfg.DATASET.RATIO) + self.lengths = [len(ds) for ds in self.datasets] + + @property + def __name__(self, ): + return 'MixedData' + + def prepare_video_batch(self): + [ds.prepare_video_batch() for ds in self.datasets] + self.lengths = [len(ds) for ds in self.datasets] + + def _set_partition(self, partition): + self.partition = partition + self.ratio = partition + self.partition = np.array(self.partition).cumsum() + self.partition /= self.partition[-1] + + def __len__(self): + return int(np.array([l for l, r in zip(self.lengths, self.ratio) if r > 0]).mean()) + + def __getitem__(self, index): + # Get the dataset to sample from + p = np.random.rand() + for i in range(len(self.datasets)): + if p <= self.partition[i]: + if len(self.datasets) == 1: + return self.datasets[i][index % self.lengths[i]] + else: + d_index = np.random.randint(0, self.lengths[i]) + return self.datasets[i][d_index] \ No newline at end of file diff --git a/lib/data/datasets/videos.py b/lib/data/datasets/videos.py new file mode 100644 index 0000000000000000000000000000000000000000..2001e86c64774849ef5f47f36248b7cef8f54afa --- /dev/null +++ b/lib/data/datasets/videos.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os +import torch + +from configs import constants as _C +from .dataset3d import Dataset3D +from .dataset2d import Dataset2D +from ...utils.kp_utils import convert_kps +from smplx import SMPL + + +class Human36M(Dataset3D): + def __init__(self, cfg, dset='train'): + parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'human36m_{dset}_backbone.pth') + parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) + super(Human36M, self).__init__(cfg, parsed_data_path, dset=='train') + + self.has_3d = True + self.has_traj = True + self.has_smpl = False + self.has_verts = False + + # Among 31 joints format, 14 common joints are avaialable + self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) + self.mask[-14:] = 1 + + @property + def __name__(self, ): + return 'Human36M' + + def compute_3d_keypoints(self, index): + return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m' + )[:, _C.KEYPOINTS.H36M_TO_J14].float() + +class MPII3D(Dataset3D): + def __init__(self, cfg, dset='train'): + parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'mpii3d_{dset}_backbone.pth') + parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) + super(MPII3D, self).__init__(cfg, parsed_data_path, dset=='train') + + self.has_3d = True + self.has_traj = True + self.has_smpl = False + self.has_verts = False + + # Among 31 joints format, 14 common joints are avaialable + self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) + self.mask[-14:] = 1 + + @property + def __name__(self, ): + return 'MPII3D' + + def compute_3d_keypoints(self, index): + return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m' + )[:, _C.KEYPOINTS.H36M_TO_J17].float() + +class ThreeDPW(Dataset3D): + def __init__(self, cfg, dset='train'): + parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'3dpw_{dset}_backbone.pth') + parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) + super(ThreeDPW, self).__init__(cfg, parsed_data_path, dset=='train') + + self.has_3d = True + self.has_traj = False + self.has_smpl = True + self.has_verts = True # In testing + + # Among 31 joints format, 14 common joints are avaialable + self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) + self.mask[:-14] = 1 + + self.smpl_gender = { + 0: SMPL(_C.BMODEL.FLDR, gender='male', num_betas=10), + 1: SMPL(_C.BMODEL.FLDR, gender='female', num_betas=10) + } + + @property + def __name__(self, ): + return 'ThreeDPW' + + def compute_3d_keypoints(self, index): + return self.labels['joints3D'][index] + + +class InstaVariety(Dataset2D): + def __init__(self, cfg, dset='train'): + parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'insta_{dset}_backbone.pth') + parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) + super(InstaVariety, self).__init__(cfg, parsed_data_path, dset=='train') + + self.has_3d = False + self.has_traj = False + self.has_smpl = False + + # Among 31 joints format, 17 coco joints are avaialable + self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) + self.mask[:17] = 1 + + @property + def __name__(self, ): + return 'InstaVariety' \ No newline at end of file diff --git a/lib/data/utils/__pycache__/augmentor.cpython-39.pyc b/lib/data/utils/__pycache__/augmentor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0330486fd2e6ed222ac687d420bb112eea0349b Binary files /dev/null and b/lib/data/utils/__pycache__/augmentor.cpython-39.pyc differ diff --git a/lib/data/utils/__pycache__/normalizer.cpython-39.pyc b/lib/data/utils/__pycache__/normalizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c46765d436822c52ba518881a0c17eca7da9b165 Binary files /dev/null and b/lib/data/utils/__pycache__/normalizer.cpython-39.pyc differ diff --git a/lib/data/utils/augmentor.py b/lib/data/utils/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..baa338013ba668192f854b39faf8cf2399803996 --- /dev/null +++ b/lib/data/utils/augmentor.py @@ -0,0 +1,292 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +from configs import constants as _C + +import torch +import numpy as np +from torch.nn import functional as F + +from ...utils import transforms + +__all__ = ['VideoAugmentor', 'SMPLAugmentor', 'SequenceAugmentor', 'CameraAugmentor'] + + +num_joints = _C.KEYPOINTS.NUM_JOINTS +class VideoAugmentor(): + def __init__(self, cfg, train=True): + self.train = train + self.l = cfg.DATASET.SEQLEN + 1 + self.aug_dict = torch.load(_C.KEYPOINTS.COCO_AUG_DICT) + + def get_jitter(self, ): + """Guassian jitter modeling.""" + jittering_noise = torch.normal( + mean=torch.zeros((self.l, num_joints, 3)), + std=self.aug_dict['jittering'].reshape(1, num_joints, 1).expand(self.l, -1, 3) + ) * _C.KEYPOINTS.S_JITTERING + return jittering_noise + + def get_lfhp(self, ): + """Low-frequency high-peak noise modeling.""" + def get_peak_noise_mask(): + peak_noise_mask = torch.rand(self.l, num_joints).float() * self.aug_dict['pmask'].squeeze(0) + peak_noise_mask = peak_noise_mask < _C.KEYPOINTS.S_PEAK_MASK + return peak_noise_mask + + peak_noise_mask = get_peak_noise_mask() + peak_noise = peak_noise_mask.float().unsqueeze(-1).repeat(1, 1, 3) + peak_noise = peak_noise * torch.randn(3) * self.aug_dict['peak'].reshape(1, -1, 1) * _C.KEYPOINTS.S_PEAK + return peak_noise + + def get_bias(self, ): + """Bias noise modeling.""" + bias_noise = torch.normal( + mean=torch.zeros((num_joints, 3)), std=self.aug_dict['bias'].reshape(num_joints, 1) + ).unsqueeze(0) * _C.KEYPOINTS.S_BIAS + return bias_noise + + def get_mask(self, scale=None): + """Mask modeling.""" + + if scale is None: + scale = _C.KEYPOINTS.S_MASK + # Per-frame and joint + mask = torch.rand(self.l, num_joints) < scale + visible = (~mask).clone() + for child in range(num_joints): + parent = _C.KEYPOINTS.TREE[child] + if parent == -1: continue + if isinstance(parent, list): + visible[:, child] *= (visible[:, parent[0]] * visible[:, parent[1]]) + else: + visible[:, child] *= visible[:, parent] + mask = (~visible).clone() + + return mask + + def __call__(self, keypoints): + keypoints += self.get_bias() + self.get_jitter() + self.get_lfhp() + return keypoints + + +class SMPLAugmentor(): + noise_scale = 1e-2 + + def __init__(self, cfg, augment=True): + self.n_frames = cfg.DATASET.SEQLEN + self.augment = augment + + def __call__(self, target): + if not self.augment: + # Only add initial frame augmentation + if not 'init_pose' in target: + target['init_pose'] = target['pose'][:1] @ self.get_initial_pose_augmentation() + return target + + n_frames = target['pose'].shape[0] + + # Global rotation + rmat = self.get_global_augmentation() + target['pose'][:, 0] = rmat @ target['pose'][:, 0] + target['transl'] = (rmat.squeeze() @ target['transl'].T).T + + # Shape + shape_noise = self.get_shape_augmentation(n_frames) + target['betas'] = target['betas'] + shape_noise + + # Initial frames mis-prediction + target['init_pose'] = target['pose'][:1] @ self.get_initial_pose_augmentation() + + return target + + def get_global_augmentation(self, ): + """Global coordinate augmentation. Random rotation around y-axis""" + + angle_y = torch.rand(1) * 2 * np.pi * float(self.augment) + aa = torch.tensor([0.0, angle_y, 0.0]).float().unsqueeze(0) + rmat = transforms.axis_angle_to_matrix(aa) + + return rmat + + def get_shape_augmentation(self, n_frames): + """Shape noise modeling.""" + + shape_noise = torch.normal( + mean=torch.zeros((1, 10)), + std=torch.ones((1, 10)) * 0.1 * float(self.augment)).expand(n_frames, 10) + + return shape_noise + + def get_initial_pose_augmentation(self, ): + """Initial frame pose noise modeling. Random rotation around all joints.""" + + euler = torch.normal( + mean=torch.zeros((24, 3)), + std=torch.ones((24, 3)) + ) * self.noise_scale #* float(self.augment) + rmat = transforms.axis_angle_to_matrix(euler) + + return rmat.unsqueeze(0) + + +class SequenceAugmentor: + """Augment the play speed of the motion sequence""" + l_factor = 1.5 + def __init__(self, l_default): + self.l_default = l_default + + def __call__(self, target): + l = torch.randint(low=int(self.l_default / self.l_factor), high=int(self.l_default * self.l_factor), size=(1, )) + + pose = transforms.matrix_to_rotation_6d(target['pose']) + resampled_pose = F.interpolate( + pose[:l].permute(1, 2, 0), self.l_default, mode='linear', align_corners=True + ).permute(2, 0, 1) + resampled_pose = transforms.rotation_6d_to_matrix(resampled_pose) + + transl = target['transl'].unsqueeze(1) + resampled_transl = F.interpolate( + transl[:l].permute(1, 2, 0), self.l_default, mode='linear', align_corners=True + ).squeeze(0).T + + target['pose'] = resampled_pose + target['transl'] = resampled_transl + target['betas'] = target['betas'][:self.l_default] + + return target + + +class CameraAugmentor: + rx_factor = np.pi/8 + ry_factor = np.pi/4 + rz_factor = np.pi/8 + + pitch_std = np.pi/8 + pitch_mean = np.pi/36 + roll_std = np.pi/24 + t_factor = 1 + + tz_scale = 10 + tz_min = 2 + + motion_prob = 0.75 + interp_noise = 0.2 + + def __init__(self, l, w, h, f): + self.l = l + self.w = w + self.h = h + self.f = f + self.fov_tol = 1.2 * (0.5 ** 0.5) + + def __call__(self, target): + + R, T = self.create_camera(target) + + if np.random.rand() < self.motion_prob: + R = self.create_rotation_move(R) + T = self.create_translation_move(T) + + return self.apply(target, R, T) + + def create_camera(self, target): + """Create the initial frame camera pose""" + yaw = np.random.rand() * 2 * np.pi + pitch = np.random.normal(scale=self.pitch_std) + self.pitch_mean + roll = np.random.normal(scale=self.roll_std) + + yaw_rm = transforms.axis_angle_to_matrix(torch.tensor([[0, yaw, 0]]).float()) + pitch_rm = transforms.axis_angle_to_matrix(torch.tensor([[pitch, 0, 0]]).float()) + roll_rm = transforms.axis_angle_to_matrix(torch.tensor([[0, 0, roll]]).float()) + R = (roll_rm @ pitch_rm @ yaw_rm) + + # Place people in the scene + tz = np.random.rand() * self.tz_scale + self.tz_min + max_d = self.w * tz / self.f / 2 + tx = np.random.normal(scale=0.25) * max_d + ty = np.random.normal(scale=0.25) * max_d + dist = torch.tensor([tx, ty, tz]).float() + T = dist - torch.matmul(R, target['transl'][0]) + + return R.repeat(self.l, 1, 1), T.repeat(self.l, 1) + + def create_rotation_move(self, R): + """Create rotational move for the camera""" + + # Create final camera pose + rx = np.random.normal(scale=self.rx_factor) + ry = np.random.normal(scale=self.ry_factor) + rz = np.random.normal(scale=self.rz_factor) + Rf = R[0] @ transforms.axis_angle_to_matrix(torch.tensor([rx, ry, rz]).float()) + + # Inbetweening two poses + Rs = torch.stack((R[0], Rf)) + rs = transforms.matrix_to_rotation_6d(Rs).numpy() + rs_move = self.noisy_interpolation(rs) + R_move = transforms.rotation_6d_to_matrix(torch.from_numpy(rs_move).float()) + return R_move + + def create_translation_move(self, T): + """Create translational move for the camera""" + + # Create final camera position + tx = np.random.normal(scale=self.t_factor) + ty = np.random.normal(scale=self.t_factor) + tz = np.random.normal(scale=self.t_factor) + Ts = np.array([[0, 0, 0], [tx, ty, tz]]) + + T_move = self.noisy_interpolation(Ts) + T_move = torch.from_numpy(T_move).float() + return T_move + T + + def noisy_interpolation(self, data): + """Non-linear interpolation with noise""" + + dim = data.shape[-1] + output = np.zeros((self.l, dim)) + + linspace = np.stack([np.linspace(0, 1, self.l) for _ in range(dim)]) + noise = (linspace[0, 1] - linspace[0, 0]) * self.interp_noise + space_noise = np.stack([np.random.uniform(-noise, noise, self.l - 2) for _ in range(dim)]) + + linspace[:, 1:-1] = linspace[:, 1:-1] + space_noise + for i in range(dim): + output[:, i] = np.interp(linspace[i], np.array([0., 1.,]), data[:, i]) + return output + + def apply(self, target, R, T): + target['R'] = R + target['T'] = T + + # Recompute the translation + transl_cam = torch.matmul(R, target['transl'].unsqueeze(-1)).squeeze(-1) + transl_cam = transl_cam + T + if transl_cam[..., 2].min() < 0.5: # If the person is too close to the camera + transl_cam[..., 2] = transl_cam[..., 2] + (1.0 - transl_cam[..., 2].min()) + + # If the subject is away from the field of view, put the camera behind + fov = torch.div(transl_cam[..., :2], transl_cam[..., 2:]).abs() + if fov.max() > self.fov_tol: + t_max = transl_cam[fov.max(1)[0].max(0)[1].item()] + z_trg = t_max[:2].abs().max(0)[0] / self.fov_tol + pad = z_trg - t_max[2] + transl_cam[..., 2] = transl_cam[..., 2] + pad + + target['transl_cam'] = transl_cam + + # Transform world coordinate to camera coordinate + target['pose_root'] = target['pose'][:, 0].clone() + target['pose'][:, 0] = R @ target['pose'][:, 0] # pose + target['init_pose'][:, 0] = R[:1] @ target['init_pose'][:, 0] # init pose + + # Compute angular velocity + cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2)) + cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize + target['cam_angvel'] = cam_angvel * 3e1 # assume 30-fps + + if 'kp3d' in target: + target['kp3d'] = torch.matmul(R, target['kp3d'].transpose(1, 2)).transpose(1, 2) + target['transl_cam'].unsqueeze(1) + + return target \ No newline at end of file diff --git a/lib/data/utils/normalizer.py b/lib/data/utils/normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea8c049f424b4896c9c473802ff6908328e8b36 --- /dev/null +++ b/lib/data/utils/normalizer.py @@ -0,0 +1,105 @@ +import torch +import numpy as np + +from ...utils.imutils import transform_keypoints + +class Normalizer: + def __init__(self, cfg): + pass + + def __call__(self, kp_2d, res, cam_intrinsics, patch_width=224, patch_height=224, bbox=None, mask=None): + if bbox is None: + bbox = compute_bbox_from_keypoints(kp_2d, do_augment=True, mask=mask) + + out_kp_2d = self.bbox_normalization(kp_2d, bbox, res, patch_width, patch_height) + return out_kp_2d, bbox + + def bbox_normalization(self, kp_2d, bbox, res, patch_width, patch_height): + to_torch = False + if isinstance(kp_2d, torch.Tensor): + to_torch = True + kp_2d = kp_2d.numpy() + bbox = bbox.numpy() + + out_kp_2d = np.zeros_like(kp_2d) + for idx in range(len(out_kp_2d)): + out_kp_2d[idx] = transform_keypoints(kp_2d[idx], bbox[idx][:3], patch_width, patch_height)[0] + out_kp_2d[idx] = normalize_keypoints_to_patch(out_kp_2d[idx], patch_width) + + if to_torch: + out_kp_2d = torch.from_numpy(out_kp_2d) + bbox = torch.from_numpy(bbox) + + centers = normalize_keypoints_to_image(bbox[:, :2].unsqueeze(1), res).squeeze(1) + scale = bbox[:, 2:] * 200 / res.max() + location = torch.cat((centers, scale), dim=-1) + + out_kp_2d = out_kp_2d.reshape(out_kp_2d.shape[0], -1) + out_kp_2d = torch.cat((out_kp_2d, location), dim=-1) + return out_kp_2d + + +def normalize_keypoints_to_patch(kp_2d, crop_size=224, inv=False): + # Normalize keypoints between -1, 1 + if not inv: + ratio = 1.0 / crop_size + kp_2d = 2.0 * kp_2d * ratio - 1.0 + else: + ratio = 1.0 / crop_size + kp_2d = (kp_2d + 1.0)/(2*ratio) + + return kp_2d + + +def normalize_keypoints_to_image(x, res): + res = res.to(x.device) + scale = res.max(-1)[0].reshape(-1) + mean = torch.stack([res[..., 0] / scale, res[..., 1] / scale], dim=-1).to(x.device) + x = (2 * x / scale.reshape(*[1 for i in range(len(x.shape[1:]))]) - \ + mean.reshape(*[1 for i in range(len(x.shape[1:-1]))], -1)) + return x + + +def compute_bbox_from_keypoints(X, do_augment=False, mask=None): + def smooth_bbox(bb): + # Smooth bounding box detection + import scipy.signal as signal + smoothed = np.array([signal.medfilt(param, int(30 / 2)) for param in bb]) + return smoothed + + def do_augmentation(scale_factor=0.2, trans_factor=0.05): + _scaleFactor = np.random.uniform(1.0 - scale_factor, 1.2 + scale_factor) + _trans_x = np.random.uniform(-trans_factor, trans_factor) + _trans_y = np.random.uniform(-trans_factor, trans_factor) + + return _scaleFactor, _trans_x, _trans_y + + if do_augment: + scaleFactor, trans_x, trans_y = do_augmentation() + else: + scaleFactor, trans_x, trans_y = 1.2, 0.0, 0.0 + + if mask is None: + bbox = [X[:, :, 0].min(-1)[0], X[:, :, 1].min(-1)[0], + X[:, :, 0].max(-1)[0], X[:, :, 1].max(-1)[0]] + else: + bbox = [] + for x, _mask in zip(X, mask): + if _mask.sum() > 10: + _mask[:] = False + _bbox = [x[~_mask, 0].min(-1)[0], x[~_mask, 1].min(-1)[0], + x[~_mask, 0].max(-1)[0], x[~_mask, 1].max(-1)[0]] + bbox.append(_bbox) + bbox = torch.tensor(bbox).T + + cx, cy = [(bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2] + bbox_w = bbox[2] - bbox[0] + bbox_h = bbox[3] - bbox[1] + bbox_size = torch.stack((bbox_w, bbox_h)).max(0)[0] + scale = bbox_size * scaleFactor + bbox = torch.stack((cx + trans_x * scale, cy + trans_y * scale, scale / 200)) + + if do_augment: + bbox = torch.from_numpy(smooth_bbox(bbox.numpy())) + + return bbox.T \ No newline at end of file diff --git a/lib/data_utils/amass_utils.py b/lib/data_utils/amass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94e4e5b6dad8772568928272ba97d972837a8dda --- /dev/null +++ b/lib/data_utils/amass_utils.py @@ -0,0 +1,107 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os +import os.path as osp +from collections import defaultdict + +import torch +import joblib +import numpy as np +from tqdm import tqdm +from smplx import SMPL + +from configs import constants as _C +from lib.utils.data_utils import map_dmpl_to_smpl, transform_global_coordinate + + +@torch.no_grad() +def process_amass(): + target_fps = 30 + + _, seqs, _ = next(os.walk(_C.PATHS.AMASS_PTH)) + + zup2ydown = torch.Tensor( + [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + ).unsqueeze(0).float() + + smpl_dict = {'male': SMPL(model_path=_C.BMODEL.FLDR, gender='male'), + 'female': SMPL(model_path=_C.BMODEL.FLDR, gender='female'), + 'neutral': SMPL(model_path=_C.BMODEL.FLDR)} + processed_data = defaultdict(list) + + for seq in (seq_bar := tqdm(sorted(seqs), leave=True)): + seq_bar.set_description(f'Dataset: {seq}') + seq_fldr = osp.join(_C.PATHS.AMASS_PTH, seq) + _, subjs, _ = next(os.walk(seq_fldr)) + + for subj in (subj_bar := tqdm(sorted(subjs), leave=False)): + subj_bar.set_description(f'Subject: {subj}') + subj_fldr = osp.join(seq_fldr, subj) + acts = [x for x in os.listdir(subj_fldr) if x.endswith('.npz')] + + for act in (act_bar := tqdm(sorted(acts), leave=False)): + act_bar.set_description(f'Action: {act}') + + # Load data + fname = osp.join(subj_fldr, act) + if fname.endswith('shape.npz') or fname.endswith('stagei.npz'): + # Skip shape and stagei files + continue + data = dict(np.load(fname, allow_pickle=True)) + + # Resample data to target_fps + key = [k for k in data.keys() if 'mocap_frame' in k][0] + mocap_framerate = data[key] + retain_freq = int(mocap_framerate / target_fps + 0.5) + num_frames = len(data['poses'][::retain_freq]) + + # Skip if the sequence is too short + if num_frames < 25: continue + + # Get SMPL groundtruth from MoSh fitting + pose = map_dmpl_to_smpl(torch.from_numpy(data['poses'][::retain_freq]).float()) + transl = torch.from_numpy(data['trans'][::retain_freq]).float() + betas = torch.from_numpy( + np.repeat(data['betas'][:10][np.newaxis], pose.shape[0], axis=0)).float() + + # Convert Z-up coordinate to Y-down + pose, transl = transform_global_coordinate(pose, zup2ydown, transl) + pose = pose.reshape(-1, 72) + + # Create SMPL mesh + gender = str(data['gender']) + if not gender in ['male', 'female', 'neutral']: + if 'female' in gender: gender = 'female' + elif 'neutral' in gender: gender = 'neutral' + elif 'male' in gender: gender = 'male' + + output = smpl_dict[gender](body_pose=pose[:, 3:], + global_orient=pose[:, :3], + betas=betas, + transl=transl) + vertices = output.vertices + + # Assume motion starts with 0-height + init_height = vertices[0].max(0)[0][1] + transl[:, 1] = transl[:, 1] + init_height + vertices[:, :, 1] = vertices[:, :, 1] - init_height + + # Append data + processed_data['pose'].append(pose.numpy()) + processed_data['betas'].append(betas.numpy()) + processed_data['transl'].append(transl.numpy()) + processed_data['vid'].append(np.array([f'{seq}_{subj}_{act}'] * pose.shape[0])) + + for key, val in processed_data.items(): + processed_data[key] = np.concatenate(val) + + joblib.dump(processed_data, _C.PATHS.AMASS_LABEL) + print('\nDone!') + +if __name__ == '__main__': + out_path = '/'.join(_C.PATHS.AMASS_LABEL.split('/')[:-1]) + os.makedirs(out_path, exist_ok=True) + + process_amass() \ No newline at end of file diff --git a/lib/data_utils/emdb_eval_utils.py b/lib/data_utils/emdb_eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d519cb9d289c68534c84bef9dfa47183e415a67e --- /dev/null +++ b/lib/data_utils/emdb_eval_utils.py @@ -0,0 +1,189 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os +import os.path as osp +from glob import glob +from collections import defaultdict + +import cv2 +import torch +import pickle +import joblib +import argparse +import numpy as np +from loguru import logger +from progress.bar import Bar + +from configs import constants as _C +from lib.models.smpl import SMPL +from lib.models.preproc.extractor import FeatureExtractor +from lib.models.preproc.backbone.utils import process_image +from lib.utils import transforms +from lib.utils.imutils import ( + flip_kp, flip_bbox +) + +dataset = defaultdict(list) +detection_results_dir = 'dataset/detection_results/EMDB' + +def is_dset(emdb_pkl_file, dset): + target_dset = 'emdb' + dset + with open(emdb_pkl_file, "rb") as f: + data = pickle.load(f) + return data[target_dset] + +@torch.no_grad() +def preprocess(dset, batch_size): + + tt = lambda x: torch.from_numpy(x).float() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + save_pth = osp.join(_C.PATHS.PARSED_DATA, f'emdb_{dset}_vit.pth') # Use ViT feature extractor + extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size) + + all_emdb_pkl_files = sorted(glob(os.path.join(_C.PATHS.EMDB_PTH, "*/*/*_data.pkl"))) + emdb_sequence_roots = [] + both = [] + for emdb_pkl_file in all_emdb_pkl_files: + if is_dset(emdb_pkl_file, dset): + emdb_sequence_roots.append(os.path.dirname(emdb_pkl_file)) + + smpl = { + 'neutral': SMPL(model_path=_C.BMODEL.FLDR), + 'male': SMPL(model_path=_C.BMODEL.FLDR, gender='male'), + 'female': SMPL(model_path=_C.BMODEL.FLDR, gender='female'), + } + + for sequence in emdb_sequence_roots: + subj, seq = sequence.split('/')[-2:] + annot_pth = glob(osp.join(sequence, '*_data.pkl'))[0] + annot = pickle.load(open(annot_pth, 'rb')) + + # Get ground truth data + gender = annot['gender'] + masks = annot['good_frames_mask'] + poses_body = annot["smpl"]["poses_body"] + poses_root = annot["smpl"]["poses_root"] + betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0) + extrinsics = annot["camera"]["extrinsics"] + width, height = annot['camera']['width'], annot['camera']['height'] + xyxys = annot['bboxes']['bboxes'] + + # Map to camear coordinate + poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root))) + poses = np.concatenate([poses_root_cam, poses_body], axis=-1) + + pred_kp2d = np.load(osp.join(detection_results_dir, f'{subj}_{seq}.npy')) + + # ======== Extract features ======== # + imname_list = sorted(glob(osp.join(sequence, 'images/*'))) + bboxes, frame_ids, patch_list, features, flipped_features = [], [], [], [], [] + bar = Bar(f'Load images', fill='#', max=len(imname_list)) + for idx, (imname, xyxy, mask) in enumerate(zip(imname_list, xyxys, masks)): + if not mask: continue + + # ========= Load image ========= # + img_rgb = cv2.cvtColor(cv2.imread(imname), cv2.COLOR_BGR2RGB) + + # ========= Load bbox ========= # + x1, y1, x2, y2 = xyxy + bbox = np.array([(x1 + x2)/2., (y1 + y2)/2., max(x2 - x1, y2 - y1) / 1.1]) + + # ========= Process image ========= # + norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256) + + patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float()) + bboxes.append(bbox) + frame_ids.append(idx) + bar.next() + + patch_list = torch.split(torch.cat(patch_list), batch_size) + bboxes = torch.from_numpy(np.stack(bboxes)).float() + for i, patch in enumerate(patch_list): + bbox = bboxes[i*batch_size:min((i+1)*batch_size, len(frame_ids))].float().cuda() + bbox_center = bbox[:, :2] + bbox_scale = bbox[:, 2] / 200 + + feature = extractor.model(patch.cuda(), encode=True) + features.append(feature.cpu()) + + flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True) + flipped_features.append(flipped_feature.cpu()) + + if i == 0: + init_patch = patch[[0]].clone() + + features = torch.cat(features) + flipped_features = torch.cat(flipped_features) + res_h, res_w = img_rgb.shape[:2] + + # ======== Append data ======== # + dataset['gender'].append(gender) + dataset['bbox'].append(bboxes) + dataset['res'].append(torch.tensor([[width, height]]).repeat(len(frame_ids), 1).float()) + dataset['vid'].append(f'{subj}_{seq}') + dataset['pose'].append(tt(poses)[frame_ids]) + dataset['betas'].append(tt(betas)[frame_ids]) + dataset['kp2d'].append(tt(pred_kp2d)[frame_ids]) + dataset['frame_id'].append(torch.from_numpy(np.array(frame_ids))) + dataset['cam_poses'].append(tt(extrinsics)[frame_ids]) + dataset['features'].append(features) + dataset['flipped_features'].append(flipped_features) + + # Flipped data + dataset['flipped_bbox'].append( + torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float() + ) + dataset['flipped_kp2d'].append( + torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float() + ) + # ======== Append data ======== # + + # Pad 1 frame + for key, val in dataset.items(): + if isinstance(val[-1], torch.Tensor): + dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0) + + # Initial predictions + bbox = bboxes[:1].clone().cuda() + bbox_center = bbox[:, :2].clone() + bbox_scale = bbox[:, 2].clone() / 200 + kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(), + 'img_h': torch.tensor(res_h).repeat(1).float().cuda(), + 'bbox_center': bbox_center, 'bbox_scale': bbox_scale} + + pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs) + pred_output = smpl['neutral'].get_output(global_orient=pred_global_orient.cpu(), + body_pose=pred_pose.cpu(), + betas=pred_shape.cpu(), + pose2rot=False) + init_kp3d = pred_output.joints + init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1)) + + dataset['init_kp3d'].append(init_kp3d) + dataset['init_pose'].append(init_pose.cpu()) + + # Flipped initial predictions + bbox_center[:, 0] = res_w - bbox_center[:, 0] + pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs) + pred_output = smpl['neutral'].get_output(global_orient=pred_global_orient.cpu(), + body_pose=pred_pose.cpu(), + betas=pred_shape.cpu(), + pose2rot=False) + init_kp3d = pred_output.joints + init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1)) + + dataset['flipped_init_kp3d'].append(init_kp3d) + dataset['flipped_init_pose'].append(init_pose.cpu()) + + joblib.dump(dataset, save_pth) + logger.info(f'==> Done !') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--split', type=str, choices=['1', '2'], help='Data split') + parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split') + args = parser.parse_args() + + preprocess(args.split, args.batch_size) \ No newline at end of file diff --git a/lib/data_utils/rich_eval_utils.py b/lib/data_utils/rich_eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a174ad2a867159da007e59819a4030910728a7 --- /dev/null +++ b/lib/data_utils/rich_eval_utils.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os +import os.path as osp +from glob import glob +from collections import defaultdict + +import cv2 +import torch +import pickle +import joblib +import argparse +import numpy as np +from loguru import logger +from progress.bar import Bar + +from configs import constants as _C +from lib.models.smpl import SMPL +from lib.models.preproc.extractor import FeatureExtractor +from lib.models.preproc.backbone.utils import process_image +from lib.utils import transforms +from lib.utils.imutils import ( + flip_kp, flip_bbox +) + +dataset = defaultdict(list) +detection_results_dir = 'dataset/detection_results/RICH' + +def extract_cam_param_xml(xml_path='', dtype=torch.float32): + + import xml.etree.ElementTree as ET + tree = ET.parse(xml_path) + + extrinsics_mat = [float(s) for s in tree.find('./CameraMatrix/data').text.split()] + intrinsics_mat = [float(s) for s in tree.find('./Intrinsics/data').text.split()] + # distortion_vec = [float(s) for s in tree.find('./Distortion/data').text.split()] + + focal_length_x = intrinsics_mat[0] + focal_length_y = intrinsics_mat[4] + center = torch.tensor([[intrinsics_mat[2], intrinsics_mat[5]]], dtype=dtype) + + rotation = torch.tensor([[extrinsics_mat[0], extrinsics_mat[1], extrinsics_mat[2]], + [extrinsics_mat[4], extrinsics_mat[5], extrinsics_mat[6]], + [extrinsics_mat[8], extrinsics_mat[9], extrinsics_mat[10]]], dtype=dtype) + + translation = torch.tensor([[extrinsics_mat[3], extrinsics_mat[7], extrinsics_mat[11]]], dtype=dtype) + + # t = -Rc --> c = -R^Tt + cam_center = [ -extrinsics_mat[0]*extrinsics_mat[3] - extrinsics_mat[4]*extrinsics_mat[7] - extrinsics_mat[8]*extrinsics_mat[11], + -extrinsics_mat[1]*extrinsics_mat[3] - extrinsics_mat[5]*extrinsics_mat[7] - extrinsics_mat[9]*extrinsics_mat[11], + -extrinsics_mat[2]*extrinsics_mat[3] - extrinsics_mat[6]*extrinsics_mat[7] - extrinsics_mat[10]*extrinsics_mat[11]] + + cam_center = torch.tensor([cam_center], dtype=dtype) + + return focal_length_x, focal_length_y, center, rotation, translation, cam_center + +@torch.no_grad() +def preprocess(dset, batch_size): + import pdb; pdb.set_trace() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--split', type=str, choices=['1', '2'], help='Data split') + parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split') + args = parser.parse_args() + + preprocess(args.split, args.batch_size) \ No newline at end of file diff --git a/lib/data_utils/threedpw_eval_utils.py b/lib/data_utils/threedpw_eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac981693c3a639fbbfd65af49d76b53c3e44c9e0 --- /dev/null +++ b/lib/data_utils/threedpw_eval_utils.py @@ -0,0 +1,185 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os.path as osp +from glob import glob +from collections import defaultdict + +import cv2 +import torch +import pickle +import joblib +import argparse +import numpy as np +from loguru import logger +from progress.bar import Bar + +from configs import constants as _C +from lib.models.smpl import SMPL +from lib.models.preproc.extractor import FeatureExtractor +from lib.models.preproc.backbone.utils import process_image +from lib.utils import transforms +from lib.utils.imutils import ( + flip_kp, flip_bbox +) + + +dataset = defaultdict(list) +detection_results_dir = 'dataset/detection_results/3DPW' +tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_dset_db.pt' + +@torch.no_grad() +def preprocess(dset, batch_size): + + if dset == 'val': _dset = 'validation' + else: _dset = dset + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_{dset}_vit.pth') # Use ViT feature extractor + extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size) + + tcmr_data = joblib.load(tcmr_annot_pth.replace('dset', dset)) + smpl_neutral = SMPL(model_path=_C.BMODEL.FLDR) + + annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True) + idxs = idxs.tolist() + annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)] + annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', _dset, annot_file[:-2] + '.pkl') for annot_file in annot_file_list] + annot_file_list = list(dict.fromkeys(annot_file_list)) + + for annot_file in annot_file_list: + seq = annot_file.split('/')[-1].split('.')[0] + + data = pickle.load(open(annot_file, 'rb'), encoding='latin1') + + num_people = len(data['poses']) + num_frames = len(data['img_frame_ids']) + assert (data['poses2d'][0].shape[0] == num_frames) + + K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float() + + for p_id in range(num_people): + + logger.info(f'==> {seq} {p_id}') + gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]] + + # ======== Add TCMR data ======== # + vid_name = f'{seq}_{p_id}' + tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v] + frame_ids = tcmr_data['frame_id'][tcmr_ids] + + pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids] + shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1) + pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float() # Camera coordinate + cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float() + + # ======== Get detection results ======== # + fname = f'{seq}_{p_id}.npy' + pred_kp2d = torch.from_numpy( + np.load(osp.join(detection_results_dir, fname)) + ).float()[frame_ids] + # ======== Get detection results ======== # + + img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg'))) + img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids] + img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2] + vid_idxs = fname.split('.')[0] + + # ======== Append data ======== # + dataset['gender'].append(gender) + dataset['vid'].append(vid_idxs) + dataset['pose'].append(pose) + dataset['betas'].append(shape) + dataset['cam_poses'].append(cam_poses) + dataset['frame_id'].append(torch.from_numpy(frame_ids)) + dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float()) + dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float()) + dataset['kp2d'].append(pred_kp2d) + + # Flipped data + dataset['flipped_bbox'].append( + torch.from_numpy(flip_bbox(dataset['bbox'][-1].clone().numpy(), res_w, res_h)).float() + ) + dataset['flipped_kp2d'].append( + torch.from_numpy(flip_kp(dataset['kp2d'][-1].clone().numpy(), res_w)).float() + ) + # ======== Append data ======== # + + # ======== Extract features ======== # + patch_list = [] + bboxes = dataset['bbox'][-1].clone().numpy() + bar = Bar(f'Load images', fill='#', max=len(img_paths)) + + for img_path, bbox in zip(img_paths, bboxes): + img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) + norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256) + patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float()) + bar.next() + + patch_list = torch.split(torch.cat(patch_list), batch_size) + features, flipped_features = [], [] + for i, patch in enumerate(patch_list): + feature = extractor.model(patch.cuda(), encode=True) + features.append(feature.cpu()) + + flipped_feature = extractor.model(torch.flip(patch, (3, )).cuda(), encode=True) + flipped_features.append(flipped_feature.cpu()) + + if i == 0: + init_patch = patch[[0]].clone() + + features = torch.cat(features) + flipped_features = torch.cat(flipped_features) + dataset['features'].append(features) + dataset['flipped_features'].append(flipped_features) + # ======== Extract features ======== # + + # Pad 1 frame + for key, val in dataset.items(): + if isinstance(val[-1], torch.Tensor): + dataset[key][-1] = torch.cat((val[-1][:1].clone(), val[-1][:]), dim=0) + + # Initial predictions + bbox = torch.from_numpy(bboxes[:1].copy()).float().cuda() + bbox_center = bbox[:, :2].clone() + bbox_scale = bbox[:, 2].clone() / 200 + kwargs = {'img_w': torch.tensor(res_w).repeat(1).float().cuda(), + 'img_h': torch.tensor(res_h).repeat(1).float().cuda(), + 'bbox_center': bbox_center, 'bbox_scale': bbox_scale} + + pred_global_orient, pred_pose, pred_shape, _ = extractor.model(init_patch.cuda(), **kwargs) + pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(), + body_pose=pred_pose.cpu(), + betas=pred_shape.cpu(), + pose2rot=False) + init_kp3d = pred_output.joints + init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1)) + + dataset['init_kp3d'].append(init_kp3d) + dataset['init_pose'].append(init_pose.cpu()) + + # Flipped initial predictions + bbox_center[:, 0] = res_w - bbox_center[:, 0] + pred_global_orient, pred_pose, pred_shape, _ = extractor.model(torch.flip(init_patch, (3, )).cuda(), **kwargs) + pred_output = smpl_neutral.get_output(global_orient=pred_global_orient.cpu(), + body_pose=pred_pose.cpu(), + betas=pred_shape.cpu(), + pose2rot=False) + init_kp3d = pred_output.joints + init_pose = transforms.matrix_to_axis_angle(torch.cat((pred_global_orient, pred_pose), dim=1)) + + dataset['flipped_init_kp3d'].append(init_kp3d) + dataset['flipped_init_pose'].append(init_pose.cpu()) + + joblib.dump(dataset, save_pth) + logger.info(f'\n ==> Done !') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--split', type=str, choices=['val', 'test'], help='Data split') + parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split') + args = parser.parse_args() + + preprocess(args.split, args.batch_size) \ No newline at end of file diff --git a/lib/data_utils/threedpw_train_utils.py b/lib/data_utils/threedpw_train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..637b0e22066f35d55ea57f7d4216e0b09e88db37 --- /dev/null +++ b/lib/data_utils/threedpw_train_utils.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import os.path as osp +from glob import glob +from collections import defaultdict + +import cv2 +import torch +import pickle +import joblib +import argparse +import numpy as np +from loguru import logger +from progress.bar import Bar + +from configs import constants as _C +from lib.models.smpl import SMPL +from lib.models.preproc.extractor import FeatureExtractor +from lib.models.preproc.backbone.utils import process_image + +dataset = defaultdict(list) +detection_results_dir = 'dataset/detection_results/3DPW' +tcmr_annot_pth = 'dataset/parsed_data/TCMR_preproc/3dpw_train_db.pt' + + +@torch.no_grad() +def preprocess(batch_size): + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + save_pth = osp.join(_C.PATHS.PARSED_DATA, f'3pdw_train_vit.pth') # Use ViT feature extractor + extractor = FeatureExtractor(device, flip_eval=True, max_batch_size=batch_size) + + tcmr_data = joblib.load(tcmr_annot_pth) + + annot_file_list, idxs = np.unique(tcmr_data['vid_name'], return_index=True) + idxs = idxs.tolist() + annot_file_list = [annot_file_list[idxs.index(idx)] for idx in sorted(idxs)] + annot_file_list = [osp.join(_C.PATHS.THREEDPW_PTH, 'sequenceFiles', 'train', annot_file[:-2] + '.pkl') for annot_file in annot_file_list] + annot_file_list = list(dict.fromkeys(annot_file_list)) + + vid_idx = 0 + for annot_file in annot_file_list: + seq = annot_file.split('/')[-1].split('.')[0] + + data = pickle.load(open(annot_file, 'rb'), encoding='latin1') + + num_people = len(data['poses']) + num_frames = len(data['img_frame_ids']) + assert (data['poses2d'][0].shape[0] == num_frames) + + K = torch.from_numpy(data['cam_intrinsics']).unsqueeze(0).float() + + for p_id in range(num_people): + + logger.info(f'==> {seq} {p_id}') + gender = {'m': 'male', 'f': 'female'}[data['genders'][p_id]] + smpl_gender = SMPL(model_path=_C.BMODEL.FLDR, gender=gender) + + # ======== Add TCMR data ======== # + vid_name = f'{seq}_{p_id}' + tcmr_ids = [i for i, v in enumerate(tcmr_data['vid_name']) if vid_name in v] + frame_ids = tcmr_data['frame_id'][tcmr_ids] + + pose = torch.from_numpy(data['poses'][p_id]).float()[frame_ids] + shape = torch.from_numpy(data['betas'][p_id][:10]).float().repeat(pose.size(0), 1) + trans = torch.from_numpy(data['trans'][p_id]).float()[frame_ids] + cam_poses = torch.from_numpy(data['cam_poses'][frame_ids]).float() + + # ======== Align the mesh params ======== # + Rc = cam_poses[:, :3, :3] + Tc = cam_poses[:, :3, 3] + org_output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3], transl=trans) + org_v0 = (org_output.vertices + org_output.offset.unsqueeze(1)).mean(1) + pose = torch.from_numpy(tcmr_data['pose'][tcmr_ids]).float() + + output = smpl_gender.get_output(betas=shape, body_pose=pose[:,3:], global_orient=pose[:,:3]) + v0 = (output.vertices + output.offset.unsqueeze(1)).mean(1) + trans = (Rc @ org_v0.reshape(-1, 3, 1)).reshape(-1, 3) + Tc - v0 + j3d = output.joints + (output.offset + trans).unsqueeze(1) + j2d = torch.div(j3d, j3d[..., 2:]) + kp2d = torch.matmul(K, j2d.transpose(-1, -2)).transpose(-1, -2)[..., :2] + # ======== Align the mesh params ======== # + + # ======== Get detection results ======== # + fname = f'{seq}_{p_id}.npy' + pred_kp2d = torch.from_numpy( + np.load(osp.join(detection_results_dir, fname)) + ).float()[frame_ids] + # ======== Get detection results ======== # + + img_paths = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', seq, '*.jpg'))) + img_paths = [img_path for i, img_path in enumerate(img_paths) if i in frame_ids] + img = cv2.imread(img_paths[0]); res_h, res_w = img.shape[:2] + vid_idxs = torch.from_numpy(np.array([vid_idx] * len(img_paths)).astype(int)) + vid_idx += 1 + + # ======== Append data ======== # + dataset['bbox'].append(torch.from_numpy(tcmr_data['bbox'][tcmr_ids].copy()).float()) + dataset['res'].append(torch.tensor([[res_w, res_h]]).repeat(len(frame_ids), 1).float()) + dataset['vid'].append(vid_idxs) + dataset['pose'].append(pose) + dataset['betas'].append(shape) + dataset['transl'].append(trans) + dataset['kp2d'].append(pred_kp2d) + dataset['joints3D'].append(j3d) + dataset['joints2D'].append(kp2d) + dataset['frame_id'].append(torch.from_numpy(frame_ids)) + dataset['cam_poses'].append(cam_poses) + dataset['gender'].append(torch.tensor([['male','female'].index(gender)]).repeat(len(frame_ids))) + # ======== Append data ======== # + + # ======== Extract features ======== # + patch_list = [] + bboxes = dataset['bbox'][-1].clone().numpy() + bar = Bar(f'Load images', fill='#', max=len(img_paths)) + + for img_path, bbox in zip(img_paths, bboxes): + img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) + norm_img, crop_img = process_image(img_rgb, bbox[:2], bbox[2] / 200, 256, 256) + patch_list.append(torch.from_numpy(norm_img).unsqueeze(0).float()) + bar.next() + + patch_list = torch.split(torch.cat(patch_list), batch_size) + features = [] + for i, patch in enumerate(patch_list): + pred = extractor.model(patch.cuda(), encode=True) + features.append(pred.cpu()) + + features = torch.cat(features) + dataset['features'].append(features) + # ======== Extract features ======== # + + for key in dataset.keys(): + dataset[key] = torch.cat(dataset[key]) + + joblib.dump(dataset, save_pth) + logger.info(f'\n ==> Done !') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-b', '--batch_size', type=int, default=128, help='Data split') + args = parser.parse_args() + + preprocess(args.batch_size) \ No newline at end of file diff --git a/lib/eval/eval_utils.py b/lib/eval/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..270491c09c942c3aa47dfc53224f050dea91353c --- /dev/null +++ b/lib/eval/eval_utils.py @@ -0,0 +1,482 @@ +# Some functions are borrowed from https://github.com/akanazawa/human_dynamics/blob/master/src/evaluation/eval_util.py +# Adhere to their licence to use these functions +from pathlib import Path + +import torch +import numpy as np +from matplotlib import pyplot as plt + + +def compute_accel(joints): + """ + Computes acceleration of 3D joints. + Args: + joints (Nx25x3). + Returns: + Accelerations (N-2). + """ + velocities = joints[1:] - joints[:-1] + acceleration = velocities[1:] - velocities[:-1] + acceleration_normed = np.linalg.norm(acceleration, axis=2) + return np.mean(acceleration_normed, axis=1) + + +def compute_error_accel(joints_gt, joints_pred, vis=None): + """ + Computes acceleration error: + 1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1} + Note that for each frame that is not visible, three entries in the + acceleration error should be zero'd out. + Args: + joints_gt (Nx14x3). + joints_pred (Nx14x3). + vis (N). + Returns: + error_accel (N-2). + """ + # (N-2)x14x3 + accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:] + accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:] + + normed = np.linalg.norm(accel_pred - accel_gt, axis=2) + + if vis is None: + new_vis = np.ones(len(normed), dtype=bool) + else: + invis = np.logical_not(vis) + invis1 = np.roll(invis, -1) + invis2 = np.roll(invis, -2) + new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2] + new_vis = np.logical_not(new_invis) + + return np.mean(normed[new_vis], axis=1) + + +def compute_error_verts(pred_verts, target_verts=None, target_theta=None): + """ + Computes MPJPE over 6890 surface vertices. + Args: + verts_gt (Nx6890x3). + verts_pred (Nx6890x3). + Returns: + error_verts (N). + """ + + if target_verts is None: + from lib.models.smpl import SMPL_MODEL_DIR + from lib.models.smpl import SMPL + device = 'cpu' + smpl = SMPL( + SMPL_MODEL_DIR, + batch_size=1, # target_theta.shape[0], + ).to(device) + + betas = torch.from_numpy(target_theta[:,75:]).to(device) + pose = torch.from_numpy(target_theta[:,3:75]).to(device) + + target_verts = [] + b_ = torch.split(betas, 5000) + p_ = torch.split(pose, 5000) + + for b,p in zip(b_,p_): + output = smpl(betas=b, body_pose=p[:, 3:], global_orient=p[:, :3], pose2rot=True) + target_verts.append(output.vertices.detach().cpu().numpy()) + + target_verts = np.concatenate(target_verts, axis=0) + + assert len(pred_verts) == len(target_verts) + error_per_vert = np.sqrt(np.sum((target_verts - pred_verts) ** 2, axis=2)) + return np.mean(error_per_vert, axis=1) + + +def compute_similarity_transform(S1, S2): + ''' + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + ''' + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.T + S2 = S2.T + transposed = True + assert(S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=1, keepdims=True) + mu2 = S2.mean(axis=1, keepdims=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = np.sum(X1**2) + + # 3. The outer product of X1 and X2. + K = X1.dot(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, Vh = np.linalg.svd(K) + V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = np.eye(U.shape[0]) + Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) + # Construct R. + R = V.dot(Z.dot(U.T)) + + # 5. Recover scale. + scale = np.trace(R.dot(K)) / var1 + + # 6. Recover translation. + t = mu2 - scale*(R.dot(mu1)) + + # 7. Error: + S1_hat = scale*R.dot(S1) + t + + if transposed: + S1_hat = S1_hat.T + + return S1_hat + + +def compute_similarity_transform_torch(S1, S2): + ''' + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + ''' + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.T + S2 = S2.T + transposed = True + assert (S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=1, keepdims=True) + mu2 = S2.mean(axis=1, keepdims=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # print('X1', X1.shape) + + # 2. Compute variance of X1 used for scale. + var1 = torch.sum(X1 ** 2) + + # print('var', var1.shape) + + # 3. The outer product of X1 and X2. + K = X1.mm(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, V = torch.svd(K) + # V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[0], device=S1.device) + Z[-1, -1] *= torch.sign(torch.det(U @ V.T)) + # Construct R. + R = V.mm(Z.mm(U.T)) + + # print('R', X1.shape) + + # 5. Recover scale. + scale = torch.trace(R.mm(K)) / var1 + # print(R.shape, mu1.shape) + # 6. Recover translation. + t = mu2 - scale * (R.mm(mu1)) + # print(t.shape) + + # 7. Error: + S1_hat = scale * R.mm(S1) + t + + if transposed: + S1_hat = S1_hat.T + + return S1_hat + + +def batch_compute_similarity_transform_torch(S1, S2): + ''' + Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + ''' + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.permute(0,2,1) + S2 = S2.permute(0,2,1) + transposed = True + assert(S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=-1, keepdims=True) + mu2 = S2.mean(axis=-1, keepdims=True) + + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = torch.sum(X1**2, dim=1).sum(dim=1) + + # 3. The outer product of X1 and X2. + K = X1.bmm(X2.permute(0,2,1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, V = torch.svd(K) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) + Z = Z.repeat(U.shape[0],1,1) + Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) + + # Construct R. + R = V.bmm(Z.bmm(U.permute(0,2,1))) + + # 5. Recover scale. + scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 + + # 6. Recover translation. + t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) + + # 7. Error: + S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t + + if transposed: + S1_hat = S1_hat.permute(0,2,1) + + return S1_hat + + +def align_by_pelvis(joints): + """ + Assumes joints is 14 x 3 in LSP order. + Then hips are: [3, 2] + Takes mid point of these points, then subtracts it. + """ + + left_id = 2 + right_id = 3 + + pelvis = (joints[left_id, :] + joints[right_id, :]) / 2.0 + return joints - np.expand_dims(pelvis, axis=0) + + +def compute_errors(gt3ds, preds): + """ + Gets MPJPE after pelvis alignment + MPJPE after Procrustes. + Evaluates on the 14 common joints. + Inputs: + - gt3ds: N x 14 x 3 + - preds: N x 14 x 3 + """ + errors, errors_pa = [], [] + for i, (gt3d, pred) in enumerate(zip(gt3ds, preds)): + gt3d = gt3d.reshape(-1, 3) + # Root align. + gt3d = align_by_pelvis(gt3d) + pred3d = align_by_pelvis(pred) + + joint_error = np.sqrt(np.sum((gt3d - pred3d)**2, axis=1)) + errors.append(np.mean(joint_error)) + + # Get PA error. + pred3d_sym = compute_similarity_transform(pred3d, gt3d) + pa_error = np.sqrt(np.sum((gt3d - pred3d_sym)**2, axis=1)) + errors_pa.append(np.mean(pa_error)) + + return errors, errors_pa + + +def batch_align_by_pelvis(data_list, pelvis_idxs): + """ + Assumes data is given as [pred_j3d, target_j3d, pred_verts, target_verts]. + Each data is in shape of (frames, num_points, 3) + Pelvis is notated as one / two joints indices. + Align all data to the corresponding pelvis location. + """ + + pred_j3d, target_j3d, pred_verts, target_verts = data_list + + pred_pelvis = pred_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() + target_pelvis = target_j3d[:, pelvis_idxs].mean(dim=1, keepdims=True).clone() + + # Align to the pelvis + pred_j3d = pred_j3d - pred_pelvis + target_j3d = target_j3d - target_pelvis + pred_verts = pred_verts - pred_pelvis + target_verts = target_verts - target_pelvis + + return (pred_j3d, target_j3d, pred_verts, target_verts) + +def compute_jpe(S1, S2): + return torch.sqrt(((S1 - S2) ** 2).sum(dim=-1)).mean(dim=-1).numpy() + + +# The functions below are borrowed from SLAHMR official implementation. +# Reference: https://github.com/vye16/slahmr/blob/main/slahmr/eval/tools.py +def global_align_joints(gt_joints, pred_joints): + """ + :param gt_joints (T, J, 3) + :param pred_joints (T, J, 3) + """ + s_glob, R_glob, t_glob = align_pcl( + gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3) + ) + pred_glob = ( + s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None] + ) + return pred_glob + + +def first_align_joints(gt_joints, pred_joints): + """ + align the first two frames + :param gt_joints (T, J, 3) + :param pred_joints (T, J, 3) + """ + # (1, 1), (1, 3, 3), (1, 3) + s_first, R_first, t_first = align_pcl( + gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3) + ) + pred_first = ( + s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None] + ) + return pred_first + + +def local_align_joints(gt_joints, pred_joints): + """ + :param gt_joints (T, J, 3) + :param pred_joints (T, J, 3) + """ + s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints) + pred_loc = ( + s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints) + + t_loc[:, None] + ) + return pred_loc + + +def align_pcl(Y, X, weight=None, fixed_scale=False): + """align similarity transform to align X with Y using umeyama method + X' = s * R * X + t is aligned with Y + :param Y (*, N, 3) first trajectory + :param X (*, N, 3) second trajectory + :param weight (*, N, 1) optional weight of valid correspondences + :returns s (*, 1), R (*, 3, 3), t (*, 3) + """ + *dims, N, _ = Y.shape + N = torch.ones(*dims, 1, 1) * N + + if weight is not None: + Y = Y * weight + X = X * weight + N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) + + # subtract mean + my = Y.sum(dim=-2) / N[..., 0] # (*, 3) + mx = X.sum(dim=-2) / N[..., 0] + y0 = Y - my[..., None, :] # (*, N, 3) + x0 = X - mx[..., None, :] + + if weight is not None: + y0 = y0 * weight + x0 = x0 * weight + + # correlation + C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) + U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) + + S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) + neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 + S[neg, 2, 2] = -1 + + R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) + + D = torch.diag_embed(D) # (*, 3, 3) + if fixed_scale: + s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32) + else: + var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) + s = ( + torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( + dim=-1, keepdim=True + ) + / var[..., 0] + ) # (*, 1) + + t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) + + return s, R, t + + +def compute_foot_sliding(target_output, pred_output, masks, thr=1e-2): + """compute foot sliding error + The foot ground contact label is computed by the threshold of 1 cm/frame + Args: + target_output (SMPL ModelOutput). + pred_output (SMPL ModelOutput). + masks (N). + Returns: + error (N frames in contact). + """ + + # Foot vertices idxs + foot_idxs = [3216, 3387, 6617, 6787] + + # Compute contact label + foot_loc = target_output.vertices[masks][:, foot_idxs] + foot_disp = (foot_loc[1:] - foot_loc[:-1]).norm(2, dim=-1) + contact = foot_disp[:] < thr + + pred_feet_loc = pred_output.vertices[:, foot_idxs] + pred_disp = (pred_feet_loc[1:] - pred_feet_loc[:-1]).norm(2, dim=-1) + + error = pred_disp[contact] + + return error.cpu().numpy() + + +def compute_jitter(pred_output, fps=30): + """compute jitter of the motion + Args: + pred_output (SMPL ModelOutput). + fps (float). + Returns: + jitter (N-3). + """ + + pred3d = pred_output.joints[:, :24] + + pred_jitter = torch.norm( + (pred3d[3:] - 3 * pred3d[2:-1] + 3 * pred3d[1:-2] - pred3d[:-3]) * (fps**3), + dim=2, + ).mean(dim=-1) + + return pred_jitter.cpu().numpy() / 10.0 + + +def compute_rte(target_trans, pred_trans): + # Compute the global alignment + _, rot, trans = align_pcl(target_trans[None, :], pred_trans[None, :], fixed_scale=True) + pred_trans_hat = ( + torch.einsum("tij,tnj->tni", rot, pred_trans[None, :]) + trans[None, :] + )[0] + + # Compute the entire displacement of ground truth trajectory + disps, disp = [], 0 + for p1, p2 in zip(target_trans, target_trans[1:]): + delta = (p2 - p1).norm(2, dim=-1) + disp += delta + disps.append(disp) + + # Compute absolute root-translation-error (RTE) + rte = torch.norm(target_trans - pred_trans_hat, 2, dim=-1) + + # Normalize it to the displacement + return (rte / disp).numpy() \ No newline at end of file diff --git a/lib/eval/evaluate_3dpw.py b/lib/eval/evaluate_3dpw.py new file mode 100644 index 0000000000000000000000000000000000000000..36736288f5c02d128d17fffc1598606ada009631 --- /dev/null +++ b/lib/eval/evaluate_3dpw.py @@ -0,0 +1,181 @@ +import os +import time +import os.path as osp +from glob import glob +from collections import defaultdict + +import torch +import imageio +import numpy as np +from smplx import SMPL +from loguru import logger +from progress.bar import Bar + +from configs import constants as _C +from configs.config import parse_args +from lib.data.dataloader import setup_eval_dataloader +from lib.models import build_network, build_body_model +from lib.eval.eval_utils import ( + compute_error_accel, + batch_align_by_pelvis, + batch_compute_similarity_transform_torch, +) +from lib.utils import transforms +from lib.utils.utils import prepare_output_dir +from lib.utils.utils import prepare_batch +from lib.utils.imutils import avg_preds + +try: + from lib.vis.renderer import Renderer + _render = True +except: + print("PyTorch3D is not properly installed! Cannot render the SMPL mesh") + _render = False + + +m2mm = 1e3 +@torch.no_grad() +def main(cfg, args): + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + logger.info(f'GPU name -> {torch.cuda.get_device_name()}') + logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') + + # ========= Dataloaders ========= # + eval_loader = setup_eval_dataloader(cfg, '3dpw', 'test', cfg.MODEL.BACKBONE) + logger.info(f'Dataset loaded') + + # ========= Load WHAM ========= # + smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN + smpl = build_body_model(cfg.DEVICE, smpl_batch_size) + network = build_network(cfg, smpl) + network.eval() + + # Build SMPL models with each gender + smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']} + + # Load vertices -> joints regression matrix to evaluate + J_regressor_eval = torch.from_numpy( + np.load(_C.BMODEL.JOINTS_REGRESSOR_H36M) + )[_C.KEYPOINTS.H36M_TO_J14, :].unsqueeze(0).float().to(cfg.DEVICE) + pelvis_idxs = [2, 3] + + accumulator = defaultdict(list) + bar = Bar('Inference', fill='#', max=len(eval_loader)) + with torch.no_grad(): + for i in range(len(eval_loader)): + # Original batch + batch = eval_loader.dataset.load_data(i, False) + x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2') + + if cfg.FLIP_EVAL: + flipped_batch = eval_loader.dataset.load_data(i, True) + f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2') + + # Forward pass with flipped input + flipped_pred = network(f_x, f_inits, f_features, **f_kwargs) + + # Forward pass with normal input + pred = network(x, inits, features, **kwargs) + + if cfg.FLIP_EVAL: + # Merge two predictions + flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0) + pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0) + flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6) + avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape) + avg_pose = avg_pose.reshape(-1, 144) + + # Refine trajectory with merged prediction + network.pred_pose = avg_pose.view_as(network.pred_pose) + network.pred_shape = avg_shape.view_as(network.pred_shape) + pred = network.forward_smpl(**kwargs) + + # <======= Build predicted SMPL + pred_output = smpl['neutral'](body_pose=pred['poses_body'], + global_orient=pred['poses_root_cam'], + betas=pred['betas'].squeeze(0), + pose2rot=False) + pred_verts = pred_output.vertices.cpu() + pred_j3d = torch.matmul(J_regressor_eval, pred_output.vertices).cpu() + # =======> + + # <======= Build groundtruth SMPL + target_output = smpl[batch['gender']]( + body_pose=transforms.rotation_6d_to_matrix(gt['pose'][0, :, 1:]), + global_orient=transforms.rotation_6d_to_matrix(gt['pose'][0, :, :1]), + betas=gt['betas'][0], + pose2rot=False) + target_verts = target_output.vertices.cpu() + target_j3d = torch.matmul(J_regressor_eval, target_output.vertices).cpu() + # =======> + + # <======= Compute performance of the current sequence + pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( + [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs + ) + S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d) + pa_mpjpe = torch.sqrt(((S1_hat - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm + mpjpe = torch.sqrt(((pred_j3d - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm + pve = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm + accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d)[1:-1] + accel = accel * (30 ** 2) # per frame^s to per s^2 + # =======> + + summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}' + bar.suffix = summary_string + bar.next() + + # <======= Accumulate the results over entire sequences + accumulator['pa_mpjpe'].append(pa_mpjpe) + accumulator['mpjpe'].append(mpjpe) + accumulator['pve'].append(pve) + accumulator['accel'].append(accel) + # =======> + + # <======= (Optional) Render the prediction + if not (_render and args.render): + # Skip if PyTorch3D is not installed or rendering argument is not parsed. + continue + + # Save path + viz_pth = osp.join('output', 'visualization') + os.makedirs(viz_pth, exist_ok=True) + + # Build Renderer + width, height = batch['cam_intrinsics'][0][0, :2, -1].numpy() * 2 + focal_length = batch['cam_intrinsics'][0][0, 0, 0].item() + renderer = Renderer(width, height, focal_length, cfg.DEVICE, smpl['neutral'].faces) + + # Get images and writer + frame_list = batch['frame_id'][0].numpy() + imname_list = sorted(glob(osp.join(_C.PATHS.THREEDPW_PTH, 'imageFiles', batch['vid'][:-2], '*.jpg'))) + writer = imageio.get_writer(osp.join(viz_pth, batch['vid'] + '.mp4'), + mode='I', format='FFMPEG', fps=30, macro_block_size=1) + + # Skip the invalid frames + for i, frame in enumerate(frame_list): + image = imageio.imread(imname_list[frame]) + + # NOTE: pred['verts'] is different from pred_verts as we substracted offset from SMPL mesh. + # Check line 121 in lib/models/smpl.py + vertices = pred['verts_cam'][i] + pred['trans_cam'][[i]] + image = renderer.render_mesh(vertices, image) + writer.append_data(image) + writer.close() + # =======> + + for k, v in accumulator.items(): + accumulator[k] = np.concatenate(v).mean() + + print('') + log_str = 'Evaluation on 3DPW, ' + log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()]) + logger.info(log_str) + +if __name__ == '__main__': + cfg, cfg_file, args = parse_args(test=True) + cfg = prepare_output_dir(cfg, cfg_file) + + main(cfg, args) \ No newline at end of file diff --git a/lib/eval/evaluate_emdb.py b/lib/eval/evaluate_emdb.py new file mode 100644 index 0000000000000000000000000000000000000000..09a0b7be239c6d13f03dd3e00a621aa9245e3da8 --- /dev/null +++ b/lib/eval/evaluate_emdb.py @@ -0,0 +1,228 @@ +import os +import time +import os.path as osp +from glob import glob +from collections import defaultdict + +import torch +import pickle +import numpy as np +from smplx import SMPL +from loguru import logger +from progress.bar import Bar + +from configs import constants as _C +from configs.config import parse_args +from lib.data.dataloader import setup_eval_dataloader +from lib.models import build_network, build_body_model +from lib.eval.eval_utils import ( + compute_jpe, + compute_rte, + compute_jitter, + compute_error_accel, + compute_foot_sliding, + batch_align_by_pelvis, + first_align_joints, + global_align_joints, + compute_rte, + compute_jitter, + compute_foot_sliding + batch_compute_similarity_transform_torch, +) +from lib.utils import transforms +from lib.utils.utils import prepare_output_dir +from lib.utils.utils import prepare_batch +from lib.utils.imutils import avg_preds + +""" +This is a tentative script to evaluate WHAM on EMDB dataset. +Current implementation requires EMDB dataset downloaded at ./datasets/EMDB/ +""" + +m2mm = 1e3 +@torch.no_grad() +def main(cfg, args): + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + logger.info(f'GPU name -> {torch.cuda.get_device_name()}') + logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') + + # ========= Dataloaders ========= # + eval_loader = setup_eval_dataloader(cfg, 'emdb', args.eval_split, cfg.MODEL.BACKBONE) + logger.info(f'Dataset loaded') + + # ========= Load WHAM ========= # + smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN + smpl = build_body_model(cfg.DEVICE, smpl_batch_size) + network = build_network(cfg, smpl) + network.eval() + + # Build SMPL models with each gender + smpl = {k: SMPL(_C.BMODEL.FLDR, gender=k).to(cfg.DEVICE) for k in ['male', 'female', 'neutral']} + + # Load vertices -> joints regression matrix to evaluate + pelvis_idxs = [1, 2] + + # WHAM uses Y-down coordinate system, while EMDB dataset uses Y-up one. + yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(cfg.DEVICE) + + # To torch tensor function + tt = lambda x: torch.from_numpy(x).float().to(cfg.DEVICE) + accumulator = defaultdict(list) + bar = Bar('Inference', fill='#', max=len(eval_loader)) + with torch.no_grad(): + for i in range(len(eval_loader)): + # Original batch + batch = eval_loader.dataset.load_data(i, False) + x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2') + + # Align with groundtruth data to the first frame + cam2yup = batch['R'][0][:1].to(cfg.DEVICE) + cam2ydown = cam2yup @ yup2ydown + cam2root = transforms.rotation_6d_to_matrix(inits[1][:, 0, 0]) + ydown2root = cam2ydown.mT @ cam2root + ydown2root = transforms.matrix_to_rotation_6d(ydown2root) + kwargs['init_root'][:, 0] = ydown2root + + if cfg.FLIP_EVAL: + flipped_batch = eval_loader.dataset.load_data(i, True) + f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE == 'stage2') + + # Forward pass with flipped input + flipped_pred = network(f_x, f_inits, f_features, **f_kwargs) + + # Forward pass with normal input + pred = network(x, inits, features, **kwargs) + + if cfg.FLIP_EVAL: + # Merge two predictions + flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0) + pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0) + flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6) + avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape) + avg_pose = avg_pose.reshape(-1, 144) + avg_contact = (flipped_pred['contact'][..., [2, 3, 0, 1]] + pred['contact']) / 2 + + # Refine trajectory with merged prediction + network.pred_pose = avg_pose.view_as(network.pred_pose) + network.pred_shape = avg_shape.view_as(network.pred_shape) + network.pred_contact = avg_contact.view_as(network.pred_contact) + output = network.forward_smpl(**kwargs) + pred = network.refine_trajectory(output, return_y_up=True, **kwargs) + + # <======= Prepare groundtruth data + subj, seq = batch['vid'][:2], batch['vid'][3:] + annot_pth = glob(osp.join(_C.PATHS.EMDB_PTH, subj, seq, '*_data.pkl'))[0] + annot = pickle.load(open(annot_pth, 'rb')) + + masks = annot['good_frames_mask'] + gender = annot['gender'] + poses_body = annot["smpl"]["poses_body"] + poses_root = annot["smpl"]["poses_root"] + betas = np.repeat(annot["smpl"]["betas"].reshape((1, -1)), repeats=annot["n_frames"], axis=0) + trans = annot["smpl"]["trans"] + extrinsics = annot["camera"]["extrinsics"] + + # # Map to camear coordinate + poses_root_cam = transforms.matrix_to_axis_angle(tt(extrinsics[:, :3, :3]) @ transforms.axis_angle_to_matrix(tt(poses_root))) + + # Groundtruth global motion + target_glob = smpl[gender](body_pose=tt(poses_body), global_orient=tt(poses_root), betas=tt(betas), transl=tt(trans)) + target_j3d_glob = target_glob.joints[:, :24][masks] + + # Groundtruth local motion + target_cam = smpl[gender](body_pose=tt(poses_body), global_orient=poses_root_cam, betas=tt(betas)) + target_verts_cam = target_cam.vertices[masks] + target_j3d_cam = target_cam.joints[:, :24][masks] + # =======> + + # Convert WHAM global orient to Y-up coordinate + poses_root = pred['poses_root_world'].squeeze(0) + pred_trans = pred['trans_world'].squeeze(0) + poses_root = yup2ydown.mT @ poses_root + pred_trans = (yup2ydown.mT @ pred_trans.unsqueeze(-1)).squeeze(-1) + + # <======= Build predicted motion + # Predicted global motion + pred_glob = smpl['neutral'](body_pose=pred['poses_body'], global_orient=poses_root.unsqueeze(1), betas=pred['betas'].squeeze(0), transl=pred_trans, pose2rot=False) + pred_j3d_glob = pred_glob.joints[:, :24] + + # Predicted local motion + pred_cam = smpl['neutral'](body_pose=pred['poses_body'], global_orient=pred['poses_root_cam'], betas=pred['betas'].squeeze(0), pose2rot=False) + pred_verts_cam = pred_cam.vertices + pred_j3d_cam = pred_cam.joints[:, :24] + # =======> + + # <======= Evaluation on the local motion + pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam = batch_align_by_pelvis( + [pred_j3d_cam, target_j3d_cam, pred_verts_cam, target_verts_cam], pelvis_idxs + ) + S1_hat = batch_compute_similarity_transform_torch(pred_j3d_cam, target_j3d_cam) + pa_mpjpe = torch.sqrt(((S1_hat - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm + mpjpe = torch.sqrt(((pred_j3d_cam - target_j3d_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm + pve = torch.sqrt(((pred_verts_cam - target_verts_cam) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * m2mm + accel = compute_error_accel(joints_pred=pred_j3d_cam.cpu(), joints_gt=target_j3d_cam.cpu())[1:-1] + accel = accel * (30 ** 2) # per frame^s to per s^2 + + summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}' + bar.suffix = summary_string + bar.next() + # =======> + + # <======= Evaluation on the global motion + chunk_length = 100 + w_mpjpe, wa_mpjpe = [], [] + for start in range(0, masks.sum(), chunk_length): + end = min(masks.sum(), start + chunk_length) + + target_j3d = target_j3d_glob[start:end].clone().cpu() + pred_j3d = pred_j3d_glob[start:end].clone().cpu() + + w_j3d = first_align_joints(target_j3d, pred_j3d) + wa_j3d = global_align_joints(target_j3d, pred_j3d) + + w_jpe = compute_jpe(target_j3d, w_j3d) + wa_jpe = compute_jpe(target_j3d, wa_j3d) + w_mpjpe.append(w_jpe) + wa_mpjpe.append(wa_jpe) + + w_mpjpe = np.concatenate(w_mpjpe) * m2mm + wa_mpjpe = np.concatenate(wa_mpjpe) * m2mm + + # Additional metrics + rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2 + jitter = compute_jitter(pred_glob, fps=30) + foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm + # =======> + + # Additional metrics + rte = compute_rte(torch.from_numpy(trans[masks]), pred_trans.cpu()) * 1e2 + jitter = compute_jitter(pred_glob, fps=30) + foot_sliding = compute_foot_sliding(target_glob, pred_glob, masks) * m2mm + + # <======= Accumulate the results over entire sequences + accumulator['pa_mpjpe'].append(pa_mpjpe) + accumulator['mpjpe'].append(mpjpe) + accumulator['pve'].append(pve) + accumulator['accel'].append(accel) + accumulator['wa_mpjpe'].append(wa_mpjpe) + accumulator['w_mpjpe'].append(w_mpjpe) + accumulator['RTE'].append(rte) + accumulator['jitter'].append(jitter) + accumulator['FS'].append(foot_sliding) + # =======> + + for k, v in accumulator.items(): + accumulator[k] = np.concatenate(v).mean() + + print('') + log_str = f'Evaluation on EMDB {args.eval_split}, ' + log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()]) + logger.info(log_str) + +if __name__ == '__main__': + cfg, cfg_file, args = parse_args(test=True) + cfg = prepare_output_dir(cfg, cfg_file) + + main(cfg, args) \ No newline at end of file diff --git a/lib/eval/evaluate_rich.py b/lib/eval/evaluate_rich.py new file mode 100644 index 0000000000000000000000000000000000000000..5504ca6862dd73b27a5e88420699862e507f9ad9 --- /dev/null +++ b/lib/eval/evaluate_rich.py @@ -0,0 +1,156 @@ +import os +import os.path as osp +from collections import defaultdict +from time import time + +import torch +import joblib +import numpy as np +from loguru import logger +from smplx import SMPL, SMPLX +from progress.bar import Bar + +from configs import constants as _C +from configs.config import parse_args +from lib.data.dataloader import setup_eval_dataloader +from lib.models import build_network, build_body_model +from lib.eval.eval_utils import ( + compute_error_accel, + batch_align_by_pelvis, + batch_compute_similarity_transform_torch, +) +from lib.utils import transforms +from lib.utils.utils import prepare_output_dir +from lib.utils.utils import prepare_batch +from lib.utils.imutils import avg_preds + +m2mm = 1e3 +smplx2smpl = torch.from_numpy(joblib.load(_C.BMODEL.SMPLX2SMPL)['matrix']).unsqueeze(0).float().cuda() +@torch.no_grad() +def main(cfg, args): + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + logger.info(f'GPU name -> {torch.cuda.get_device_name()}') + logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') + + # ========= Dataloaders ========= # + eval_loader = setup_eval_dataloader(cfg, 'rich', 'test', cfg.MODEL.BACKBONE) + logger.info(f'Dataset loaded') + + # ========= Load WHAM ========= # + smpl_batch_size = cfg.TRAIN.BATCH_SIZE * cfg.DATASET.SEQLEN + smpl = build_body_model(cfg.DEVICE, smpl_batch_size) + network = build_network(cfg, smpl) + network.eval() + + # Build neutral SMPL model for WHAM and gendered SMPLX models for the groundtruth data + smpl = SMPL(_C.BMODEL.FLDR, gender='neutral').to(cfg.DEVICE) + + # Load vertices -> joints regression matrix to evaluate + J_regressor_eval = smpl.J_regressor.clone().unsqueeze(0) + pelvis_idxs = [1, 2] + + accumulator = defaultdict(list) + bar = Bar('Inference', fill='#', max=len(eval_loader)) + with torch.no_grad(): + for i in range(len(eval_loader)): + time_dict = {} + _t = time() + + # Original batch + batch = eval_loader.dataset.load_data(i, False) + x, inits, features, kwargs, gt = prepare_batch(batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2') + + # <======= Inference + if cfg.FLIP_EVAL: + flipped_batch = eval_loader.dataset.load_data(i, True) + f_x, f_inits, f_features, f_kwargs, _ = prepare_batch(flipped_batch, cfg.DEVICE, cfg.TRAIN.STAGE=='stage2') + + # Forward pass with flipped input + flipped_pred = network(f_x, f_inits, f_features, **f_kwargs) + time_dict['inference_flipped'] = time() - _t; _t = time() + + # Forward pass with normal input + pred = network(x, inits, features, **kwargs) + time_dict['inference'] = time() - _t; _t = time() + + if cfg.FLIP_EVAL: + # Merge two predictions + flipped_pose, flipped_shape = flipped_pred['pose'].squeeze(0), flipped_pred['betas'].squeeze(0) + pose, shape = pred['pose'].squeeze(0), pred['betas'].squeeze(0) + flipped_pose, pose = flipped_pose.reshape(-1, 24, 6), pose.reshape(-1, 24, 6) + avg_pose, avg_shape = avg_preds(pose, shape, flipped_pose, flipped_shape) + avg_pose = avg_pose.reshape(-1, 144) + + # Refine trajectory with merged prediction + network.pred_pose = avg_pose.view_as(network.pred_pose) + network.pred_shape = avg_shape.view_as(network.pred_shape) + pred = network.forward_smpl(**kwargs) + time_dict['averaging'] = time() - _t; _t = time() + # =======> + + # <======= Build predicted SMPL + pred_output = smpl(body_pose=pred['poses_body'], + global_orient=pred['poses_root_cam'], + betas=pred['betas'].squeeze(0), + pose2rot=False) + pred_verts = pred_output.vertices.cpu() + pred_j3d = torch.matmul(J_regressor_eval, pred_output.vertices).cpu() + time_dict['building prediction'] = time() - _t; _t = time() + # =======> + + # <======= Build groundtruth SMPL (from SMPLX) + smplx = SMPLX(_C.BMODEL.FLDR.replace('smpl', 'smplx'), + gender=batch['gender'], + batch_size=len(pred_verts) + ).to(cfg.DEVICE) + gt_pose = transforms.matrix_to_axis_angle(transforms.rotation_6d_to_matrix(gt['pose'][0])) + target_output = smplx( + body_pose=gt_pose[:, 1:-2].reshape(-1, 63), + global_orient=gt_pose[:, 0], + betas=gt['betas'][0]) + target_verts = torch.matmul(smplx2smpl, target_output.vertices.cuda()).cpu() + target_j3d = torch.matmul(J_regressor_eval, target_verts.to(cfg.DEVICE)).cpu() + time_dict['building target'] = time() - _t; _t = time() + # =======> + + # <======= Compute performance of the current sequence + pred_j3d, target_j3d, pred_verts, target_verts = batch_align_by_pelvis( + [pred_j3d, target_j3d, pred_verts, target_verts], pelvis_idxs + ) + S1_hat = batch_compute_similarity_transform_torch(pred_j3d, target_j3d) + pa_mpjpe = torch.sqrt(((S1_hat - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm + mpjpe = torch.sqrt(((pred_j3d - target_j3d) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm + pve = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).numpy() * m2mm + accel = compute_error_accel(joints_pred=pred_j3d, joints_gt=target_j3d)[1:-1] + accel = accel * (30 ** 2) # per frame^s to per s^2 + time_dict['evaluating'] = time() - _t; _t = time() + # =======> + + # summary_string = f'{batch["vid"]} | PA-MPJPE: {pa_mpjpe.mean():.1f} MPJPE: {mpjpe.mean():.1f} PVE: {pve.mean():.1f}' + summary_string = f'{batch["vid"]} | ' + ' '.join([f'{k}: {v:.1f} s' for k, v in time_dict.items()]) + bar.suffix = summary_string + bar.next() + + # <======= Accumulate the results over entire sequences + accumulator['pa_mpjpe'].append(pa_mpjpe) + accumulator['mpjpe'].append(mpjpe) + accumulator['pve'].append(pve) + accumulator['accel'].append(accel) + + # =======> + + for k, v in accumulator.items(): + accumulator[k] = np.concatenate(v).mean() + + print('') + log_str = 'Evaluation on RICH, ' + log_str += ' '.join([f'{k.upper()}: {v:.4f},'for k,v in accumulator.items()]) + logger.info(log_str) + +if __name__ == '__main__': + cfg, cfg_file, args = parse_args(test=True) + cfg = prepare_output_dir(cfg, cfg_file) + + main(cfg, args) \ No newline at end of file diff --git a/lib/models/__init__.py b/lib/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d203b287a2024e1d1d5c0ccf9846f59ca87efe55 --- /dev/null +++ b/lib/models/__init__.py @@ -0,0 +1,40 @@ +import os, sys +import yaml +import torch +from loguru import logger + +from configs import constants as _C +from .smpl import SMPL + + +def build_body_model(device, batch_size=1, gender='neutral', **kwargs): + sys.stdout = open(os.devnull, 'w') + body_model = SMPL( + model_path=_C.BMODEL.FLDR, + gender=gender, + batch_size=batch_size, + create_transl=False).to(device) + sys.stdout = sys.__stdout__ + return body_model + + +def build_network(cfg, smpl): + from .wham import Network + + with open(cfg.MODEL_CONFIG, 'r') as f: + model_config = yaml.safe_load(f) + model_config.update({'d_feat': _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]}) + + network = Network(smpl, **model_config).to(cfg.DEVICE) + + # Load Checkpoint + if os.path.isfile(cfg.TRAIN.CHECKPOINT): + checkpoint = torch.load(cfg.TRAIN.CHECKPOINT) + ignore_keys = ['smpl.body_pose', 'smpl.betas', 'smpl.global_orient', 'smpl.J_regressor_extra', 'smpl.J_regressor_eval'] + model_state_dict = {k: v for k, v in checkpoint['model'].items() if k not in ignore_keys} + network.load_state_dict(model_state_dict, strict=False) + logger.info(f"=> loaded checkpoint '{cfg.TRAIN.CHECKPOINT}' ") + else: + logger.info(f"=> Warning! no checkpoint found at '{cfg.TRAIN.CHECKPOINT}'.") + + return network \ No newline at end of file diff --git a/lib/models/__pycache__/__init__.cpython-39.pyc b/lib/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6448bf47a6483675d1066ee310b79de26e5a1a65 Binary files /dev/null and b/lib/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/models/__pycache__/smpl.cpython-39.pyc b/lib/models/__pycache__/smpl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2795edbd166325aa0584f5a5bc89e048a22a3ace Binary files /dev/null and b/lib/models/__pycache__/smpl.cpython-39.pyc differ diff --git a/lib/models/__pycache__/wham.cpython-39.pyc b/lib/models/__pycache__/wham.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0825def69bd50c6e3b95bdf4577e0b2fd1f9b437 Binary files /dev/null and b/lib/models/__pycache__/wham.cpython-39.pyc differ diff --git a/lib/models/layers/__init__.py b/lib/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd2ab8be0f475741e9cec1d896930fd1c40021 --- /dev/null +++ b/lib/models/layers/__init__.py @@ -0,0 +1,2 @@ +from .modules import MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator +from .utils import rollout_global_motion, compute_camera_pose, reset_root_velocity, compute_camera_motion \ No newline at end of file diff --git a/lib/models/layers/__pycache__/__init__.cpython-39.pyc b/lib/models/layers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6793839e899cd00867dd79278b82420a3a74f9 Binary files /dev/null and b/lib/models/layers/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/models/layers/__pycache__/modules.cpython-39.pyc b/lib/models/layers/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c9734807e8a2080b4ff90378dc5c98ab8604423 Binary files /dev/null and b/lib/models/layers/__pycache__/modules.cpython-39.pyc differ