File size: 6,773 Bytes
f561f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import joblib
from lib.utils import transforms

from configs import constants as _C

from .amass import compute_contact_label, perspective_projection
from ..utils.augmentor import *
from .._dataset import BaseDataset
from ...models import build_body_model
from ...utils import data_utils as d_utils
from ...utils.kp_utils import root_centering

class BEDLAMDataset(BaseDataset):
    def __init__(self, cfg):
        label_pth = _C.PATHS.BEDLAM_LABEL.replace('backbone', cfg.MODEL.BACKBONE)
        super(BEDLAMDataset, self).__init__(cfg, training=True)

        self.labels = joblib.load(label_pth)
        
        self.VideoAugmentor = VideoAugmentor(cfg)
        self.SMPLAugmentor = SMPLAugmentor(cfg, False)
        
        self.smpl = build_body_model('cpu', self.n_frames)
        self.prepare_video_batch()

    @property
    def __name__(self, ):
        return 'BEDLAM'

    def get_inputs(self, index, target, vis_thr=0.6):
        start_index, end_index = self.video_indices[index]
        
        bbox = self.labels['bbox'][start_index:end_index+1].clone()
        bbox[:, 2] = bbox[:, 2] / 200
        
        gt_kp3d = target['kp3d']
        inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
        # kp2d = perspective_projection(inpt_kp3d, target['K'])
        kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
        mask = self.VideoAugmentor.get_mask()
        # kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox)
        kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)

        target['bbox'] = bbox[1:]
        target['kp2d'] = kp2d
        target['mask'] = mask[1:]
        
        # Image features
        target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
        
        return target

    def get_groundtruth(self, index, target):
        start_index, end_index = self.video_indices[index]

        # GT 1. Joints
        gt_kp3d = target['kp3d']
        # gt_kp2d = perspective_projection(gt_kp3d, target['K'])
        gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
        target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1])), dim=-1)
        # target['full_kp2d'] = torch.cat((gt_kp2d, torch.zeros_like(gt_kp2d[..., :1])), dim=-1)[1:]
        target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1])), dim=-1)[1:]
        target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
        target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
        
        # GT 2. Root pose
        w_transl = self.labels['w_trans'][start_index:end_index+1]
        pose_root = transforms.axis_angle_to_matrix(self.labels['root'][start_index:end_index+1])
        vel_world = (w_transl[1:] - w_transl[:-1])
        vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
        target['vel_root'] = vel_root.clone()
        target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
        target['init_root'] = target['pose_root'][:1].clone()

        return target

    def forward_smpl(self, target):
        output = self.smpl.get_output(
            body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
            global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
            betas=target['betas'],
            transl=target['transl'],
            pose2rot=False)
        
        target['kp3d'] = output.joints + output.offset.unsqueeze(1)
        target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
        target['verts'] = output.vertices[1:, ].clone()
        
        return target

    def augment_data(self, target):
        # Augmentation 1. SMPL params augmentation
        target = self.SMPLAugmentor(target)
        
        # Get world-coordinate SMPL
        target = self.forward_smpl(target)
        
        return target

    def load_camera(self, index, target):
        start_index, end_index = self.video_indices[index]

        # Get camera info
        extrinsics = self.labels['extrinsics'][start_index:end_index+1].clone()
        R = extrinsics[:, :3, :3]
        T = extrinsics[:, :3, -1]
        K = self.labels['intrinsics'][start_index:end_index+1].clone()
        width, height = K[0, 0, 2] * 2, K[0, 1, 2] * 2
        target['R'] = R
        target['res'] = torch.tensor([width, height]).float()
        
        # Compute angular velocity
        cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
        cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
        target['cam_angvel'] = cam_angvel * 3e1 # BEDLAM is 30-fps
        
        target['K'] = K # Use GT camera intrinsics for projecting keypoints 
        self.get_naive_intrinsics(target['res'])
        target['cam_intrinsics'] = self.cam_intrinsics

        return target

    def load_params(self, index, target):
        start_index, end_index = self.video_indices[index]
        
        # Load AMASS labels
        pose = self.labels['pose'][start_index:end_index+1].clone()
        pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
        transl = self.labels['c_trans'][start_index:end_index+1].clone()
        betas = self.labels['betas'][start_index:end_index+1, :10].clone()
        
        # Stack GT
        target.update({'vid': self.labels['vid'][start_index].clone(), 
                       'pose': pose, 
                       'transl': transl, 
                       'betas': betas})
        
        return target


    def get_single_sequence(self, index):
        target = {'has_full_screen': torch.tensor(True),
                  'has_smpl': torch.tensor(True),
                  'has_traj': torch.tensor(False),
                  'has_verts': torch.tensor(True),
                  
                  # Null contact label
                  'contact': torch.ones((self.n_frames - 1, 4)) * (-1),
                  }
        
        target = self.load_params(index, target)
        target = self.load_camera(index, target)
        target = self.augment_data(target)
        target = self.get_groundtruth(index, target)
        target = self.get_inputs(index, target)
        
        target = d_utils.prepare_keypoints_data(target)
        target = d_utils.prepare_smpl_data(target)

        return target