import os import time import faiss import random import torch import itertools from colbert.utils.runs import Run from multiprocessing import Pool from colbert.modeling.inference import ModelInference from colbert.evaluation.ranking_logger import RankingLogger from colbert.utils.utils import print_message, batch from colbert.ranking.rankers import Ranker def retrieve(args): inference = ModelInference(args.colbert, amp=args.amp) ranker = Ranker(args, inference, faiss_depth=args.faiss_depth) ranking_logger = RankingLogger(Run.path, qrels=None) milliseconds = 0 with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger: queries = args.queries qids_in_order = list(queries.keys()) for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True): qbatch_text = [queries[qid] for qid in qbatch] rankings = [] for query_idx, q in enumerate(qbatch_text): torch.cuda.synchronize('cuda:0') s = time.time() Q = ranker.encode([q]) pids, scores = ranker.rank(Q) torch.cuda.synchronize() milliseconds += (time.time() - s) * 1000.0 if len(pids): print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0], milliseconds / (qoffset+query_idx+1), 'ms') rankings.append(zip(pids, scores)) for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)): query_idx = qoffset + query_idx if query_idx % 100 == 0: print_message(f"#> Logging query #{query_idx} (qid {qid}) now...") ranking = [(score, pid, None) for pid, score in itertools.islice(ranking, args.depth)] rlogger.log(qid, ranking, is_ranked=True) print('\n\n') print(ranking_logger.filename) print("#> Done.") print('\n\n')