""" This script runs the evaluation of an SBERT msmarco model on the MS MARCO dev dataset and reports different performances metrices for cossine similarity & dot-product. Usage: python eval_msmarco.py model_name [max_corpus_size_in_thousands] """ from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation, util, models import logging import sys import os import tarfile #### Just some code to print debug information to stdout logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, handlers=[LoggingHandler()]) #### /print debug information to stdout #Name of the SBERT model model_name = sys.argv[1] # You can limit the approx. max size of the corpus. Pass 100 as second parameter and the corpus has a size of approx 100k docs corpus_max_size = int(sys.argv[2])*1000 if len(sys.argv) >= 3 else 0 #### Load model model = SentenceTransformer(model_name) ### Data files data_folder = 'msmarco-data' os.makedirs(data_folder, exist_ok=True) collection_filepath = os.path.join(data_folder, 'collection.tsv') dev_queries_file = os.path.join(data_folder, 'queries.dev.small.tsv') qrels_filepath = os.path.join(data_folder, 'qrels.dev.tsv') ### Download files if needed if not os.path.exists(collection_filepath) or not os.path.exists(dev_queries_file): tar_filepath = os.path.join(data_folder, 'collectionandqueries.tar.gz') if not os.path.exists(tar_filepath): logging.info("Download: "+tar_filepath) util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: tar.extractall(path=data_folder) if not os.path.exists(qrels_filepath): util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.dev.tsv', qrels_filepath) ### Load data corpus = {} #Our corpus pid => passage dev_queries = {} #Our dev queries. qid => query dev_rel_docs = {} #Mapping qid => set with relevant pids needed_pids = set() #Passage IDs we need needed_qids = set() #Query IDs we need # Load the 6980 dev queries with open(dev_queries_file, encoding='utf8') as fIn: for line in fIn: qid, query = line.strip().split("\t") dev_queries[qid] = query.strip() # Load which passages are relevant for which queries with open(qrels_filepath) as fIn: for line in fIn: qid, _, pid, _ = line.strip().split('\t') if qid not in dev_queries: continue if qid not in dev_rel_docs: dev_rel_docs[qid] = set() dev_rel_docs[qid].add(pid) needed_pids.add(pid) needed_qids.add(qid) # Read passages with open(collection_filepath, encoding='utf8') as fIn: for line in fIn: pid, passage = line.strip().split("\t") passage = passage if pid in needed_pids or corpus_max_size <= 0 or len(corpus) <= corpus_max_size: corpus[pid] = passage.strip() ## Run evaluator logging.info("Queries: {}".format(len(dev_queries))) logging.info("Corpus: {}".format(len(corpus))) ir_evaluator = evaluation.InformationRetrievalEvaluator(dev_queries, corpus, dev_rel_docs, show_progress_bar=True, corpus_chunk_size=100000, precision_recall_at_k=[10, 100], name="msmarco dev") ir_evaluator(model)