Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
6.77 kB
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import joblib
from lib.utils import transforms
from configs import constants as _C
from .amass import compute_contact_label, perspective_projection
from ..utils.augmentor import *
from .._dataset import BaseDataset
from ...models import build_body_model
from ...utils import data_utils as d_utils
from ...utils.kp_utils import root_centering
class BEDLAMDataset(BaseDataset):
def __init__(self, cfg):
label_pth = _C.PATHS.BEDLAM_LABEL.replace('backbone', cfg.MODEL.BACKBONE)
super(BEDLAMDataset, self).__init__(cfg, training=True)
self.labels = joblib.load(label_pth)
self.VideoAugmentor = VideoAugmentor(cfg)
self.SMPLAugmentor = SMPLAugmentor(cfg, False)
self.smpl = build_body_model('cpu', self.n_frames)
self.prepare_video_batch()
@property
def __name__(self, ):
return 'BEDLAM'
def get_inputs(self, index, target, vis_thr=0.6):
start_index, end_index = self.video_indices[index]
bbox = self.labels['bbox'][start_index:end_index+1].clone()
bbox[:, 2] = bbox[:, 2] / 200
gt_kp3d = target['kp3d']
inpt_kp3d = self.VideoAugmentor(gt_kp3d[:, :self.n_joints, :-1].clone())
# kp2d = perspective_projection(inpt_kp3d, target['K'])
kp2d = perspective_projection(inpt_kp3d, self.cam_intrinsics)
mask = self.VideoAugmentor.get_mask()
# kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224, bbox)
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], self.cam_intrinsics, 224, 224)
target['bbox'] = bbox[1:]
target['kp2d'] = kp2d
target['mask'] = mask[1:]
# Image features
target['features'] = self.labels['features'][start_index+1:end_index+1].clone()
return target
def get_groundtruth(self, index, target):
start_index, end_index = self.video_indices[index]
# GT 1. Joints
gt_kp3d = target['kp3d']
# gt_kp2d = perspective_projection(gt_kp3d, target['K'])
gt_kp2d = perspective_projection(gt_kp3d, self.cam_intrinsics)
target['kp3d'] = torch.cat((gt_kp3d, torch.ones_like(gt_kp3d[..., :1])), dim=-1)
# target['full_kp2d'] = torch.cat((gt_kp2d, torch.zeros_like(gt_kp2d[..., :1])), dim=-1)[1:]
target['full_kp2d'] = torch.cat((gt_kp2d, torch.ones_like(gt_kp2d[..., :1])), dim=-1)[1:]
target['weak_kp2d'] = torch.zeros_like(target['full_kp2d'])
target['init_kp3d'] = root_centering(gt_kp3d[:1, :self.n_joints].clone()).reshape(1, -1)
# GT 2. Root pose
w_transl = self.labels['w_trans'][start_index:end_index+1]
pose_root = transforms.axis_angle_to_matrix(self.labels['root'][start_index:end_index+1])
vel_world = (w_transl[1:] - w_transl[:-1])
vel_root = (pose_root[:-1].transpose(-1, -2) @ vel_world.unsqueeze(-1)).squeeze(-1)
target['vel_root'] = vel_root.clone()
target['pose_root'] = transforms.matrix_to_rotation_6d(pose_root)
target['init_root'] = target['pose_root'][:1].clone()
return target
def forward_smpl(self, target):
output = self.smpl.get_output(
body_pose=torch.cat((target['init_pose'][:, 1:], target['pose'][1:, 1:])),
global_orient=torch.cat((target['init_pose'][:, :1], target['pose'][1:, :1])),
betas=target['betas'],
transl=target['transl'],
pose2rot=False)
target['kp3d'] = output.joints + output.offset.unsqueeze(1)
target['feet'] = output.feet[1:] + target['transl'][1:].unsqueeze(-2)
target['verts'] = output.vertices[1:, ].clone()
return target
def augment_data(self, target):
# Augmentation 1. SMPL params augmentation
target = self.SMPLAugmentor(target)
# Get world-coordinate SMPL
target = self.forward_smpl(target)
return target
def load_camera(self, index, target):
start_index, end_index = self.video_indices[index]
# Get camera info
extrinsics = self.labels['extrinsics'][start_index:end_index+1].clone()
R = extrinsics[:, :3, :3]
T = extrinsics[:, :3, -1]
K = self.labels['intrinsics'][start_index:end_index+1].clone()
width, height = K[0, 0, 2] * 2, K[0, 1, 2] * 2
target['R'] = R
target['res'] = torch.tensor([width, height]).float()
# Compute angular velocity
cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
cam_angvel = cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel) # Normalize
target['cam_angvel'] = cam_angvel * 3e1 # BEDLAM is 30-fps
target['K'] = K # Use GT camera intrinsics for projecting keypoints
self.get_naive_intrinsics(target['res'])
target['cam_intrinsics'] = self.cam_intrinsics
return target
def load_params(self, index, target):
start_index, end_index = self.video_indices[index]
# Load AMASS labels
pose = self.labels['pose'][start_index:end_index+1].clone()
pose = transforms.axis_angle_to_matrix(pose.reshape(-1, 24, 3))
transl = self.labels['c_trans'][start_index:end_index+1].clone()
betas = self.labels['betas'][start_index:end_index+1, :10].clone()
# Stack GT
target.update({'vid': self.labels['vid'][start_index].clone(),
'pose': pose,
'transl': transl,
'betas': betas})
return target
def get_single_sequence(self, index):
target = {'has_full_screen': torch.tensor(True),
'has_smpl': torch.tensor(True),
'has_traj': torch.tensor(False),
'has_verts': torch.tensor(True),
# Null contact label
'contact': torch.ones((self.n_frames - 1, 4)) * (-1),
}
target = self.load_params(index, target)
target = self.load_camera(index, target)
target = self.augment_data(target)
target = self.get_groundtruth(index, target)
target = self.get_inputs(index, target)
target = d_utils.prepare_keypoints_data(target)
target = d_utils.prepare_smpl_data(target)
return target