File size: 2,575 Bytes
c55e75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from chunker_final import chunk_documents_to_dict
import numpy as np

class Retriever:
    def __init__(self, docs: dict) -> None:
        self.chunked_docs = chunk_documents_to_dict(docs)
        self.chunk_ids = list(self.chunked_docs.keys())
        self.chunk_texts = list(self.chunked_docs.values())
        tokenized_chunks = [text.lower().split(" ") for text in self.chunk_texts]
        self.bm25 = BM25Okapi(tokenized_chunks)
        self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
        self.doc_embeddings = self.sbert.encode(self.chunk_texts)
    def get_docs(self, query, method, n=15) -> dict:

        if method == "BM25":
            scores = self._get_bm25_scores(query)
        elif method == "semantic":
            scores = self._get_semantic_scores(query)
        elif method == "combined search":
            bm25_scores = self._get_bm25_scores(query)
            semantic_scores = self._get_semantic_scores(query)
            scores = 0.3 * bm25_scores + 0.7 * semantic_scores
        else:
            raise ValueError(f"Invalid search method: {method}")

        sorted_indices = scores.argsort(descending=True)
        result = {self.chunk_ids[i]: self.chunk_texts[i] for i in sorted_indices[:n]}
        return result

    def rerank(self, query, retrieved_docs: dict) -> dict:
        query_embedding = self.sbert.encode(query)
        rerank_scores = {}

        for chunk_id, chunk_text in retrieved_docs.items():
            chunk_embedding = self.sbert.encode(chunk_text)
            similarity = np.dot(query_embedding, chunk_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(chunk_embedding)
            )
            rerank_scores[chunk_id] = similarity

        sorted_chunks = sorted(rerank_scores.items(), key=lambda x: x[1], reverse=True)

        reranked_docs = {chunk_id: retrieved_docs[chunk_id] for chunk_id, _ in sorted_chunks}
        return reranked_docs

    def _get_bm25_scores(self, query):
        tokenized_query = query.lower().split(" ")
        return torch.tensor(self.bm25.get_scores(tokenized_query))

    def _get_semantic_scores(self, query):
        query_embedding = self.sbert.encode(query)
        scores = np.dot(self.doc_embeddings, query_embedding) / (
            np.linalg.norm(self.doc_embeddings, axis=1) * np.linalg.norm(query_embedding)
        )
        return torch.tensor(scores)