law_poc / utils /reranker.py
SUMANA SUMANAKUL (ING)
commit
8e5a9dd
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from utils.retriever import get_retriever
# -- HuggingFaceCrossEncoder
# https://python.langchain.com/docs/integrations/document_transformers/cross_encoder_reranker/
# ---- Configurations for model ----
MODEL = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
TOP_N = 5
COMPRESSOR = CrossEncoderReranker(model=MODEL,
top_n=TOP_N)
# ---- Reranker Retriever ----
class RerankRetriever:
def __init__(self):
pass
def get_base_retriever(self, **kwargs):
filters = {**kwargs}
retriever = get_retriever(**filters)
return retriever
def get_compression_retriever(self, **kwargs):
# ---- get_base_retriever ----
base_retriever_used = self.get_base_retriever(**kwargs)
# ---- Instantiate compression retriever ----
compression_retriever = ContextualCompressionRetriever(
base_compressor=COMPRESSOR,
base_retriever=base_retriever_used,
tags=["qa_retriever", "rerank"]
)
return compression_retriever
def pretty_print_docs(self, docs):
return(
f"\n{'-' * 100}\n".join(
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
)
)