Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from qdrant_client.http import models as rest | |
| from auditqa.process_chunks import getconfig | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| model_config = getconfig("model_params.cfg") | |
| def create_filter(reports:list = [],sources:str =None, | |
| subtype:str =None,year:str =None): | |
| if len(reports) == 0: | |
| print("defining filter for sources:{},subtype:{}".format(sources,subtype)) | |
| filter=rest.Filter( | |
| must=[rest.FieldCondition( | |
| key="metadata.source", | |
| match=rest.MatchValue(value=sources) | |
| ), | |
| rest.FieldCondition( | |
| key="metadata.filename", | |
| match=rest.MatchAny(any=subtype) | |
| ), | |
| #rest.FieldCondition( | |
| # key="metadata.year", | |
| # match=rest.MatchAny(any=year) | |
| ]) | |
| else: | |
| print("defining filter for allreports:",reports) | |
| filter=rest.Filter( | |
| must=[ | |
| rest.FieldCondition( | |
| key="metadata.filename", | |
| match=rest.MatchAny(any=reports) | |
| )]) | |
| return filter | |
| def get_context(vectorstore,query,reports,sources,subtype): | |
| # create metadata filter | |
| filter = create_filter(reports=reports,sources=sources,subtype=subtype) | |
| # getting context | |
| retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", | |
| search_kwargs={"score_threshold": 0.4, | |
| "k": int(model_config.get('retriever','TOP_K')), | |
| "filter":filter}) | |
| # re-ranking the retrieved results | |
| model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL')) | |
| compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K'))) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, base_retriever=retriever | |
| ) | |
| context_retrieved = compression_retriever.invoke(query) | |
| print(f"retrieved paragraphs:{len(context_retrieved)}") | |
| return context_retrieved |