File size: 2,822 Bytes
30adccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from .retriever import get_retriever

MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
TOP_N = 15 # 7
COMPRESSOR = CrossEncoderReranker(model=MODEL, top_n=TOP_N)

class RerankRetriever:
    def __init__(self):
        pass
    
    def get_base_retriever(self, **kwargs):
        """
        ส่ง kwargs ทั้งหมดไปให้ get_retriever โดยตรง
        """
        retriever = get_retriever(**kwargs)
        return retriever

    def get_compression_retriever(self, **kwargs):
        """
        รับ kwargs และส่งต่อไปยัง get_base_retriever
        """
        base_retriever_used = self.get_base_retriever(**kwargs)
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=COMPRESSOR,
            base_retriever=base_retriever_used,
        )
        return compression_retriever

    def pretty_print_docs(self, docs):
        return(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))

# from langchain.retrievers import ContextualCompressionRetriever
# from langchain.retrievers.document_compressors import CrossEncoderReranker
# from langchain_community.cross_encoders import HuggingFaceCrossEncoder
# from .retriever import get_retriever

# # -- HuggingFaceCrossEncoder
# # https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/


# # ---- Configurations for model ----
# MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
# TOP_N = 15
# COMPRESSOR = CrossEncoderReranker(model=MODEL, 
#                                   top_n=TOP_N)


# # ---- Reranker Retriever ----
# class RerankRetriever:
#     def __init__(self):
#         pass
    
#     def get_base_retriever(self, **kwargs):
#         filters = {**kwargs}
#         retriever = get_retriever(**filters)
#         return retriever

#     def get_compression_retriever(self, **kwargs):

#         # ---- get_base_retriever ----
#         base_retriever_used = self.get_base_retriever(**kwargs)

#         # ---- Instantiate compression retriever ----
#         compression_retriever = ContextualCompressionRetriever(
#             base_compressor=COMPRESSOR,
#             base_retriever=base_retriever_used,
#             tags=["qa_retriever", "rerank"]
#         )

#         return compression_retriever

#     def pretty_print_docs(self, docs):
#         return(
#             f"\n{'-' * 100}\n".join(
#                 [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
#             )
#         )