Spaces:
Runtime error
Runtime error
| # Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py | |
| from torch.utils.data import BatchSampler, DataLoader, IterableDataset | |
| # kwargs of the DataLoader in min version 1.4.0. | |
| _PYTORCH_DATALOADER_KWARGS = { | |
| "batch_size": 1, | |
| "shuffle": False, | |
| "sampler": None, | |
| "batch_sampler": None, | |
| "num_workers": 0, | |
| "collate_fn": None, | |
| "pin_memory": False, | |
| "drop_last": False, | |
| "timeout": 0, | |
| "worker_init_fn": None, | |
| "multiprocessing_context": None, | |
| "generator": None, | |
| "prefetch_factor": 2, | |
| "persistent_workers": False, | |
| } | |
| class SkipBatchSampler(BatchSampler): | |
| """ | |
| A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. | |
| """ | |
| def __init__(self, batch_sampler, skip_batches=0): | |
| self.batch_sampler = batch_sampler | |
| self.skip_batches = skip_batches | |
| def __iter__(self): | |
| for index, samples in enumerate(self.batch_sampler): | |
| if index >= self.skip_batches: | |
| yield samples | |
| def total_length(self): | |
| return len(self.batch_sampler) | |
| def __len__(self): | |
| return len(self.batch_sampler) - self.skip_batches | |
| class SkipDataLoader(DataLoader): | |
| """ | |
| Subclass of a PyTorch `DataLoader` that will skip the first batches. | |
| Args: | |
| dataset (`torch.utils.data.dataset.Dataset`): | |
| The dataset to use to build this datalaoder. | |
| skip_batches (`int`, *optional*, defaults to 0): | |
| The number of batches to skip at the beginning. | |
| kwargs: | |
| All other keyword arguments to pass to the regular `DataLoader` initialization. | |
| """ | |
| def __init__(self, dataset, skip_batches=0, **kwargs): | |
| super().__init__(dataset, **kwargs) | |
| self.skip_batches = skip_batches | |
| def __iter__(self): | |
| for index, batch in enumerate(super().__iter__()): | |
| if index >= self.skip_batches: | |
| yield batch | |
| # Adapted from https://github.com/huggingface/accelerate | |
| def skip_first_batches(dataloader, num_batches=0): | |
| """ | |
| Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. | |
| """ | |
| dataset = dataloader.dataset | |
| sampler_is_batch_sampler = False | |
| if isinstance(dataset, IterableDataset): | |
| new_batch_sampler = None | |
| else: | |
| sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) | |
| batch_sampler = ( | |
| dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler | |
| ) | |
| new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) | |
| # We ignore all of those since they are all dealt with by our new_batch_sampler | |
| ignore_kwargs = [ | |
| "batch_size", | |
| "shuffle", | |
| "sampler", | |
| "batch_sampler", | |
| "drop_last", | |
| ] | |
| kwargs = { | |
| k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) | |
| for k in _PYTORCH_DATALOADER_KWARGS | |
| if k not in ignore_kwargs | |
| } | |
| # Need to provide batch_size as batch_sampler is None for Iterable dataset | |
| if new_batch_sampler is None: | |
| kwargs["drop_last"] = dataloader.drop_last | |
| kwargs["batch_size"] = dataloader.batch_size | |
| if new_batch_sampler is None: | |
| # Need to manually skip batches in the dataloader | |
| dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) | |
| else: | |
| dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) | |
| return dataloader | |