File size: 5,270 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import time
import torch
import queue
import threading
from collections import defaultdict
from colbert.utils.runs import Run
from colbert.modeling.inference import ModelInference
from colbert.evaluation.ranking_logger import RankingLogger
from colbert.utils.utils import print_message, flatten, zipstar
from colbert.indexing.loaders import get_parts
from colbert.ranking.index_part import IndexPart
MAX_DEPTH_LOGGED = 1000 # TODO: Use args.depth
def prepare_ranges(index_path, dim, step, part_range):
print_message("#> Launching a separate thread to load index parts asynchronously.")
parts, _, _ = get_parts(index_path)
positions = [(offset, offset + step) for offset in range(0, len(parts), step)]
if part_range is not None:
positions = positions[part_range.start: part_range.stop]
loaded_parts = queue.Queue(maxsize=2)
def _loader_thread(index_path, dim, positions):
for offset, endpos in positions:
index = IndexPart(index_path, dim=dim, part_range=range(offset, endpos), verbose=True)
loaded_parts.put(index, block=True)
thread = threading.Thread(target=_loader_thread, args=(index_path, dim, positions,))
thread.start()
return positions, loaded_parts, thread
def score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids):
print_message("#> Sorting by PID..")
all_query_indexes, all_pids = zipstar(all_pids)
sorting_pids = torch.tensor(all_pids).sort()
all_query_indexes, all_pids = torch.tensor(all_query_indexes)[sorting_pids.indices], sorting_pids.values
range_start, range_end = 0, 0
for offset, endpos in positions:
print_message(f"#> Fetching parts {offset}--{endpos} from queue..")
index = loaded_parts.get()
print_message(f"#> Filtering PIDs to the range {index.pids_range}..")
range_start = range_start + (all_pids[range_start:] < index.pids_range.start).sum()
range_end = range_end + (all_pids[range_end:] < index.pids_range.stop).sum()
pids = all_pids[range_start:range_end]
query_indexes = all_query_indexes[range_start:range_end]
print_message(f"#> Got {len(pids)} query--passage pairs in this range.")
if len(pids) == 0:
continue
print_message(f"#> Ranking in batches the pairs #{range_start} through #{range_end}...")
scores = index.batch_rank(all_query_embeddings, query_indexes, pids, sorted_pids=True)
for query_index, pid, score in zip(query_indexes.tolist(), pids.tolist(), scores):
all_query_rankings[0][query_index].append(pid)
all_query_rankings[1][query_index].append(score)
def batch_rerank(args):
positions, loaded_parts, thread = prepare_ranges(args.index_path, args.dim, args.step, args.part_range)
inference = ModelInference(args.colbert, amp=args.amp)
queries, topK_pids = args.queries, args.topK_pids
with torch.no_grad():
queries_in_order = list(queries.values())
print_message(f"#> Encoding all {len(queries_in_order)} queries in batches...")
all_query_embeddings = inference.queryFromText(queries_in_order, bsize=512, to_cpu=True)
all_query_embeddings = all_query_embeddings.to(dtype=torch.float16).permute(0, 2, 1).contiguous()
for qid in queries:
"""
Since topK_pids is a defaultdict, make sure each qid *has* actual PID information (even if empty).
"""
assert qid in topK_pids, qid
all_pids = flatten([[(query_index, pid) for pid in topK_pids[qid]] for query_index, qid in enumerate(queries)])
all_query_rankings = [defaultdict(list), defaultdict(list)]
print_message(f"#> Will process {len(all_pids)} query--document pairs in total.")
with torch.no_grad():
score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids)
ranking_logger = RankingLogger(Run.path, qrels=None, log_scores=args.log_scores)
with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
with torch.no_grad():
for query_index, qid in enumerate(queries):
if query_index % 1000 == 0:
print_message("#> Logging query #{} (qid {}) now...".format(query_index, qid))
pids = all_query_rankings[0][query_index]
scores = all_query_rankings[1][query_index]
K = min(MAX_DEPTH_LOGGED, len(scores))
if K == 0:
continue
scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True)
pids, scores = torch.tensor(pids)[scores_topk.indices].tolist(), scores_topk.values.tolist()
ranking = [(score, pid, None) for pid, score in zip(pids, scores)]
assert len(ranking) <= MAX_DEPTH_LOGGED, (len(ranking), MAX_DEPTH_LOGGED)
rlogger.log(qid, ranking, is_ranked=True, print_positions=[1, 2] if query_index % 100 == 0 else [])
print('\n\n')
print(ranking_logger.filename)
print_message('#> Done.\n')
thread.join()
|