Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Generate n-best translations using a trained model. | |
""" | |
from contextlib import redirect_stdout | |
import os | |
import subprocess | |
import rerank_utils | |
from examples.noisychannel import rerank_options | |
from fairseq import options | |
import generate | |
import preprocess | |
def gen_and_reprocess_nbest(args): | |
if args.score_dict_dir is None: | |
args.score_dict_dir = args.data | |
if args.prefix_len is not None: | |
assert args.right_to_left1 is False, "prefix length not compatible with right to left models" | |
assert args.right_to_left2 is False, "prefix length not compatible with right to left models" | |
if args.nbest_list is not None: | |
assert args.score_model2 is None | |
if args.backwards1: | |
scorer1_src = args.target_lang | |
scorer1_tgt = args.source_lang | |
else: | |
scorer1_src = args.source_lang | |
scorer1_tgt = args.target_lang | |
store_data = os.path.join(os.path.dirname(__file__))+"/rerank_data/"+args.data_dir_name | |
if not os.path.exists(store_data): | |
os.makedirs(store_data) | |
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ | |
backwards_preprocessed_dir, lm_preprocessed_dir = \ | |
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, | |
args.gen_model_name, args.shard_id, args.num_shards, | |
args.sampling, args.prefix_len, args.target_prefix_frac, | |
args.source_prefix_frac) | |
assert not (args.right_to_left1 and args.backwards1), "backwards right to left not supported" | |
assert not (args.right_to_left2 and args.backwards2), "backwards right to left not supported" | |
assert not (args.prefix_len is not None and args.target_prefix_frac is not None), \ | |
"target prefix frac and target prefix len incompatible" | |
# make directory to store generation results | |
if not os.path.exists(pre_gen): | |
os.makedirs(pre_gen) | |
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None | |
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None | |
if args.nbest_list is not None: | |
rerank2_is_gen = True | |
# make directories to store preprossed nbest list for reranking | |
if not os.path.exists(left_to_right_preprocessed_dir): | |
os.makedirs(left_to_right_preprocessed_dir) | |
if not os.path.exists(right_to_left_preprocessed_dir): | |
os.makedirs(right_to_left_preprocessed_dir) | |
if not os.path.exists(lm_preprocessed_dir): | |
os.makedirs(lm_preprocessed_dir) | |
if not os.path.exists(backwards_preprocessed_dir): | |
os.makedirs(backwards_preprocessed_dir) | |
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards1) | |
if args.score_model2 is not None: | |
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards2) | |
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" | |
using_nbest = args.nbest_list is not None | |
if using_nbest: | |
print("Using predefined n-best list from interactive.py") | |
predictions_bpe_file = args.nbest_list | |
else: | |
if not os.path.isfile(predictions_bpe_file): | |
print("STEP 1: generate predictions using the p(T|S) model with bpe") | |
print(args.data) | |
param1 = [args.data, | |
"--path", args.gen_model, | |
"--shard-id", str(args.shard_id), | |
"--num-shards", str(args.num_shards), | |
"--nbest", str(args.num_rescore), | |
"--batch-size", str(args.batch_size), | |
"--beam", str(args.num_rescore), | |
"--max-sentences", str(args.num_rescore), | |
"--gen-subset", args.gen_subset, | |
"--source-lang", args.source_lang, | |
"--target-lang", args.target_lang] | |
if args.sampling: | |
param1 += ["--sampling"] | |
gen_parser = options.get_generation_parser() | |
input_args = options.parse_args_and_arch(gen_parser, param1) | |
print(input_args) | |
with open(predictions_bpe_file, 'w') as f: | |
with redirect_stdout(f): | |
generate.main(input_args) | |
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, | |
nbest=using_nbest, prefix_len=args.prefix_len, | |
target_prefix_frac=args.target_prefix_frac) | |
if args.diff_bpe: | |
rerank_utils.write_reprocessed(gen_output.no_bpe_source, gen_output.no_bpe_hypo, | |
gen_output.no_bpe_target, pre_gen+"/source_gen_bpe."+args.source_lang, | |
pre_gen+"/target_gen_bpe."+args.target_lang, | |
pre_gen+"/reference_gen_bpe."+args.target_lang) | |
bitext_bpe = args.rescore_bpe_code | |
bpe_src_param = ["-c", bitext_bpe, | |
"--input", pre_gen+"/source_gen_bpe."+args.source_lang, | |
"--output", pre_gen+"/rescore_data."+args.source_lang] | |
bpe_tgt_param = ["-c", bitext_bpe, | |
"--input", pre_gen+"/target_gen_bpe."+args.target_lang, | |
"--output", pre_gen+"/rescore_data."+args.target_lang] | |
subprocess.call(["python", | |
os.path.join(os.path.dirname(__file__), | |
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_src_param, | |
shell=False) | |
subprocess.call(["python", | |
os.path.join(os.path.dirname(__file__), | |
"subword-nmt/subword_nmt/apply_bpe.py")] + bpe_tgt_param, | |
shell=False) | |
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or \ | |
(args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen): | |
print("STEP 2: process the output of generate.py so we have clean text files with the translations") | |
rescore_file = "/rescore_data" | |
if args.prefix_len is not None: | |
prefix_len_rescore_file = rescore_file + "prefix"+str(args.prefix_len) | |
if args.target_prefix_frac is not None: | |
target_prefix_frac_rescore_file = rescore_file + "target_prefix_frac"+str(args.target_prefix_frac) | |
if args.source_prefix_frac is not None: | |
source_prefix_frac_rescore_file = rescore_file + "source_prefix_frac"+str(args.source_prefix_frac) | |
if not args.right_to_left1 or not args.right_to_left2: | |
if not args.diff_bpe: | |
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, | |
pre_gen+rescore_file+"."+args.source_lang, | |
pre_gen+rescore_file+"."+args.target_lang, | |
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe) | |
if args.prefix_len is not None: | |
bw_rescore_file = prefix_len_rescore_file | |
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, | |
pre_gen+prefix_len_rescore_file+"."+args.source_lang, | |
pre_gen+prefix_len_rescore_file+"."+args.target_lang, | |
pre_gen+"/reference_file", prefix_len=args.prefix_len, | |
bpe_symbol=args.remove_bpe) | |
elif args.target_prefix_frac is not None: | |
bw_rescore_file = target_prefix_frac_rescore_file | |
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, | |
pre_gen+target_prefix_frac_rescore_file+"."+args.source_lang, | |
pre_gen+target_prefix_frac_rescore_file+"."+args.target_lang, | |
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, | |
target_prefix_frac=args.target_prefix_frac) | |
else: | |
bw_rescore_file = rescore_file | |
if args.source_prefix_frac is not None: | |
fw_rescore_file = source_prefix_frac_rescore_file | |
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, | |
pre_gen+source_prefix_frac_rescore_file+"."+args.source_lang, | |
pre_gen+source_prefix_frac_rescore_file+"."+args.target_lang, | |
pre_gen+"/reference_file", bpe_symbol=args.remove_bpe, | |
source_prefix_frac=args.source_prefix_frac) | |
else: | |
fw_rescore_file = rescore_file | |
if args.right_to_left1 or args.right_to_left2: | |
rerank_utils.write_reprocessed(gen_output.source, gen_output.hypo, gen_output.target, | |
pre_gen+"/right_to_left_rescore_data."+args.source_lang, | |
pre_gen+"/right_to_left_rescore_data."+args.target_lang, | |
pre_gen+"/right_to_left_reference_file", | |
right_to_left=True, bpe_symbol=args.remove_bpe) | |
print("STEP 3: binarize the translations") | |
if not args.right_to_left1 or args.score_model2 is not None and not args.right_to_left2 or not rerank1_is_gen: | |
if args.backwards1 or args.backwards2: | |
if args.backwards_score_dict_dir is not None: | |
bw_dict = args.backwards_score_dict_dir | |
else: | |
bw_dict = args.score_dict_dir | |
bw_preprocess_param = ["--source-lang", scorer1_src, | |
"--target-lang", scorer1_tgt, | |
"--trainpref", pre_gen+bw_rescore_file, | |
"--srcdict", bw_dict + "/dict." + scorer1_src + ".txt", | |
"--tgtdict", bw_dict + "/dict." + scorer1_tgt + ".txt", | |
"--destdir", backwards_preprocessed_dir] | |
preprocess_parser = options.get_preprocessing_parser() | |
input_args = preprocess_parser.parse_args(bw_preprocess_param) | |
preprocess.main(input_args) | |
preprocess_param = ["--source-lang", scorer1_src, | |
"--target-lang", scorer1_tgt, | |
"--trainpref", pre_gen+fw_rescore_file, | |
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", | |
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", | |
"--destdir", left_to_right_preprocessed_dir] | |
preprocess_parser = options.get_preprocessing_parser() | |
input_args = preprocess_parser.parse_args(preprocess_param) | |
preprocess.main(input_args) | |
if args.right_to_left1 or args.right_to_left2: | |
preprocess_param = ["--source-lang", scorer1_src, | |
"--target-lang", scorer1_tgt, | |
"--trainpref", pre_gen+"/right_to_left_rescore_data", | |
"--srcdict", args.score_dict_dir+"/dict."+scorer1_src+".txt", | |
"--tgtdict", args.score_dict_dir+"/dict."+scorer1_tgt+".txt", | |
"--destdir", right_to_left_preprocessed_dir] | |
preprocess_parser = options.get_preprocessing_parser() | |
input_args = preprocess_parser.parse_args(preprocess_param) | |
preprocess.main(input_args) | |
return gen_output | |
def cli_main(): | |
parser = rerank_options.get_reranking_parser() | |
args = options.parse_args_and_arch(parser) | |
gen_and_reprocess_nbest(args) | |
if __name__ == '__main__': | |
cli_main() | |