Spaces:
Runtime error
Runtime error
File size: 6,454 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
from collections import OrderedDict
import numpy as np
from fairseq import tokenizer, utils
from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("cross_lingual_lm")
class CrossLingualLMTask(LegacyFairseqTask):
"""
Task for training cross-lingual language models.
For more details look at: https://arxiv.org/pdf/1901.07291.pdf
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data",
help="colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner",
)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments" " per sample",
)
parser.add_argument(
"--monolingual-langs",
default="en",
type=str,
help="comma separated list of languages for which we"
" want to train XLM on",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle each monolingual dataset while" " training",
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
self.distributed_world_size = args.distributed_world_size
self.langs2id = self._lang_to_id(args.monolingual_langs)
def _lang_to_id(self, languages: str):
"""
Build a map from languages to ids. These ids are used as segment labels
for cross-lingual LM training.
"""
lang2id = {}
langs = [l.strip() for l in languages.split(",")]
for id, lang in enumerate(langs):
lang2id[lang] = id
return lang2id
@classmethod
def load_dictionary(cls, filename):
return MaskedLMDictionary.load(filename)
@classmethod
def build_dictionary(
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
):
d = MaskedLMDictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(
filename, d, tokenizer.tokenize_line, workers
)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@property
def target_dictionary(self):
return self.dictionary
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def _load_single_lang_dataset(self, split, epoch):
loaded_datasets = []
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
path = os.path.join(data_path, split_k)
ds = data_utils.load_indexed_dataset(
path, self.dictionary, self.args.dataset_impl
)
if ds is None:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
# Since we append each block with the classification_token,
# we need to effectively create blocks of length
# tokens_per_sample-1
loaded_datasets.append(
TokenBlockDataset(
ds,
ds.sizes,
self.args.tokens_per_sample - 1,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
)
)
logger.info(
"{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
)
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
return dataset, sizes
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
for lang in self.langs2id.keys():
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = "{}.{}".format(split, lang)
block_dataset, sizes = self._load_single_lang_dataset(
split=language_split, epoch=epoch
)
dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset,
sizes=sizes,
vocab=self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(),
classif_token_idx=self.dictionary.eos(),
sep_token_idx=self.dictionary.eos(),
shuffle=getattr(self.args, "shuffle", False),
has_pairs=False,
segment_id=self.langs2id[lang],
seed=self.seed,
)
self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
logger.info(
"{} {} {} examples".format(
utils.split_paths(self.args.data)[epoch - 1],
split,
len(self.datasets[split]),
)
)
|