Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import os | |
import torch | |
from configs import constants as _C | |
from .dataset3d import Dataset3D | |
from .dataset2d import Dataset2D | |
from ...utils.kp_utils import convert_kps | |
from smplx import SMPL | |
class Human36M(Dataset3D): | |
def __init__(self, cfg, dset='train'): | |
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'human36m_{dset}_backbone.pth') | |
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) | |
super(Human36M, self).__init__(cfg, parsed_data_path, dset=='train') | |
self.has_3d = True | |
self.has_traj = True | |
self.has_smpl = False | |
self.has_verts = False | |
# Among 31 joints format, 14 common joints are avaialable | |
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) | |
self.mask[-14:] = 1 | |
def __name__(self, ): | |
return 'Human36M' | |
def compute_3d_keypoints(self, index): | |
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m' | |
)[:, _C.KEYPOINTS.H36M_TO_J14].float() | |
class MPII3D(Dataset3D): | |
def __init__(self, cfg, dset='train'): | |
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'mpii3d_{dset}_backbone.pth') | |
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) | |
super(MPII3D, self).__init__(cfg, parsed_data_path, dset=='train') | |
self.has_3d = True | |
self.has_traj = True | |
self.has_smpl = False | |
self.has_verts = False | |
# Among 31 joints format, 14 common joints are avaialable | |
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) | |
self.mask[-14:] = 1 | |
def __name__(self, ): | |
return 'MPII3D' | |
def compute_3d_keypoints(self, index): | |
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m' | |
)[:, _C.KEYPOINTS.H36M_TO_J17].float() | |
class ThreeDPW(Dataset3D): | |
def __init__(self, cfg, dset='train'): | |
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'3dpw_{dset}_backbone.pth') | |
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) | |
super(ThreeDPW, self).__init__(cfg, parsed_data_path, dset=='train') | |
self.has_3d = True | |
self.has_traj = False | |
self.has_smpl = True | |
self.has_verts = True # In testing | |
# Among 31 joints format, 14 common joints are avaialable | |
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) | |
self.mask[:-14] = 1 | |
self.smpl_gender = { | |
0: SMPL(_C.BMODEL.FLDR, gender='male', num_betas=10), | |
1: SMPL(_C.BMODEL.FLDR, gender='female', num_betas=10) | |
} | |
def __name__(self, ): | |
return 'ThreeDPW' | |
def compute_3d_keypoints(self, index): | |
return self.labels['joints3D'][index] | |
class InstaVariety(Dataset2D): | |
def __init__(self, cfg, dset='train'): | |
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'insta_{dset}_backbone.pth') | |
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower()) | |
super(InstaVariety, self).__init__(cfg, parsed_data_path, dset=='train') | |
self.has_3d = False | |
self.has_traj = False | |
self.has_smpl = False | |
# Among 31 joints format, 17 coco joints are avaialable | |
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14) | |
self.mask[:17] = 1 | |
def __name__(self, ): | |
return 'InstaVariety' |