Harry_Potter / retriever_reranker_final.py
Sonja-Subt's picture
Upload folder using huggingface_hub
c55e75f verified
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from chunker_final import chunk_documents_to_dict
import numpy as np
class Retriever:
def __init__(self, docs: dict) -> None:
self.chunked_docs = chunk_documents_to_dict(docs)
self.chunk_ids = list(self.chunked_docs.keys())
self.chunk_texts = list(self.chunked_docs.values())
tokenized_chunks = [text.lower().split(" ") for text in self.chunk_texts]
self.bm25 = BM25Okapi(tokenized_chunks)
self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
self.doc_embeddings = self.sbert.encode(self.chunk_texts)
def get_docs(self, query, method, n=15) -> dict:
if method == "BM25":
scores = self._get_bm25_scores(query)
elif method == "semantic":
scores = self._get_semantic_scores(query)
elif method == "combined search":
bm25_scores = self._get_bm25_scores(query)
semantic_scores = self._get_semantic_scores(query)
scores = 0.3 * bm25_scores + 0.7 * semantic_scores
else:
raise ValueError(f"Invalid search method: {method}")
sorted_indices = scores.argsort(descending=True)
result = {self.chunk_ids[i]: self.chunk_texts[i] for i in sorted_indices[:n]}
return result
def rerank(self, query, retrieved_docs: dict) -> dict:
query_embedding = self.sbert.encode(query)
rerank_scores = {}
for chunk_id, chunk_text in retrieved_docs.items():
chunk_embedding = self.sbert.encode(chunk_text)
similarity = np.dot(query_embedding, chunk_embedding) / (
np.linalg.norm(query_embedding) * np.linalg.norm(chunk_embedding)
)
rerank_scores[chunk_id] = similarity
sorted_chunks = sorted(rerank_scores.items(), key=lambda x: x[1], reverse=True)
reranked_docs = {chunk_id: retrieved_docs[chunk_id] for chunk_id, _ in sorted_chunks}
return reranked_docs
def _get_bm25_scores(self, query):
tokenized_query = query.lower().split(" ")
return torch.tensor(self.bm25.get_scores(tokenized_query))
def _get_semantic_scores(self, query):
query_embedding = self.sbert.encode(query)
scores = np.dot(self.doc_embeddings, query_embedding) / (
np.linalg.norm(self.doc_embeddings, axis=1) * np.linalg.norm(query_embedding)
)
return torch.tensor(scores)