Techt3o's picture
e01e49325338173592071b44501d91416d6a9072d1040c9d9f5aecf816533bec
f561f8b verified
raw
history blame
3.77 kB
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import torch
from configs import constants as _C
from .dataset3d import Dataset3D
from .dataset2d import Dataset2D
from ...utils.kp_utils import convert_kps
from smplx import SMPL
class Human36M(Dataset3D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'human36m_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(Human36M, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = True
self.has_traj = True
self.has_smpl = False
self.has_verts = False
# Among 31 joints format, 14 common joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[-14:] = 1
@property
def __name__(self, ):
return 'Human36M'
def compute_3d_keypoints(self, index):
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
)[:, _C.KEYPOINTS.H36M_TO_J14].float()
class MPII3D(Dataset3D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'mpii3d_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(MPII3D, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = True
self.has_traj = True
self.has_smpl = False
self.has_verts = False
# Among 31 joints format, 14 common joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[-14:] = 1
@property
def __name__(self, ):
return 'MPII3D'
def compute_3d_keypoints(self, index):
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
)[:, _C.KEYPOINTS.H36M_TO_J17].float()
class ThreeDPW(Dataset3D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'3dpw_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(ThreeDPW, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = True
self.has_traj = False
self.has_smpl = True
self.has_verts = True # In testing
# Among 31 joints format, 14 common joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[:-14] = 1
self.smpl_gender = {
0: SMPL(_C.BMODEL.FLDR, gender='male', num_betas=10),
1: SMPL(_C.BMODEL.FLDR, gender='female', num_betas=10)
}
@property
def __name__(self, ):
return 'ThreeDPW'
def compute_3d_keypoints(self, index):
return self.labels['joints3D'][index]
class InstaVariety(Dataset2D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'insta_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(InstaVariety, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = False
self.has_traj = False
self.has_smpl = False
# Among 31 joints format, 17 coco joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[:17] = 1
@property
def __name__(self, ):
return 'InstaVariety'