disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/datasets
/distributed_sampler.py
# python3.7 | |
"""Contains the distributed data sampler. | |
This file is mostly borrowed from `torch/utils/data/distributed.py`. | |
However, sometimes, initialize the data loader and data sampler can be time | |
consuming (since it will load a large amount of data at one time). To avoid | |
re-initializing the data loader again and again, we modified the sampler to | |
support loading the data for only one time and then repeating the data loader. | |
Please use the class member `repeat` to control how many times you want the | |
data load to repeat. After `repeat` times, the data will be re-loaded. | |
NOTE: The number of repeat times should not be very large, especially when there | |
are too many samples in the dataset. We recommend to set `repeat = 500` for | |
datasets with ~50K samples. | |
""" | |
# pylint: disable=line-too-long | |
import math | |
from typing import TypeVar, Optional, Iterator | |
import torch | |
from torch.utils.data import Sampler, Dataset | |
import torch.distributed as dist | |
T_co = TypeVar('T_co', covariant=True) | |
class DistributedSampler(Sampler): | |
r"""Sampler that restricts data loading to a subset of the dataset. | |
It is especially useful in conjunction with | |
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each | |
process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a | |
:class:`~torch.utils.data.DataLoader` sampler, and load a subset of the | |
original dataset that is exclusive to it. | |
.. note:: | |
Dataset is assumed to be of constant size. | |
Arguments: | |
dataset: Dataset used for sampling. | |
num_replicas (int, optional): Number of processes participating in | |
distributed training. By default, :attr:`rank` is retrieved from the | |
current distributed group. | |
rank (int, optional): Rank of the current process within :attr:`num_replicas`. | |
By default, :attr:`rank` is retrieved from the current distributed | |
group. | |
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the | |
indices. | |
seed (int, optional): random seed used to shuffle the sampler if | |
:attr:`shuffle=True`. This number should be identical across all | |
processes in the distributed group. Default: ``0``. | |
drop_last (bool, optional): if ``True``, then the sampler will drop the | |
tail of the data to make it evenly divisible across the number of | |
replicas. If ``False``, the sampler will add extra indices to make | |
the data evenly divisible across the replicas. Default: ``False``. | |
current_iter (int, optional): Number of current iteration. Default: ``0``. | |
repeat (int, optional): Repeating number of the whole dataloader. Default: ``1000``. | |
.. warning:: | |
In distributed mode, calling the :meth:`set_epoch` method at | |
the beginning of each epoch **before** creating the :class:`DataLoader` iterator | |
is necessary to make shuffling work properly across multiple epochs. Otherwise, | |
the same ordering will be always used. | |
""" | |
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, | |
rank: Optional[int] = None, shuffle: bool = True, | |
seed: int = 0, drop_last: bool = False, current_iter: int = 0, | |
repeat: int = 1000) -> None: | |
super().__init__(None) | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
self.dataset = dataset | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.iter = current_iter | |
self.drop_last = drop_last | |
# NOTE: self.dataset_length is `repeat X len(self.dataset)` | |
self.repeat = repeat | |
self.dataset_length = len(self.dataset) * self.repeat | |
if self.drop_last and self.dataset_length % self.num_replicas != 0: | |
# Split to nearest available length that is evenly divisible. | |
# This is to ensure each rank receives the same amount of data when | |
# using this Sampler. | |
self.num_samples = math.ceil( | |
(self.dataset_length - self.num_replicas) / self.num_replicas | |
) | |
else: | |
self.num_samples = math.ceil(self.dataset_length / self.num_replicas) | |
self.total_size = self.num_samples * self.num_replicas | |
self.shuffle = shuffle | |
self.seed = seed | |
self.__generate_indices__() | |
def __generate_indices__(self) -> None: | |
g = torch.Generator() | |
indices_bank = [] | |
for iter_ in range(self.iter, self.iter + self.repeat): | |
g.manual_seed(self.seed + iter_) | |
indices = torch.randperm(len(self.dataset), generator=g).tolist() | |
indices_bank.extend(indices) | |
self.indices = indices_bank | |
def __iter__(self) -> Iterator[T_co]: | |
if self.shuffle: | |
# deterministically shuffle based on iter and seed | |
indices = self.indices | |
else: | |
indices = list(range(self.dataset_length)) | |
if not self.drop_last: | |
# add extra samples to make it evenly divisible | |
indices += indices[:(self.total_size - len(indices))] | |
else: | |
# remove tail of data to make it evenly divisible. | |
indices = indices[:self.total_size] | |
# subsample | |
indices = indices[self.rank:self.total_size:self.num_replicas] | |
return iter(indices) | |
def __len__(self) -> int: | |
return self.num_samples | |
def __reset__(self, iteration: int) -> None: | |
self.iter = iteration | |
self.__generate_indices__() | |
# pylint: enable=line-too-long | |