File size: 1,488 Bytes
8e5a9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from utils.retriever import get_retriever

# -- HuggingFaceCrossEncoder
# https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/


# ---- Configurations for model ----
MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
TOP_N = 5
COMPRESSOR = CrossEncoderReranker(model=MODEL, 
                                  top_n=TOP_N)


# ---- Reranker Retriever ----
class RerankRetriever:
    def __init__(self):
        pass
    
    def get_base_retriever(self, **kwargs):
        filters = {**kwargs}
        retriever = get_retriever(**filters)
        return retriever

    def get_compression_retriever(self, **kwargs):

        # ---- get_base_retriever ----
        base_retriever_used = self.get_base_retriever(**kwargs)

        # ---- Instantiate compression retriever ----
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=COMPRESSOR,
            base_retriever=base_retriever_used,
            tags=["qa_retriever", "rerank"]
        )

        return compression_retriever

    def pretty_print_docs(self, docs):
        return(
            f"\n{'-' * 100}\n".join(
                [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
            )
        )