import os import math import torch import ujson import traceback from itertools import accumulate from colbert.parameters import DEVICE from colbert.utils.utils import print_message, dotdict, flatten BSIZE = 1 << 14 class IndexRanker(): def __init__(self, tensor, doclens): self.tensor = tensor self.doclens = doclens self.maxsim_dtype = torch.float32 self.doclens_pfxsum = [0] + list(accumulate(self.doclens)) self.doclens = torch.tensor(self.doclens) self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum) self.dim = self.tensor.size(-1) self.strides = [torch_percentile(self.doclens, p) for p in [90]] self.strides.append(self.doclens.max().item()) self.strides = sorted(list(set(self.strides))) print_message(f"#> Using strides {self.strides}..") self.views = self._create_views(self.tensor) self.buffers = self._create_buffers(BSIZE, self.tensor.dtype, {'cpu', 'cuda:0'}) def _create_views(self, tensor): views = [] for stride in self.strides: outdim = tensor.size(0) - stride + 1 view = torch.as_strided(tensor, (outdim, stride, self.dim), (self.dim, self.dim, 1)) views.append(view) return views def _create_buffers(self, max_bsize, dtype, devices): buffers = {} for device in devices: buffers[device] = [torch.zeros(max_bsize, stride, self.dim, dtype=dtype, device=device, pin_memory=(device == 'cpu')) for stride in self.strides] return buffers def rank(self, Q, pids, views=None, shift=0): assert len(pids) > 0 assert Q.size(0) in [1, len(pids)] Q = Q.contiguous().to(DEVICE).to(dtype=self.maxsim_dtype) views = self.views if views is None else views VIEWS_DEVICE = views[0].device D_buffers = self.buffers[str(VIEWS_DEVICE)] raw_pids = pids if type(pids) is list else pids.tolist() pids = torch.tensor(pids) if type(pids) is list else pids doclens, offsets = self.doclens[pids], self.doclens_pfxsum[pids] assignments = (doclens.unsqueeze(1) > torch.tensor(self.strides).unsqueeze(0) + 1e-6).sum(-1) one_to_n = torch.arange(len(raw_pids)) output_pids, output_scores, output_permutation = [], [], [] for group_idx, stride in enumerate(self.strides): locator = (assignments == group_idx) if locator.sum() < 1e-5: continue group_pids, group_doclens, group_offsets = pids[locator], doclens[locator], offsets[locator] group_Q = Q if Q.size(0) == 1 else Q[locator] group_offsets = group_offsets.to(VIEWS_DEVICE) - shift group_offsets_uniq, group_offsets_expand = torch.unique_consecutive(group_offsets, return_inverse=True) D_size = group_offsets_uniq.size(0) D = torch.index_select(views[group_idx], 0, group_offsets_uniq, out=D_buffers[group_idx][:D_size]) D = D.to(DEVICE) D = D[group_offsets_expand.to(DEVICE)].to(dtype=self.maxsim_dtype) mask = torch.arange(stride, device=DEVICE) + 1 mask = mask.unsqueeze(0) <= group_doclens.to(DEVICE).unsqueeze(-1) scores = (D @ group_Q) * mask.unsqueeze(-1) scores = scores.max(1).values.sum(-1).cpu() output_pids.append(group_pids) output_scores.append(scores) output_permutation.append(one_to_n[locator]) output_permutation = torch.cat(output_permutation).sort().indices output_pids = torch.cat(output_pids)[output_permutation].tolist() output_scores = torch.cat(output_scores)[output_permutation].tolist() assert len(raw_pids) == len(output_pids) assert len(raw_pids) == len(output_scores) assert raw_pids == output_pids return output_scores def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids, sorted_pids): assert sorted_pids is True ###### scores = [] range_start, range_end = 0, 0 for pid_offset in range(0, len(self.doclens), 50_000): pid_endpos = min(pid_offset + 50_000, len(self.doclens)) range_start = range_start + (all_pids[range_start:] < pid_offset).sum() range_end = range_end + (all_pids[range_end:] < pid_endpos).sum() pids = all_pids[range_start:range_end] query_indexes = all_query_indexes[range_start:range_end] print_message(f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}.") if len(pids) == 0: continue print_message(f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range.") tensor_offset = self.doclens_pfxsum[pid_offset].item() tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512 collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE) views = self._create_views(collection) print_message(f"#> Ranking in batches of {BSIZE} query--passage pairs...") for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)): if batch_idx % 100 == 0: print_message("#> Processing batch #{}..".format(batch_idx)) endpos = offset + BSIZE batch_query_index, batch_pids = query_indexes[offset:endpos], pids[offset:endpos] Q = all_query_embeddings[batch_query_index] scores.extend(self.rank(Q, batch_pids, views, shift=tensor_offset)) return scores def torch_percentile(tensor, p): assert p in range(1, 100+1) assert tensor.dim() == 1 return tensor.kthvalue(int(p * tensor.size(0) / 100.0)).values.item()