Spaces:
Runtime error
Runtime error
import torch | |
from rank_bm25 import BM25Okapi | |
from sentence_transformers import SentenceTransformer | |
from chunker_final import chunk_documents_to_dict | |
import numpy as np | |
class Retriever: | |
def __init__(self, docs: dict) -> None: | |
self.chunked_docs = chunk_documents_to_dict(docs) | |
self.chunk_ids = list(self.chunked_docs.keys()) | |
self.chunk_texts = list(self.chunked_docs.values()) | |
tokenized_chunks = [text.lower().split(" ") for text in self.chunk_texts] | |
self.bm25 = BM25Okapi(tokenized_chunks) | |
self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1') | |
self.doc_embeddings = self.sbert.encode(self.chunk_texts) | |
def get_docs(self, query, method, n=15) -> dict: | |
if method == "BM25": | |
scores = self._get_bm25_scores(query) | |
elif method == "semantic": | |
scores = self._get_semantic_scores(query) | |
elif method == "combined search": | |
bm25_scores = self._get_bm25_scores(query) | |
semantic_scores = self._get_semantic_scores(query) | |
scores = 0.3 * bm25_scores + 0.7 * semantic_scores | |
else: | |
raise ValueError(f"Invalid search method: {method}") | |
sorted_indices = scores.argsort(descending=True) | |
result = {self.chunk_ids[i]: self.chunk_texts[i] for i in sorted_indices[:n]} | |
return result | |
def rerank(self, query, retrieved_docs: dict) -> dict: | |
query_embedding = self.sbert.encode(query) | |
rerank_scores = {} | |
for chunk_id, chunk_text in retrieved_docs.items(): | |
chunk_embedding = self.sbert.encode(chunk_text) | |
similarity = np.dot(query_embedding, chunk_embedding) / ( | |
np.linalg.norm(query_embedding) * np.linalg.norm(chunk_embedding) | |
) | |
rerank_scores[chunk_id] = similarity | |
sorted_chunks = sorted(rerank_scores.items(), key=lambda x: x[1], reverse=True) | |
reranked_docs = {chunk_id: retrieved_docs[chunk_id] for chunk_id, _ in sorted_chunks} | |
return reranked_docs | |
def _get_bm25_scores(self, query): | |
tokenized_query = query.lower().split(" ") | |
return torch.tensor(self.bm25.get_scores(tokenized_query)) | |
def _get_semantic_scores(self, query): | |
query_embedding = self.sbert.encode(query) | |
scores = np.dot(self.doc_embeddings, query_embedding) / ( | |
np.linalg.norm(self.doc_embeddings, axis=1) * np.linalg.norm(query_embedding) | |
) | |
return torch.tensor(scores) |