import numpy as np import random import copy import time import warnings from torch.utils.data import Sampler from torch._six import int_classes as _int_classes # from configs.dog_breeds.dog_breed_class import get_partial_summary class TwoDatasetSampler(Sampler): """Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, batch_size_half, size0, size1, shuffle=True, drop_last=True): # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if not isinstance(batch_size_half, _int_classes) or isinstance(batch_size_half, bool) or \ batch_size_half <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size_half*2)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) assert size0 >= size1 self.batch_size_half = batch_size_half self.size0 = size0 self.size1 = size1 self.shuffle = shuffle self.n_batches = self.size1//batch_size_half self.drop_last = drop_last def get_description(self): description = "\ This sampler samples equally from two different datasets" return description def __iter__(self): dataset0 = np.arange(self.size0) dataset1_init = np.arange(self.size1) + self.size0 if self.shuffle: np.random.shuffle(dataset0) dataset1 = [] for ind in range(self.size0 // self.size1 + 1): dataset1_part = dataset1_init.copy() if self.shuffle: np.random.shuffle(dataset1_part) dataset1.extend(dataset1_part) dataset0 = dataset0[0:self.n_batches*self.batch_size_half] dataset1 = dataset1[0:self.n_batches*self.batch_size_half] # import pdb; pdb.set_trace() for ind_batch in range(self.n_batches): d0 = dataset0[ind_batch*self.batch_size_half:(ind_batch+1)*self.batch_size_half] d1 = dataset1[ind_batch*self.batch_size_half:(ind_batch+1)*self.batch_size_half] batch = list(d0) + list(d1) # print(len(batch)) yield batch def __len__(self): # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] '''if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore''' return self.n_batches