Spaces:
Sleeping
Sleeping
File size: 3,765 Bytes
f561f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import torch
from configs import constants as _C
from .dataset3d import Dataset3D
from .dataset2d import Dataset2D
from ...utils.kp_utils import convert_kps
from smplx import SMPL
class Human36M(Dataset3D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'human36m_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(Human36M, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = True
self.has_traj = True
self.has_smpl = False
self.has_verts = False
# Among 31 joints format, 14 common joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[-14:] = 1
@property
def __name__(self, ):
return 'Human36M'
def compute_3d_keypoints(self, index):
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
)[:, _C.KEYPOINTS.H36M_TO_J14].float()
class MPII3D(Dataset3D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'mpii3d_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(MPII3D, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = True
self.has_traj = True
self.has_smpl = False
self.has_verts = False
# Among 31 joints format, 14 common joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[-14:] = 1
@property
def __name__(self, ):
return 'MPII3D'
def compute_3d_keypoints(self, index):
return convert_kps(self.labels['joints3D'][index], 'spin', 'h36m'
)[:, _C.KEYPOINTS.H36M_TO_J17].float()
class ThreeDPW(Dataset3D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'3dpw_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(ThreeDPW, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = True
self.has_traj = False
self.has_smpl = True
self.has_verts = True # In testing
# Among 31 joints format, 14 common joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[:-14] = 1
self.smpl_gender = {
0: SMPL(_C.BMODEL.FLDR, gender='male', num_betas=10),
1: SMPL(_C.BMODEL.FLDR, gender='female', num_betas=10)
}
@property
def __name__(self, ):
return 'ThreeDPW'
def compute_3d_keypoints(self, index):
return self.labels['joints3D'][index]
class InstaVariety(Dataset2D):
def __init__(self, cfg, dset='train'):
parsed_data_path = os.path.join(_C.PATHS.PARSED_DATA, f'insta_{dset}_backbone.pth')
parsed_data_path = parsed_data_path.replace('backbone', cfg.MODEL.BACKBONE.lower())
super(InstaVariety, self).__init__(cfg, parsed_data_path, dset=='train')
self.has_3d = False
self.has_traj = False
self.has_smpl = False
# Among 31 joints format, 17 coco joints are avaialable
self.mask = torch.zeros(_C.KEYPOINTS.NUM_JOINTS + 14)
self.mask[:17] = 1
@property
def __name__(self, ):
return 'InstaVariety' |