from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CrossEncoderReranker from langchain_community.cross_encoders import HuggingFaceCrossEncoder from .retriever import get_retriever MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") TOP_N = 15 # 7 COMPRESSOR = CrossEncoderReranker(model=MODEL, top_n=TOP_N) class RerankRetriever: def __init__(self): pass def get_base_retriever(self, **kwargs): """ ส่ง kwargs ทั้งหมดไปให้ get_retriever โดยตรง """ retriever = get_retriever(**kwargs) return retriever def get_compression_retriever(self, **kwargs): """ รับ kwargs และส่งต่อไปยัง get_base_retriever """ base_retriever_used = self.get_base_retriever(**kwargs) compression_retriever = ContextualCompressionRetriever( base_compressor=COMPRESSOR, base_retriever=base_retriever_used, ) return compression_retriever def pretty_print_docs(self, docs): return(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)])) # from langchain.retrievers import ContextualCompressionRetriever # from langchain.retrievers.document_compressors import CrossEncoderReranker # from langchain_community.cross_encoders import HuggingFaceCrossEncoder # from .retriever import get_retriever # # -- HuggingFaceCrossEncoder # # https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/ # # ---- Configurations for model ---- # MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") # TOP_N = 15 # COMPRESSOR = CrossEncoderReranker(model=MODEL, # top_n=TOP_N) # # ---- Reranker Retriever ---- # class RerankRetriever: # def __init__(self): # pass # def get_base_retriever(self, **kwargs): # filters = {**kwargs} # retriever = get_retriever(**filters) # return retriever # def get_compression_retriever(self, **kwargs): # # ---- get_base_retriever ---- # base_retriever_used = self.get_base_retriever(**kwargs) # # ---- Instantiate compression retriever ---- # compression_retriever = ContextualCompressionRetriever( # base_compressor=COMPRESSOR, # base_retriever=base_retriever_used, # tags=["qa_retriever", "rerank"] # ) # return compression_retriever # def pretty_print_docs(self, docs): # return( # f"\n{'-' * 100}\n".join( # [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] # ) # )