Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
import numpy as np | |
from lib.utils import transforms | |
def make_collate_fn(): | |
def collate_fn(items): | |
items = list(filter(lambda x: x is not None , items)) | |
batch = dict() | |
try: batch['vid'] = [item['vid'] for item in items] | |
except: pass | |
try: batch['gender'] = [item['gender'] for item in items] | |
except: pass | |
for key in items[0].keys(): | |
try: batch[key] = torch.stack([item[key] for item in items]) | |
except: pass | |
return batch | |
return collate_fn | |
def prepare_keypoints_data(target): | |
"""Prepare keypoints data""" | |
# Prepare 2D keypoints | |
target['init_kp2d'] = target['kp2d'][:1] | |
target['kp2d'] = target['kp2d'][1:] | |
if 'kp3d' in target: | |
target['kp3d'] = target['kp3d'][1:] | |
return target | |
def prepare_smpl_data(target): | |
if 'pose' in target.keys(): | |
# Use only the main joints | |
pose = target['pose'][:] | |
# 6-D Rotation representation | |
pose6d = transforms.matrix_to_rotation_6d(pose) | |
target['pose'] = pose6d[1:] | |
if 'betas' in target.keys(): | |
target['betas'] = target['betas'][1:] | |
# Translation and shape parameters | |
if 'transl' in target.keys(): | |
target['cam'] = target['transl'][1:] | |
# Initial pose and translation | |
target['init_pose'] = transforms.matrix_to_rotation_6d(target['init_pose']) | |
return target | |
def append_target(target, label, key_list, idx1, idx2=None, pad=True): | |
for key in key_list: | |
if idx2 is None: data = label[key][idx1] | |
else: data = label[key][idx1:idx2+1] | |
if not pad: data = data[2:] | |
target[key] = data | |
return target | |
def map_dmpl_to_smpl(pose): | |
""" Map AMASS DMPL pose representation to SMPL pose representation | |
Args: | |
pose - tensor / array with shape of (n_frames, 156) | |
Return: | |
pose - tensor / array with shape of (n_frames, 24, 3) | |
""" | |
pose = pose.reshape(pose.shape[0], -1, 3) | |
pose[:, 23] = pose[:, 37] # right hand | |
if isinstance(pose, np.ndarray): pose = pose[:, :24].copy() | |
else: pose = pose[:, :24].clone() | |
return pose | |
def transform_global_coordinate(pose, T, transl=None): | |
""" Transform global coordinate of dataset with respect to the given matrix. | |
Various datasets have different global coordinate system, | |
thus we united all datasets to the cronical coordinate system. | |
Args: | |
pose - SMPL pose; tensor / array | |
T - Transformation matrix | |
transl - SMPL translation | |
""" | |
return_to_numpy = False | |
if isinstance(pose, np.ndarray): | |
return_to_numpy = True | |
pose = torch.from_numpy(pose).float() | |
if transl is not None: transl = torch.from_numpy(transl).float() | |
pose = transforms.axis_angle_to_matrix(pose) | |
pose[:, 0] = T @ pose[:, 0] | |
pose = transforms.matrix_to_axis_angle(pose) | |
if transl is not None: | |
transl = (T @ transl.T).squeeze().T | |
if return_to_numpy: | |
pose = pose.detach().numpy() | |
if transl is not None: transl = transl.detach().numpy() | |
return pose, transl |