Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import numpy as np | |
import torch | |
from fairseq.data import FairseqDataset, data_utils | |
logger = logging.getLogger(__name__) | |
def collate( | |
samples, | |
pad_idx, | |
eos_idx, | |
left_pad_source=True, | |
left_pad_target=False, | |
input_feeding=True, | |
pad_to_length=None, | |
pad_to_multiple=1, | |
): | |
if len(samples) == 0: | |
return {} | |
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): | |
return data_utils.collate_tokens( | |
[s[key] for s in samples], | |
pad_idx, | |
eos_idx, | |
left_pad, | |
move_eos_to_beginning, | |
pad_to_length=pad_to_length, | |
pad_to_multiple=pad_to_multiple, | |
) | |
def check_alignment(alignment, src_len, tgt_len): | |
if alignment is None or len(alignment) == 0: | |
return False | |
if ( | |
alignment[:, 0].max().item() >= src_len - 1 | |
or alignment[:, 1].max().item() >= tgt_len - 1 | |
): | |
logger.warning("alignment size mismatch found, skipping alignment!") | |
return False | |
return True | |
def compute_alignment_weights(alignments): | |
""" | |
Given a tensor of shape [:, 2] containing the source-target indices | |
corresponding to the alignments, a weight vector containing the | |
inverse frequency of each target index is computed. | |
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then | |
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target | |
index 3 is repeated twice) | |
""" | |
align_tgt = alignments[:, 1] | |
_, align_tgt_i, align_tgt_c = torch.unique( | |
align_tgt, return_inverse=True, return_counts=True | |
) | |
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] | |
return 1.0 / align_weights.float() | |
id = torch.LongTensor([s["id"] for s in samples]) | |
src_tokens = merge( | |
"source", | |
left_pad=left_pad_source, | |
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
) | |
# sort by descending source length | |
src_lengths = torch.LongTensor( | |
[s["source"].ne(pad_idx).long().sum() for s in samples] | |
) | |
src_lengths, sort_order = src_lengths.sort(descending=True) | |
id = id.index_select(0, sort_order) | |
src_tokens = src_tokens.index_select(0, sort_order) | |
prev_output_tokens = None | |
target = None | |
if samples[0].get("target", None) is not None: | |
target = merge( | |
"target", | |
left_pad=left_pad_target, | |
pad_to_length=pad_to_length["target"] | |
if pad_to_length is not None | |
else None, | |
) | |
target = target.index_select(0, sort_order) | |
tgt_lengths = torch.LongTensor( | |
[s["target"].ne(pad_idx).long().sum() for s in samples] | |
).index_select(0, sort_order) | |
ntokens = tgt_lengths.sum().item() | |
if samples[0].get("prev_output_tokens", None) is not None: | |
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) | |
elif input_feeding: | |
# we create a shifted version of targets for feeding the | |
# previous output token(s) into the next decoder step | |
prev_output_tokens = merge( | |
"target", | |
left_pad=left_pad_target, | |
move_eos_to_beginning=True, | |
pad_to_length=pad_to_length["target"] | |
if pad_to_length is not None | |
else None, | |
) | |
else: | |
ntokens = src_lengths.sum().item() | |
batch = { | |
"id": id, | |
"nsentences": len(samples), | |
"ntokens": ntokens, | |
"net_input": { | |
"src_tokens": src_tokens, | |
"src_lengths": src_lengths, | |
}, | |
"target": target, | |
} | |
if prev_output_tokens is not None: | |
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( | |
0, sort_order | |
) | |
if samples[0].get("alignment", None) is not None: | |
bsz, tgt_sz = batch["target"].shape | |
src_sz = batch["net_input"]["src_tokens"].shape[1] | |
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) | |
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz | |
if left_pad_source: | |
offsets[:, 0] += src_sz - src_lengths | |
if left_pad_target: | |
offsets[:, 1] += tgt_sz - tgt_lengths | |
alignments = [ | |
alignment + offset | |
for align_idx, offset, src_len, tgt_len in zip( | |
sort_order, offsets, src_lengths, tgt_lengths | |
) | |
for alignment in [samples[align_idx]["alignment"].view(-1, 2)] | |
if check_alignment(alignment, src_len, tgt_len) | |
] | |
if len(alignments) > 0: | |
alignments = torch.cat(alignments, dim=0) | |
align_weights = compute_alignment_weights(alignments) | |
batch["alignments"] = alignments | |
batch["align_weights"] = align_weights | |
if samples[0].get("constraints", None) is not None: | |
# Collate the packed constraints across the samples, padding to | |
# the length of the longest sample. | |
lens = [sample.get("constraints").size(0) for sample in samples] | |
max_len = max(lens) | |
constraints = torch.zeros((len(samples), max(lens))).long() | |
for i, sample in enumerate(samples): | |
constraints[i, 0 : lens[i]] = samples[i].get("constraints") | |
batch["constraints"] = constraints.index_select(0, sort_order) | |
return batch | |
class LanguagePairDataset(FairseqDataset): | |
""" | |
A pair of torch.utils.data.Datasets. | |
Args: | |
src (torch.utils.data.Dataset): source dataset to wrap | |
src_sizes (List[int]): source sentence lengths | |
src_dict (~fairseq.data.Dictionary): source vocabulary | |
tgt (torch.utils.data.Dataset, optional): target dataset to wrap | |
tgt_sizes (List[int], optional): target sentence lengths | |
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary | |
left_pad_source (bool, optional): pad source tensors on the left side | |
(default: True). | |
left_pad_target (bool, optional): pad target tensors on the left side | |
(default: False). | |
shuffle (bool, optional): shuffle dataset elements before batching | |
(default: True). | |
input_feeding (bool, optional): create a shifted version of the targets | |
to be passed into the model for teacher forcing (default: True). | |
remove_eos_from_source (bool, optional): if set, removes eos from end | |
of source if it's present (default: False). | |
append_eos_to_target (bool, optional): if set, appends eos to end of | |
target if it's absent (default: False). | |
align_dataset (torch.utils.data.Dataset, optional): dataset | |
containing alignments. | |
constraints (Tensor, optional): 2d tensor with a concatenated, zero- | |
delimited list of constraints for each sentence. | |
append_bos (bool, optional): if set, appends bos to the beginning of | |
source/target sentence. | |
num_buckets (int, optional): if set to a value greater than 0, then | |
batches will be bucketed into the given number of batch shapes. | |
src_lang_id (int, optional): source language ID, if set, the collated batch | |
will contain a field 'src_lang_id' in 'net_input' which indicates the | |
source language of the samples. | |
tgt_lang_id (int, optional): target language ID, if set, the collated batch | |
will contain a field 'tgt_lang_id' which indicates the target language | |
of the samples. | |
""" | |
def __init__( | |
self, | |
src, | |
src_sizes, | |
src_dict, | |
tgt=None, | |
tgt_sizes=None, | |
tgt_dict=None, | |
left_pad_source=True, | |
left_pad_target=False, | |
shuffle=True, | |
input_feeding=True, | |
remove_eos_from_source=False, | |
append_eos_to_target=False, | |
align_dataset=None, | |
constraints=None, | |
append_bos=False, | |
eos=None, | |
num_buckets=0, | |
src_lang_id=None, | |
tgt_lang_id=None, | |
pad_to_multiple=1, | |
): | |
if tgt_dict is not None: | |
assert src_dict.pad() == tgt_dict.pad() | |
assert src_dict.eos() == tgt_dict.eos() | |
assert src_dict.unk() == tgt_dict.unk() | |
if tgt is not None: | |
assert len(src) == len( | |
tgt | |
), "Source and target must contain the same number of examples" | |
self.src = src | |
self.tgt = tgt | |
self.src_sizes = np.array(src_sizes) | |
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None | |
self.sizes = ( | |
np.vstack((self.src_sizes, self.tgt_sizes)).T | |
if self.tgt_sizes is not None | |
else self.src_sizes | |
) | |
self.src_dict = src_dict | |
self.tgt_dict = tgt_dict | |
self.left_pad_source = left_pad_source | |
self.left_pad_target = left_pad_target | |
self.shuffle = shuffle | |
self.input_feeding = input_feeding | |
self.remove_eos_from_source = remove_eos_from_source | |
self.append_eos_to_target = append_eos_to_target | |
self.align_dataset = align_dataset | |
if self.align_dataset is not None: | |
assert ( | |
self.tgt_sizes is not None | |
), "Both source and target needed when alignments are provided" | |
self.constraints = constraints | |
self.append_bos = append_bos | |
self.eos = eos if eos is not None else src_dict.eos() | |
self.src_lang_id = src_lang_id | |
self.tgt_lang_id = tgt_lang_id | |
if num_buckets > 0: | |
from fairseq.data import BucketPadLengthDataset | |
self.src = BucketPadLengthDataset( | |
self.src, | |
sizes=self.src_sizes, | |
num_buckets=num_buckets, | |
pad_idx=self.src_dict.pad(), | |
left_pad=self.left_pad_source, | |
) | |
self.src_sizes = self.src.sizes | |
logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) | |
if self.tgt is not None: | |
self.tgt = BucketPadLengthDataset( | |
self.tgt, | |
sizes=self.tgt_sizes, | |
num_buckets=num_buckets, | |
pad_idx=self.tgt_dict.pad(), | |
left_pad=self.left_pad_target, | |
) | |
self.tgt_sizes = self.tgt.sizes | |
logger.info( | |
"bucketing target lengths: {}".format(list(self.tgt.buckets)) | |
) | |
# determine bucket sizes using self.num_tokens, which will return | |
# the padded lengths (thanks to BucketPadLengthDataset) | |
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) | |
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) | |
self.buckets = [ | |
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) | |
] | |
else: | |
self.buckets = None | |
self.pad_to_multiple = pad_to_multiple | |
def get_batch_shapes(self): | |
return self.buckets | |
def __getitem__(self, index): | |
tgt_item = self.tgt[index] if self.tgt is not None else None | |
src_item = self.src[index] | |
# Append EOS to end of tgt sentence if it does not have an EOS and remove | |
# EOS from end of src sentence if it exists. This is useful when we use | |
# use existing datasets for opposite directions i.e., when we want to | |
# use tgt_dataset as src_dataset and vice versa | |
if self.append_eos_to_target: | |
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() | |
if self.tgt and self.tgt[index][-1] != eos: | |
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) | |
if self.append_bos: | |
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() | |
if self.tgt and self.tgt[index][0] != bos: | |
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) | |
bos = self.src_dict.bos() | |
if self.src[index][0] != bos: | |
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) | |
if self.remove_eos_from_source: | |
eos = self.src_dict.eos() | |
if self.src[index][-1] == eos: | |
src_item = self.src[index][:-1] | |
example = { | |
"id": index, | |
"source": src_item, | |
"target": tgt_item, | |
} | |
if self.align_dataset is not None: | |
example["alignment"] = self.align_dataset[index] | |
if self.constraints is not None: | |
example["constraints"] = self.constraints[index] | |
return example | |
def __len__(self): | |
return len(self.src) | |
def collater(self, samples, pad_to_length=None): | |
"""Merge a list of samples to form a mini-batch. | |
Args: | |
samples (List[dict]): samples to collate | |
pad_to_length (dict, optional): a dictionary of | |
{'source': source_pad_to_length, 'target': target_pad_to_length} | |
to indicate the max length to pad to in source and target respectively. | |
Returns: | |
dict: a mini-batch with the following keys: | |
- `id` (LongTensor): example IDs in the original input order | |
- `ntokens` (int): total number of tokens in the batch | |
- `net_input` (dict): the input to the Model, containing keys: | |
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in | |
the source sentence of shape `(bsz, src_len)`. Padding will | |
appear on the left if *left_pad_source* is ``True``. | |
- `src_lengths` (LongTensor): 1D Tensor of the unpadded | |
lengths of each source sentence of shape `(bsz)` | |
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of | |
tokens in the target sentence, shifted right by one | |
position for teacher forcing, of shape `(bsz, tgt_len)`. | |
This key will not be present if *input_feeding* is | |
``False``. Padding will appear on the left if | |
*left_pad_target* is ``True``. | |
- `src_lang_id` (LongTensor): a long Tensor which contains source | |
language IDs of each sample in the batch | |
- `target` (LongTensor): a padded 2D Tensor of tokens in the | |
target sentence of shape `(bsz, tgt_len)`. Padding will appear | |
on the left if *left_pad_target* is ``True``. | |
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language | |
IDs of each sample in the batch | |
""" | |
res = collate( | |
samples, | |
pad_idx=self.src_dict.pad(), | |
eos_idx=self.eos, | |
left_pad_source=self.left_pad_source, | |
left_pad_target=self.left_pad_target, | |
input_feeding=self.input_feeding, | |
pad_to_length=pad_to_length, | |
pad_to_multiple=self.pad_to_multiple, | |
) | |
if self.src_lang_id is not None or self.tgt_lang_id is not None: | |
src_tokens = res["net_input"]["src_tokens"] | |
bsz = src_tokens.size(0) | |
if self.src_lang_id is not None: | |
res["net_input"]["src_lang_id"] = ( | |
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) | |
) | |
if self.tgt_lang_id is not None: | |
res["tgt_lang_id"] = ( | |
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) | |
) | |
return res | |
def num_tokens(self, index): | |
"""Return the number of tokens in a sample. This value is used to | |
enforce ``--max-tokens`` during batching.""" | |
return max( | |
self.src_sizes[index], | |
self.tgt_sizes[index] if self.tgt_sizes is not None else 0, | |
) | |
def num_tokens_vec(self, indices): | |
"""Return the number of tokens for a set of positions defined by indices. | |
This value is used to enforce ``--max-tokens`` during batching.""" | |
sizes = self.src_sizes[indices] | |
if self.tgt_sizes is not None: | |
sizes = np.maximum(sizes, self.tgt_sizes[indices]) | |
return sizes | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
return ( | |
self.src_sizes[index], | |
self.tgt_sizes[index] if self.tgt_sizes is not None else 0, | |
) | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
if self.shuffle: | |
indices = np.random.permutation(len(self)).astype(np.int64) | |
else: | |
indices = np.arange(len(self), dtype=np.int64) | |
if self.buckets is None: | |
# sort by target length, then source length | |
if self.tgt_sizes is not None: | |
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] | |
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] | |
else: | |
# sort by bucketed_num_tokens, which is: | |
# max(padded_src_len, padded_tgt_len) | |
return indices[ | |
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") | |
] | |
def supports_prefetch(self): | |
return getattr(self.src, "supports_prefetch", False) and ( | |
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None | |
) | |
def prefetch(self, indices): | |
self.src.prefetch(indices) | |
if self.tgt is not None: | |
self.tgt.prefetch(indices) | |
if self.align_dataset is not None: | |
self.align_dataset.prefetch(indices) | |
def filter_indices_by_size(self, indices, max_sizes): | |
"""Filter a list of sample indices. Remove those that are longer | |
than specified in max_sizes. | |
Args: | |
indices (np.array): original array of sample indices | |
max_sizes (int or list[int] or tuple[int]): max sample size, | |
can be defined separately for src and tgt (then list or tuple) | |
Returns: | |
np.array: filtered sample array | |
list: list of removed indices | |
""" | |
return data_utils.filter_paired_dataset_indices_by_size( | |
self.src_sizes, | |
self.tgt_sizes, | |
indices, | |
max_sizes, | |
) | |