Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from fairseq.tasks.translation import TranslationTask | |
| from fairseq.tasks.language_modeling import LanguageModelingTask | |
| from fairseq import checkpoint_utils | |
| import argparse | |
| from fairseq.tasks import register_task | |
| import torch | |
| class NoisyChannelTranslation(TranslationTask): | |
| """ | |
| Rescore the top k candidates from each beam using noisy channel modeling | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| TranslationTask.add_args(parser) | |
| # fmt: off | |
| parser.add_argument('--channel-model', metavar='FILE', | |
| help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.') | |
| parser.add_argument('--combine-method', default='lm_only', | |
| choices=['lm_only', 'noisy_channel'], | |
| help="""method for combining direct and channel model scores. | |
| lm_only: decode with P(T|S)P(T) | |
| noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""") | |
| parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False, | |
| help='normalize lm score by target length instead of source length') | |
| parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'], | |
| help="Normalize bw scores with log softmax or return bw scores without log softmax") | |
| parser.add_argument('--top-k-vocab', default=0, type=int, | |
| help='top k vocab IDs to use with `src_vocab` in channel model scoring') | |
| parser.add_argument('--k2', default=50, type=int, | |
| help='the top k2 candidates to rescore with the noisy channel model for each beam') | |
| parser.add_argument('--ch-wt', default=1, type=float, | |
| help='weight for the channel model') | |
| parser.add_argument('--lm-model', metavar='FILE', | |
| help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side') | |
| parser.add_argument('--lm-data', metavar='FILE', | |
| help='path to lm model training data for target language, used to properly load LM with correct dictionary') | |
| parser.add_argument('--lm-wt', default=1, type=float, | |
| help='the weight of the lm in joint decoding') | |
| # fmt: on | |
| def build_generator( | |
| self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None | |
| ): | |
| if getattr(args, "score_reference", False): | |
| raise NotImplementedError() | |
| else: | |
| from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator | |
| use_cuda = torch.cuda.is_available() and not self.args.cpu | |
| assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!' | |
| assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs' | |
| if self.args.channel_model is not None: | |
| import copy | |
| ch_args_task = copy.deepcopy(self.args) | |
| tmp = ch_args_task.source_lang | |
| ch_args_task.source_lang = ch_args_task.target_lang | |
| ch_args_task.target_lang = tmp | |
| ch_args_task._name = 'translation' | |
| channel_task = TranslationTask.setup_task(ch_args_task) | |
| arg_dict = {} | |
| arg_dict['task'] = 'language_modeling' | |
| arg_dict['sample_break_mode'] = 'eos' | |
| arg_dict['data'] = self.args.lm_data | |
| arg_dict['output_dictionary_size'] = -1 | |
| lm_args = argparse.Namespace(**arg_dict) | |
| lm_task = LanguageModelingTask.setup_task(lm_args) | |
| lm_dict = lm_task.output_dictionary | |
| if self.args.channel_model is not None: | |
| channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task) | |
| for model in channel_models: | |
| model.make_generation_fast_( | |
| beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, | |
| need_attn=args.print_alignment, | |
| ) | |
| if self.args.fp16: | |
| model.half() | |
| if use_cuda: | |
| model.cuda() | |
| else: | |
| channel_models = None | |
| lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task) | |
| for model in lm_models: | |
| model.make_generation_fast_( | |
| beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, | |
| need_attn=args.print_alignment, | |
| ) | |
| if self.args.fp16: | |
| model.half() | |
| if use_cuda: | |
| model.cuda() | |
| return NoisyChannelSequenceGenerator( | |
| combine_method=self.args.combine_method, | |
| tgt_dict=self.target_dictionary, | |
| src_dict=self.source_dictionary, | |
| beam_size=getattr(args, 'beam', 5), | |
| max_len_a=getattr(args, 'max_len_a', 0), | |
| max_len_b=getattr(args, 'max_len_b', 200), | |
| min_len=getattr(args, 'min_len', 1), | |
| len_penalty=getattr(args, 'lenpen', 1), | |
| unk_penalty=getattr(args, 'unkpen', 0), | |
| temperature=getattr(args, 'temperature', 1.), | |
| match_source_len=getattr(args, 'match_source_len', False), | |
| no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), | |
| normalize_scores=(not getattr(args, 'unnormalized', False)), | |
| channel_models=channel_models, | |
| k2=getattr(self.args, 'k2', 50), | |
| ch_weight=getattr(self.args, 'ch_wt', 1), | |
| channel_scoring_type=self.args.channel_scoring_type, | |
| top_k_vocab=self.args.top_k_vocab, | |
| lm_models=lm_models, | |
| lm_dict=lm_dict, | |
| lm_weight=getattr(self.args, 'lm_wt', 1), | |
| normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False), | |
| ) | |