| import chainer | |
| from chainer.iterators import MultiprocessIterator | |
| from chainer.iterators import SerialIterator | |
| from chainer.iterators import ShuffleOrderSampler | |
| from chainer.training.extension import Extension | |
| import numpy as np | |
| class ShufflingEnabler(Extension): | |
| """An extension enabling shuffling on an Iterator""" | |
| def __init__(self, iterators): | |
| """Inits the ShufflingEnabler | |
| :param list[Iterator] iterators: The iterators to enable shuffling on | |
| """ | |
| self.set = False | |
| self.iterators = iterators | |
| def __call__(self, trainer): | |
| """Calls the enabler on the given iterator | |
| :param trainer: The iterator | |
| """ | |
| if not self.set: | |
| for iterator in self.iterators: | |
| iterator.start_shuffle() | |
| self.set = True | |
| class ToggleableShufflingSerialIterator(SerialIterator): | |
| """A SerialIterator having its shuffling property activated during training""" | |
| def __init__(self, dataset, batch_size, repeat=True, shuffle=True): | |
| """Init the Iterator | |
| :param torch.nn.Tensor dataset: The dataset to take batches from | |
| :param int batch_size: The batch size | |
| :param bool repeat: Whether to repeat data (allow multiple epochs) | |
| :param bool shuffle: Whether to shuffle the batches | |
| """ | |
| super(ToggleableShufflingSerialIterator, self).__init__( | |
| dataset, batch_size, repeat, shuffle | |
| ) | |
| def start_shuffle(self): | |
| """Starts shuffling (or reshuffles) the batches""" | |
| self._shuffle = True | |
| if int(chainer._version.__version__[0]) <= 4: | |
| self._order = np.random.permutation(len(self.dataset)) | |
| else: | |
| self.order_sampler = ShuffleOrderSampler() | |
| self._order = self.order_sampler(np.arange(len(self.dataset)), 0) | |
| class ToggleableShufflingMultiprocessIterator(MultiprocessIterator): | |
| """A MultiprocessIterator having its shuffling property activated during training""" | |
| def __init__( | |
| self, | |
| dataset, | |
| batch_size, | |
| repeat=True, | |
| shuffle=True, | |
| n_processes=None, | |
| n_prefetch=1, | |
| shared_mem=None, | |
| maxtasksperchild=20, | |
| ): | |
| """Init the iterator | |
| :param torch.nn.Tensor dataset: The dataset to take batches from | |
| :param int batch_size: The batch size | |
| :param bool repeat: Whether to repeat batches or not (enables multiple epochs) | |
| :param bool shuffle: Whether to shuffle the order of the batches | |
| :param int n_processes: How many processes to use | |
| :param int n_prefetch: The number of prefetch to use | |
| :param int shared_mem: How many memory to share between processes | |
| :param int maxtasksperchild: Maximum number of tasks per child | |
| """ | |
| super(ToggleableShufflingMultiprocessIterator, self).__init__( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| repeat=repeat, | |
| shuffle=shuffle, | |
| n_processes=n_processes, | |
| n_prefetch=n_prefetch, | |
| shared_mem=shared_mem, | |
| maxtasksperchild=maxtasksperchild, | |
| ) | |
| def start_shuffle(self): | |
| """Starts shuffling (or reshuffles) the batches""" | |
| self.shuffle = True | |
| if int(chainer._version.__version__[0]) <= 4: | |
| self._order = np.random.permutation(len(self.dataset)) | |
| else: | |
| self.order_sampler = ShuffleOrderSampler() | |
| self._order = self.order_sampler(np.arange(len(self.dataset)), 0) | |
| self._set_prefetch_state() | |