Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union | |
__all__ = [ | |
"BatchSampler", | |
"RandomSampler", | |
"Sampler", | |
"SequentialSampler", | |
"SubsetRandomSampler", | |
"WeightedRandomSampler", | |
] | |
T_co = TypeVar('T_co', covariant=True) | |
class Sampler(Generic[T_co]): | |
r"""Base class for all Samplers. | |
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a | |
way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:`__len__` method | |
that returns the length of the returned iterators. | |
Args: | |
data_source (Dataset): This argument is not used and will be removed in 2.2.0. | |
You may still have custom implementation that utilizes it. | |
Example: | |
>>> # xdoctest: +SKIP | |
>>> class AccedingSequenceLengthSampler(Sampler[int]): | |
>>> def __init__(self, data: List[str]) -> None: | |
>>> self.data = data | |
>>> | |
>>> def __len__(self) -> int: | |
>>> return len(self.data) | |
>>> | |
>>> def __iter__(self) -> Iterator[int]: | |
>>> sizes = torch.tensor([len(x) for x in self.data]) | |
>>> yield from torch.argsort(sizes).tolist() | |
>>> | |
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): | |
>>> def __init__(self, data: List[str], batch_size: int) -> None: | |
>>> self.data = data | |
>>> self.batch_size = batch_size | |
>>> | |
>>> def __len__(self) -> int: | |
>>> return (len(self.data) + self.batch_size - 1) // self.batch_size | |
>>> | |
>>> def __iter__(self) -> Iterator[List[int]]: | |
>>> sizes = torch.tensor([len(x) for x in self.data]) | |
>>> for batch in torch.chunk(torch.argsort(sizes), len(self)): | |
>>> yield batch.tolist() | |
.. note:: The :meth:`__len__` method isn't strictly required by | |
:class:`~torch.utils.data.DataLoader`, but is expected in any | |
calculation involving the length of a :class:`~torch.utils.data.DataLoader`. | |
""" | |
def __init__(self, data_source: Optional[Sized] = None) -> None: | |
if data_source is not None: | |
import warnings | |
warnings.warn("`data_source` argument is not used and will be removed in 2.2.0." | |
"You may still have custom implementation that utilizes it.") | |
def __iter__(self) -> Iterator[T_co]: | |
raise NotImplementedError | |
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] | |
# | |
# Many times we have an abstract class representing a collection/iterable of | |
# data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally | |
# implementing a `__len__` method. In such cases, we must make sure to not | |
# provide a default implementation, because both straightforward default | |
# implementations have their issues: | |
# | |
# + `return NotImplemented`: | |
# Calling `len(subclass_instance)` raises: | |
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer | |
# | |
# + `raise NotImplementedError()`: | |
# This prevents triggering some fallback behavior. E.g., the built-in | |
# `list(X)` tries to call `len(X)` first, and executes a different code | |
# path if the method is not found or `NotImplemented` is returned, while | |
# raising a `NotImplementedError` will propagate and make the call fail | |
# where it could have used `__iter__` to complete the call. | |
# | |
# Thus, the only two sensible things to do are | |
# | |
# + **not** provide a default `__len__`. | |
# | |
# + raise a `TypeError` instead, which is what Python uses when users call | |
# a method that is not defined on an object. | |
# (@ssnl verifies that this works on at least Python 3.7.) | |
class SequentialSampler(Sampler[int]): | |
r"""Samples elements sequentially, always in the same order. | |
Args: | |
data_source (Dataset): dataset to sample from | |
""" | |
data_source: Sized | |
def __init__(self, data_source: Sized) -> None: | |
self.data_source = data_source | |
def __iter__(self) -> Iterator[int]: | |
return iter(range(len(self.data_source))) | |
def __len__(self) -> int: | |
return len(self.data_source) | |
class RandomSampler(Sampler[int]): | |
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. | |
If with replacement, then user can specify :attr:`num_samples` to draw. | |
Args: | |
data_source (Dataset): dataset to sample from | |
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` | |
num_samples (int): number of samples to draw, default=`len(dataset)`. | |
generator (Generator): Generator used in sampling. | |
""" | |
data_source: Sized | |
replacement: bool | |
def __init__(self, data_source: Sized, replacement: bool = False, | |
num_samples: Optional[int] = None, generator=None) -> None: | |
self.data_source = data_source | |
self.replacement = replacement | |
self._num_samples = num_samples | |
self.generator = generator | |
if not isinstance(self.replacement, bool): | |
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") | |
if not isinstance(self.num_samples, int) or self.num_samples <= 0: | |
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") | |
def num_samples(self) -> int: | |
# dataset size might change at runtime | |
if self._num_samples is None: | |
return len(self.data_source) | |
return self._num_samples | |
def __iter__(self) -> Iterator[int]: | |
n = len(self.data_source) | |
if self.generator is None: | |
seed = int(torch.empty((), dtype=torch.int64).random_().item()) | |
generator = torch.Generator() | |
generator.manual_seed(seed) | |
else: | |
generator = self.generator | |
if self.replacement: | |
for _ in range(self.num_samples // 32): | |
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() | |
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() | |
else: | |
for _ in range(self.num_samples // n): | |
yield from torch.randperm(n, generator=generator).tolist() | |
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] | |
def __len__(self) -> int: | |
return self.num_samples | |
class SubsetRandomSampler(Sampler[int]): | |
r"""Samples elements randomly from a given list of indices, without replacement. | |
Args: | |
indices (sequence): a sequence of indices | |
generator (Generator): Generator used in sampling. | |
""" | |
indices: Sequence[int] | |
def __init__(self, indices: Sequence[int], generator=None) -> None: | |
self.indices = indices | |
self.generator = generator | |
def __iter__(self) -> Iterator[int]: | |
for i in torch.randperm(len(self.indices), generator=self.generator): | |
yield self.indices[i] | |
def __len__(self) -> int: | |
return len(self.indices) | |
class WeightedRandomSampler(Sampler[int]): | |
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). | |
Args: | |
weights (sequence) : a sequence of weights, not necessary summing up to one | |
num_samples (int): number of samples to draw | |
replacement (bool): if ``True``, samples are drawn with replacement. | |
If not, they are drawn without replacement, which means that when a | |
sample index is drawn for a row, it cannot be drawn again for that row. | |
generator (Generator): Generator used in sampling. | |
Example: | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) | |
[4, 4, 1, 4, 5] | |
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) | |
[0, 1, 4, 3, 2] | |
""" | |
weights: Tensor | |
num_samples: int | |
replacement: bool | |
def __init__(self, weights: Sequence[float], num_samples: int, | |
replacement: bool = True, generator=None) -> None: | |
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ | |
num_samples <= 0: | |
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}") | |
if not isinstance(replacement, bool): | |
raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}") | |
weights_tensor = torch.as_tensor(weights, dtype=torch.double) | |
if len(weights_tensor.shape) != 1: | |
raise ValueError("weights should be a 1d sequence but given " | |
f"weights have shape {tuple(weights_tensor.shape)}") | |
self.weights = weights_tensor | |
self.num_samples = num_samples | |
self.replacement = replacement | |
self.generator = generator | |
def __iter__(self) -> Iterator[int]: | |
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) | |
yield from iter(rand_tensor.tolist()) | |
def __len__(self) -> int: | |
return self.num_samples | |
class BatchSampler(Sampler[List[int]]): | |
r"""Wraps another sampler to yield a mini-batch of indices. | |
Args: | |
sampler (Sampler or Iterable): Base sampler. Can be any iterable object | |
batch_size (int): Size of mini-batch. | |
drop_last (bool): If ``True``, the sampler will drop the last batch if | |
its size would be less than ``batch_size`` | |
Example: | |
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] | |
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) | |
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | |
""" | |
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: | |
# Since collections.abc.Iterable does not check for `__getitem__`, which | |
# is one way for an object to be an iterable, we don't do an `isinstance` | |
# check here. | |
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ | |
batch_size <= 0: | |
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") | |
if not isinstance(drop_last, bool): | |
raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") | |
self.sampler = sampler | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
def __iter__(self) -> Iterator[List[int]]: | |
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 | |
if self.drop_last: | |
sampler_iter = iter(self.sampler) | |
while True: | |
try: | |
batch = [next(sampler_iter) for _ in range(self.batch_size)] | |
yield batch | |
except StopIteration: | |
break | |
else: | |
batch = [0] * self.batch_size | |
idx_in_batch = 0 | |
for idx in self.sampler: | |
batch[idx_in_batch] = idx | |
idx_in_batch += 1 | |
if idx_in_batch == self.batch_size: | |
yield batch | |
idx_in_batch = 0 | |
batch = [0] * self.batch_size | |
if idx_in_batch > 0: | |
yield batch[:idx_in_batch] | |
def __len__(self) -> int: | |
# Can only be called if self.sampler has __len__ implemented | |
# We cannot enforce this condition, so we turn off typechecking for the | |
# implementation below. | |
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] | |
if self.drop_last: | |
return len(self.sampler) // self.batch_size # type: ignore[arg-type] | |
else: | |
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] | |