SUMANA SUMANAKUL (ING)
first commit
30adccc
raw
history blame
2.82 kB
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from .retriever import get_retriever
MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
TOP_N = 15 # 7
COMPRESSOR = CrossEncoderReranker(model=MODEL, top_n=TOP_N)
class RerankRetriever:
def __init__(self):
pass
def get_base_retriever(self, **kwargs):
"""
ส่ง kwargs ทั้งหมดไปให้ get_retriever โดยตรง
"""
retriever = get_retriever(**kwargs)
return retriever
def get_compression_retriever(self, **kwargs):
"""
รับ kwargs และส่งต่อไปยัง get_base_retriever
"""
base_retriever_used = self.get_base_retriever(**kwargs)
compression_retriever = ContextualCompressionRetriever(
base_compressor=COMPRESSOR,
base_retriever=base_retriever_used,
)
return compression_retriever
def pretty_print_docs(self, docs):
return(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
# from langchain.retrievers import ContextualCompressionRetriever
# from langchain.retrievers.document_compressors import CrossEncoderReranker
# from langchain_community.cross_encoders import HuggingFaceCrossEncoder
# from .retriever import get_retriever
# # -- HuggingFaceCrossEncoder
# # https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/
# # ---- Configurations for model ----
# MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
# TOP_N = 15
# COMPRESSOR = CrossEncoderReranker(model=MODEL,
# top_n=TOP_N)
# # ---- Reranker Retriever ----
# class RerankRetriever:
# def __init__(self):
# pass
# def get_base_retriever(self, **kwargs):
# filters = {**kwargs}
# retriever = get_retriever(**filters)
# return retriever
# def get_compression_retriever(self, **kwargs):
# # ---- get_base_retriever ----
# base_retriever_used = self.get_base_retriever(**kwargs)
# # ---- Instantiate compression retriever ----
# compression_retriever = ContextualCompressionRetriever(
# base_compressor=COMPRESSOR,
# base_retriever=base_retriever_used,
# tags=["qa_retriever", "rerank"]
# )
# return compression_retriever
# def pretty_print_docs(self, docs):
# return(
# f"\n{'-' * 100}\n".join(
# [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
# )
# )