|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import CrossEncoderReranker |
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder |
|
from utils.retriever import get_retriever |
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") |
|
TOP_N = 5 |
|
COMPRESSOR = CrossEncoderReranker(model=MODEL, |
|
top_n=TOP_N) |
|
|
|
|
|
|
|
class RerankRetriever: |
|
def __init__(self): |
|
pass |
|
|
|
def get_base_retriever(self, **kwargs): |
|
filters = {**kwargs} |
|
retriever = get_retriever(**filters) |
|
return retriever |
|
|
|
def get_compression_retriever(self, **kwargs): |
|
|
|
|
|
base_retriever_used = self.get_base_retriever(**kwargs) |
|
|
|
|
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=COMPRESSOR, |
|
base_retriever=base_retriever_used, |
|
tags=["qa_retriever", "rerank"] |
|
) |
|
|
|
return compression_retriever |
|
|
|
def pretty_print_docs(self, docs): |
|
return( |
|
f"\n{'-' * 100}\n".join( |
|
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] |
|
) |
|
) |