Spaces:
Sleeping
Sleeping
import rerank_utils | |
import rerank_generate | |
import rerank_score_bw | |
import rerank_score_lm | |
from fairseq import bleu, options | |
from fairseq.data import dictionary | |
from examples.noisychannel import rerank_options | |
from multiprocessing import Pool | |
import math | |
import numpy as np | |
def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize): | |
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) | |
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) | |
dict = dictionary.Dictionary() | |
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) | |
ordered_hypos = {} | |
ordered_targets = {} | |
for shard_id in range(len(bitext1_lst)): | |
bitext1 = bitext1_lst[shard_id] | |
bitext2 = bitext2_lst[shard_id] | |
gen_output = gen_output_lst[shard_id] | |
lm_res = lm_res_lst[shard_id] | |
total = len(bitext1.rescore_source.keys()) | |
source_lst = [] | |
hypo_lst = [] | |
score_lst = [] | |
reference_lst = [] | |
j = 1 | |
best_score = -math.inf | |
for i in range(total): | |
# length is measured in terms of words, not bpe tokens, since models may not share the same bpe | |
target_len = len(bitext1.rescore_hypo[i].split()) | |
if lm_res is not None: | |
lm_score = lm_res.score[i] | |
else: | |
lm_score = 0 | |
if bitext2 is not None: | |
bitext2_score = bitext2.rescore_score[i] | |
bitext2_backwards = bitext2.backwards | |
else: | |
bitext2_score = None | |
bitext2_backwards = None | |
score = rerank_utils.get_score(a, b, c, target_len, | |
bitext1.rescore_score[i], bitext2_score, lm_score=lm_score, | |
lenpen=lenpen, src_len=bitext1.source_lengths[i], | |
tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards, | |
bitext2_backwards=bitext2_backwards, normalize=normalize) | |
if score > best_score: | |
best_score = score | |
best_hypo = bitext1.rescore_hypo[i] | |
if j == gen_output.num_hypos[i] or j == args.num_rescore: | |
j = 1 | |
hypo_lst.append(best_hypo) | |
score_lst.append(best_score) | |
source_lst.append(bitext1.rescore_source[i]) | |
reference_lst.append(bitext1.rescore_target[i]) | |
best_score = -math.inf | |
best_hypo = "" | |
else: | |
j += 1 | |
gen_keys = list(sorted(gen_output.no_bpe_target.keys())) | |
for key in range(len(gen_keys)): | |
if args.prefix_len is None: | |
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( | |
"pred and rescore hypo mismatch: i: " + str(key) + ", " | |
+ str(hypo_lst[key]) + str(gen_keys[key]) | |
+ str(gen_output.no_bpe_hypo[key]) | |
) | |
sys_tok = dict.encode_line(hypo_lst[key]) | |
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) | |
scorer.add(ref_tok, sys_tok) | |
else: | |
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) | |
sys_tok = dict.encode_line(full_hypo) | |
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) | |
scorer.add(ref_tok, sys_tok) | |
# if only one set of hyper parameters is provided, write the predictions to a file | |
if write_hypos: | |
# recover the orinal ids from n best list generation | |
for key in range(len(gen_output.no_bpe_target)): | |
if args.prefix_len is None: | |
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ | |
"pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key]) | |
ordered_hypos[gen_keys[key]] = hypo_lst[key] | |
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] | |
else: | |
full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) | |
ordered_hypos[gen_keys[key]] = full_hypo | |
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] | |
# write the hypos in the original order from nbest list generation | |
if args.num_shards == (len(bitext1_lst)): | |
with open(target_outfile, 'w') as t: | |
with open(hypo_outfile, 'w') as h: | |
for key in range(len(ordered_hypos)): | |
t.write(ordered_targets[key]) | |
h.write(ordered_hypos[key]) | |
res = scorer.result_string(4) | |
if write_hypos: | |
print(res) | |
score = rerank_utils.parse_bleu_scoring(res) | |
return score | |
def match_target_hypo(args, target_outfile, hypo_outfile): | |
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" | |
if len(args.weight1) == 1: | |
res = score_target_hypo(args, args.weight1[0], args.weight2[0], | |
args.weight3[0], args.lenpen[0], target_outfile, | |
hypo_outfile, True, args.normalize) | |
rerank_scores = [res] | |
else: | |
print("launching pool") | |
with Pool(32) as p: | |
rerank_scores = p.starmap(score_target_hypo, | |
[(args, args.weight1[i], args.weight2[i], args.weight3[i], | |
args.lenpen[i], target_outfile, hypo_outfile, | |
False, args.normalize) for i in range(len(args.weight1))]) | |
if len(rerank_scores) > 1: | |
best_index = np.argmax(rerank_scores) | |
best_score = rerank_scores[best_index] | |
print("best score", best_score) | |
print("best lenpen", args.lenpen[best_index]) | |
print("best weight1", args.weight1[best_index]) | |
print("best weight2", args.weight2[best_index]) | |
print("best weight3", args.weight3[best_index]) | |
return args.lenpen[best_index], args.weight1[best_index], \ | |
args.weight2[best_index], args.weight3[best_index], best_score | |
else: | |
return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0] | |
def load_score_files(args): | |
if args.all_shards: | |
shard_ids = list(range(args.num_shards)) | |
else: | |
shard_ids = [args.shard_id] | |
gen_output_lst = [] | |
bitext1_lst = [] | |
bitext2_lst = [] | |
lm_res1_lst = [] | |
for shard_id in shard_ids: | |
using_nbest = args.nbest_list is not None | |
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ | |
backwards_preprocessed_dir, lm_preprocessed_dir = \ | |
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, | |
args.gen_model_name, shard_id, args.num_shards, args.sampling, | |
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) | |
rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None | |
rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None | |
score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards1) | |
if args.score_model2 is not None: | |
score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, | |
target_prefix_frac=args.target_prefix_frac, | |
source_prefix_frac=args.source_prefix_frac, | |
backwards=args.backwards2) | |
if args.language_model is not None: | |
lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) | |
# get gen output | |
predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" | |
if using_nbest: | |
print("Using predefined n-best list from interactive.py") | |
predictions_bpe_file = args.nbest_list | |
gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, | |
nbest=using_nbest, prefix_len=args.prefix_len, | |
target_prefix_frac=args.target_prefix_frac) | |
if rerank1_is_gen: | |
bitext1 = gen_output | |
else: | |
bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1, | |
args.remove_bpe, args.prefix_len, args.target_prefix_frac, | |
args.source_prefix_frac) | |
if args.score_model2 is not None or args.nbest_list is not None: | |
if rerank2_is_gen: | |
bitext2 = gen_output | |
else: | |
bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2, | |
args.remove_bpe, args.prefix_len, args.target_prefix_frac, | |
args.source_prefix_frac) | |
assert bitext2.source_lengths == bitext1.source_lengths, \ | |
"source lengths for rescoring models do not match" | |
assert bitext2.target_lengths == bitext1.target_lengths, \ | |
"target lengths for rescoring models do not match" | |
else: | |
if args.diff_bpe: | |
assert args.score_model2 is None | |
bitext2 = gen_output | |
else: | |
bitext2 = None | |
if args.language_model is not None: | |
lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len, | |
args.remove_bpe, args.target_prefix_frac) | |
else: | |
lm_res1 = None | |
gen_output_lst.append(gen_output) | |
bitext1_lst.append(bitext1) | |
bitext2_lst.append(bitext2) | |
lm_res1_lst.append(lm_res1) | |
return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst | |
def rerank(args): | |
if type(args.lenpen) is not list: | |
args.lenpen = [args.lenpen] | |
if type(args.weight1) is not list: | |
args.weight1 = [args.weight1] | |
if type(args.weight2) is not list: | |
args.weight2 = [args.weight2] | |
if type(args.weight3) is not list: | |
args.weight3 = [args.weight3] | |
if args.all_shards: | |
shard_ids = list(range(args.num_shards)) | |
else: | |
shard_ids = [args.shard_id] | |
for shard_id in shard_ids: | |
pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ | |
backwards_preprocessed_dir, lm_preprocessed_dir = \ | |
rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, | |
args.gen_model_name, shard_id, args.num_shards, args.sampling, | |
args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) | |
rerank_generate.gen_and_reprocess_nbest(args) | |
rerank_score_bw.score_bw(args) | |
rerank_score_lm.score_lm(args) | |
if args.write_hypos is None: | |
write_targets = pre_gen+"/matched_targets" | |
write_hypos = pre_gen+"/matched_hypos" | |
else: | |
write_targets = args.write_hypos+"_targets" + args.gen_subset | |
write_hypos = args.write_hypos+"_hypos" + args.gen_subset | |
if args.all_shards: | |
write_targets += "_all_shards" | |
write_hypos += "_all_shards" | |
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \ | |
match_target_hypo(args, write_targets, write_hypos) | |
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score | |
def cli_main(): | |
parser = rerank_options.get_reranking_parser() | |
args = options.parse_args_and_arch(parser) | |
rerank(args) | |
if __name__ == '__main__': | |
cli_main() | |