Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import os | |
import torch | |
import joblib | |
from configs import constants as _C | |
from .._dataset import BaseDataset | |
from ...utils import transforms | |
from ...utils import data_utils as d_utils | |
from ...utils.kp_utils import root_centering | |
FPS = 30 | |
class EvalDataset(BaseDataset): | |
def __init__(self, cfg, data, split, backbone): | |
super(EvalDataset, self).__init__(cfg, False) | |
self.prefix = '' | |
self.data = data | |
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth') | |
self.labels = joblib.load(parsed_data_path) | |
def load_data(self, index, flip=False): | |
if flip: | |
self.prefix = 'flipped_' | |
else: | |
self.prefix = '' | |
target = self.__getitem__(index) | |
for key, val in target.items(): | |
if isinstance(val, torch.Tensor): | |
target[key] = val.unsqueeze(0) | |
return target | |
def __getitem__(self, index): | |
target = {} | |
target = self.get_data(index) | |
target = d_utils.prepare_keypoints_data(target) | |
target = d_utils.prepare_smpl_data(target) | |
return target | |
def __len__(self): | |
return len(self.labels['kp2d']) | |
def prepare_labels(self, index, target): | |
# Ground truth SMPL parameters | |
target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3)) | |
target['betas'] = self.labels['betas'][index] | |
target['gender'] = self.labels['gender'][index] | |
# Sequence information | |
target['res'] = self.labels['res'][index][0] | |
target['vid'] = self.labels['vid'][index] | |
target['frame_id'] = self.labels['frame_id'][index][1:] | |
# Camera information | |
self.get_naive_intrinsics(target['res']) | |
target['cam_intrinsics'] = self.cam_intrinsics | |
R = self.labels['cam_poses'][index][:, :3, :3].clone() | |
if 'emdb' in self.data.lower(): | |
# Use groundtruth camera angular velocity. | |
# Can be updated with SLAM results if you have it. | |
cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2)) | |
cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS | |
target['R'] = R | |
else: | |
cam_angvel = torch.zeros((len(target['pose']) - 1, 6)) | |
target['cam_angvel'] = cam_angvel | |
return target | |
def prepare_inputs(self, index, target): | |
for key in ['features', 'bbox']: | |
data = self.labels[self.prefix + key][index][1:] | |
target[key] = data | |
bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float() | |
bbox[:, 2] = bbox[:, 2] / 200 | |
# Normalize keypoints | |
kp2d, bbox = self.keypoints_normalizer( | |
self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(), | |
target['res'], target['cam_intrinsics'], 224, 224, bbox) | |
target['kp2d'] = kp2d | |
target['bbox'] = bbox[1:] | |
# Masking out low confident keypoints | |
mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3 | |
target['input_kp2d'] = self.labels['kp2d'][index][1:] | |
target['input_kp2d'][mask[1:]] *= 0 | |
target['mask'] = mask[1:] | |
return target | |
def prepare_initialization(self, index, target): | |
# Initial frame per-frame estimation | |
target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1) | |
target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu() | |
pose_root = target['pose'][:, 0].clone() | |
target['init_root'] = transforms.matrix_to_rotation_6d(pose_root) | |
return target | |
def get_data(self, index): | |
target = {} | |
target = self.prepare_labels(index, target) | |
target = self.prepare_inputs(index, target) | |
target = self.prepare_initialization(index, target) | |
return target |