Spaces:
Runtime error
Runtime error
File size: 3,131 Bytes
2df809d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import numpy as np
import torch
from accelerate import Accelerator
import torch.utils
from torch.utils.data import BatchSampler, Sampler
import torch.utils.data
class CustomRandomSampler(Sampler):
"""Random sampling under a constraint: each sample in the batch has the same feature,
which is chosen randomly from a known pool of 'features' for each batch.
For instance, the 'feature' could be the image aspect-ratio.
The index returned is a tuple (sample_idx, feat_idx).
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
"""
def __init__(
self,
dataset,
batch_size,
pool_size,
min_view_size,
max_view_size,
world_size,
warmup=1,
drop_last=True,
):
self.batch_size = batch_size
self.pool_size = pool_size
self.min_view_size = min_view_size
self.max_view_size = max_view_size
self.drop_last = drop_last
self.len_dataset = N = len(dataset)
self.total_size = N
self.epoch = None
self.epochf = 0.0
def __len__(self):
return self.total_size
def set_epoch(self, epoch):
self.epoch = epoch
def __iter__(self):
if self.epoch is None:
raise ValueError(
"Epoch number not set. Please call 'set_epoch(epoch)' before iterating."
)
seed = self.epoch + 788
rng = np.random.default_rng(seed=seed)
# random indices (will restart from 0 if not drop_last)
sample_idxs = np.arange(self.total_size)
rng.shuffle(sample_idxs)
# random feat_idxs (same across each batch)
n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
if self.pool_size > 1:
p = np.ones(self.pool_size)
p[: self.pool_size // 2] *= 2
p = p / p.sum()
_feat_idxs = rng.choice(self.pool_size, size=n_batches, p=p)
else:
_feat_idxs = rng.integers(self.pool_size, size=n_batches)
_feat_idxs = np.broadcast_to(_feat_idxs[:, None], (n_batches, self.batch_size))
_feat_idxs = _feat_idxs.ravel()[: self.total_size]
_view_idxs = rng.integers(
self.min_view_size, self.max_view_size + 1, size=n_batches
)
_view_idxs = np.broadcast_to(_view_idxs[:, None], (n_batches, self.batch_size))
_view_idxs = _view_idxs.ravel()[: self.total_size]
idxs = np.c_[sample_idxs, _feat_idxs, _view_idxs]
yield from (tuple(idx) for idx in idxs)
class BatchedRandomSampler(BatchSampler):
"""Batch sampler that groups indices from RandomSampler into batches."""
def __init__(self, sampler: CustomRandomSampler, batch_size, drop_last=True):
self.sampler = sampler # An instance of RandomSampler
self.batch_size = batch_size
self.drop_last = drop_last
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def round_by(total, multiple, up=False):
if up:
total = total + multiple - 1
return (total // multiple) * multiple
|