pdiscoformer / utils /data_utils /class_balanced_sampler.py
ananthu-aniraj's picture
add initial files
20239f9
raw
history blame
1.4 kB
import torch
from torch.utils.data import Dataset
class ClassBalancedRandomSampler(torch.utils.data.Sampler):
"""
A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class
This is essentially the non-ddp version of ClassBalancedDistributedSampler
Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117
"""
def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None:
self.dataset = dataset
self.seed = seed
# Calculate the number of samples
self.generator = torch.Generator()
self.generator.manual_seed(self.seed)
self.num_samples_per_class = num_samples_per_class
indices = dataset.generate_class_balanced_indices(self.generator,
num_samples_per_class=num_samples_per_class)
self.num_samples = len(indices)
def __iter__(self):
# Change seed for every function call
seed = int(torch.empty((), dtype=torch.int64).random_().item())
self.generator.manual_seed(seed)
indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class)
return iter(indices)
def __len__(self) -> int:
return self.num_samples