import numpy as np import argparse import json import os from comet import download_model, load_from_checkpoint from transformers import AutoTokenizer COMET_REF_MODELS = ["wmt20-comet-da", "wmt21-comet-mqm", "wmt22-comet-da"] COMET_SRC_MODELS = ["wmt20-comet-qe-da", "wmt21-comet-qe-mqm", "wmt22-cometkiwi-da"] os.environ['TOKENIZERS_PARALLELISM'] = 'false' tokenizer = AutoTokenizer.from_pretrained("microsoft/infoxlm-large") def _is_doc_boundary(doc_ids, idx): after_idx = min(len(doc_ids) - 1, idx + 1) return (not doc_ids[after_idx] == doc_ids[idx]) or (idx == len(doc_ids) - 1) def _build_context(doc, current_idx, context_window, start_left=True): balance = context_window low = current_idx if start_left else max([0, current_idx - (context_window // 2)]) balance -= (current_idx - low) high = min([len(doc), current_idx + balance]) balance -= (high - current_idx) low = max([0, low - balance]) pos = current_idx - low return doc[low:high], pos def _check_max_tokens(src_context, mt_context, ref_context=None, max_tokens=512): src = " ".join(src_context).strip() mt = " ".join(mt_context).strip() if ref_context: ref = " ".join(ref_context).strip() full_input = tokenizer(" ".join([src, mt, ref])).input_ids else: full_input = tokenizer(" ".join([src, mt])).input_ids return len(full_input) < max_tokens def _calculate_doc_comet(args, model, src_docs, hyp_docs, ref_docs=None): scores, doc_lengths = [], [] if ref_docs: for s, h, r in zip(src_docs, hyp_docs, ref_docs): data_for_eval = [] # Check if the doc has length shorter than the context length if len(s) <= args.context_length: data_for_eval.append({"src": " ".join(s).strip(), "mt": " ".join(h).strip(), "ref": " ".join(r).strip()}) else: prev_context_src, prev_context_mt, prev_context_ref = [], [], [] for i in range(len(s)): src_context, _ = _build_context(s, i, args.context_length) mt_context, _ = _build_context(h, i, args.context_length) ref_context, _ = _build_context(r, i, args.context_length) # Ensure max_tokens is respected reduce = 1 while (not _check_max_tokens(src_context, mt_context, ref_context=ref_context)) and (args.context_length - reduce > 1): src_context, _ = _build_context(s, i, args.context_length - reduce) mt_context, _ = _build_context(h, i, args.context_length - reduce) ref_context, _ = _build_context(r, i, args.context_length - reduce) reduce += 1 # Ensure same context is not evaluated twice if not src_context == prev_context_src and not mt_context == prev_context_mt and not ref_context == prev_context_ref: src, mt, ref = " ".join(src_context).strip(), " ".join(mt_context).strip(), " ".join(ref_context).strip() data_for_eval.append({ "src": src, "mt": mt, "ref": ref }) prev_context_src, prev_context_mt, prev_context_ref = src_context, mt_context, ref_context # Compute the score pred = model.predict(data_for_eval, batch_size=8, gpus=1) scores.append(pred.system_score) doc_lengths.append(len(s)) else: for s, h in zip(src_docs, hyp_docs): data_for_eval = [] # Check if the doc has length shorter than the context length if len(s) <= args.context_length: data_for_eval.append({"src": " ".join(s).strip(), "mt": " ".join(h).strip()}) else: prev_context_src, prev_context_mt = [], [] for i in range(len(s)): src_context, _ = _build_context(s, i, args.context_length) mt_context, _ = _build_context(h, i, args.context_length) # Ensure max_tokens is respected reduce = 1 while (not _check_max_tokens(src_context, mt_context)) and (args.context_length - reduce > 1): src_context, _ = _build_context(s, i, args.context_length - reduce) mt_context, _ = _build_context(h, i, args.context_length - reduce) reduce += 1 # Ensure same context is not evaluated twice if not src_context == prev_context_src and not mt_context == prev_context_mt: src, mt = " ".join(src_context).strip(), " ".join(mt_context).strip() data_for_eval.append({ "src": src, "mt": mt }) prev_context_src, prev_context_mt = src_context, mt_context # Compute the score pred = model.predict(data_for_eval, batch_size=8, gpus=1) scores.append(pred.system_score) # type: ignore doc_lengths.append(len(s)) return scores, doc_lengths def _load_data(args): with open(args.sources_file, 'r') as src_file, open(args.hypotheses_file, 'r') as hyp_file, open(args.docids_file, 'r') as docids_file: sources = src_file.readlines() hypotheses = hyp_file.readlines() docids = docids_file.readlines() src_docs, hyp_docs, ref_docs = [], [], None current_src_doc, current_hyp_doc = [], [] i = 0 while i < len(docids): current_src_doc.append(sources[i].strip()) current_hyp_doc.append(hypotheses[i].strip()) if _is_doc_boundary(docids, i): src_docs.append(current_src_doc) hyp_docs.append(current_hyp_doc) current_src_doc, current_hyp_doc = [], [] i += 1 if args.references_file: # Load reference files with open(args.references_file, 'r') as ref_file: references = ref_file.readlines() ref_docs = [] current_ref_doc = [] i = 0 while i < len(docids): current_ref_doc.append(references[i].strip()) if _is_doc_boundary(docids, i): ref_docs.append(current_ref_doc) current_ref_doc = [] i += 1 return src_docs, hyp_docs, ref_docs def main(): parser = argparse.ArgumentParser() parser.add_argument('--sources-file', '-src', type=str, required=True, help='A path to the source file') parser.add_argument('--hypotheses-file', '-hyp', type=str, required=True, help='A path to the model output file') parser.add_argument('--references-file', '-ref', type=str, required=False, help='A path to the reference file') parser.add_argument('--docids-file', '-doc', type=str, required=True, help='A path to the doc-ids file') parser.add_argument('--model', type=str, required=True, help='The COMET model name used for automatic evaluation') parser.add_argument('--sliding-window', type=int, required=False, default=1, help='The stride step over document') parser.add_argument('--context-length', type=int, required=False, default=4, help='The number of sentences in a single context') args = parser.parse_args() comet_model_path = download_model(args.model) model = load_from_checkpoint(comet_model_path) if args.references_file: assert args.model in COMET_REF_MODELS, f"Reference files should not be passed for evaluating {COMET_SRC_MODELS}" else: assert args.model not in COMET_REF_MODELS, f"Reference files are required for evaluating {COMET_REF_MODELS}" src_docs, mt_docs, ref_docs = _load_data(args) scores, _ = _calculate_doc_comet(args, model, src_docs, mt_docs, ref_docs) ret = { 'model': args.model, 'sources_file': args.sources_file, 'mt_file': args.hypotheses_file, 'sliding_window': args.sliding_window, 'context_length': args.context_length, 'score': np.mean(scores) } print(json.dumps(ret, indent=2)) if __name__ == "__main__": main()