Spaces:
Sleeping
Sleeping
import bisect | |
import itertools | |
import math | |
import warnings | |
from typing import ( | |
cast, | |
Dict, | |
Generic, | |
Iterable, | |
List, | |
Optional, | |
Sequence, | |
Tuple, | |
TypeVar, | |
Union, | |
) | |
# No 'default_generator' in torch/__init__.pyi | |
from torch import default_generator, randperm | |
from ... import Generator, Tensor | |
__all__ = [ | |
"Dataset", | |
"IterableDataset", | |
"TensorDataset", | |
"StackDataset", | |
"ConcatDataset", | |
"ChainDataset", | |
"Subset", | |
"random_split", | |
] | |
T_co = TypeVar("T_co", covariant=True) | |
T = TypeVar("T") | |
T_dict = Dict[str, T_co] | |
T_tuple = Tuple[T_co, ...] | |
T_stack = TypeVar("T_stack", T_tuple, T_dict) | |
class Dataset(Generic[T_co]): | |
r"""An abstract class representing a :class:`Dataset`. | |
All datasets that represent a map from keys to data samples should subclass | |
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a | |
data sample for a given key. Subclasses could also optionally overwrite | |
:meth:`__len__`, which is expected to return the size of the dataset by many | |
:class:`~torch.utils.data.Sampler` implementations and the default options | |
of :class:`~torch.utils.data.DataLoader`. Subclasses could also | |
optionally implement :meth:`__getitems__`, for speedup batched samples | |
loading. This method accepts list of indices of samples of batch and returns | |
list of samples. | |
.. note:: | |
:class:`~torch.utils.data.DataLoader` by default constructs an index | |
sampler that yields integral indices. To make it work with a map-style | |
dataset with non-integral indices/keys, a custom sampler must be provided. | |
""" | |
def __getitem__(self, index) -> T_co: | |
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") | |
# def __getitems__(self, indices: List) -> List[T_co]: | |
# Not implemented to prevent false-positives in fetcher check in | |
# torch.utils.data._utils.fetch._MapDatasetFetcher | |
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]": | |
return ConcatDataset([self, other]) | |
# No `def __len__(self)` default? | |
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] | |
# in pytorch/torch/utils/data/sampler.py | |
class IterableDataset(Dataset[T_co], Iterable[T_co]): | |
r"""An iterable Dataset. | |
All datasets that represent an iterable of data samples should subclass it. | |
Such form of datasets is particularly useful when data come from a stream. | |
All subclasses should overwrite :meth:`__iter__`, which would return an | |
iterator of samples in this dataset. | |
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each | |
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` | |
iterator. When :attr:`num_workers > 0`, each worker process will have a | |
different copy of the dataset object, so it is often desired to configure | |
each copy independently to avoid having duplicate data returned from the | |
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker | |
process, returns information about the worker. It can be used in either the | |
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's | |
:attr:`worker_init_fn` option to modify each copy's behavior. | |
Example 1: splitting workload across all workers in :meth:`__iter__`:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) | |
>>> # xdoctest: +SKIP("Fails on MacOS12") | |
>>> class MyIterableDataset(torch.utils.data.IterableDataset): | |
... def __init__(self, start, end): | |
... super(MyIterableDataset).__init__() | |
... assert end > start, "this example code only works with end >= start" | |
... self.start = start | |
... self.end = end | |
... | |
... def __iter__(self): | |
... worker_info = torch.utils.data.get_worker_info() | |
... if worker_info is None: # single-process data loading, return the full iterator | |
... iter_start = self.start | |
... iter_end = self.end | |
... else: # in a worker process | |
... # split workload | |
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) | |
... worker_id = worker_info.id | |
... iter_start = self.start + worker_id * per_worker | |
... iter_end = min(iter_start + per_worker, self.end) | |
... return iter(range(iter_start, iter_end)) | |
... | |
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. | |
>>> ds = MyIterableDataset(start=3, end=7) | |
>>> # Single-process loading | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) | |
[tensor([3]), tensor([4]), tensor([5]), tensor([6])] | |
>>> # xdoctest: +REQUIRES(POSIX) | |
>>> # Mult-process loading with two worker processes | |
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. | |
>>> # xdoctest: +IGNORE_WANT("non deterministic") | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) | |
[tensor([3]), tensor([5]), tensor([4]), tensor([6])] | |
>>> # With even more workers | |
>>> # xdoctest: +IGNORE_WANT("non deterministic") | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) | |
[tensor([3]), tensor([5]), tensor([4]), tensor([6])] | |
Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) | |
>>> class MyIterableDataset(torch.utils.data.IterableDataset): | |
... def __init__(self, start, end): | |
... super(MyIterableDataset).__init__() | |
... assert end > start, "this example code only works with end >= start" | |
... self.start = start | |
... self.end = end | |
... | |
... def __iter__(self): | |
... return iter(range(self.start, self.end)) | |
... | |
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. | |
>>> ds = MyIterableDataset(start=3, end=7) | |
>>> # Single-process loading | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) | |
[3, 4, 5, 6] | |
>>> | |
>>> # Directly doing multi-process loading yields duplicate data | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) | |
[3, 3, 4, 4, 5, 5, 6, 6] | |
>>> # Define a `worker_init_fn` that configures each dataset copy differently | |
>>> def worker_init_fn(worker_id): | |
... worker_info = torch.utils.data.get_worker_info() | |
... dataset = worker_info.dataset # the dataset copy in this worker process | |
... overall_start = dataset.start | |
... overall_end = dataset.end | |
... # configure the dataset to only process the split workload | |
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) | |
... worker_id = worker_info.id | |
... dataset.start = overall_start + worker_id * per_worker | |
... dataset.end = min(dataset.start + per_worker, overall_end) | |
... | |
>>> # Mult-process loading with the custom `worker_init_fn` | |
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) | |
[3, 5, 4, 6] | |
>>> # With even more workers | |
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) | |
[3, 4, 5, 6] | |
""" | |
def __add__(self, other: Dataset[T_co]): | |
return ChainDataset([self, other]) | |
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed. | |
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] | |
class TensorDataset(Dataset[Tuple[Tensor, ...]]): | |
r"""Dataset wrapping tensors. | |
Each sample will be retrieved by indexing tensors along the first dimension. | |
Args: | |
*tensors (Tensor): tensors that have the same size of the first dimension. | |
""" | |
tensors: Tuple[Tensor, ...] | |
def __init__(self, *tensors: Tensor) -> None: | |
assert all( | |
tensors[0].size(0) == tensor.size(0) for tensor in tensors | |
), "Size mismatch between tensors" | |
self.tensors = tensors | |
def __getitem__(self, index): | |
return tuple(tensor[index] for tensor in self.tensors) | |
def __len__(self): | |
return self.tensors[0].size(0) | |
class StackDataset(Dataset[T_stack]): | |
r"""Dataset as a stacking of multiple datasets. | |
This class is useful to assemble different parts of complex input data, given as datasets. | |
Example: | |
>>> # xdoctest: +SKIP | |
>>> images = ImageDataset() | |
>>> texts = TextDataset() | |
>>> tuple_stack = StackDataset(images, texts) | |
>>> tuple_stack[0] == (images[0], texts[0]) | |
>>> dict_stack = StackDataset(image=images, text=texts) | |
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]} | |
Args: | |
*args (Dataset): Datasets for stacking returned as tuple. | |
**kwargs (Dataset): Datasets for stacking returned as dict. | |
""" | |
datasets: Union[tuple, dict] | |
def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None: | |
if args: | |
if kwargs: | |
raise ValueError( | |
"Supported either ``tuple``- (via ``args``) or" | |
"``dict``- (via ``kwargs``) like input/output, but both types are given." | |
) | |
self._length = len(args[0]) # type: ignore[arg-type] | |
if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] | |
raise ValueError("Size mismatch between datasets") | |
self.datasets = args | |
elif kwargs: | |
tmp = list(kwargs.values()) | |
self._length = len(tmp[0]) # type: ignore[arg-type] | |
if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] | |
raise ValueError("Size mismatch between datasets") | |
self.datasets = kwargs | |
else: | |
raise ValueError("At least one dataset should be passed") | |
def __getitem__(self, index): | |
if isinstance(self.datasets, dict): | |
return {k: dataset[index] for k, dataset in self.datasets.items()} | |
return tuple(dataset[index] for dataset in self.datasets) | |
def __getitems__(self, indices: list): | |
# add batched sampling support when parent datasets supports it. | |
if isinstance(self.datasets, dict): | |
dict_batch: List[T_dict] = [{} for _ in indices] | |
for k, dataset in self.datasets.items(): | |
if callable(getattr(dataset, "__getitems__", None)): | |
items = dataset.__getitems__(indices) # type: ignore[attr-defined] | |
if len(items) != len(indices): | |
raise ValueError( | |
"Nested dataset's output size mismatch." | |
f" Expected {len(indices)}, got {len(items)}" | |
) | |
for data, d_sample in zip(items, dict_batch): | |
d_sample[k] = data | |
else: | |
for idx, d_sample in zip(indices, dict_batch): | |
d_sample[k] = dataset[idx] | |
return dict_batch | |
# tuple data | |
list_batch: List[list] = [[] for _ in indices] | |
for dataset in self.datasets: | |
if callable(getattr(dataset, "__getitems__", None)): | |
items = dataset.__getitems__(indices) # type: ignore[attr-defined] | |
if len(items) != len(indices): | |
raise ValueError( | |
"Nested dataset's output size mismatch." | |
f" Expected {len(indices)}, got {len(items)}" | |
) | |
for data, t_sample in zip(items, list_batch): | |
t_sample.append(data) | |
else: | |
for idx, t_sample in zip(indices, list_batch): | |
t_sample.append(dataset[idx]) | |
tuple_batch: List[T_tuple] = [tuple(sample) for sample in list_batch] | |
return tuple_batch | |
def __len__(self): | |
return self._length | |
class ConcatDataset(Dataset[T_co]): | |
r"""Dataset as a concatenation of multiple datasets. | |
This class is useful to assemble different existing datasets. | |
Args: | |
datasets (sequence): List of datasets to be concatenated | |
""" | |
datasets: List[Dataset[T_co]] | |
cumulative_sizes: List[int] | |
def cumsum(sequence): | |
r, s = [], 0 | |
for e in sequence: | |
l = len(e) | |
r.append(l + s) | |
s += l | |
return r | |
def __init__(self, datasets: Iterable[Dataset]) -> None: | |
super().__init__() | |
self.datasets = list(datasets) | |
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] | |
for d in self.datasets: | |
assert not isinstance( | |
d, IterableDataset | |
), "ConcatDataset does not support IterableDataset" | |
self.cumulative_sizes = self.cumsum(self.datasets) | |
def __len__(self): | |
return self.cumulative_sizes[-1] | |
def __getitem__(self, idx): | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError( | |
"absolute value of index should not exceed dataset length" | |
) | |
idx = len(self) + idx | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
return self.datasets[dataset_idx][sample_idx] | |
def cummulative_sizes(self): | |
warnings.warn( | |
"cummulative_sizes attribute is renamed to " "cumulative_sizes", | |
DeprecationWarning, | |
stacklevel=2, | |
) | |
return self.cumulative_sizes | |
class ChainDataset(IterableDataset): | |
r"""Dataset for chaining multiple :class:`IterableDataset` s. | |
This class is useful to assemble different existing dataset streams. The | |
chaining operation is done on-the-fly, so concatenating large-scale | |
datasets with this class will be efficient. | |
Args: | |
datasets (iterable of IterableDataset): datasets to be chained together | |
""" | |
def __init__(self, datasets: Iterable[Dataset]) -> None: | |
super().__init__() | |
self.datasets = datasets | |
def __iter__(self): | |
for d in self.datasets: | |
assert isinstance( | |
d, IterableDataset | |
), "ChainDataset only supports IterableDataset" | |
yield from d | |
def __len__(self): | |
total = 0 | |
for d in self.datasets: | |
assert isinstance( | |
d, IterableDataset | |
), "ChainDataset only supports IterableDataset" | |
total += len(d) # type: ignore[arg-type] | |
return total | |
class Subset(Dataset[T_co]): | |
r""" | |
Subset of a dataset at specified indices. | |
Args: | |
dataset (Dataset): The whole Dataset | |
indices (sequence): Indices in the whole set selected for subset | |
""" | |
dataset: Dataset[T_co] | |
indices: Sequence[int] | |
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: | |
self.dataset = dataset | |
self.indices = indices | |
def __getitem__(self, idx): | |
if isinstance(idx, list): | |
return self.dataset[[self.indices[i] for i in idx]] | |
return self.dataset[self.indices[idx]] | |
def __getitems__(self, indices: List[int]) -> List[T_co]: | |
# add batched sampling support when parent dataset supports it. | |
# see torch.utils.data._utils.fetch._MapDatasetFetcher | |
if callable(getattr(self.dataset, "__getitems__", None)): | |
return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] | |
else: | |
return [self.dataset[self.indices[idx]] for idx in indices] | |
def __len__(self): | |
return len(self.indices) | |
def random_split( | |
dataset: Dataset[T], | |
lengths: Sequence[Union[int, float]], | |
generator: Optional[Generator] = default_generator, | |
) -> List[Subset[T]]: | |
r""" | |
Randomly split a dataset into non-overlapping new datasets of given lengths. | |
If a list of fractions that sum up to 1 is given, | |
the lengths will be computed automatically as | |
floor(frac * len(dataset)) for each fraction provided. | |
After computing the lengths, if there are any remainders, 1 count will be | |
distributed in round-robin fashion to the lengths | |
until there are no remainders left. | |
Optionally fix the generator for reproducible results, e.g.: | |
Example: | |
>>> # xdoctest: +SKIP | |
>>> generator1 = torch.Generator().manual_seed(42) | |
>>> generator2 = torch.Generator().manual_seed(42) | |
>>> random_split(range(10), [3, 7], generator=generator1) | |
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) | |
Args: | |
dataset (Dataset): Dataset to be split | |
lengths (sequence): lengths or fractions of splits to be produced | |
generator (Generator): Generator used for the random permutation. | |
""" | |
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: | |
subset_lengths: List[int] = [] | |
for i, frac in enumerate(lengths): | |
if frac < 0 or frac > 1: | |
raise ValueError(f"Fraction at index {i} is not between 0 and 1") | |
n_items_in_split = int( | |
math.floor(len(dataset) * frac) # type: ignore[arg-type] | |
) | |
subset_lengths.append(n_items_in_split) | |
remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] | |
# add 1 to all the lengths in round-robin fashion until the remainder is 0 | |
for i in range(remainder): | |
idx_to_add_at = i % len(subset_lengths) | |
subset_lengths[idx_to_add_at] += 1 | |
lengths = subset_lengths | |
for i, length in enumerate(lengths): | |
if length == 0: | |
warnings.warn( | |
f"Length of split at index {i} is 0. " | |
f"This might result in an empty dataset." | |
) | |
# Cannot verify that dataset is Sized | |
if sum(lengths) != len(dataset): # type: ignore[arg-type] | |
raise ValueError( | |
"Sum of input lengths does not equal the length of the input dataset!" | |
) | |
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] | |
lengths = cast(Sequence[int], lengths) | |
return [ | |
Subset(dataset, indices[offset - length : offset]) | |
for offset, length in zip(itertools.accumulate(lengths), lengths) | |
] | |