import torch from torch.utils.data import Dataset from typing import Optional import math import torch.distributed as dist class ClassBalancedDistributedSampler(torch.utils.data.Sampler): """ A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13 """ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None: if not shuffle: raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler") # Check if the dataset has a generate_class_balanced_indices method if not hasattr(dataset, 'generate_class_balanced_indices'): raise ValueError("Dataset does not have a generate_class_balanced_indices method") self.shuffle = shuffle self.seed = seed if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last # Calculate the number of samples g = torch.Generator() g.manual_seed(self.seed + self.epoch) self.num_samples_per_class = num_samples_per_class indices = dataset.generate_class_balanced_indices(torch.Generator(), num_samples_per_class=num_samples_per_class) dataset_size = len(indices) # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas def __iter__(self): # deterministically shuffle based on epoch and seed, here shuffle is assumed to be True g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class) if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] # subsample indices = indices[self.rank:self.total_size:self.num_replicas] return iter(indices) def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Set the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch