Spaces:
Runtime error
Runtime error
# Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py | |
from torch.utils.data import BatchSampler, DataLoader, IterableDataset | |
# kwargs of the DataLoader in min version 1.4.0. | |
_PYTORCH_DATALOADER_KWARGS = { | |
"batch_size": 1, | |
"shuffle": False, | |
"sampler": None, | |
"batch_sampler": None, | |
"num_workers": 0, | |
"collate_fn": None, | |
"pin_memory": False, | |
"drop_last": False, | |
"timeout": 0, | |
"worker_init_fn": None, | |
"multiprocessing_context": None, | |
"generator": None, | |
"prefetch_factor": 2, | |
"persistent_workers": False, | |
} | |
class SkipBatchSampler(BatchSampler): | |
""" | |
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. | |
""" | |
def __init__(self, batch_sampler, skip_batches=0): | |
self.batch_sampler = batch_sampler | |
self.skip_batches = skip_batches | |
def __iter__(self): | |
for index, samples in enumerate(self.batch_sampler): | |
if index >= self.skip_batches: | |
yield samples | |
def total_length(self): | |
return len(self.batch_sampler) | |
def __len__(self): | |
return len(self.batch_sampler) - self.skip_batches | |
class SkipDataLoader(DataLoader): | |
""" | |
Subclass of a PyTorch `DataLoader` that will skip the first batches. | |
Args: | |
dataset (`torch.utils.data.dataset.Dataset`): | |
The dataset to use to build this datalaoder. | |
skip_batches (`int`, *optional*, defaults to 0): | |
The number of batches to skip at the beginning. | |
kwargs: | |
All other keyword arguments to pass to the regular `DataLoader` initialization. | |
""" | |
def __init__(self, dataset, skip_batches=0, **kwargs): | |
super().__init__(dataset, **kwargs) | |
self.skip_batches = skip_batches | |
def __iter__(self): | |
for index, batch in enumerate(super().__iter__()): | |
if index >= self.skip_batches: | |
yield batch | |
# Adapted from https://github.com/huggingface/accelerate | |
def skip_first_batches(dataloader, num_batches=0): | |
""" | |
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. | |
""" | |
dataset = dataloader.dataset | |
sampler_is_batch_sampler = False | |
if isinstance(dataset, IterableDataset): | |
new_batch_sampler = None | |
else: | |
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) | |
batch_sampler = ( | |
dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler | |
) | |
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) | |
# We ignore all of those since they are all dealt with by our new_batch_sampler | |
ignore_kwargs = [ | |
"batch_size", | |
"shuffle", | |
"sampler", | |
"batch_sampler", | |
"drop_last", | |
] | |
kwargs = { | |
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) | |
for k in _PYTORCH_DATALOADER_KWARGS | |
if k not in ignore_kwargs | |
} | |
# Need to provide batch_size as batch_sampler is None for Iterable dataset | |
if new_batch_sampler is None: | |
kwargs["drop_last"] = dataloader.drop_last | |
kwargs["batch_size"] = dataloader.batch_size | |
if new_batch_sampler is None: | |
# Need to manually skip batches in the dataloader | |
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) | |
else: | |
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) | |
return dataloader | |