isayahc's picture
general refactoring
5439651 unverified
raw
history blame
4.97 kB
from pathlib import Path
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv
import os
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
import requests
from langchain_community.vectorstores import Chroma
load_dotenv()
def get_reranked_docs_faiss(
query:str,
path_to_db:str,
embedding_model:str,
hf_api_key:str,
num_docs:int=5
) -> list:
""" Re-ranks the similarity search results and returns top-k highest ranked docs
Args:
query (str): The search query
path_to_db (str): Path to the vectorstore database
embedding_model (str): Embedding model used in the vector store
num_docs (int): Number of documents to return
Returns: A list of documents with the highest rank
"""
assert num_docs <= 10, "num_docs should be less than similarity search results"
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=hf_api_key,
model_name=embedding_model
)
# Load the vectorstore database
db = FAISS.load_local(
folder_path=path_to_db,
embeddings=embeddings,
allow_dangerous_deserialization=True
)
# Get 10 documents based on similarity search
docs = db.similarity_search(query=query, k=10)
# Add the page_content, description and title together
passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "")
for doc in docs]
# Prepare the payload
inputs = [{"text": query, "text_pair": passage} for passage in passages]
API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking"
headers = {"Authorization": f"Bearer {hf_api_key}"}
response = requests.post(API_URL, headers=headers, json=inputs)
scores = response.json()
try:
relevance_scores = [item[1]['score'] for item in scores]
except ValueError as e:
print('Could not get the relevance_scores -> something might be wrong with the json output')
return
if relevance_scores:
ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
top_k_results = ranked_results[:num_docs]
return [doc for doc, _, _ in top_k_results]
def get_reranked_docs_chroma(query:str,
path_to_db:str,
embedding_model:str,
hf_api_key:str,
reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2",
num_docs:int=5) -> list:
""" Re-ranks the similarity search results and returns top-k highest ranked docs
Args:
query (str): The search query
path_to_db (str): Path to the vectorstore database
embedding_model (str): Embedding model used in the vector store
num_docs (int): Number of documents to return
Returns: A list of documents with the highest rank
"""
embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
model_name=embedding_model)
# Load the vectorstore database
db = Chroma(persist_directory=path_to_db, embedding_function=embeddings)
# Get k documents based on similarity search
sim_docs = db.similarity_search(query=query, k=10)
passages = [doc.page_content for doc in sim_docs]
# Prepare the payload
payload = {"inputs":
{"source_sentence": query,
"sentences": passages}}
headers = {"Authorization": f"Bearer {hf_api_key}"}
response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
print(f'{response = }')
if response.status_code != 200:
print('Something went wrong with the response')
return
similarity_scores = response.json()
ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
top_k_results = ranked_results[:num_docs]
return [doc for doc, _, _ in top_k_results]
if __name__ == "__main__":
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
project_dir = Path().cwd().parent
path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500')
assert Path(path_to_vector_db).exists(), "Cannot access path_to_vector_db "
query = "I'm looking for student insurance"
re_ranked_docs = get_reranked_docs_chroma(query=query,
path_to_db= path_to_vector_db,
embedding_model=EMBEDDING_MODEL,
hf_api_key=HUGGINGFACEHUB_API_TOKEN)
print(f"{re_ranked_docs=}")