Tzktz's picture
Upload 7664 files
6fc683c verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass
class DummyLMConfig(FairseqDataclass):
dict_size: int = 49996
dataset_size: int = 100000
tokens_per_sample: int = field(
default=512, metadata={"help": "max sequence length"}
)
add_bos_token: bool = False
batch_size: Optional[int] = II("dataset.batch_size")
max_tokens: Optional[int] = II("dataset.max_tokens")
max_target_positions: int = II("task.tokens_per_sample")
@register_task("dummy_lm", dataclass=DummyLMConfig)
class DummyLMTask(FairseqTask):
def __init__(self, cfg: DummyLMConfig):
super().__init__(cfg)
# load dictionary
self.dictionary = Dictionary()
for i in range(cfg.dict_size):
self.dictionary.add_symbol("word{}".format(i))
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
logger.info("dictionary: {} types".format(len(self.dictionary)))
seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
self.dummy_src = seq[:-1]
self.dummy_tgt = seq[1:]
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.cfg.batch_size is not None:
bsz = self.cfg.batch_size
else:
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
),
},
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.cfg.tokens_per_sample,
},
num_items=self.cfg.dataset_size,
item_size=self.cfg.tokens_per_sample,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False