Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
import joblib | |
from .._dataset import BaseDataset | |
from ..utils.augmentor import * | |
from ...utils import data_utils as d_utils | |
from ...utils import transforms | |
from ...models import build_body_model | |
from ...utils.kp_utils import convert_kps, root_centering | |
class Dataset2D(BaseDataset): | |
def __init__(self, cfg, fname, training): | |
super(Dataset2D, self).__init__(cfg, training) | |
self.epoch = 0 | |
self.n_frames = cfg.DATASET.SEQLEN + 1 | |
self.labels = joblib.load(fname) | |
if self.training: | |
self.prepare_video_batch() | |
self.smpl = build_body_model('cpu', self.n_frames) | |
self.SMPLAugmentor = SMPLAugmentor(cfg, False) | |
def __getitem__(self, index): | |
return self.get_single_sequence(index) | |
def get_inputs(self, index, target, vis_thr=0.6): | |
start_index, end_index = self.video_indices[index] | |
# 2D keypoints detection | |
kp2d = self.labels['kp2d'][start_index:end_index+1][..., :2].clone() | |
kp2d, bbox = self.keypoints_normalizer(kp2d, target['res'], target['cam_intrinsics'], 224, 224, target['bbox']) | |
target['bbox'] = bbox[1:] | |
target['kp2d'] = kp2d | |
# Detection mask | |
target['mask'] = ~self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone().bool() | |
# Image features | |
target['features'] = self.labels['features'][start_index+1:end_index+1].clone() | |
return target | |
def get_labels(self, index, target): | |
start_index, end_index = self.video_indices[index] | |
# SMPL parameters | |
# NOTE: We use NeuralAnnot labels for Human36m and MPII3D only for the 0th frame input. | |
# We do not supervise the network on SMPL parameters. | |
target['pose'] = transforms.axis_angle_to_matrix( | |
self.labels['pose'][start_index:end_index+1].clone().reshape(-1, 24, 3)) | |
target['betas'] = self.labels['betas'][start_index:end_index+1].clone() # No t | |
# Apply SMPL augmentor (y-axis rotation and initial frame noise) | |
target = self.SMPLAugmentor(target) | |
# 2D keypoints | |
kp2d = self.labels['kp2d'][start_index:end_index+1].clone().float()[..., :2] | |
gt_kp2d = torch.zeros((self.n_frames - 1, 31, 2)) | |
gt_kp2d[:, :17] = kp2d[1:].clone() | |
# Set 0 confidence to the masked keypoints | |
mask = torch.zeros((self.n_frames - 1, 31)) | |
mask[:, :17] = self.labels['joints2D'][start_index+1:end_index+1][..., -1].clone() | |
mask = torch.logical_and(gt_kp2d.mean(-1) != 0, mask) | |
gt_kp2d = torch.cat((gt_kp2d, mask.float().unsqueeze(-1)), dim=-1) | |
_gt_kp2d = gt_kp2d.clone() | |
for idx in range(len(_gt_kp2d)): | |
_gt_kp2d[idx][..., :2] = torch.from_numpy( | |
self.j2d_processing(gt_kp2d[idx][..., :2].numpy().copy(), | |
target['bbox'][idx].numpy().copy())) | |
target['weak_kp2d'] = _gt_kp2d.clone() | |
target['full_kp2d'] = torch.zeros_like(gt_kp2d) | |
target['kp3d'] = torch.zeros((kp2d.shape[0], 31, 4)) | |
# No SMPL vertices available | |
target['verts'] = torch.zeros((self.n_frames - 1, 6890, 3)).float() | |
return target | |
def get_init_frame(self, target): | |
# Prepare initial frame | |
output = self.smpl.get_output( | |
body_pose=target['init_pose'][:, 1:], | |
global_orient=target['init_pose'][:, :1], | |
betas=target['betas'][:1], | |
pose2rot=False | |
) | |
target['init_kp3d'] = root_centering(output.joints[:1, :self.n_joints]).reshape(1, -1) | |
return target | |
def get_single_sequence(self, index): | |
# Camera parameters | |
res = (224.0, 224.0) | |
bbox = torch.tensor([112.0, 112.0, 1.12]) | |
res = torch.tensor(res) | |
self.get_naive_intrinsics(res) | |
bbox = bbox.repeat(self.n_frames, 1) | |
# Universal target | |
target = {'has_full_screen': torch.tensor(False), | |
'has_smpl': torch.tensor(self.has_smpl), | |
'has_traj': torch.tensor(self.has_traj), | |
'has_verts': torch.tensor(False), | |
'transl': torch.zeros((self.n_frames, 3)), | |
# Camera parameters and bbox | |
'res': res, | |
'cam_intrinsics': self.cam_intrinsics, | |
'bbox': bbox, | |
# Null camera motion | |
'R': torch.eye(3).repeat(self.n_frames, 1, 1), | |
'cam_angvel': torch.zeros((self.n_frames - 1, 6)), | |
# Null root orientation and velocity | |
'pose_root': torch.zeros((self.n_frames, 6)), | |
'vel_root': torch.zeros((self.n_frames - 1, 3)), | |
'init_root': torch.zeros((1, 6)), | |
# Null contact label | |
'contact': torch.ones((self.n_frames - 1, 4)) * (-1) | |
} | |
self.get_inputs(index, target) | |
self.get_labels(index, target) | |
self.get_init_frame(target) | |
target = d_utils.prepare_keypoints_data(target) | |
target = d_utils.prepare_smpl_data(target) | |
return target |