from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import joblib import numpy as np from .._dataset import BaseDataset from ..utils.augmentor import * from ...utils import data_utils as d_utils from ...utils import transforms from ...models import build_body_model from ...utils.kp_utils import convert_kps, root_centering class Dataset3D(BaseDataset): def __init__(self, cfg, fname, training): super(Dataset3D, self).__init__(cfg, training) self.epoch = 0 self.labels = joblib.load(fname) self.n_frames = cfg.DATASET.SEQLEN + 1 if self.training: self.prepare_video_batch() self.smpl = build_body_model('cpu', self.n_frames) self.SMPLAugmentor = SMPLAugmentor(cfg, False) self.VideoAugmentor = VideoAugmentor(cfg) def __getitem__(self, index): return self.get_single_sequence(index) def get_inputs(self, index, target, vis_thr=0.6): start_index, end_index = self.video_indices[index] # 2D keypoints detection kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone() bbox = self.labels['bbox'][start_index:end_index+1][..., [0, 1, -1]].clone() bbox[:, 2] = bbox[:, 2] / 200 kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox) target['bbox'] = bbox[1:] target['kp2d'] = kp2d target['mask'] = self.labels['kp2d'][start_index+1:end_index+1][..., -1] < vis_thr # Image features target['features'] = self.labels['features'][start_index+1:end_index+1].clone() return target def get_labels(self, index, target): start_index, end_index = self.video_indices[index] # SMPL parameters # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input. # We do not supervise the network on SMPL parameters. target['pose'] = transforms.axis_angle_to_matrix( self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3)) target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t # Apply SMPL augmentor (y-axis rotation and initial frame noise) target = self.SMPLAugmentor(target) # 3D and 2D keypoints if self.__name__ == 'ThreeDPW': # 3DPW has SMPL labels gt_kp3d = self.labels['joints3D'][start_index:end_index+1].clone() gt_kp2d = self.labels['joints2D'][start_index+1:end_index+1, ..., :2].clone() gt_kp3d = root_centering(gt_kp3d.clone()) else: # Human36m and MPII do not have SMPL labels gt_kp3d = torch.zeros((self.n_frames, self.n_joints + 14, 3)) gt_kp3d[:, self.n_joints:] = convert_kps(self.labels['joints3D'][start_index:end_index+1], 'spin', 'common') gt_kp2d = torch.zeros((self.n_frames - 1, self.n_joints + 14, 2)) gt_kp2d[:, self.n_joints:] = convert_kps(self.labels['joints2D'][start_index+1:end_index+1, ..., :2], 'spin', 'common') conf = self.mask.repeat(self.n_frames, 1).unsqueeze(-1) gt_kp2d = torch.cat((gt_kp2d, conf[1:]), dim=-1) gt_kp3d = torch.cat((gt_kp3d, conf), dim=-1) target['kp3d'] = gt_kp3d target['full_kp2d'] = gt_kp2d target['weak_kp2d'] = torch.zeros_like(gt_kp2d) if self.__name__ != 'ThreeDPW': # 3DPW does not contain world-coordinate motion # Foot ground contact labels for Human36M and MPII3D target['contact'] = self.labels['stationaries'][start_index+1:end_index+1].clone() else: # No foot ground contact label available for 3DPW target['contact'] = torch.ones((self.n_frames - 1, 4)) * (-1) if self.has_verts: # SMPL vertices available for 3DPW with torch.no_grad(): start_index, end_index = self.video_indices[index] gender = self.labels['gender'][start_index].item() output = self.smpl_gender[gender]( body_pose=target['pose'][1:, 1:], global_orient=target['pose'][1:, :1], betas=target['betas'][1:], pose2rot=False, ) target['verts'] = output.vertices.clone() else: # No SMPL vertices available target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float() return target def get_init_frame(self, target): # Prepare initial frame output = self.smpl.get_output( body_pose=target['init_pose'][:, 1:], global_orient=target['init_pose'][:, :1], betas=target['betas'][:1], pose2rot=False ) target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1) return target def get_camera_info(self, index, target): start_index, end_index = self.video_indices[index] # Intrinsics target['res'] = self.labels['res'][start_index:end_index+1][0].clone() self.get_naive_intrinsics(target['res']) target['cam_intrinsics'] = self.cam_intrinsics.clone() # Extrinsics pose R = self.labels['cam_poses'][start_index:end_index+1, :3, :3].clone().float() yaw = transforms.axis_angle_to_matrix(torch.tensor([[0, 2 * np.pi * np.random.uniform(), 0]])).float() if self.__name__ == 'Human36M': # Map Z-up to Y-down coordinate zup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[-np.pi/2, 0, 0]])).float() zup2ydown = torch.matmul(yaw, zup2ydown) R = torch.matmul(R, zup2ydown) elif self.__name__ == 'MPII3D': # Map Y-up to Y-down coordinate yup2ydown = transforms.axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float() yup2ydown = torch.matmul(yaw, yup2ydown) R = torch.matmul(R, yup2ydown) return target def get_single_sequence(self, index): # Universal target target = {'has_full_screen': torch.tensor(True), 'has_smpl': torch.tensor(self.has_smpl), 'has_traj': torch.tensor(self.has_traj), 'has_verts': torch.tensor(self.has_verts), 'transl': torch.zeros((self.n_frames, 3)), # Null camera motion 'R': torch.eye(3).repeat(self.n_frames, 1, 1), 'cam_angvel': torch.zeros((self.n_frames - 1, 6)), # Null root orientation and velocity 'pose_root': torch.zeros((self.n_frames, 6)), 'vel_root': torch.zeros((self.n_frames - 1, 3)), 'init_root': torch.zeros((1, 6)), } self.get_camera_info(index, target) self.get_inputs(index, target) self.get_labels(index, target) self.get_init_frame(target) target = d_utils.prepare_keypoints_data(target) target = d_utils.prepare_smpl_data(target) return target