#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Generate n-best translations using a trained model. """ from contextlib import redirect_stdout import os import subprocess import rerank_utils from examples.noisychannel import rerank_options from fairseq import options import generate import preprocess def gen_and_reprocess_nbest(args): if args.score_dict_dir is None: args.score_dict_dir = args.data if args.prefix_len is not None: assert args.right_to_left1 is False, "prefix length not compatible with right to left models" assert args.right_to_left2 is False, "prefix length not compatible with right to left models" if args.nbest_list is not None: assert args.score_model2 is None if args.backwards1: scorer1_src = args.target_lang scorer1_tgt = args.source_lang else: scorer1_src = args.source_lang scorer1_tgt = args.target_lang store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name if not os.path.exists(store_data): os.makedirs(store_data) pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ backwards_preprocessed_dir, lm_preprocessed_dir = \ rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, args.shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported" assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported" assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \ "target prefix frac and target prefix len incompatible" # make directory to store generation results if not os.path.exists(pre_gen): os.makedirs(pre_gen) rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None if args.nbest_list is not None: rerank2_is_gen = True # make directories to store preprossed nbest list for reranking if not os.path.exists(left_to_right_preprocessed_dir): os.makedirs(left_to_right_preprocessed_dir) if not os.path.exists(right_to_left_preprocessed_dir): os.makedirs(right_to_left_preprocessed_dir) if not os.path.exists(lm_preprocessed_dir): os.makedirs(lm_preprocessed_dir) if not os.path.exists(backwards_preprocessed_dir): os.makedirs(backwards_preprocessed_dir) score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards1) if args.score_model2 is not None: score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards2) predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" using_nbest = args.nbest_list is not None if using_nbest: print("Using predefined n-best list from interactive.py") predictions_bpe_file = args.nbest_list else: if not os.path.isfile(predictions_bpe_file): print("STEP 1: generate predictions using the p(T|S) model with bpe") print(args.data) param1 = [args.data, "--path", args.gen_model, "--shard-id", str(args.shard_id), "--num-shards", str(args.num_shards), "--nbest", str(args.num_rescore), "--batch-size", str(args.batch_size), "--beam", str(args.num_rescore), "--max-sentences", str(args.num_rescore), "--gen-subset", args.gen_subset, "--source-lang", args.source_lang, "--target-lang", args.target_lang] if args.sampling: param1 += ["--sampling"] gen_parser = options.get_generation_parser() input_args = options.parse_args_and_arch(gen_parser, param1) print(input_args) with open(predictions_bpe_file, 'w') as f: with redirect_stdout(f): generate.main(input_args) gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest, prefix_len=args.prefix_len, target_prefix_frac=args.target_prefix_frac) if args.diff_bpe: rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang, pre_gen+"/target_gen_bpe."+args.target_lang, pre_gen+"/reference_gen_bpe."+args.target_lang) bitext_bpe = args.rescore_bpe_code bpe_src_param = ["-c", bitext_bpe, "--input", pre_gen+"/source_gen_bpe."+args.source_lang, "--output", pre_gen+"/rescore_data."+args.source_lang] bpe_tgt_param = ["-c", bitext_bpe, "--input", pre_gen+"/target_gen_bpe."+args.target_lang, "--output", pre_gen+"/rescore_data."+args.target_lang] subprocess.call(["python", os.path.join(os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, shell=False) subprocess.call(["python", os.path.join(os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param, shell=False) if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \ (args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen): print("STEP 2: process the output of generate.py so we have clean text files with the translations") rescore_file = "/rescore_data" if args.prefix_len is not None: prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len) if args.target_prefix_frac is not None: target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac) if args.source_prefix_frac is not None: source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac) if not args.right_to_left1 or not args.right_to_left2: if not args.diff_bpe: rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, pre_gen+rescore_file+"."+args.source_lang, pre_gen+rescore_file+"."+args.target_lang, pre_gen+"/reference_file", bpe_symbol=args.remove_bpe) if args.prefix_len is not None: bw_rescore_file = prefix_len_rescore_file rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, pre_gen+prefix_len_rescore_file+"."+args.source_lang, pre_gen+prefix_len_rescore_file+"."+args.target_lang, pre_gen+"/reference_file", prefix_len=args.prefix_len, bpe_symbol=args.remove_bpe) elif args.target_prefix_frac is not None: bw_rescore_file = target_prefix_frac_rescore_file rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang, pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang, pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, target_prefix_frac=args.target_prefix_frac) else: bw_rescore_file = rescore_file if args.source_prefix_frac is not None: fw_rescore_file = source_prefix_frac_rescore_file rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang, pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang, pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, source_prefix_frac=args.source_prefix_frac) else: fw_rescore_file = rescore_file if args.right_to_left1 or args.right_to_left2: rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, pre_gen+"/right_to_left_rescore_data."+args.source_lang, pre_gen+"/right_to_left_rescore_data."+args.target_lang, pre_gen+"/right_to_left_reference_file", right_to_left=True, bpe_symbol=args.remove_bpe) print("STEP 3: binarize the translations") if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen: if args.backwards1 or args.backwards2: if args.backwards_score_dict_dir is not None: bw_dict = args.backwards_score_dict_dir else: bw_dict = args.score_dict_dir bw_preprocess_param = ["--source-lang", scorer1_src, "--target-lang", scorer1_tgt, "--trainpref", pre_gen+bw_rescore_file, "--srcdict", bw_dict + "/dict." + scorer1_src + ".txt", "--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt", "--destdir", backwards_preprocessed_dir] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(bw_preprocess_param) preprocess.main(input_args) preprocess_param = ["--source-lang", scorer1_src, "--target-lang", scorer1_tgt, "--trainpref", pre_gen+fw_rescore_file, "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", "--destdir", left_to_right_preprocessed_dir] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_param) preprocess.main(input_args) if args.right_to_left1 or args.right_to_left2: preprocess_param = ["--source-lang", scorer1_src, "--target-lang", scorer1_tgt, "--trainpref", pre_gen+"/right_to_left_rescore_data", "--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", "--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", "--destdir", right_to_left_preprocessed_dir] preprocess_parser = options.get_preprocessing_parser() input_args = preprocess_parser.parse_args(preprocess_param) preprocess.main(input_args) return gen_output def cli_main(): parser = rerank_options.get_reranking_parser() args = options.parse_args_and_arch(parser) gen_and_reprocess_nbest(args) if __name__ == '__main__': cli_main()