|
""" |
|
This script runs the evaluation of an SBERT msmarco model on the |
|
MS MARCO dev dataset and reports different performances metrices for cossine similarity & dot-product. |
|
|
|
Usage: |
|
python eval_msmarco.py model_name [max_corpus_size_in_thousands] |
|
""" |
|
|
|
from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation, util, models |
|
import logging |
|
import sys |
|
import os |
|
import tarfile |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
|
|
|
|
|
|
model_name = sys.argv[1] |
|
|
|
|
|
corpus_max_size = int(sys.argv[2])*1000 if len(sys.argv) >= 3 else 0 |
|
|
|
|
|
|
|
|
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
data_folder = 'msmarco-data' |
|
os.makedirs(data_folder, exist_ok=True) |
|
|
|
collection_filepath = os.path.join(data_folder, 'collection.tsv') |
|
dev_queries_file = os.path.join(data_folder, 'queries.dev.small.tsv') |
|
qrels_filepath = os.path.join(data_folder, 'qrels.dev.tsv') |
|
|
|
|
|
if not os.path.exists(collection_filepath) or not os.path.exists(dev_queries_file): |
|
tar_filepath = os.path.join(data_folder, 'collectionandqueries.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download: "+tar_filepath) |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
|
|
if not os.path.exists(qrels_filepath): |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.dev.tsv', qrels_filepath) |
|
|
|
|
|
|
|
corpus = {} |
|
dev_queries = {} |
|
dev_rel_docs = {} |
|
needed_pids = set() |
|
needed_qids = set() |
|
|
|
|
|
with open(dev_queries_file, encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
dev_queries[qid] = query.strip() |
|
|
|
|
|
|
|
with open(qrels_filepath) as fIn: |
|
for line in fIn: |
|
qid, _, pid, _ = line.strip().split('\t') |
|
|
|
if qid not in dev_queries: |
|
continue |
|
|
|
if qid not in dev_rel_docs: |
|
dev_rel_docs[qid] = set() |
|
dev_rel_docs[qid].add(pid) |
|
|
|
needed_pids.add(pid) |
|
needed_qids.add(qid) |
|
|
|
|
|
|
|
with open(collection_filepath, encoding='utf8') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
passage = passage |
|
|
|
if pid in needed_pids or corpus_max_size <= 0 or len(corpus) <= corpus_max_size: |
|
corpus[pid] = passage.strip() |
|
|
|
|
|
|
|
|
|
logging.info("Queries: {}".format(len(dev_queries))) |
|
logging.info("Corpus: {}".format(len(corpus))) |
|
|
|
ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, corpus, dev_rel_docs, |
|
show_progress_bar=True, |
|
corpus_chunk_size=100000, |
|
precision_recall_at_k=[10, 100], |
|
name="msmarco dev") |
|
|
|
ir_evaluator(model) |
|
|