File size: 6,116 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import math
import torch
import ujson
import traceback
from itertools import accumulate
from colbert.parameters import DEVICE
from colbert.utils.utils import print_message, dotdict, flatten
BSIZE = 1 << 14
class IndexRanker():
def __init__(self, tensor, doclens):
self.tensor = tensor
self.doclens = doclens
self.maxsim_dtype = torch.float32
self.doclens_pfxsum = [0] + list(accumulate(self.doclens))
self.doclens = torch.tensor(self.doclens)
self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum)
self.dim = self.tensor.size(-1)
self.strides = [torch_percentile(self.doclens, p) for p in [90]]
self.strides = sorted(list(set(self.strides)))
print_message(f"#> Using strides {self.strides}..")
self.views = self._create_views(self.tensor)
self.buffers = self._create_buffers(BSIZE, self.tensor.dtype, {'cpu', 'cuda:0'})
def _create_views(self, tensor):
views = []
for stride in self.strides:
outdim = tensor.size(0) - stride + 1
view = torch.as_strided(tensor, (outdim, stride, self.dim), (self.dim, self.dim, 1))
return views
def _create_buffers(self, max_bsize, dtype, devices):
buffers = {}
for device in devices:
buffers[device] = [torch.zeros(max_bsize, stride, self.dim, dtype=dtype,
device=device, pin_memory=(device == 'cpu'))
for stride in self.strides]
return buffers
def rank(self, Q, pids, views=None, shift=0):
assert len(pids) > 0
assert Q.size(0) in [1, len(pids)]
Q = Q.contiguous().to(DEVICE).to(dtype=self.maxsim_dtype)
views = self.views if views is None else views
VIEWS_DEVICE = views[0].device
D_buffers = self.buffers[str(VIEWS_DEVICE)]
raw_pids = pids if type(pids) is list else pids.tolist()
pids = torch.tensor(pids) if type(pids) is list else pids
doclens, offsets = self.doclens[pids], self.doclens_pfxsum[pids]
assignments = (doclens.unsqueeze(1) > torch.tensor(self.strides).unsqueeze(0) + 1e-6).sum(-1)
one_to_n = torch.arange(len(raw_pids))
output_pids, output_scores, output_permutation = [], [], []
for group_idx, stride in enumerate(self.strides):
locator = (assignments == group_idx)
if locator.sum() < 1e-5:
group_pids, group_doclens, group_offsets = pids[locator], doclens[locator], offsets[locator]
group_Q = Q if Q.size(0) == 1 else Q[locator]
group_offsets = - shift
group_offsets_uniq, group_offsets_expand = torch.unique_consecutive(group_offsets, return_inverse=True)
D_size = group_offsets_uniq.size(0)
D = torch.index_select(views[group_idx], 0, group_offsets_uniq, out=D_buffers[group_idx][:D_size])
D =
D = D[].to(dtype=self.maxsim_dtype)
mask = torch.arange(stride, device=DEVICE) + 1
mask = mask.unsqueeze(0) <=
scores = (D @ group_Q) * mask.unsqueeze(-1)
scores = scores.max(1).values.sum(-1).cpu()
output_permutation =
output_pids =[output_permutation].tolist()
output_scores =[output_permutation].tolist()
assert len(raw_pids) == len(output_pids)
assert len(raw_pids) == len(output_scores)
assert raw_pids == output_pids
return output_scores
def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids, sorted_pids):
assert sorted_pids is True
scores = []
range_start, range_end = 0, 0
for pid_offset in range(0, len(self.doclens), 50_000):
pid_endpos = min(pid_offset + 50_000, len(self.doclens))
range_start = range_start + (all_pids[range_start:] < pid_offset).sum()
range_end = range_end + (all_pids[range_end:] < pid_endpos).sum()
pids = all_pids[range_start:range_end]
query_indexes = all_query_indexes[range_start:range_end]
print_message(f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}.")
if len(pids) == 0:
print_message(f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range.")
tensor_offset = self.doclens_pfxsum[pid_offset].item()
tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512
collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE)
views = self._create_views(collection)
print_message(f"#> Ranking in batches of {BSIZE} query--passage pairs...")
for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)):
if batch_idx % 100 == 0:
print_message("#> Processing batch #{}..".format(batch_idx))
endpos = offset + BSIZE
batch_query_index, batch_pids = query_indexes[offset:endpos], pids[offset:endpos]
Q = all_query_embeddings[batch_query_index]
scores.extend(self.rank(Q, batch_pids, views, shift=tensor_offset))
return scores
def torch_percentile(tensor, p):
assert p in range(1, 100+1)
assert tensor.dim() == 1
return tensor.kthvalue(int(p * tensor.size(0) / 100.0)).values.item()