File size: 4,344 Bytes
f561f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
import torch
import joblib

from configs import constants as _C
from .._dataset import BaseDataset
from ...utils import transforms
from ...utils import data_utils as d_utils
from ...utils.kp_utils import root_centering

FPS = 30
class EvalDataset(BaseDataset):
    def __init__(self, cfg, data, split, backbone):
        super(EvalDataset, self).__init__(cfg, False)
        
        self.prefix = ''
        self.data = data
        parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth')
        self.labels = joblib.load(parsed_data_path)

    def load_data(self, index, flip=False):
        if flip:
            self.prefix = 'flipped_'
        else:
            self.prefix = ''
        
        target = self.__getitem__(index)
        for key, val in target.items():
            if isinstance(val, torch.Tensor):
                target[key] = val.unsqueeze(0)
        return target

    def __getitem__(self, index):
        target = {}
        target = self.get_data(index)
        target = d_utils.prepare_keypoints_data(target)
        target = d_utils.prepare_smpl_data(target)

        return target

    def __len__(self):
        return len(self.labels['kp2d'])

    def prepare_labels(self, index, target):
        # Ground truth SMPL parameters
        target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3))
        target['betas'] = self.labels['betas'][index]
        target['gender'] = self.labels['gender'][index]
        
        # Sequence information
        target['res'] = self.labels['res'][index][0]
        target['vid'] = self.labels['vid'][index]
        target['frame_id'] = self.labels['frame_id'][index][1:]
        
        # Camera information
        self.get_naive_intrinsics(target['res'])
        target['cam_intrinsics'] = self.cam_intrinsics
        R = self.labels['cam_poses'][index][:, :3, :3].clone()
        if 'emdb' in self.data.lower():
            # Use groundtruth camera angular velocity.
            # Can be updated with SLAM results if you have it.
            cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
            cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS
            target['R'] = R
        else:
            cam_angvel = torch.zeros((len(target['pose']) - 1, 6))
        target['cam_angvel'] = cam_angvel
        return target

    def prepare_inputs(self, index, target):
        for key in ['features', 'bbox']:
            data = self.labels[self.prefix + key][index][1:]
            target[key] = data
        
        bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float()
        bbox[:, 2] = bbox[:, 2] / 200
        
        # Normalize keypoints
        kp2d, bbox = self.keypoints_normalizer(
            self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(), 
            target['res'], target['cam_intrinsics'], 224, 224, bbox)
        target['kp2d'] = kp2d
        target['bbox'] = bbox[1:]
        
        # Masking out low confident keypoints
        mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3
        target['input_kp2d'] = self.labels['kp2d'][index][1:]
        target['input_kp2d'][mask[1:]] *= 0
        target['mask'] = mask[1:]
        
        return target

    def prepare_initialization(self, index, target):
        # Initial frame per-frame estimation
        target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1)
        target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu()
        pose_root = target['pose'][:, 0].clone()
        target['init_root'] = transforms.matrix_to_rotation_6d(pose_root)
        
        return target
        
    def get_data(self, index):
        target = {}
        
        target = self.prepare_labels(index, target)
        target = self.prepare_inputs(index, target)
        target = self.prepare_initialization(index, target)
        
        return target