mColBERT / colbert /ranking /batch_reranking.py
vjeronymo2's picture
Adding model and checkpoint
828992f
import os
import time
import torch
import queue
import threading
from collections import defaultdict
from colbert.utils.runs import Run
from colbert.modeling.inference import ModelInference
from colbert.evaluation.ranking_logger import RankingLogger
from colbert.utils.utils import print_message, flatten, zipstar
from colbert.indexing.loaders import get_parts
from colbert.ranking.index_part import IndexPart
MAX_DEPTH_LOGGED = 1000 # TODO: Use args.depth
def prepare_ranges(index_path, dim, step, part_range):
print_message("#> Launching a separate thread to load index parts asynchronously.")
parts, _, _ = get_parts(index_path)
positions = [(offset, offset + step) for offset in range(0, len(parts), step)]
if part_range is not None:
positions = positions[part_range.start: part_range.stop]
loaded_parts = queue.Queue(maxsize=2)
def _loader_thread(index_path, dim, positions):
for offset, endpos in positions:
index = IndexPart(index_path, dim=dim, part_range=range(offset, endpos), verbose=True)
loaded_parts.put(index, block=True)
thread = threading.Thread(target=_loader_thread, args=(index_path, dim, positions,))
thread.start()
return positions, loaded_parts, thread
def score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids):
print_message("#> Sorting by PID..")
all_query_indexes, all_pids = zipstar(all_pids)
sorting_pids = torch.tensor(all_pids).sort()
all_query_indexes, all_pids = torch.tensor(all_query_indexes)[sorting_pids.indices], sorting_pids.values
range_start, range_end = 0, 0
for offset, endpos in positions:
print_message(f"#> Fetching parts {offset}--{endpos} from queue..")
index = loaded_parts.get()
print_message(f"#> Filtering PIDs to the range {index.pids_range}..")
range_start = range_start + (all_pids[range_start:] < index.pids_range.start).sum()
range_end = range_end + (all_pids[range_end:] < index.pids_range.stop).sum()
pids = all_pids[range_start:range_end]
query_indexes = all_query_indexes[range_start:range_end]
print_message(f"#> Got {len(pids)} query--passage pairs in this range.")
if len(pids) == 0:
continue
print_message(f"#> Ranking in batches the pairs #{range_start} through #{range_end}...")
scores = index.batch_rank(all_query_embeddings, query_indexes, pids, sorted_pids=True)
for query_index, pid, score in zip(query_indexes.tolist(), pids.tolist(), scores):
all_query_rankings[0][query_index].append(pid)
all_query_rankings[1][query_index].append(score)
def batch_rerank(args):
positions, loaded_parts, thread = prepare_ranges(args.index_path, args.dim, args.step, args.part_range)
inference = ModelInference(args.colbert, amp=args.amp)
queries, topK_pids = args.queries, args.topK_pids
with torch.no_grad():
queries_in_order = list(queries.values())
print_message(f"#> Encoding all {len(queries_in_order)} queries in batches...")
all_query_embeddings = inference.queryFromText(queries_in_order, bsize=512, to_cpu=True)
all_query_embeddings = all_query_embeddings.to(dtype=torch.float16).permute(0, 2, 1).contiguous()
for qid in queries:
"""
Since topK_pids is a defaultdict, make sure each qid *has* actual PID information (even if empty).
"""
assert qid in topK_pids, qid
all_pids = flatten([[(query_index, pid) for pid in topK_pids[qid]] for query_index, qid in enumerate(queries)])
all_query_rankings = [defaultdict(list), defaultdict(list)]
print_message(f"#> Will process {len(all_pids)} query--document pairs in total.")
with torch.no_grad():
score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids)
ranking_logger = RankingLogger(Run.path, qrels=None, log_scores=args.log_scores)
with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
with torch.no_grad():
for query_index, qid in enumerate(queries):
if query_index % 1000 == 0:
print_message("#> Logging query #{} (qid {}) now...".format(query_index, qid))
pids = all_query_rankings[0][query_index]
scores = all_query_rankings[1][query_index]
K = min(MAX_DEPTH_LOGGED, len(scores))
if K == 0:
continue
scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True)
pids, scores = torch.tensor(pids)[scores_topk.indices].tolist(), scores_topk.values.tolist()
ranking = [(score, pid, None) for pid, score in zip(pids, scores)]
assert len(ranking) <= MAX_DEPTH_LOGGED, (len(ranking), MAX_DEPTH_LOGGED)
rlogger.log(qid, ranking, is_ranked=True, print_positions=[1, 2] if query_index % 100 == 0 else [])
print('\n\n')
print(ranking_logger.filename)
print_message('#> Done.\n')
thread.join()