from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CrossEncoderReranker from langchain_community.cross_encoders import HuggingFaceCrossEncoder from utils.retriever import get_retriever # -- HuggingFaceCrossEncoder # https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/ # ---- Configurations for model ---- MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") TOP_N = 5 COMPRESSOR = CrossEncoderReranker(model=MODEL, top_n=TOP_N) # ---- Reranker Retriever ---- class RerankRetriever: def __init__(self): pass def get_base_retriever(self, **kwargs): filters = {**kwargs} retriever = get_retriever(**filters) return retriever def get_compression_retriever(self, **kwargs): # ---- get_base_retriever ---- base_retriever_used = self.get_base_retriever(**kwargs) # ---- Instantiate compression retriever ---- compression_retriever = ContextualCompressionRetriever( base_compressor=COMPRESSOR, base_retriever=base_retriever_used, tags=["qa_retriever", "rerank"] ) return compression_retriever def pretty_print_docs(self, docs): return( f"\n{'-' * 100}\n".join( [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] ) )