from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch from ..utils.normalizer import Normalizer from ...models import build_body_model from ...utils import transforms from ...utils.kp_utils import root_centering from ...utils.imutils import compute_cam_intrinsics KEYPOINTS_THR = 0.3 def convert_dpvo_to_cam_angvel(traj, fps): """Function to convert DPVO trajectory output to camera angular velocity""" # 0 ~ 3: translation, 3 ~ 7: Quaternion quat = traj[:, 3:] # Convert (x,y,z,q) to (q,x,y,z) quat = quat[:, [3, 0, 1, 2]] # Quat is camera to world transformation. Convert it to world to camera world2cam = transforms.quaternion_to_matrix(torch.from_numpy(quat)).float() R = world2cam.mT # Compute the rotational changes over time. cam_angvel = transforms.matrix_to_axis_angle(R[:-1] @ R[1:].transpose(-1, -2)) # Convert matrix to 6D representation cam_angvel = transforms.matrix_to_rotation_6d(transforms.axis_angle_to_matrix(cam_angvel)) # Normalize 6D angular velocity cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize cam_angvel = cam_angvel * fps cam_angvel = torch.cat((cam_angvel, cam_angvel[:1]), dim=0) return cam_angvel class CustomDataset(torch.utils.data.Dataset): def __init__(self, cfg, tracking_results, slam_results, width, height, fps): self.tracking_results = tracking_results self.slam_results = slam_results self.width = width self.height = height self.fps = fps self.res = torch.tensor([width, height]).float() self.intrinsics = compute_cam_intrinsics(self.res) self.device = cfg.DEVICE.lower() self.smpl = build_body_model('cpu') self.keypoints_normalizer = Normalizer(cfg) self._to = lambda x: x.unsqueeze(0).to(self.device) def __len__(self): return len(self.tracking_results.keys()) def load_data(self, index, flip=False): if flip: self.prefix = 'flipped_' else: self.prefix = '' return self.__getitem__(index) def __getitem__(self, _index): if _index >= len(self): return index = sorted(list(self.tracking_results.keys()))[_index] # Process 2D keypoints kp2d = torch.from_numpy(self.tracking_results[index][self.prefix + 'keypoints']).float() mask = kp2d[..., -1] < KEYPOINTS_THR bbox = torch.from_numpy(self.tracking_results[index][self.prefix + 'bbox']).float() norm_kp2d, _ = self.keypoints_normalizer( kp2d[..., :-1].clone(), self.res, self.intrinsics, 224, 224, bbox ) # Process image features features = self.tracking_results[index][self.prefix + 'features'] # Process initial pose init_output = self.smpl.get_output( global_orient=self.tracking_results[index][self.prefix + 'init_global_orient'], body_pose=self.tracking_results[index][self.prefix + 'init_body_pose'], betas=self.tracking_results[index][self.prefix + 'init_betas'], pose2rot=False, return_full_pose=True ) init_kp3d = root_centering(init_output.joints[:, :17], 'coco') init_kp = torch.cat((init_kp3d.reshape(1, -1), norm_kp2d[0].clone().reshape(1, -1)), dim=-1) init_smpl = transforms.matrix_to_rotation_6d(init_output.full_pose) init_root = transforms.matrix_to_rotation_6d(init_output.global_orient) # Process SLAM results cam_angvel = convert_dpvo_to_cam_angvel(self.slam_results, self.fps) return ( index, # subject id self._to(norm_kp2d), # 2d keypoints (self._to(init_kp), self._to(init_smpl)), # initial pose self._to(features), # image features self._to(mask), # keypoints mask init_root.to(self.device), # initial root orientation self._to(cam_angvel), # camera angular velocity self.tracking_results[index]['frame_id'], # frame indices {'cam_intrinsics': self._to(self.intrinsics), # other keyword arguments 'bbox': self._to(bbox), 'res': self._to(self.res)}, )