SentenceTransformer / examples /training /ms_marco /eval_cross-encoder-trec-dl.py
lengocduc195's picture
pushNe
2359bda
"""
This file evaluates CrossEncoder on the TREC 2019 Deep Learning (DL) Track: https://arxiv.org/abs/2003.07820
TREC 2019 DL is based on the corpus of MS Marco. MS Marco provides a sparse annotation, i.e., usually only a single
passage is marked as relevant for a given query. Many other highly relevant passages are not annotated and hence are treated
as an error if a model ranks those high.
TREC DL instead annotated up to 200 passages per query for their relevance to a given query. It is better suited to estimate
the model performance for the task of reranking in Information Retrieval.
Run:
python eval_cross-encoder-trec-dl.py cross-encoder-model-name
"""
import gzip
from collections import defaultdict
import logging
import tqdm
import numpy as np
import sys
import pytrec_eval
from sentence_transformers import SentenceTransformer, util, CrossEncoder
import os
data_folder = 'trec2019-data'
os.makedirs(data_folder, exist_ok=True)
#Read test queries
queries = {}
queries_filepath = os.path.join(data_folder, 'msmarco-test2019-queries.tsv.gz')
if not os.path.exists(queries_filepath):
logging.info("Download "+os.path.basename(queries_filepath))
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2019-queries.tsv.gz', queries_filepath)
with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
queries[qid] = query
#Read which passages are relevant
relevant_docs = defaultdict(lambda: defaultdict(int))
qrels_filepath = os.path.join(data_folder, '2019qrels-pass.txt')
if not os.path.exists(qrels_filepath):
logging.info("Download "+os.path.basename(qrels_filepath))
util.http_get('https://trec.nist.gov/data/deep/2019qrels-pass.txt', qrels_filepath)
with open(qrels_filepath) as fIn:
for line in fIn:
qid, _, pid, score = line.strip().split()
score = int(score)
if score > 0:
relevant_docs[qid][pid] = score
# Only use queries that have at least one relevant passage
relevant_qid = []
for qid in queries:
if len(relevant_docs[qid]) > 0:
relevant_qid.append(qid)
# Read the top 1000 passages that are supposed to be re-ranked
passage_filepath = os.path.join(data_folder, 'msmarco-passagetest2019-top1000.tsv.gz')
if not os.path.exists(passage_filepath):
logging.info("Download "+os.path.basename(passage_filepath))
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-passagetest2019-top1000.tsv.gz', passage_filepath)
passage_cand = {}
with gzip.open(passage_filepath, 'rt', encoding='utf8') as fIn:
for line in fIn:
qid, pid, query, passage = line.strip().split("\t")
if qid not in passage_cand:
passage_cand[qid] = []
passage_cand[qid].append([pid, passage])
logging.info("Queries: {}".format(len(queries)))
queries_result_list = []
run = {}
model = CrossEncoder(sys.argv[1], max_length=512)
for qid in tqdm.tqdm(relevant_qid):
query = queries[qid]
cand = passage_cand[qid]
pids = [c[0] for c in cand]
corpus_sentences = [c[1] for c in cand]
cross_inp = [[query, sent] for sent in corpus_sentences]
if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
else:
cross_scores = model.predict(cross_inp).tolist()
cross_scores_sparse = {}
for idx, pid in enumerate(pids):
cross_scores_sparse[pid] = cross_scores[idx]
sparse_scores = cross_scores_sparse
run[qid] = {}
for pid in sparse_scores:
run[qid][pid] = float(sparse_scores[pid])
evaluator = pytrec_eval.RelevanceEvaluator(relevant_docs, {'ndcg_cut.10'})
scores = evaluator.evaluate(run)
print("Queries:", len(relevant_qid))
print("NDCG@10: {:.2f}".format(np.mean([ele["ndcg_cut_10"] for ele in scores.values()])*100))