Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from langchain_community.vectorstores import FAISS | |
from dotenv import load_dotenv | |
import os | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
import requests | |
from langchain_community.vectorstores import Chroma | |
load_dotenv() | |
def get_reranked_docs_faiss( | |
query:str, | |
path_to_db:str, | |
embedding_model:str, | |
hf_api_key:str, | |
num_docs:int=5 | |
) -> list: | |
""" Re-ranks the similarity search results and returns top-k highest ranked docs | |
Args: | |
query (str): The search query | |
path_to_db (str): Path to the vectorstore database | |
embedding_model (str): Embedding model used in the vector store | |
num_docs (int): Number of documents to return | |
Returns: A list of documents with the highest rank | |
""" | |
assert num_docs <= 10, "num_docs should be less than similarity search results" | |
embeddings = HuggingFaceInferenceAPIEmbeddings( | |
api_key=hf_api_key, | |
model_name=embedding_model | |
) | |
# Load the vectorstore database | |
db = FAISS.load_local( | |
folder_path=path_to_db, | |
embeddings=embeddings, | |
allow_dangerous_deserialization=True | |
) | |
# Get 10 documents based on similarity search | |
docs = db.similarity_search(query=query, k=10) | |
# Add the page_content, description and title together | |
passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "") | |
for doc in docs] | |
# Prepare the payload | |
inputs = [{"text": query, "text_pair": passage} for passage in passages] | |
API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking" | |
headers = {"Authorization": f"Bearer {hf_api_key}"} | |
response = requests.post(API_URL, headers=headers, json=inputs) | |
scores = response.json() | |
try: | |
relevance_scores = [item[1]['score'] for item in scores] | |
except ValueError as e: | |
print('Could not get the relevance_scores -> something might be wrong with the json output') | |
return | |
if relevance_scores: | |
ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True) | |
top_k_results = ranked_results[:num_docs] | |
return [doc for doc, _, _ in top_k_results] | |
def get_reranked_docs_chroma(query:str, | |
path_to_db:str, | |
embedding_model:str, | |
hf_api_key:str, | |
reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2", | |
num_docs:int=5) -> list: | |
""" Re-ranks the similarity search results and returns top-k highest ranked docs | |
Args: | |
query (str): The search query | |
path_to_db (str): Path to the vectorstore database | |
embedding_model (str): Embedding model used in the vector store | |
num_docs (int): Number of documents to return | |
Returns: A list of documents with the highest rank | |
""" | |
embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key, | |
model_name=embedding_model) | |
# Load the vectorstore database | |
db = Chroma(persist_directory=path_to_db, embedding_function=embeddings) | |
# Get k documents based on similarity search | |
sim_docs = db.similarity_search(query=query, k=10) | |
passages = [doc.page_content for doc in sim_docs] | |
# Prepare the payload | |
payload = {"inputs": | |
{"source_sentence": query, | |
"sentences": passages}} | |
headers = {"Authorization": f"Bearer {hf_api_key}"} | |
response = requests.post(url=reranking_hf_url, headers=headers, json=payload) | |
print(f'{response = }') | |
if response.status_code != 200: | |
print('Something went wrong with the response') | |
return | |
similarity_scores = response.json() | |
ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True) | |
top_k_results = ranked_results[:num_docs] | |
return [doc for doc, _, _ in top_k_results] | |
if __name__ == "__main__": | |
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
project_dir = Path().cwd().parent | |
path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500') | |
assert Path(path_to_vector_db).exists(), "Cannot access path_to_vector_db " | |
query = "I'm looking for student insurance" | |
re_ranked_docs = get_reranked_docs_chroma(query=query, | |
path_to_db= path_to_vector_db, | |
embedding_model=EMBEDDING_MODEL, | |
hf_api_key=HUGGINGFACEHUB_API_TOKEN) | |
print(f"{re_ranked_docs=}") | |