import numpy as np import torch from accelerate import Accelerator import torch.utils from torch.utils.data import BatchSampler, Sampler import torch.utils.data class CustomRandomSampler(Sampler): """Random sampling under a constraint: each sample in the batch has the same feature, which is chosen randomly from a known pool of 'features' for each batch. For instance, the 'feature' could be the image aspect-ratio. The index returned is a tuple (sample_idx, feat_idx). This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. """ def __init__( self, dataset, batch_size, pool_size, min_view_size, max_view_size, world_size, warmup=1, drop_last=True, ): self.batch_size = batch_size self.pool_size = pool_size self.min_view_size = min_view_size self.max_view_size = max_view_size self.drop_last = drop_last self.len_dataset = N = len(dataset) self.total_size = N self.epoch = None self.epochf = 0.0 def __len__(self): return self.total_size def set_epoch(self, epoch): self.epoch = epoch def __iter__(self): if self.epoch is None: raise ValueError( "Epoch number not set. Please call 'set_epoch(epoch)' before iterating." ) seed = self.epoch + 788 rng = np.random.default_rng(seed=seed) # random indices (will restart from 0 if not drop_last) sample_idxs = np.arange(self.total_size) rng.shuffle(sample_idxs) # random feat_idxs (same across each batch) n_batches = (self.total_size + self.batch_size - 1) // self.batch_size if self.pool_size > 1: p = np.ones(self.pool_size) p[: self.pool_size // 2] *= 2 p = p / p.sum() _feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p) else: _feat_idxs = rng.integers(self.pool_size, size=n_batches) _feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size)) _feat_idxs = _feat_idxs.ravel()[: self.total_size] _view_idxs = rng.integers( self.min_view_size, self.max_view_size + 1, size=n_batches ) _view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size)) _view_idxs = _view_idxs.ravel()[: self.total_size] idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs] yield from (tuple(idx) for idx in idxs) class BatchedRandomSampler(BatchSampler): """Batch sampler that groups indices from RandomSampler into batches.""" def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True): self.sampler = sampler # An instance of RandomSampler self.batch_size = batch_size self.drop_last = drop_last def set_epoch(self, epoch): self.sampler.set_epoch(epoch) def round_by(total, multiple, up=False): if up: total = total + multiple - 1 return (total // multiple) * multiple