File size: 2,206 Bytes
f561f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import numpy as np

from .amass import AMASSDataset
from .videos import Human36M, ThreeDPW, MPII3D, InstaVariety
from .bedlam import BEDLAMDataset
from lib.utils.data_utils import make_collate_fn


class DataFactory(torch.utils.data.Dataset):
    def __init__(self, cfg, train_stage='syn'):
        super(DataFactory, self).__init__()
        
        if train_stage == 'stage1':
            self.datasets = [AMASSDataset(cfg)]
            self.dataset_names = ['AMASS']
        elif train_stage == 'stage2':
            self.datasets = [
                AMASSDataset(cfg), ThreeDPW(cfg), 
                Human36M(cfg), MPII3D(cfg), InstaVariety(cfg)
            ]
            self.dataset_names = ['AMASS', '3DPW', 'Human36M', 'MPII3D', 'Insta']
            
            if len(cfg.DATASET.RATIO) == 6: # Use BEDLAM
                self.datasets.append(BEDLAMDataset(cfg))
                self.dataset_names.append('BEDLAM')

        self._set_partition(cfg.DATASET.RATIO)
        self.lengths = [len(ds) for ds in self.datasets]

    @property
    def __name__(self, ):
        return 'MixedData'

    def prepare_video_batch(self):
        [ds.prepare_video_batch() for ds in self.datasets]
        self.lengths = [len(ds) for ds in self.datasets]

    def _set_partition(self, partition):
        self.partition = partition
        self.ratio = partition
        self.partition = np.array(self.partition).cumsum()
        self.partition /= self.partition[-1]

    def __len__(self):
        return int(np.array([l for l, r in zip(self.lengths, self.ratio) if r > 0]).mean())

    def __getitem__(self, index):
        # Get the dataset to sample from
        p = np.random.rand()
        for i in range(len(self.datasets)):
            if p <= self.partition[i]:
                if len(self.datasets) == 1:
                    return self.datasets[i][index % self.lengths[i]]
                else:
                    d_index = np.random.randint(0, self.lengths[i])
                    return self.datasets[i][d_index]