Spaces:
Sleeping
Sleeping
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from .retriever import get_retriever | |
MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | |
TOP_N = 15 # 7 | |
COMPRESSOR = CrossEncoderReranker(model=MODEL, top_n=TOP_N) | |
class RerankRetriever: | |
def __init__(self): | |
pass | |
def get_base_retriever(self, **kwargs): | |
""" | |
ส่ง kwargs ทั้งหมดไปให้ get_retriever โดยตรง | |
""" | |
retriever = get_retriever(**kwargs) | |
return retriever | |
def get_compression_retriever(self, **kwargs): | |
""" | |
รับ kwargs และส่งต่อไปยัง get_base_retriever | |
""" | |
base_retriever_used = self.get_base_retriever(**kwargs) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=COMPRESSOR, | |
base_retriever=base_retriever_used, | |
) | |
return compression_retriever | |
def pretty_print_docs(self, docs): | |
return(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)])) | |
# from langchain.retrievers import ContextualCompressionRetriever | |
# from langchain.retrievers.document_compressors import CrossEncoderReranker | |
# from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
# from .retriever import get_retriever | |
# # -- HuggingFaceCrossEncoder | |
# # https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/ | |
# # ---- Configurations for model ---- | |
# MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | |
# TOP_N = 15 | |
# COMPRESSOR = CrossEncoderReranker(model=MODEL, | |
# top_n=TOP_N) | |
# # ---- Reranker Retriever ---- | |
# class RerankRetriever: | |
# def __init__(self): | |
# pass | |
# def get_base_retriever(self, **kwargs): | |
# filters = {**kwargs} | |
# retriever = get_retriever(**filters) | |
# return retriever | |
# def get_compression_retriever(self, **kwargs): | |
# # ---- get_base_retriever ---- | |
# base_retriever_used = self.get_base_retriever(**kwargs) | |
# # ---- Instantiate compression retriever ---- | |
# compression_retriever = ContextualCompressionRetriever( | |
# base_compressor=COMPRESSOR, | |
# base_retriever=base_retriever_used, | |
# tags=["qa_retriever", "rerank"] | |
# ) | |
# return compression_retriever | |
# def pretty_print_docs(self, docs): | |
# return( | |
# f"\n{'-' * 100}\n".join( | |
# [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] | |
# ) | |
# ) |