Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import numpy as np | |
import torch.utils.data | |
from fairseq.data import data_utils | |
logger = logging.getLogger(__name__) | |
class EpochListening: | |
"""Mixin for receiving updates whenever the epoch increments.""" | |
def can_reuse_epoch_itr_across_epochs(self): | |
""" | |
Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for | |
this dataset across epochs. | |
This needs to return ``False`` if the sample sizes can change across | |
epochs, in which case we may need to regenerate batches at each epoch. | |
If your dataset relies in ``set_epoch`` then you should consider setting | |
this to ``False``. | |
""" | |
return True | |
def set_epoch(self, epoch): | |
"""Will receive the updated epoch number at the beginning of the epoch.""" | |
pass | |
class FairseqDataset(torch.utils.data.Dataset, EpochListening): | |
"""A dataset that provides helpers for batching.""" | |
def __getitem__(self, index): | |
raise NotImplementedError | |
def __len__(self): | |
raise NotImplementedError | |
def collater(self, samples): | |
"""Merge a list of samples to form a mini-batch. | |
Args: | |
samples (List[dict]): samples to collate | |
Returns: | |
dict: a mini-batch suitable for forwarding with a Model | |
""" | |
raise NotImplementedError | |
def num_tokens(self, index): | |
"""Return the number of tokens in a sample. This value is used to | |
enforce ``--max-tokens`` during batching.""" | |
raise NotImplementedError | |
def num_tokens_vec(self, indices): | |
"""Return the number of tokens for a set of positions defined by indices. | |
This value is used to enforce ``--max-tokens`` during batching.""" | |
raise NotImplementedError | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
raise NotImplementedError | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
return np.arange(len(self), dtype=np.int64) | |
def supports_prefetch(self): | |
"""Whether this dataset supports prefetching.""" | |
return False | |
def attr(self, attr: str, index: int): | |
return getattr(self, attr, None) | |
def prefetch(self, indices): | |
"""Prefetch the data required for this epoch.""" | |
raise NotImplementedError | |
def get_batch_shapes(self): | |
""" | |
Return a list of valid batch shapes, for example:: | |
[(8, 512), (16, 256), (32, 128)] | |
The first dimension of each tuple is the batch size and can be ``None`` | |
to automatically infer the max batch size based on ``--max-tokens``. | |
The second dimension of each tuple is the max supported length as given | |
by :func:`fairseq.data.FairseqDataset.num_tokens`. | |
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` | |
to restrict batch shapes. This is useful on TPUs to avoid too many | |
dynamic shapes (and recompilations). | |
""" | |
return None | |
def batch_by_size( | |
self, | |
indices, | |
max_tokens=None, | |
max_sentences=None, | |
required_batch_size_multiple=1, | |
): | |
""" | |
Given an ordered set of indices, return batches according to | |
*max_tokens*, *max_sentences* and *required_batch_size_multiple*. | |
""" | |
from fairseq.data import data_utils | |
fixed_shapes = self.get_batch_shapes() | |
if fixed_shapes is not None: | |
def adjust_bsz(bsz, num_tokens): | |
if bsz is None: | |
assert max_tokens is not None, "Must specify --max-tokens" | |
bsz = max_tokens // num_tokens | |
if max_sentences is not None: | |
bsz = min(bsz, max_sentences) | |
elif ( | |
bsz >= required_batch_size_multiple | |
and bsz % required_batch_size_multiple != 0 | |
): | |
bsz -= bsz % required_batch_size_multiple | |
return bsz | |
fixed_shapes = np.array( | |
[ | |
[adjust_bsz(bsz, num_tokens), num_tokens] | |
for (bsz, num_tokens) in fixed_shapes | |
] | |
) | |
try: | |
num_tokens_vec = self.num_tokens_vec(indices).astype("int64") | |
except NotImplementedError: | |
num_tokens_vec = None | |
return data_utils.batch_by_size( | |
indices, | |
num_tokens_fn=self.num_tokens, | |
num_tokens_vec=num_tokens_vec, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
fixed_shapes=fixed_shapes, | |
) | |
def filter_indices_by_size(self, indices, max_sizes): | |
""" | |
Filter a list of sample indices. Remove those that are longer than | |
specified in *max_sizes*. | |
WARNING: don't update, override method in child classes | |
Args: | |
indices (np.array): original array of sample indices | |
max_sizes (int or list[int] or tuple[int]): max sample size, | |
can be defined separately for src and tgt (then list or tuple) | |
Returns: | |
np.array: filtered sample array | |
list: list of removed indices | |
""" | |
if isinstance(max_sizes, float) or isinstance(max_sizes, int): | |
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): | |
ignored = indices[self.sizes[indices] > max_sizes].tolist() | |
indices = indices[self.sizes[indices] <= max_sizes] | |
elif ( | |
hasattr(self, "sizes") | |
and isinstance(self.sizes, list) | |
and len(self.sizes) == 1 | |
): | |
ignored = indices[self.sizes[0][indices] > max_sizes].tolist() | |
indices = indices[self.sizes[0][indices] <= max_sizes] | |
else: | |
indices, ignored = data_utils._filter_by_size_dynamic( | |
indices, self.size, max_sizes | |
) | |
else: | |
indices, ignored = data_utils._filter_by_size_dynamic( | |
indices, self.size, max_sizes | |
) | |
return indices, ignored | |
def supports_fetch_outside_dataloader(self): | |
"""Whether this dataset supports fetching outside the workers of the dataloader.""" | |
return True | |
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): | |
""" | |
For datasets that need to be read sequentially, usually because the data is | |
being streamed or otherwise can't be manipulated on a single machine. | |
""" | |
def __iter__(self): | |
raise NotImplementedError | |