Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset | |
from typing import Optional | |
import math | |
import torch.distributed as dist | |
class ClassBalancedDistributedSampler(torch.utils.data.Sampler): | |
""" | |
A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class | |
Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13 | |
""" | |
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, | |
shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None: | |
if not shuffle: | |
raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler") | |
# Check if the dataset has a generate_class_balanced_indices method | |
if not hasattr(dataset, 'generate_class_balanced_indices'): | |
raise ValueError("Dataset does not have a generate_class_balanced_indices method") | |
self.shuffle = shuffle | |
self.seed = seed | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
if rank >= num_replicas or rank < 0: | |
raise ValueError( | |
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") | |
self.dataset = dataset | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
self.drop_last = drop_last | |
# Calculate the number of samples | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
self.num_samples_per_class = num_samples_per_class | |
indices = dataset.generate_class_balanced_indices(torch.Generator(), | |
num_samples_per_class=num_samples_per_class) | |
dataset_size = len(indices) | |
# If the dataset length is evenly divisible by # of replicas, then there | |
# is no need to drop any data, since the dataset will be split equally. | |
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] | |
# Split to nearest available length that is evenly divisible. | |
# This is to ensure each rank receives the same amount of data when | |
# using this Sampler. | |
self.num_samples = math.ceil( | |
(dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type] | |
) | |
else: | |
self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type] | |
self.total_size = self.num_samples * self.num_replicas | |
def __iter__(self): | |
# deterministically shuffle based on epoch and seed, here shuffle is assumed to be True | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class) | |
if not self.drop_last: | |
# add extra samples to make it evenly divisible | |
padding_size = self.total_size - len(indices) | |
if padding_size <= len(indices): | |
indices += indices[:padding_size] | |
else: | |
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |
else: | |
# remove tail of data to make it evenly divisible. | |
indices = indices[:self.total_size] | |
# subsample | |
indices = indices[self.rank:self.total_size:self.num_replicas] | |
return iter(indices) | |
def __len__(self) -> int: | |
return self.num_samples | |
def set_epoch(self, epoch: int) -> None: | |
r""" | |
Set the epoch for this sampler. | |
When :attr:`shuffle=True`, this ensures all replicas | |
use a different random ordering for each epoch. Otherwise, the next iteration of this | |
sampler will yield the same ordering. | |
Args: | |
epoch (int): Epoch number. | |
""" | |
self.epoch = epoch | |