|
|
|
"""Contains the class of data loader.""" |
|
|
|
import argparse |
|
|
|
from torch.utils.data import DataLoader |
|
from .distributed_sampler import DistributedSampler |
|
from .datasets import BaseDataset |
|
|
|
|
|
__all__ = ['IterDataLoader'] |
|
|
|
|
|
class IterDataLoader(object): |
|
"""Iteration-based data loader.""" |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size, |
|
shuffle=True, |
|
num_workers=1, |
|
current_iter=0, |
|
repeat=1): |
|
"""Initializes the data loader. |
|
|
|
Args: |
|
dataset: The dataset to load data from. |
|
batch_size: The batch size on each GPU. |
|
shuffle: Whether to shuffle the data. (default: True) |
|
num_workers: Number of data workers for each GPU. (default: 1) |
|
current_iter: The current number of iterations. (default: 0) |
|
repeat: The repeating number of the whole dataloader. (default: 1) |
|
""" |
|
self._dataset = dataset |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.num_workers = num_workers |
|
self._dataloader = None |
|
self.iter_loader = None |
|
self._iter = current_iter |
|
self.repeat = repeat |
|
self.build_dataloader() |
|
|
|
def build_dataloader(self): |
|
"""Builds data loader.""" |
|
dist_sampler = DistributedSampler(self._dataset, |
|
shuffle=self.shuffle, |
|
current_iter=self._iter, |
|
repeat=self.repeat) |
|
|
|
self._dataloader = DataLoader(self._dataset, |
|
batch_size=self.batch_size, |
|
shuffle=(dist_sampler is None), |
|
num_workers=self.num_workers, |
|
drop_last=self.shuffle, |
|
pin_memory=True, |
|
sampler=dist_sampler) |
|
self.iter_loader = iter(self._dataloader) |
|
|
|
|
|
def overwrite_param(self, batch_size=None, resolution=None): |
|
"""Overwrites some parameters for progressive training.""" |
|
if (not batch_size) and (not resolution): |
|
return |
|
if (batch_size == self.batch_size) and ( |
|
resolution == self.dataset.resolution): |
|
return |
|
if batch_size: |
|
self.batch_size = batch_size |
|
if resolution: |
|
self._dataset.resolution = resolution |
|
self.build_dataloader() |
|
|
|
@property |
|
def iter(self): |
|
"""Returns the current iteration.""" |
|
return self._iter |
|
|
|
@property |
|
def dataset(self): |
|
"""Returns the dataset.""" |
|
return self._dataset |
|
|
|
@property |
|
def dataloader(self): |
|
"""Returns the data loader.""" |
|
return self._dataloader |
|
|
|
def __next__(self): |
|
try: |
|
data = next(self.iter_loader) |
|
self._iter += 1 |
|
except StopIteration: |
|
self._dataloader.sampler.__reset__(self._iter) |
|
self.iter_loader = iter(self._dataloader) |
|
data = next(self.iter_loader) |
|
self._iter += 1 |
|
return data |
|
|
|
def __len__(self): |
|
return len(self._dataloader) |
|
|
|
|
|
def dataloader_test(root_dir, test_num=10): |
|
"""Tests data loader.""" |
|
res = 2 |
|
bs = 2 |
|
dataset = BaseDataset(root_dir=root_dir, resolution=res) |
|
dataloader = IterDataLoader(dataset=dataset, |
|
batch_size=bs, |
|
shuffle=False) |
|
for _ in range(test_num): |
|
data_batch = next(dataloader) |
|
image = data_batch['image'] |
|
assert image.shape == (bs, 3, res, res) |
|
res *= 2 |
|
bs += 1 |
|
dataloader.overwrite_param(batch_size=bs, resolution=res) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Test Data Loader.') |
|
parser.add_argument('root_dir', type=str, |
|
help='Root directory of the dataset.') |
|
parser.add_argument('--test_num', type=int, default=10, |
|
help='Number of tests. (default: %(default)s)') |
|
args = parser.parse_args() |
|
dataloader_test(args.root_dir, args.test_num) |
|
|