|
import os
|
|
import time
|
|
import faiss
|
|
import random
|
|
import torch
|
|
|
|
from multiprocessing import Pool
|
|
from colbert.modeling.inference import ModelInference
|
|
|
|
from colbert.utils.utils import print_message, flatten, batch
|
|
from colbert.indexing.loaders import load_doclens
|
|
|
|
|
|
class FaissIndex():
|
|
def __init__(self, index_path, faiss_index_path, nprobe, part_range=None):
|
|
print_message("#> Loading the FAISS index from", faiss_index_path, "..")
|
|
|
|
faiss_part_range = os.path.basename(faiss_index_path).split('.')[-2].split('-')
|
|
|
|
if len(faiss_part_range) == 2:
|
|
faiss_part_range = range(*map(int, faiss_part_range))
|
|
assert part_range[0] in faiss_part_range, (part_range, faiss_part_range)
|
|
assert part_range[-1] in faiss_part_range, (part_range, faiss_part_range)
|
|
else:
|
|
faiss_part_range = None
|
|
|
|
self.part_range = part_range
|
|
self.faiss_part_range = faiss_part_range
|
|
|
|
self.faiss_index = faiss.read_index(faiss_index_path)
|
|
self.faiss_index.nprobe = nprobe
|
|
|
|
print_message("#> Building the emb2pid mapping..")
|
|
all_doclens = load_doclens(index_path, flatten=False)
|
|
|
|
pid_offset = 0
|
|
if faiss_part_range is not None:
|
|
print(f"#> Restricting all_doclens to the range {faiss_part_range}.")
|
|
pid_offset = len(flatten(all_doclens[:faiss_part_range.start]))
|
|
all_doclens = all_doclens[faiss_part_range.start:faiss_part_range.stop]
|
|
|
|
self.relative_range = None
|
|
if self.part_range is not None:
|
|
start = self.faiss_part_range.start if self.faiss_part_range is not None else 0
|
|
a = len(flatten(all_doclens[:self.part_range.start - start]))
|
|
b = len(flatten(all_doclens[:self.part_range.stop - start]))
|
|
self.relative_range = range(a, b)
|
|
print(f"self.relative_range = {self.relative_range}")
|
|
|
|
all_doclens = flatten(all_doclens)
|
|
|
|
total_num_embeddings = sum(all_doclens)
|
|
self.emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)
|
|
|
|
offset_doclens = 0
|
|
for pid, dlength in enumerate(all_doclens):
|
|
self.emb2pid[offset_doclens: offset_doclens + dlength] = pid_offset + pid
|
|
offset_doclens += dlength
|
|
|
|
print_message("len(self.emb2pid) =", len(self.emb2pid))
|
|
|
|
self.parallel_pool = Pool(16)
|
|
|
|
def retrieve(self, faiss_depth, Q, verbose=False):
|
|
embedding_ids = self.queries_to_embedding_ids(faiss_depth, Q, verbose=verbose)
|
|
pids = self.embedding_ids_to_pids(embedding_ids, verbose=verbose)
|
|
|
|
if self.relative_range is not None:
|
|
pids = [[pid for pid in pids_ if pid in self.relative_range] for pids_ in pids]
|
|
|
|
return pids
|
|
|
|
def queries_to_embedding_ids(self, faiss_depth, Q, verbose=True):
|
|
|
|
num_queries, embeddings_per_query, dim = Q.size()
|
|
Q_faiss = Q.view(num_queries * embeddings_per_query, dim).cpu().contiguous()
|
|
|
|
|
|
print_message("#> Search in batches with faiss. \t\t",
|
|
f"Q.size() = {Q.size()}, Q_faiss.size() = {Q_faiss.size()}",
|
|
condition=verbose)
|
|
|
|
embeddings_ids = []
|
|
faiss_bsize = embeddings_per_query * 5000
|
|
for offset in range(0, Q_faiss.size(0), faiss_bsize):
|
|
endpos = min(offset + faiss_bsize, Q_faiss.size(0))
|
|
|
|
print_message("#> Searching from {} to {}...".format(offset, endpos), condition=verbose)
|
|
|
|
some_Q_faiss = Q_faiss[offset:endpos].float().numpy()
|
|
_, some_embedding_ids = self.faiss_index.search(some_Q_faiss, faiss_depth)
|
|
embeddings_ids.append(torch.from_numpy(some_embedding_ids))
|
|
|
|
embedding_ids = torch.cat(embeddings_ids)
|
|
|
|
|
|
embedding_ids = embedding_ids.view(num_queries, embeddings_per_query * embedding_ids.size(1))
|
|
|
|
return embedding_ids
|
|
|
|
def embedding_ids_to_pids(self, embedding_ids, verbose=True):
|
|
|
|
print_message("#> Lookup the PIDs..", condition=verbose)
|
|
all_pids = self.emb2pid[embedding_ids]
|
|
|
|
print_message(f"#> Converting to a list [shape = {all_pids.size()}]..", condition=verbose)
|
|
all_pids = all_pids.tolist()
|
|
|
|
print_message("#> Removing duplicates (in parallel if large enough)..", condition=verbose)
|
|
|
|
if len(all_pids) > 5000:
|
|
all_pids = list(self.parallel_pool.map(uniq, all_pids))
|
|
else:
|
|
all_pids = list(map(uniq, all_pids))
|
|
|
|
print_message("#> Done with embedding_ids_to_pids().", condition=verbose)
|
|
|
|
return all_pids
|
|
|
|
|
|
def uniq(l):
|
|
return list(set(l))
|
|
|