Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import numpy as np | |
import torch | |
from fairseq.data import Dictionary, FairseqDataset | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
logger = logging.getLogger(__name__) | |
class DummyMTTask(LegacyFairseqTask): | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
parser.add_argument("--dict-size", default=49996, type=int) | |
parser.add_argument("--dataset-size", default=100000, type=int) | |
parser.add_argument("--src-len", default=30, type=int) | |
parser.add_argument("--tgt-len", default=30, type=int) | |
def __init__(self, args, dictionary): | |
super().__init__(args) | |
self.dictionary = dictionary | |
self.seed = args.seed | |
dictionary.pad_to_multiple_(8) # often faster if divisible by 8 | |
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1 | |
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1 | |
def setup_task(cls, args, **kwargs): | |
"""Setup the task. """ | |
dictionary = Dictionary() | |
for i in range(args.dict_size): | |
dictionary.add_symbol("word{}".format(i)) | |
logger.info("dictionary: {} types".format(len(dictionary))) | |
args.max_source_positions = args.src_len + dictionary.pad() + 2 | |
args.max_target_positions = args.tgt_len + dictionary.pad() + 2 | |
return cls(args, dictionary) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
item_size = max(self.args.src_len, self.args.tgt_len) | |
if self.args.batch_size is not None: | |
bsz = self.args.batch_size | |
else: | |
bsz = max(1, self.args.max_tokens // item_size) | |
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) | |
self.datasets[split] = DummyDataset( | |
{ | |
"id": 1, | |
"net_input": { | |
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), | |
"src_lengths": torch.full( | |
(bsz,), self.args.src_len, dtype=torch.long | |
), | |
"prev_output_tokens": tgt.clone(), | |
}, | |
"target": tgt, | |
"nsentences": bsz, | |
"ntokens": bsz * self.args.tgt_len, | |
}, | |
num_items=self.args.dataset_size, | |
item_size=item_size, | |
) | |
def source_dictionary(self): | |
return self.dictionary | |
def target_dictionary(self): | |
return self.dictionary | |
class DummyDataset(FairseqDataset): | |
def __init__(self, batch, num_items, item_size): | |
super().__init__() | |
self.batch = batch | |
self.num_items = num_items | |
self.item_size = item_size | |
def __getitem__(self, index): | |
return index | |
def __len__(self): | |
return self.num_items | |
def collater(self, samples): | |
return self.batch | |
def sizes(self): | |
return np.array([self.item_size] * self.num_items) | |
def num_tokens(self, index): | |
return self.item_size | |
def size(self, index): | |
return self.item_size | |
def ordered_indices(self): | |
return np.arange(self.num_items) | |
def supports_prefetch(self): | |
return False | |