Nadine Rueegg
initial commit with code and data
753fd9a
import numpy as np
import random
import copy
import time
import warnings
from torch.utils.data import Sampler
from torch._six import int_classes as _int_classes
# from configs.dog_breeds.dog_breed_class import get_partial_summary
class TwoDatasetSampler(Sampler):
"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, batch_size_half, size0, size1, shuffle=True, drop_last=True):
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size_half, _int_classes) or isinstance(batch_size_half, bool) or \
batch_size_half <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size_half*2))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
assert size0 >= size1
self.batch_size_half = batch_size_half
self.size0 = size0
self.size1 = size1
self.shuffle = shuffle
self.n_batches = self.size1//batch_size_half
self.drop_last = drop_last
def get_description(self):
description = "\
This sampler samples equally from two different datasets"
return description
def __iter__(self):
dataset0 = np.arange(self.size0)
dataset1_init = np.arange(self.size1) + self.size0
if self.shuffle:
np.random.shuffle(dataset0)
dataset1 = []
for ind in range(self.size0 // self.size1 + 1):
dataset1_part = dataset1_init.copy()
if self.shuffle:
np.random.shuffle(dataset1_part)
dataset1.extend(dataset1_part)
dataset0 = dataset0[0:self.n_batches*self.batch_size_half]
dataset1 = dataset1[0:self.n_batches*self.batch_size_half]
# import pdb; pdb.set_trace()
for ind_batch in range(self.n_batches):
d0 = dataset0[ind_batch*self.batch_size_half:(ind_batch+1)*self.batch_size_half]
d1 = dataset1[ind_batch*self.batch_size_half:(ind_batch+1)*self.batch_size_half]
batch = list(d0) + list(d1)
# print(len(batch))
yield batch
def __len__(self):
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
'''if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore'''
return self.n_batches