from __future__ import absolute_import from __future__ import print_function from __future__ import division import os import torch import joblib from configs import constants as _C from .._dataset import BaseDataset from ...utils import transforms from ...utils import data_utils as d_utils from ...utils.kp_utils import root_centering FPS = 30 class EvalDataset(BaseDataset): def __init__(self, cfg, data, split, backbone): super(EvalDataset, self).__init__(cfg, False) self.prefix = '' self.data = data parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth') self.labels = joblib.load(parsed_data_path) def load_data(self, index, flip=False): if flip: self.prefix = 'flipped_' else: self.prefix = '' target = self.__getitem__(index) for key, val in target.items(): if isinstance(val, torch.Tensor): target[key] = val.unsqueeze(0) return target def __getitem__(self, index): target = {} target = self.get_data(index) target = d_utils.prepare_keypoints_data(target) target = d_utils.prepare_smpl_data(target) return target def __len__(self): return len(self.labels['kp2d']) def prepare_labels(self, index, target): # Ground truth SMPL parameters target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3)) target['betas'] = self.labels['betas'][index] target['gender'] = self.labels['gender'][index] # Sequence information target['res'] = self.labels['res'][index][0] target['vid'] = self.labels['vid'][index] target['frame_id'] = self.labels['frame_id'][index][1:] # Camera information self.get_naive_intrinsics(target['res']) target['cam_intrinsics'] = self.cam_intrinsics R = self.labels['cam_poses'][index][:, :3, :3].clone() if 'emdb' in self.data.lower(): # Use groundtruth camera angular velocity. # Can be updated with SLAM results if you have it. cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2)) cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS target['R'] = R else: cam_angvel = torch.zeros((len(target['pose']) - 1, 6)) target['cam_angvel'] = cam_angvel return target def prepare_inputs(self, index, target): for key in ['features', 'bbox']: data = self.labels[self.prefix + key][index][1:] target[key] = data bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float() bbox[:, 2] = bbox[:, 2] / 200 # Normalize keypoints kp2d, bbox = self.keypoints_normalizer( self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(), target['res'], target['cam_intrinsics'], 224, 224, bbox) target['kp2d'] = kp2d target['bbox'] = bbox[1:] # Masking out low confident keypoints mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3 target['input_kp2d'] = self.labels['kp2d'][index][1:] target['input_kp2d'][mask[1:]] *= 0 target['mask'] = mask[1:] return target def prepare_initialization(self, index, target): # Initial frame per-frame estimation target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1) target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu() pose_root = target['pose'][:, 0].clone() target['init_root'] = transforms.matrix_to_rotation_6d(pose_root) return target def get_data(self, index): target = {} target = self.prepare_labels(index, target) target = self.prepare_inputs(index, target) target = self.prepare_initialization(index, target) return target