Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# domainbed/lib/fast_data_loader.py | |
import torch | |
from .datasets.ab_dataset import ABDataset | |
class _InfiniteSampler(torch.utils.data.Sampler): | |
"""Wraps another Sampler to yield an infinite stream.""" | |
def __init__(self, sampler): | |
self.sampler = sampler | |
def __iter__(self): | |
while True: | |
for batch in self.sampler: | |
yield batch | |
class InfiniteDataLoader: | |
def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None): | |
super().__init__() | |
if weights: | |
sampler = torch.utils.data.WeightedRandomSampler( | |
weights, replacement=True, num_samples=batch_size | |
) | |
else: | |
sampler = torch.utils.data.RandomSampler(dataset, replacement=True) | |
batch_sampler = torch.utils.data.BatchSampler( | |
sampler, batch_size=batch_size, drop_last=True | |
) | |
if collate_fn is not None: | |
self._infinite_iterator = iter( | |
torch.utils.data.DataLoader( | |
dataset, | |
num_workers=num_workers, | |
batch_sampler=_InfiniteSampler(batch_sampler), | |
pin_memory=False, | |
collate_fn=collate_fn | |
) | |
) | |
else: | |
self._infinite_iterator = iter( | |
torch.utils.data.DataLoader( | |
dataset, | |
num_workers=num_workers, | |
batch_sampler=_InfiniteSampler(batch_sampler), | |
pin_memory=False | |
) | |
) | |
self.dataset = dataset | |
def __iter__(self): | |
while True: | |
yield next(self._infinite_iterator) | |
def __len__(self): | |
raise ValueError | |
class FastDataLoader: | |
""" | |
DataLoader wrapper with slightly improved speed by not respawning worker | |
processes at every epoch. | |
""" | |
def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None): | |
super().__init__() | |
self.num_workers = num_workers | |
if shuffle: | |
sampler = torch.utils.data.RandomSampler(dataset, replacement=False) | |
else: | |
sampler = torch.utils.data.SequentialSampler(dataset) | |
batch_sampler = torch.utils.data.BatchSampler( | |
sampler, | |
batch_size=batch_size, | |
drop_last=False, | |
) | |
if collate_fn is not None: | |
self._infinite_iterator = iter( | |
torch.utils.data.DataLoader( | |
dataset, | |
num_workers=num_workers, | |
batch_sampler=_InfiniteSampler(batch_sampler), | |
pin_memory=False, | |
collate_fn=collate_fn | |
) | |
) | |
else: | |
self._infinite_iterator = iter( | |
torch.utils.data.DataLoader( | |
dataset, | |
num_workers=num_workers, | |
batch_sampler=_InfiniteSampler(batch_sampler), | |
pin_memory=False, | |
) | |
) | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self._length = len(batch_sampler) | |
def __iter__(self): | |
for _ in range(len(self)): | |
yield next(self._infinite_iterator) | |
def __len__(self): | |
return self._length | |
def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None): | |
assert batch_size <= len(dataset), len(dataset) | |
if infinite: | |
dataloader = InfiniteDataLoader( | |
dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn) | |
else: | |
dataloader = FastDataLoader( | |
dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn) | |
return dataloader | |
def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool): | |
pass | |