Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
4.34 kB
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import torch
import joblib
from configs import constants as _C
from .._dataset import BaseDataset
from ...utils import transforms
from ...utils import data_utils as d_utils
from ...utils.kp_utils import root_centering
FPS = 30
class EvalDataset(BaseDataset):
def __init__(self, cfg, data, split, backbone):
super(EvalDataset, self).__init__(cfg, False)
self.prefix = ''
self.data = data
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'{data}_{split}_{backbone}.pth')
self.labels = joblib.load(parsed_data_path)
def load_data(self, index, flip=False):
if flip:
self.prefix = 'flipped_'
else:
self.prefix = ''
target = self.__getitem__(index)
for key, val in target.items():
if isinstance(val, torch.Tensor):
target[key] = val.unsqueeze(0)
return target
def __getitem__(self, index):
target = {}
target = self.get_data(index)
target = d_utils.prepare_keypoints_data(target)
target = d_utils.prepare_smpl_data(target)
return target
def __len__(self):
return len(self.labels['kp2d'])
def prepare_labels(self, index, target):
# Ground truth SMPL parameters
target['pose'] = transforms.axis_angle_to_matrix(self.labels['pose'][index].reshape(-1, 24, 3))
target['betas'] = self.labels['betas'][index]
target['gender'] = self.labels['gender'][index]
# Sequence information
target['res'] = self.labels['res'][index][0]
target['vid'] = self.labels['vid'][index]
target['frame_id'] = self.labels['frame_id'][index][1:]
# Camera information
self.get_naive_intrinsics(target['res'])
target['cam_intrinsics'] = self.cam_intrinsics
R = self.labels['cam_poses'][index][:, :3, :3].clone()
if 'emdb' in self.data.lower():
# Use groundtruth camera angular velocity.
# Can be updated with SLAM results if you have it.
cam_angvel = transforms.matrix_to_rotation_6d(R[:-1] @ R[1:].transpose(-1, -2))
cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]]).to(cam_angvel)) * FPS
target['R'] = R
else:
cam_angvel = torch.zeros((len(target['pose']) - 1, 6))
target['cam_angvel'] = cam_angvel
return target
def prepare_inputs(self, index, target):
for key in ['features', 'bbox']:
data = self.labels[self.prefix + key][index][1:]
target[key] = data
bbox = self.labels[self.prefix + 'bbox'][index][..., [0, 1, -1]].clone().float()
bbox[:, 2] = bbox[:, 2] / 200
# Normalize keypoints
kp2d, bbox = self.keypoints_normalizer(
self.labels[self.prefix + 'kp2d'][index][..., :2].clone().float(),
target['res'], target['cam_intrinsics'], 224, 224, bbox)
target['kp2d'] = kp2d
target['bbox'] = bbox[1:]
# Masking out low confident keypoints
mask = self.labels[self.prefix + 'kp2d'][index][..., -1] < 0.3
target['input_kp2d'] = self.labels['kp2d'][index][1:]
target['input_kp2d'][mask[1:]] *= 0
target['mask'] = mask[1:]
return target
def prepare_initialization(self, index, target):
# Initial frame per-frame estimation
target['init_kp3d'] = root_centering(self.labels[self.prefix + 'init_kp3d'][index][:1, :self.n_joints]).reshape(1, -1)
target['init_pose'] = transforms.axis_angle_to_matrix(self.labels[self.prefix + 'init_pose'][index][:1]).cpu()
pose_root = target['pose'][:, 0].clone()
target['init_root'] = transforms.matrix_to_rotation_6d(pose_root)
return target
def get_data(self, index):
target = {}
target = self.prepare_labels(index, target)
target = self.prepare_inputs(index, target)
target = self.prepare_initialization(index, target)
return target