|
""" |
|
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820 |
|
|
|
TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single |
|
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated |
|
as an error if a model ranks those high. |
|
|
|
TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate |
|
the model performance for the task of reranking in Information Retrieval. |
|
|
|
Run: |
|
python eval_cross-encoder-trec-dl.py cross-encoder-model-name |
|
|
|
""" |
|
import gzip |
|
from collections import defaultdict |
|
import logging |
|
import tqdm |
|
import numpy as np |
|
import sys |
|
import pytrec_eval |
|
from sentence_transformers import SentenceTransformer, util, CrossEncoder |
|
import os |
|
|
|
data_folder = 'trec2019-data' |
|
os.makedirs(data_folder, exist_ok=True) |
|
|
|
|
|
queries = {} |
|
queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz') |
|
if not os.path.exists(queries_filepath): |
|
logging.info("Download "+os.path.basename(queries_filepath)) |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath) |
|
|
|
with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
queries[qid] = query |
|
|
|
|
|
relevant_docs = defaultdict(lambda: defaultdict(int)) |
|
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt') |
|
|
|
if not os.path.exists(qrels_filepath): |
|
logging.info("Download "+os.path.basename(qrels_filepath)) |
|
util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath) |
|
|
|
|
|
with open(qrels_filepath) as fIn: |
|
for line in fIn: |
|
qid, _, pid, score = line.strip().split() |
|
score = int(score) |
|
if score > 0: |
|
relevant_docs[qid][pid] = score |
|
|
|
|
|
relevant_qid = [] |
|
for qid in queries: |
|
if len(relevant_docs[qid]) > 0: |
|
relevant_qid.append(qid) |
|
|
|
|
|
|
|
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz') |
|
|
|
if not os.path.exists(passage_filepath): |
|
logging.info("Download "+os.path.basename(passage_filepath)) |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath) |
|
|
|
|
|
|
|
passage_cand = {} |
|
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, pid, query, passage = line.strip().split("\t") |
|
if qid not in passage_cand: |
|
passage_cand[qid] = [] |
|
|
|
passage_cand[qid].append([pid, passage]) |
|
|
|
logging.info("Queries: {}".format(len(queries))) |
|
|
|
queries_result_list = [] |
|
run = {} |
|
model = CrossEncoder(sys.argv[1], max_length=512) |
|
|
|
for qid in tqdm.tqdm(relevant_qid): |
|
query = queries[qid] |
|
|
|
cand = passage_cand[qid] |
|
pids = [c[0] for c in cand] |
|
corpus_sentences = [c[1] for c in cand] |
|
|
|
cross_inp = [[query, sent] for sent in corpus_sentences] |
|
|
|
if model.config.num_labels > 1: |
|
cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist() |
|
else: |
|
cross_scores = model.predict(cross_inp).tolist() |
|
|
|
cross_scores_sparse = {} |
|
for idx, pid in enumerate(pids): |
|
cross_scores_sparse[pid] = cross_scores[idx] |
|
|
|
sparse_scores = cross_scores_sparse |
|
run[qid] = {} |
|
for pid in sparse_scores: |
|
run[qid][pid] = float(sparse_scores[pid]) |
|
|
|
|
|
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10'}) |
|
scores = evaluator.evaluate(run) |
|
|
|
print("Queries:", len(relevant_qid)) |
|
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100)) |
|
|
|
|
|
|
|
|
|
|