from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import joblib from .._dataset import BaseDataset from ..utils.augmentor import * from ...utils import data_utils as d_utils from ...utils import transforms from ...models import build_body_model from ...utils.kp_utils import convert_kps, root_centering class Dataset2D(BaseDataset): def __init__(self, cfg, fname, training): super(Dataset2D, self).__init__(cfg, training) self.epoch = 0 self.n_frames = cfg.DATASET.SEQLEN + 1 self.labels = joblib.load(fname) if self.training: self.prepare_video_batch() self.smpl = build_body_model('cpu', self.n_frames) self.SMPLAugmentor = SMPLAugmentor(cfg, False) def __getitem__(self, index): return self.get_single_sequence(index) def get_inputs(self, index, target, vis_thr=0.6): start_index, end_index = self.video_indices[index] # 2D keypoints detection kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone() kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], target['cam_intrinsics'], 224, 224, target['bbox']) target['bbox'] = bbox[1:] target['kp2d'] = kp2d # Detection mask target['mask'] = ~self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone().bool() # Image features target['features'] = self.labels['features'][start_index+1:end_index+1].clone() return target def get_labels(self, index, target): start_index, end_index = self.video_indices[index] # SMPL parameters # NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input. # We do not supervise the network on SMPL parameters. target['pose'] = transforms.axis_angle_to_matrix( self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3)) target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t # Apply SMPL augmentor (y-axis rotation and initial frame noise) target = self.SMPLAugmentor(target) # 2D keypoints kp2d = self.labels['kp2d'][start_index:end_index+1].clone().float()[..., :2] gt_kp2d = torch.zeros((self.n_frames - 1, 31, 2)) gt_kp2d[:, :17] = kp2d[1:].clone() # Set 0 confidence to the masked keypoints mask = torch.zeros((self.n_frames - 1, 31)) mask[:, :17] = self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone() mask = torch.logical_and(gt_kp2d.mean(-1) != 0, mask) gt_kp2d = torch.cat((gt_kp2d, mask.float().unsqueeze(-1)), dim=-1) _gt_kp2d = gt_kp2d.clone() for idx in range(len(_gt_kp2d)): _gt_kp2d[idx][..., :2] = torch.from_numpy( self.j2d_processing(gt_kp2d[idx][..., :2].numpy().copy(), target['bbox'][idx].numpy().copy())) target['weak_kp2d'] = _gt_kp2d.clone() target['full_kp2d'] = torch.zeros_like(gt_kp2d) target['kp3d'] = torch.zeros((kp2d.shape[0], 31, 4)) # No SMPL vertices available target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float() return target def get_init_frame(self, target): # Prepare initial frame output = self.smpl.get_output( body_pose=target['init_pose'][:, 1:], global_orient=target['init_pose'][:, :1], betas=target['betas'][:1], pose2rot=False ) target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1) return target def get_single_sequence(self, index): # Camera parameters res = (224.0, 224.0) bbox = torch.tensor([112.0, 112.0, 1.12]) res = torch.tensor(res) self.get_naive_intrinsics(res) bbox = bbox.repeat(self.n_frames, 1) # Universal target target = {'has_full_screen': torch.tensor(False), 'has_smpl': torch.tensor(self.has_smpl), 'has_traj': torch.tensor(self.has_traj), 'has_verts': torch.tensor(False), 'transl': torch.zeros((self.n_frames, 3)), # Camera parameters and bbox 'res': res, 'cam_intrinsics': self.cam_intrinsics, 'bbox': bbox, # Null camera motion 'R': torch.eye(3).repeat(self.n_frames, 1, 1), 'cam_angvel': torch.zeros((self.n_frames - 1, 6)), # Null root orientation and velocity 'pose_root': torch.zeros((self.n_frames, 6)), 'vel_root': torch.zeros((self.n_frames - 1, 3)), 'init_root': torch.zeros((1, 6)), # Null contact label 'contact': torch.ones((self.n_frames - 1, 4)) * (-1) } self.get_inputs(index, target) self.get_labels(index, target) self.get_init_frame(target) target = d_utils.prepare_keypoints_data(target) target = d_utils.prepare_smpl_data(target) return target