Spaces:
Runtime error
Runtime error
HF-SillyTavern-Extras
/
modules
/voice_conversion
/fairseq
/tasks
/translation_from_pretrained_bart.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairseq import utils | |
| from fairseq.data import LanguagePairDataset | |
| from . import register_task | |
| from .translation import TranslationTask, load_langpair_dataset | |
| class TranslationFromPretrainedBARTTask(TranslationTask): | |
| """ | |
| Translate from source language to target language with a model initialized with a multilingual pretrain. | |
| Args: | |
| src_dict (~fairseq.data.Dictionary): dictionary for the source language | |
| tgt_dict (~fairseq.data.Dictionary): dictionary for the target language | |
| .. note:: | |
| The translation task is compatible with :mod:`fairseq-train`, | |
| :mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
| The translation task provides the following additional command-line | |
| arguments: | |
| .. argparse:: | |
| :ref: fairseq.tasks.translation_parser | |
| :prog: | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| # fmt: off | |
| TranslationTask.add_args(parser) | |
| parser.add_argument('--langs', type=str, metavar='LANG', | |
| help='comma-separated list of monolingual language, ' | |
| 'for example, "en,de,fr". These should match the ' | |
| 'langs from pretraining (and be in the same order). ' | |
| 'You should always add all pretraining language idx ' | |
| 'during finetuning.') | |
| parser.add_argument('--prepend-bos', action='store_true', | |
| help='prepend bos token to each sentence, which matches ' | |
| 'mBART pretraining') | |
| # fmt: on | |
| def __init__(self, args, src_dict, tgt_dict): | |
| super().__init__(args, src_dict, tgt_dict) | |
| self.langs = args.langs.split(",") | |
| for d in [src_dict, tgt_dict]: | |
| for l in self.langs: | |
| d.add_symbol("[{}]".format(l)) | |
| d.add_symbol("<mask>") | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| paths = utils.split_paths(self.args.data) | |
| assert len(paths) > 0 | |
| data_path = paths[(epoch - 1) % len(paths)] | |
| # infer langcode | |
| src, tgt = self.args.source_lang, self.args.target_lang | |
| self.datasets[split] = load_langpair_dataset( | |
| data_path, | |
| split, | |
| src, | |
| self.src_dict, | |
| tgt, | |
| self.tgt_dict, | |
| combine=combine, | |
| dataset_impl=self.args.dataset_impl, | |
| upsample_primary=self.args.upsample_primary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| max_source_positions=getattr(self.args, "max_source_positions", 1024), | |
| max_target_positions=getattr(self.args, "max_target_positions", 1024), | |
| load_alignments=self.args.load_alignments, | |
| prepend_bos=getattr(self.args, "prepend_bos", False), | |
| append_source_id=True, | |
| ) | |
| def build_generator(self, models, args, **unused): | |
| if getattr(args, "score_reference", False): | |
| from fairseq.sequence_scorer import SequenceScorer | |
| return SequenceScorer( | |
| self.target_dictionary, | |
| eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), | |
| ) | |
| else: | |
| from fairseq.sequence_generator import SequenceGenerator | |
| return SequenceGenerator( | |
| models, | |
| self.target_dictionary, | |
| beam_size=getattr(args, "beam", 5), | |
| max_len_a=getattr(args, "max_len_a", 0), | |
| max_len_b=getattr(args, "max_len_b", 200), | |
| min_len=getattr(args, "min_len", 1), | |
| normalize_scores=(not getattr(args, "unnormalized", False)), | |
| len_penalty=getattr(args, "lenpen", 1), | |
| unk_penalty=getattr(args, "unkpen", 0), | |
| temperature=getattr(args, "temperature", 1.0), | |
| match_source_len=getattr(args, "match_source_len", False), | |
| no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
| eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), | |
| ) | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
| src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) | |
| source_tokens = [] | |
| for s_t in src_tokens: | |
| s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) | |
| source_tokens.append(s_t) | |
| dataset = LanguagePairDataset( | |
| source_tokens, | |
| src_lengths, | |
| self.source_dictionary, | |
| tgt_dict=self.target_dictionary, | |
| constraints=constraints, | |
| ) | |
| return dataset | |