File size: 4,750 Bytes
f561f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch

from ..utils.normalizer import Normalizer
from ...models import build_body_model
from ...utils import transforms
from ...utils.kp_utils import root_centering
from ...utils.imutils import compute_cam_intrinsics

KEYPOINTS_THR = 0.3

def convert_dpvo_to_cam_angvel(traj, fps):
    """Function to convert DPVO trajectory output to camera angular velocity"""
    
    # 0 ~ 3: translation, 3 ~ 7: Quaternion
    quat = traj[:, 3:]
    
    # Convert (x,y,z,q) to (q,x,y,z)
    quat = quat[:, [3, 0, 1, 2]]
    
    # Quat is camera to world transformation. Convert it to world to camera
    world2cam = transforms.quaternion_to_matrix(torch.from_numpy(quat)).float()
    R = world2cam.mT
    
    # Compute the rotational changes over time.
    cam_angvel = transforms.matrix_to_axis_angle(R[:-1] @ R[1:].transpose(-1, -2))
    
    # Convert matrix to 6D representation
    cam_angvel = transforms.matrix_to_rotation_6d(transforms.axis_angle_to_matrix(cam_angvel))
    
    # Normalize 6D angular velocity
    cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
    cam_angvel = cam_angvel * fps
    cam_angvel = torch.cat((cam_angvel, cam_angvel[:1]), dim=0)
    return cam_angvel


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, cfg, tracking_results, slam_results, width, height, fps):
        
        self.tracking_results = tracking_results
        self.slam_results = slam_results
        self.width = width
        self.height = height
        self.fps = fps
        self.res = torch.tensor([width, height]).float()
        self.intrinsics = compute_cam_intrinsics(self.res)
        
        self.device = cfg.DEVICE.lower()
        
        self.smpl = build_body_model('cpu')
        self.keypoints_normalizer = Normalizer(cfg)
        
        self._to = lambda x: x.unsqueeze(0).to(self.device)
        
    def __len__(self):
        return len(self.tracking_results.keys())

    def load_data(self, index, flip=False):
        if flip:
            self.prefix = 'flipped_'
        else:
            self.prefix = ''
        
        return self.__getitem__(index)
    
    def __getitem__(self, _index):
        if _index >= len(self): return
        
        index = sorted(list(self.tracking_results.keys()))[_index]
            
        # Process 2D keypoints
        kp2d = torch.from_numpy(self.tracking_results[index][self.prefix + 'keypoints']).float()
        mask = kp2d[..., -1] < KEYPOINTS_THR
        bbox = torch.from_numpy(self.tracking_results[index][self.prefix + 'bbox']).float()
        
        norm_kp2d, _ = self.keypoints_normalizer(
            kp2d[..., :-1].clone(), self.res, self.intrinsics, 224, 224, bbox
        )
        
        # Process image features
        features = self.tracking_results[index][self.prefix + 'features']
        
        # Process initial pose
        init_output = self.smpl.get_output(
            global_orient=self.tracking_results[index][self.prefix + 'init_global_orient'],
            body_pose=self.tracking_results[index][self.prefix + 'init_body_pose'],
            betas=self.tracking_results[index][self.prefix + 'init_betas'],
            pose2rot=False,
            return_full_pose=True
        )
        init_kp3d = root_centering(init_output.joints[:, :17], 'coco')
        init_kp = torch.cat((init_kp3d.reshape(1, -1), norm_kp2d[0].clone().reshape(1, -1)), dim=-1)
        init_smpl = transforms.matrix_to_rotation_6d(init_output.full_pose)
        init_root = transforms.matrix_to_rotation_6d(init_output.global_orient)
        
        # Process SLAM results
        cam_angvel = convert_dpvo_to_cam_angvel(self.slam_results, self.fps)
        
        return (
            index,                                          # subject id
            self._to(norm_kp2d),                            # 2d keypoints
            (self._to(init_kp), self._to(init_smpl)),       # initial pose
            self._to(features),                             # image features
            self._to(mask),                                 # keypoints mask
            init_root.to(self.device),                      # initial root orientation
            self._to(cam_angvel),                           # camera angular velocity
            self.tracking_results[index]['frame_id'],       # frame indices
            {'cam_intrinsics': self._to(self.intrinsics),   # other keyword arguments
             'bbox': self._to(bbox),
             'res': self._to(self.res)},
            )