Xalt8 commited on
Commit
466b7d1
·
1 Parent(s): 15c0646

reranking with chroma fixed

Browse files
rag_app/loading_data/load_chroma_db_cross_platform.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import boto3
3
+ from botocore.client import Config
4
+ from botocore import UNSIGNED
5
+ from dotenv import load_dotenv
6
+ import os
7
+ import sys
8
+ import zipfile
9
+
10
+
11
+ S3_LOCATION = os.getenv("S3_LOCATION")
12
+
13
+
14
+ def download_chroma_from_s3(s3_location:str,
15
+ chroma_vs_name:str,
16
+ vectorstore_folder:str,
17
+ vs_save_name:str) -> None:
18
+ """
19
+ Downloads the Chroma DB from an S3 storage to local folder
20
+
21
+ Args
22
+ s3_location (str): The name of S3 bucket
23
+ chroma_vs_name (str): The name of the file to download from S3
24
+ vectorstore_folder (str): The filepath to vectorstore folder in project dir
25
+ vs_save_name (str): The name of the vector store
26
+
27
+ """
28
+ vs_destination = Path()/vectorstore_folder/vs_save_name
29
+ vs_save_path = vs_destination.with_suffix('.zip')
30
+
31
+ try:
32
+ # Initialize an S3 client with unsigned configuration for public access
33
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
34
+ s3.download_file(s3_location, chroma_vs_name, vs_save_path)
35
+
36
+ # Extract the zip file
37
+ with zipfile.ZipFile(file=str(vs_save_path), mode='r') as zip_ref:
38
+ zip_ref.extractall(path=vectorstore_folder)
39
+
40
+ except Exception as e:
41
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
42
+
43
+ # Delete the zip file
44
+ vs_save_path.unlink()
45
+
46
+ if __name__ == "__main__":
47
+ chroma_vs_name = "vectorstores/chroma-zurich-mpnet-1500.zip"
48
+ project_dir = Path().cwd().parent
49
+ vs_destination = str(project_dir / 'vectorstore')
50
+ assert Path(vs_destination).is_dir(), "Cannot find vectorstore folder"
51
+
52
+ download_chroma_from_s3(s3_location=S3_LOCATION,
53
+ chroma_vs_name=chroma_vs_name,
54
+ vectorstore_folder=vs_destination,
55
+ vs_save_name='chroma-zurich-mpnet-1500')
rag_app/reranking.py CHANGED
@@ -5,11 +5,13 @@ from dotenv import load_dotenv
5
  import os
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
  import requests
 
 
8
 
9
  load_dotenv()
10
 
11
 
12
- def get_reranked_docs(query:str,
13
  path_to_db:str,
14
  embedding_model:str,
15
  hf_api_key:str,
@@ -59,22 +61,72 @@ def get_reranked_docs(query:str,
59
  ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
60
  top_k_results = ranked_results[:num_docs]
61
  return [doc for doc, _, _ in top_k_results]
 
62
 
63
-
64
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
67
- EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
 
 
 
 
 
 
 
 
68
 
69
- path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
 
 
 
 
 
 
70
 
71
- query = "Ich möchte wissen, ob ich meine geriatrische Haustier-Eidechse versichern kann"
 
 
 
 
 
 
 
72
 
73
- top_5_docs = get_reranked_docs(query=query,
74
- path_to_db=path_to_vector_db,
75
- embedding_model=EMBEDDING_MODEL,
76
- hf_api_key=HUGGINGFACEHUB_API_TOKEN,
77
- num_docs=5)
78
 
79
- for i, doc in enumerate(top_5_docs):
80
- print(f"{i}: {doc}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
  import requests
8
+ from langchain_community.vectorstores import Chroma
9
+
10
 
11
  load_dotenv()
12
 
13
 
14
+ def get_reranked_docs_faiss(query:str,
15
  path_to_db:str,
16
  embedding_model:str,
17
  hf_api_key:str,
 
61
  ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
62
  top_k_results = ranked_results[:num_docs]
63
  return [doc for doc, _, _ in top_k_results]
64
+
65
 
66
+
67
+ def get_reranked_docs_chroma(query:str,
68
+ path_to_db:str,
69
+ embedding_model:str,
70
+ hf_api_key:str,
71
+ reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2",
72
+ num_docs:int=5) -> list:
73
+ """ Re-ranks the similarity search results and returns top-k highest ranked docs
74
+
75
+ Args:
76
+ query (str): The search query
77
+ path_to_db (str): Path to the vectorstore database
78
+ embedding_model (str): Embedding model used in the vector store
79
+ num_docs (int): Number of documents to return
80
+
81
+ Returns: A list of documents with the highest rank
82
+ """
83
+ assert num_docs <= 10, "num_docs should be less than similarity search results"
84
 
85
+ embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
86
+ model_name=embedding_model)
87
+ # Load the vectorstore database
88
+ db = Chroma(persist_directory=path_to_db, embedding_function=embeddings)
89
+
90
+ # Get 10 documents based on similarity search
91
+ sim_docs = db.similarity_search(query=query, k=10)
92
+
93
+ # Add the page_content, description and title together
94
+ passages = [doc.page_content for doc in sim_docs]
95
 
96
+ # Prepare the payload
97
+ payload = {"inputs":
98
+ {"source_sentence": query,
99
+ "sentences": passages}}
100
+
101
+
102
+ headers = {"Authorization": f"Bearer {hf_api_key}"}
103
 
104
+ response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
105
+ if response.status_code != 200:
106
+ print('Something went wrong with the response')
107
+ return
108
+ similarity_scores = response.json()
109
+ ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
110
+ top_k_results = ranked_results[:num_docs]
111
+ return [doc for doc, _, _ in top_k_results]
112
 
113
+
114
+
115
+ if __name__ == "__main__":
 
 
116
 
117
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
118
+ EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
119
+
120
+ project_dir = Path().cwd().parent
121
+ path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500')
122
+
123
+ query = "I'm looking for student insurance"
124
+
125
+
126
+ re_ranked_docs = get_reranked_docs_chroma(query=query,
127
+ path_to_db= path_to_vector_db,
128
+ embedding_model=EMBEDDING_MODEL,
129
+ hf_api_key=HUGGINGFACEHUB_API_TOKEN)
130
+
131
+
132
+ print(f"{re_ranked_docs=}")