Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import torch | |
import numpy as np | |
from .amass import AMASSDataset | |
from .videos import Human36M, ThreeDPW, MPII3D, InstaVariety | |
from .bedlam import BEDLAMDataset | |
from lib.utils.data_utils import make_collate_fn | |
class DataFactory(torch.utils.data.Dataset): | |
def __init__(self, cfg, train_stage='syn'): | |
super(DataFactory, self).__init__() | |
if train_stage == 'stage1': | |
self.datasets = [AMASSDataset(cfg)] | |
self.dataset_names = ['AMASS'] | |
elif train_stage == 'stage2': | |
self.datasets = [ | |
AMASSDataset(cfg), ThreeDPW(cfg), | |
Human36M(cfg), MPII3D(cfg), InstaVariety(cfg) | |
] | |
self.dataset_names = ['AMASS', '3DPW', 'Human36M', 'MPII3D', 'Insta'] | |
if len(cfg.DATASET.RATIO) == 6: # Use BEDLAM | |
self.datasets.append(BEDLAMDataset(cfg)) | |
self.dataset_names.append('BEDLAM') | |
self._set_partition(cfg.DATASET.RATIO) | |
self.lengths = [len(ds) for ds in self.datasets] | |
def __name__(self, ): | |
return 'MixedData' | |
def prepare_video_batch(self): | |
[ds.prepare_video_batch() for ds in self.datasets] | |
self.lengths = [len(ds) for ds in self.datasets] | |
def _set_partition(self, partition): | |
self.partition = partition | |
self.ratio = partition | |
self.partition = np.array(self.partition).cumsum() | |
self.partition /= self.partition[-1] | |
def __len__(self): | |
return int(np.array([l for l, r in zip(self.lengths, self.ratio) if r > 0]).mean()) | |
def __getitem__(self, index): | |
# Get the dataset to sample from | |
p = np.random.rand() | |
for i in range(len(self.datasets)): | |
if p <= self.partition[i]: | |
if len(self.datasets) == 1: | |
return self.datasets[i][index % self.lengths[i]] | |
else: | |
d_index = np.random.randint(0, self.lengths[i]) | |
return self.datasets[i][d_index] |