|
import os
|
|
import time
|
|
import torch
|
|
import queue
|
|
import threading
|
|
|
|
from collections import defaultdict
|
|
|
|
from colbert.utils.runs import Run
|
|
from colbert.modeling.inference import ModelInference
|
|
from colbert.evaluation.ranking_logger import RankingLogger
|
|
|
|
from colbert.utils.utils import print_message, flatten, zipstar
|
|
from colbert.indexing.loaders import get_parts
|
|
from colbert.ranking.index_part import IndexPart
|
|
|
|
MAX_DEPTH_LOGGED = 1000
|
|
|
|
|
|
def prepare_ranges(index_path, dim, step, part_range):
|
|
print_message("#> Launching a separate thread to load index parts asynchronously.")
|
|
parts, _, _ = get_parts(index_path)
|
|
|
|
positions = [(offset, offset + step) for offset in range(0, len(parts), step)]
|
|
|
|
if part_range is not None:
|
|
positions = positions[part_range.start: part_range.stop]
|
|
|
|
loaded_parts = queue.Queue(maxsize=2)
|
|
|
|
def _loader_thread(index_path, dim, positions):
|
|
for offset, endpos in positions:
|
|
index = IndexPart(index_path, dim=dim, part_range=range(offset, endpos), verbose=True)
|
|
loaded_parts.put(index, block=True)
|
|
|
|
thread = threading.Thread(target=_loader_thread, args=(index_path, dim, positions,))
|
|
thread.start()
|
|
|
|
return positions, loaded_parts, thread
|
|
|
|
|
|
def score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids):
|
|
print_message("#> Sorting by PID..")
|
|
all_query_indexes, all_pids = zipstar(all_pids)
|
|
sorting_pids = torch.tensor(all_pids).sort()
|
|
all_query_indexes, all_pids = torch.tensor(all_query_indexes)[sorting_pids.indices], sorting_pids.values
|
|
|
|
range_start, range_end = 0, 0
|
|
|
|
for offset, endpos in positions:
|
|
print_message(f"#> Fetching parts {offset}--{endpos} from queue..")
|
|
index = loaded_parts.get()
|
|
|
|
print_message(f"#> Filtering PIDs to the range {index.pids_range}..")
|
|
range_start = range_start + (all_pids[range_start:] < index.pids_range.start).sum()
|
|
range_end = range_end + (all_pids[range_end:] < index.pids_range.stop).sum()
|
|
|
|
pids = all_pids[range_start:range_end]
|
|
query_indexes = all_query_indexes[range_start:range_end]
|
|
|
|
print_message(f"#> Got {len(pids)} query--passage pairs in this range.")
|
|
|
|
if len(pids) == 0:
|
|
continue
|
|
|
|
print_message(f"#> Ranking in batches the pairs #{range_start} through #{range_end}...")
|
|
scores = index.batch_rank(all_query_embeddings, query_indexes, pids, sorted_pids=True)
|
|
|
|
for query_index, pid, score in zip(query_indexes.tolist(), pids.tolist(), scores):
|
|
all_query_rankings[0][query_index].append(pid)
|
|
all_query_rankings[1][query_index].append(score)
|
|
|
|
|
|
def batch_rerank(args):
|
|
positions, loaded_parts, thread = prepare_ranges(args.index_path, args.dim, args.step, args.part_range)
|
|
|
|
inference = ModelInference(args.colbert, amp=args.amp)
|
|
queries, topK_pids = args.queries, args.topK_pids
|
|
|
|
with torch.no_grad():
|
|
queries_in_order = list(queries.values())
|
|
|
|
print_message(f"#> Encoding all {len(queries_in_order)} queries in batches...")
|
|
|
|
all_query_embeddings = inference.queryFromText(queries_in_order, bsize=512, to_cpu=True)
|
|
all_query_embeddings = all_query_embeddings.to(dtype=torch.float16).permute(0, 2, 1).contiguous()
|
|
|
|
for qid in queries:
|
|
"""
|
|
Since topK_pids is a defaultdict, make sure each qid *has* actual PID information (even if empty).
|
|
"""
|
|
assert qid in topK_pids, qid
|
|
|
|
all_pids = flatten([[(query_index, pid) for pid in topK_pids[qid]] for query_index, qid in enumerate(queries)])
|
|
all_query_rankings = [defaultdict(list), defaultdict(list)]
|
|
|
|
print_message(f"#> Will process {len(all_pids)} query--document pairs in total.")
|
|
|
|
with torch.no_grad():
|
|
score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids)
|
|
|
|
ranking_logger = RankingLogger(Run.path, qrels=None, log_scores=args.log_scores)
|
|
|
|
with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
|
|
with torch.no_grad():
|
|
for query_index, qid in enumerate(queries):
|
|
if query_index % 1000 == 0:
|
|
print_message("#> Logging query #{} (qid {}) now...".format(query_index, qid))
|
|
|
|
pids = all_query_rankings[0][query_index]
|
|
scores = all_query_rankings[1][query_index]
|
|
|
|
K = min(MAX_DEPTH_LOGGED, len(scores))
|
|
|
|
if K == 0:
|
|
continue
|
|
|
|
scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True)
|
|
|
|
pids, scores = torch.tensor(pids)[scores_topk.indices].tolist(), scores_topk.values.tolist()
|
|
|
|
ranking = [(score, pid, None) for pid, score in zip(pids, scores)]
|
|
assert len(ranking) <= MAX_DEPTH_LOGGED, (len(ranking), MAX_DEPTH_LOGGED)
|
|
|
|
rlogger.log(qid, ranking, is_ranked=True, print_positions=[1, 2] if query_index % 100 == 0 else [])
|
|
|
|
print('\n\n')
|
|
print(ranking_logger.filename)
|
|
print_message('#> Done.\n')
|
|
|
|
thread.join()
|
|
|