mColBERT / colbert /ranking /index_ranker.py
vjeronymo2's picture
Adding model and checkpoint
828992f
import os
import math
import torch
import ujson
import traceback
from itertools import accumulate
from colbert.parameters import DEVICE
from colbert.utils.utils import print_message, dotdict, flatten
BSIZE = 1 << 14
class IndexRanker():
def __init__(self, tensor, doclens):
self.tensor = tensor
self.doclens = doclens
self.maxsim_dtype = torch.float32
self.doclens_pfxsum = [0] + list(accumulate(self.doclens))
self.doclens = torch.tensor(self.doclens)
self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum)
self.dim = self.tensor.size(-1)
self.strides = [torch_percentile(self.doclens, p) for p in [90]]
self.strides.append(self.doclens.max().item())
self.strides = sorted(list(set(self.strides)))
print_message(f"#> Using strides {self.strides}..")
self.views = self._create_views(self.tensor)
self.buffers = self._create_buffers(BSIZE, self.tensor.dtype, {'cpu', 'cuda:0'})
def _create_views(self, tensor):
views = []
for stride in self.strides:
outdim = tensor.size(0) - stride + 1
view = torch.as_strided(tensor, (outdim, stride, self.dim), (self.dim, self.dim, 1))
views.append(view)
return views
def _create_buffers(self, max_bsize, dtype, devices):
buffers = {}
for device in devices:
buffers[device] = [torch.zeros(max_bsize, stride, self.dim, dtype=dtype,
device=device, pin_memory=(device == 'cpu'))
for stride in self.strides]
return buffers
def rank(self, Q, pids, views=None, shift=0):
assert len(pids) > 0
assert Q.size(0) in [1, len(pids)]
Q = Q.contiguous().to(DEVICE).to(dtype=self.maxsim_dtype)
views = self.views if views is None else views
VIEWS_DEVICE = views[0].device
D_buffers = self.buffers[str(VIEWS_DEVICE)]
raw_pids = pids if type(pids) is list else pids.tolist()
pids = torch.tensor(pids) if type(pids) is list else pids
doclens, offsets = self.doclens[pids], self.doclens_pfxsum[pids]
assignments = (doclens.unsqueeze(1) > torch.tensor(self.strides).unsqueeze(0) + 1e-6).sum(-1)
one_to_n = torch.arange(len(raw_pids))
output_pids, output_scores, output_permutation = [], [], []
for group_idx, stride in enumerate(self.strides):
locator = (assignments == group_idx)
if locator.sum() < 1e-5:
continue
group_pids, group_doclens, group_offsets = pids[locator], doclens[locator], offsets[locator]
group_Q = Q if Q.size(0) == 1 else Q[locator]
group_offsets = group_offsets.to(VIEWS_DEVICE) - shift
group_offsets_uniq, group_offsets_expand = torch.unique_consecutive(group_offsets, return_inverse=True)
D_size = group_offsets_uniq.size(0)
D = torch.index_select(views[group_idx], 0, group_offsets_uniq, out=D_buffers[group_idx][:D_size])
D = D.to(DEVICE)
D = D[group_offsets_expand.to(DEVICE)].to(dtype=self.maxsim_dtype)
mask = torch.arange(stride, device=DEVICE) + 1
mask = mask.unsqueeze(0) <= group_doclens.to(DEVICE).unsqueeze(-1)
scores = (D @ group_Q) * mask.unsqueeze(-1)
scores = scores.max(1).values.sum(-1).cpu()
output_pids.append(group_pids)
output_scores.append(scores)
output_permutation.append(one_to_n[locator])
output_permutation = torch.cat(output_permutation).sort().indices
output_pids = torch.cat(output_pids)[output_permutation].tolist()
output_scores = torch.cat(output_scores)[output_permutation].tolist()
assert len(raw_pids) == len(output_pids)
assert len(raw_pids) == len(output_scores)
assert raw_pids == output_pids
return output_scores
def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids, sorted_pids):
assert sorted_pids is True
######
scores = []
range_start, range_end = 0, 0
for pid_offset in range(0, len(self.doclens), 50_000):
pid_endpos = min(pid_offset + 50_000, len(self.doclens))
range_start = range_start + (all_pids[range_start:] < pid_offset).sum()
range_end = range_end + (all_pids[range_end:] < pid_endpos).sum()
pids = all_pids[range_start:range_end]
query_indexes = all_query_indexes[range_start:range_end]
print_message(f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}.")
if len(pids) == 0:
continue
print_message(f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range.")
tensor_offset = self.doclens_pfxsum[pid_offset].item()
tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512
collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE)
views = self._create_views(collection)
print_message(f"#> Ranking in batches of {BSIZE} query--passage pairs...")
for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)):
if batch_idx % 100 == 0:
print_message("#> Processing batch #{}..".format(batch_idx))
endpos = offset + BSIZE
batch_query_index, batch_pids = query_indexes[offset:endpos], pids[offset:endpos]
Q = all_query_embeddings[batch_query_index]
scores.extend(self.rank(Q, batch_pids, views, shift=tensor_offset))
return scores
def torch_percentile(tensor, p):
assert p in range(1, 100+1)
assert tensor.dim() == 1
return tensor.kthvalue(int(p * tensor.size(0) / 100.0)).values.item()