Spaces:
Sleeping
Sleeping
File size: 6,969 Bytes
f561f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import joblib
from lib.utils import transforms
from configs import constants as _C
from ..utils.augmentor import *
from .._dataset import BaseDataset
from ...models import build_body_model
from ...utils import data_utils as d_utils
from ...utils.kp_utils import root_centering
def compute_contact_label(feet, thr=1e-2, alpha=5):
vel = torch.zeros_like(feet[..., 0])
label = torch.zeros_like(feet[..., 0])
vel[1:-1] = (feet[2:] - feet[:-2]).norm(dim=-1) / 2.0
vel[0] = vel[1].clone()
vel[-1] = vel[-2].clone()
label = 1 / (1 + torch.exp(alpha * (thr ** -1) * (vel - thr)))
return label
class AMASSDataset(BaseDataset):
def __init__(self, cfg):
label_pth = _C.PATHS.AMASS_LABEL
super(AMASSDataset, self).__init__(cfg, training=True)
self.supervise_pose = cfg.TRAIN.STAGE == 'stage1'
self.labels = joblib.load(label_pth)
self.SequenceAugmentor = SequenceAugmentor(cfg.DATASET.SEQLEN + 1)
# Load augmentators
self.VideoAugmentor = VideoAugmentor(cfg)
self.SMPLAugmentor = SMPLAugmentor(cfg)
self.d_img_feature = _C.IMG_FEAT_DIM[cfg.MODEL.BACKBONE]
self.n_frames = int(cfg.DATASET.SEQLEN * self.SequenceAugmentor.l_factor) + 1
self.smpl = build_body_model('cpu', self.n_frames)
self.prepare_video_batch()
# Naive assumption of image intrinsics
self.img_w, self.img_h = 1000, 1000
self.get_naive_intrinsics((self.img_w, self.img_h))
self.CameraAugmentor = CameraAugmentor(cfg.DATASET.SEQLEN + 1, self.img_w, self.img_h, self.focal_length)
@property
def __name__(self, ):
return 'AMASS'
def get_input(self, target):
gt_kp3d = target['kp3d']
inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
mask = self.VideoAugmentor.get_mask()
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)
target['bbox'] = bbox[1:]
target['kp2d'] = kp2d
target['mask'] = mask[1:]
target['features'] = torch.zeros((self.SMPLAugmentor.n_frames, self.d_img_feature)).float()
return target
def get_groundtruth(self, target):
# GT 1. Joints
gt_kp3d = target['kp3d']
gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1]) * float(self.supervise_pose)), dim=-1)
target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1]) * float(self.supervise_pose)), dim=-1)[1:]
target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
target['verts'] = torch.zeros((self.SMPLAugmentor.n_frames, 6890, 3)).float()
# GT 2. Root pose
vel_world = (target['transl'][1:] - target['transl'][:-1])
pose_root = target['pose_root'].clone()
vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
target['vel_root'] = vel_root.clone()
target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
target['init_root'] = target['pose_root'][:1].clone()
# GT 3. Foot contact
contact = compute_contact_label(target['feet'])
if 'tread' in target['vid']:
target['contact'] = torch.ones_like(contact) * (-1)
else:
target['contact'] = contact
return target
def forward_smpl(self, target):
output = self.smpl.get_output(
body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
betas=target['betas'],
pose2rot=False)
target['transl'] = target['transl'] - output.offset
target['transl'] = target['transl'] - target['transl'][0]
target['kp3d'] = output.joints
target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
return target
def augment_data(self, target):
# Augmentation 1. SMPL params augmentation
target = self.SMPLAugmentor(target)
# Augmentation 2. Sequence speed augmentation
target = self.SequenceAugmentor(target)
# Get world-coordinate SMPL
target = self.forward_smpl(target)
# Augmentation 3. Virtual camera generation
target = self.CameraAugmentor(target)
return target
def load_amass(self, index, target):
start_index, end_index = self.video_indices[index]
# Load AMASS labels
pose = torch.from_numpy(self.labels['pose'][start_index:end_index+1].copy())
pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
transl = torch.from_numpy(self.labels['transl'][start_index:end_index+1].copy())
betas = torch.from_numpy(self.labels['betas'][start_index:end_index+1].copy())
# Stack GT
target.update({'vid': self.labels['vid'][start_index],
'pose': pose,
'transl': transl,
'betas': betas})
return target
def get_single_sequence(self, index):
target = {'res': torch.tensor([self.img_w, self.img_h]).float(),
'cam_intrinsics': self.cam_intrinsics.clone(),
'has_full_screen': torch.tensor(True),
'has_smpl': torch.tensor(self.supervise_pose),
'has_traj': torch.tensor(True),
'has_verts': torch.tensor(False),}
target = self.load_amass(index, target)
target = self.augment_data(target)
target = self.get_groundtruth(target)
target = self.get_input(target)
target = d_utils.prepare_keypoints_data(target)
target = d_utils.prepare_smpl_data(target)
return target
def perspective_projection(points, cam_intrinsics, rotation=None, translation=None):
K = cam_intrinsics
if rotation is not None:
points = torch.matmul(rotation, points.transpose(1, 2)).transpose(1, 2)
if translation is not None:
points = points + translation.unsqueeze(1)
projected_points = points / points[:, :, -1].unsqueeze(-1)
projected_points = torch.einsum('bij,bkj->bki', K, projected_points.float())
return projected_points[:, :, :-1] |