|
import argparse |
|
import csv |
|
import json |
|
import os |
|
import time |
|
import pickle |
|
|
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
from rank_bm25 import BM25Okapi |
|
from sentence_transformers import SentenceTransformer |
|
|
|
def gtr_build_index(encoder, docs): |
|
with torch.inference_mode(): |
|
embs = encoder.encode(docs, show_progress_bar=True, normalize_embeddings=True) |
|
embs = embs.astype("float16") |
|
|
|
GTR_EMB = os.environ.get("GTR_EMB") |
|
with open(GTR_EMB, "wb") as f: |
|
pickle.dump(embs, f) |
|
return embs |
|
|
|
|
|
class DPRRetriever: |
|
def __init__(self, DPR_WIKI_TSV, GTR_EMB = None, emb_model = "sentence-transformers/gtr-t5-xxl") -> None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.device = device |
|
self.encoder = SentenceTransformer(emb_model, device = device) |
|
self.docs = [] |
|
print("loading wikipedia file...") |
|
with open(DPR_WIKI_TSV) as f: |
|
reader = csv.reader(f, delimiter="\t") |
|
for i, row in enumerate(reader): |
|
if i == 0: |
|
continue |
|
self.docs.append(row[2] + "\n" + row[1]) |
|
if not GTR_EMB: |
|
print("gtr embeddings not found, building...") |
|
embs = gtr_build_index(self.encoder, self.docs) |
|
else: |
|
print("gtr embeddings found, loading...") |
|
with open(GTR_EMB, "rb") as f: |
|
embs = pickle.load(f) |
|
|
|
self.gtr_emb = torch.tensor(embs, dtype=torch.float16, device=device) |
|
|
|
def retrieve(self, question, topk): |
|
with torch.inference_mode(): |
|
query = self.encoder.encode(question, batch_size=4, show_progress_bar=True, normalize_embeddings=True) |
|
query = torch.tensor(query, dtype=torch.float16, device=self.device) |
|
query = query.to(self.device) |
|
scores = torch.matmul(self.gtr_emb, query) |
|
score, idx = torch.topk(scores, topk) |
|
ret = [] |
|
for i in range(idx.size(0)): |
|
title, text = self.docs[idx[i].item()].split("\n") |
|
ret.append({"id": str(idx[i].item() + 1), "title": title, "text": text, "score": score[i].item()}) |
|
|
|
return ret |
|
|
|
def __repr__(self) -> str: |
|
return 'DPR Retriever' |
|
|
|
def __str__(self) -> str: |
|
return repr(self) |
|
|
|
class BM25Retriever: |
|
def __init__(self, DPR_WIKI_TSV): |
|
self.docs = [] |
|
self.tokenized_docs = [] |
|
print("loading wikipedia file...") |
|
with open(DPR_WIKI_TSV) as f: |
|
reader = csv.reader(f, delimiter="\t") |
|
for i, row in enumerate(reader): |
|
if i == 0: |
|
continue |
|
self.docs.append(row[2] + "\n" + row[1]) |
|
self.tokenized_docs.append((row[2] + " " + row[1]).split()) |
|
|
|
print("BM25 index building...") |
|
self.bm25 = BM25Okapi(self.tokenized_docs) |
|
|
|
def retrieve(self, question, topk): |
|
query = question.split() |
|
scores = self.bm25.get_scores(query) |
|
topk_indices = scores.argsort()[-topk:][::-1] |
|
ret = [] |
|
for idx in topk_indices: |
|
title, text = self.docs[idx].split("\n", 1) |
|
ret.append({"id": str(idx + 1), "title": title, "text": text, "score": scores[idx]}) |
|
|
|
return ret |
|
def __repr__(self) -> str: |
|
return 'BM25 Retriever' |
|
|
|
def __str__(self) -> str: |
|
return repr(self) |
|
|