Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import Dataset | |
class ClassBalancedRandomSampler(torch.utils.data.Sampler): | |
""" | |
A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class | |
This is essentially the non-ddp version of ClassBalancedDistributedSampler | |
Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117 | |
""" | |
def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None: | |
self.dataset = dataset | |
self.seed = seed | |
# Calculate the number of samples | |
self.generator = torch.Generator() | |
self.generator.manual_seed(self.seed) | |
self.num_samples_per_class = num_samples_per_class | |
indices = dataset.generate_class_balanced_indices(self.generator, | |
num_samples_per_class=num_samples_per_class) | |
self.num_samples = len(indices) | |
def __iter__(self): | |
# Change seed for every function call | |
seed = int(torch.empty((), dtype=torch.int64).random_().item()) | |
self.generator.manual_seed(seed) | |
indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class) | |
return iter(indices) | |
def __len__(self) -> int: | |
return self.num_samples | |