File size: 9,251 Bytes
499e141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from joblib import Parallel, delayed
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, ConcatDataset
from datasets.augment import build_augmentor
from tools.misc import tqdm_joblib

from .gl3d.gl3d import GL3DDataset
from .gtasfm.gtasfm import GTASfMDataset
from .multifov.multifov import MultiFoVDataset
from .gl3d.gl3d import GL3DDataset as BlendedMVSDataset
from .iclnuim.iclnuim import ICLNUIMDataset
from .scenenet.scenenet import SceneNetDataset
from .eth3d.eth3d import ETH3DDataset
from .kitti.kitti import KITTIDataset
from .robotcar.robotcar import RobotcarDataset

Benchmarks = dict(
    GL3D            = GL3DDataset,
    GTASfM          = GTASfMDataset,
    MultiFoV        = MultiFoVDataset,
    BlendedMVS      = BlendedMVSDataset,
    ICLNUIM         = ICLNUIMDataset,
    SceneNet        = SceneNetDataset,
    ETH3DO          = ETH3DDataset,
    ETH3DI          = ETH3DDataset,
    KITTI           = KITTIDataset,
    RobotcarNight   = RobotcarDataset,
    RobotcarSeason  = RobotcarDataset,
    RobotcarWeather = RobotcarDataset,
)


class MultiSceneDataModule(pl.LightningDataModule):
    """ 
    For distributed training, each training process is assgined 
    only a part of the training scenes to reduce memory overhead.
    """

    def __init__(self, args, dcfg):
        """
        
        Args:
            args: (ArgumentParser) The only useful args is args.trains and args.valids
                    each one is a list, which contain like [PhotoTourism, MegaDepth,...]
                    We should traverse each item in args.trains and args.valids to build
                    self.train_datasets and self.valid_datasets
            dcfg: (yacs) It contain all configs for each benchmark in args.trains and
                    args.valids
        """
        super().__init__()

        self.args = args
        self.dcfg = dcfg
        self.train_loader_params = {'batch_size': args.batch_size,
                                    'shuffle': True,
                                    'num_workers': args.threads,
                                    'pin_memory': True,
                                    'drop_last': True}
        self.valid_loader_params = {'batch_size': args.batch_size,
                                    'shuffle': False,
                                    'num_workers': args.threads,
                                    'pin_memory': True,
                                    'drop_last': False}
        self.tests_loader_params = {'batch_size': args.batch_size,
                                    'shuffle': False,
                                    'num_workers': args.threads,
                                    'pin_memory': True,
                                    'drop_last': False}

    def setup(self, stage=None):
        """ 
        Setup train/valid/test dataset. This method will be called by PL automatically.
        Args:
            stage (str): 'fit' in training phase, and 'test' in testing phase.
        """

        self.gpus = self.trainer.gpus
        self.gpuid = self.trainer.global_rank

        self.train_datasets = None
        self.valid_datasets = None
        self.tests_datasets = None

        # TRAIN
        if stage == 'fit':
            train_datasets = []
            for benchmark in self.args.trains:
                dcfg = self.dcfg.get(benchmark, None)
                assert dcfg is not None, "Training dcfg is None"

                datasets = self._setup_dataset(
                    benchmark=benchmark,
                    data_root=dcfg.DATASET.TRAIN.DATA_ROOT,
                    npz_root=dcfg.DATASET.TRAIN.NPZ_ROOT,
                    scene_list_path=dcfg.DATASET.TRAIN.LIST_PATH,
                    df=self.dcfg.DF,
                    padding=dcfg.DATASET.TRAIN.PADDING,
                    min_overlap_score=dcfg.DATASET.TRAIN.MIN_OVERLAP_SCORE,
                    max_overlap_score=dcfg.DATASET.TRAIN.MAX_OVERLAP_SCORE,
                    max_resize=self.args.img_size,
                    augment_fn=build_augmentor(dcfg.DATASET.TRAIN.AUGMENTATION_TYPE),
                    max_samples=dcfg.DATASET.TRAIN.MAX_SAMPLES,
                    mode='train',
                    njobs=dcfg.NJOBS,
                    cfg=dcfg.DATASET.TRAIN,
                )
                train_datasets += datasets
            self.train_datasets = ConcatDataset(train_datasets)
            os.environ['TOTAL_TRAIN_SAMPLES'] = str(len(self.train_datasets))

            # VALID
            valid_datasets = []
            for benchmark in self.args.valids:
                dcfg = self.dcfg.get(benchmark, None)
                assert dcfg is not None, "Validing dcfg is None"

                datasets = self._setup_dataset(
                    benchmark=benchmark,
                    data_root=dcfg.DATASET.VALID.DATA_ROOT,
                    npz_root=dcfg.DATASET.VALID.NPZ_ROOT,
                    scene_list_path=dcfg.DATASET.VALID.LIST_PATH,
                    df=self.dcfg.DF,
                    padding=dcfg.DATASET.VALID.PADDING,
                    min_overlap_score=dcfg.DATASET.VALID.MIN_OVERLAP_SCORE,
                    max_overlap_score=dcfg.DATASET.VALID.MAX_OVERLAP_SCORE,
                    max_resize=self.args.img_size,
                    augment_fn=build_augmentor(dcfg.DATASET.VALID.AUGMENTATION_TYPE),
                    max_samples=dcfg.DATASET.VALID.MAX_SAMPLES,
                    mode='valid',
                    njobs=dcfg.NJOBS,
                    cfg=dcfg.DATASET.VALID,
                )
                valid_datasets += datasets
            self.valid_datasets = ConcatDataset(valid_datasets)
            os.environ['TOTAL_VALID_SAMPLES'] = str(len(self.valid_datasets))

        # TEST
        if stage == 'test':
            tests_datasets = []
            for benchmark in [self.args.tests]:
                dcfg = self.dcfg.get(benchmark, None)
                assert dcfg is not None, "Validing dcfg is None"

                datasets = self._setup_dataset(
                    benchmark=benchmark,
                    data_root=dcfg.DATASET.TESTS.DATA_ROOT,
                    npz_root=dcfg.DATASET.TESTS.NPZ_ROOT,
                    scene_list_path=dcfg.DATASET.TESTS.LIST_PATH,
                    df=self.dcfg.DF,
                    padding=dcfg.DATASET.TESTS.PADDING,
                    min_overlap_score=dcfg.DATASET.TESTS.MIN_OVERLAP_SCORE,
                    max_overlap_score=dcfg.DATASET.TESTS.MAX_OVERLAP_SCORE,
                    max_resize=self.args.img_size,
                    augment_fn=build_augmentor(dcfg.DATASET.TESTS.AUGMENTATION_TYPE),
                    max_samples=dcfg.DATASET.TESTS.MAX_SAMPLES,
                    mode='test',
                    njobs=dcfg.NJOBS,
                    cfg=dcfg.DATASET.TESTS,
                )
                tests_datasets += datasets
            self.tests_datasets = ConcatDataset(tests_datasets)
            os.environ['TOTAL_TESTS_SAMPLES'] = str(len(self.tests_datasets))
            if self.gpuid == 0: print('TOTAL_TESTS_SAMPLES:', len(self.tests_datasets))

    def _setup_dataset(self, benchmark, data_root, npz_root, scene_list_path, df, padding,
                       min_overlap_score, max_overlap_score, max_resize, augment_fn,
                       max_samples, mode, njobs, cfg):

        seq_names = [benchmark.lower()]

        with tqdm_joblib(tqdm(bar_format="{l_bar}{bar:3}{r_bar}", ncols=100,
                              desc=f'[GPU {self.gpuid}] load {mode} {benchmark:14} data',
                              total=len(seq_names), disable=int(self.gpuid) != 0)):
            datasets = Parallel(n_jobs=njobs)(
                delayed(lambda x: _build_dataset(
                    Benchmarks.get(benchmark),
                    root_dir=data_root,
                    npz_root=npz_root,
                    seq_name=x,
                    mode=mode,
                    min_overlap_score=min_overlap_score,
                    max_overlap_score=max_overlap_score,
                    max_resize=max_resize,
                    df=df,
                    padding=padding,
                    augment_fn=augment_fn,
                    max_samples=max_samples,
                    **cfg
                ))(seqname) for seqname in seq_names)
        return datasets

    def train_dataloader(self, *args, **kwargs):
        return DataLoader(self.train_datasets, collate_fn=collate_fn, **self.train_loader_params)

    def valid_dataloader(self, *args, **kwargs):
        return DataLoader(self.valid_datasets, collate_fn=collate_fn, **self.valid_loader_params)

    def val_dataloader(self, *args, **kwargs):
        return self.valid_dataloader(*args, **kwargs)

    def test_dataloader(self, *args, **kwargs):
        return DataLoader(self.tests_datasets, collate_fn=collate_fn, **self.tests_loader_params)


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


def _build_dataset(dataset: Dataset, *args, **kwargs):
    # noinspection PyCallingNonCallable
    return dataset(*args, **kwargs)