FormFighterAIStack / lib /utils /data_utils.py
Techt3o's picture
ca5705cc9c8581d916aca37e6759c44f0b1e70429e49ce83e658a0517cd3d6fe
c87d1bc verified
raw
history blame
3.37 kB
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import numpy as np
from lib.utils import transforms
def make_collate_fn():
def collate_fn(items):
items = list(filter(lambda x: x is not None , items))
batch = dict()
try: batch['vid'] = [item['vid'] for item in items]
except: pass
try: batch['gender'] = [item['gender'] for item in items]
except: pass
for key in items[0].keys():
try: batch[key] = torch.stack([item[key] for item in items])
except: pass
return batch
return collate_fn
def prepare_keypoints_data(target):
"""Prepare keypoints data"""
# Prepare 2D keypoints
target['init_kp2d'] = target['kp2d'][:1]
target['kp2d'] = target['kp2d'][1:]
if 'kp3d' in target:
target['kp3d'] = target['kp3d'][1:]
return target
def prepare_smpl_data(target):
if 'pose' in target.keys():
# Use only the main joints
pose = target['pose'][:]
# 6-D Rotation representation
pose6d = transforms.matrix_to_rotation_6d(pose)
target['pose'] = pose6d[1:]
if 'betas' in target.keys():
target['betas'] = target['betas'][1:]
# Translation and shape parameters
if 'transl' in target.keys():
target['cam'] = target['transl'][1:]
# Initial pose and translation
target['init_pose'] = transforms.matrix_to_rotation_6d(target['init_pose'])
return target
def append_target(target, label, key_list, idx1, idx2=None, pad=True):
for key in key_list:
if idx2 is None: data = label[key][idx1]
else: data = label[key][idx1:idx2+1]
if not pad: data = data[2:]
target[key] = data
return target
def map_dmpl_to_smpl(pose):
""" Map AMASS DMPL pose representation to SMPL pose representation
Args:
pose - tensor / array with shape of (n_frames, 156)
Return:
pose - tensor / array with shape of (n_frames, 24, 3)
"""
pose = pose.reshape(pose.shape[0], -1, 3)
pose[:, 23] = pose[:, 37] # right hand
if isinstance(pose, np.ndarray): pose = pose[:, :24].copy()
else: pose = pose[:, :24].clone()
return pose
def transform_global_coordinate(pose, T, transl=None):
""" Transform global coordinate of dataset with respect to the given matrix.
Various datasets have different global coordinate system,
thus we united all datasets to the cronical coordinate system.
Args:
pose - SMPL pose; tensor / array
T - Transformation matrix
transl - SMPL translation
"""
return_to_numpy = False
if isinstance(pose, np.ndarray):
return_to_numpy = True
pose = torch.from_numpy(pose).float()
if transl is not None: transl = torch.from_numpy(transl).float()
pose = transforms.axis_angle_to_matrix(pose)
pose[:, 0] = T @ pose[:, 0]
pose = transforms.matrix_to_axis_angle(pose)
if transl is not None:
transl = (T @ transl.T).squeeze().T
if return_to_numpy:
pose = pose.detach().numpy()
if transl is not None: transl = transl.detach().numpy()
return pose, transl