sabazo commited on
Commit
ee48a1f
·
unverified ·
2 Parent(s): 15c0646 1dcc70b

Merge pull request #45 from almutareb/reranking_chroma

Browse files
rag_app/loading_data/load_chroma_db_cross_platform.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import boto3
3
+ from botocore.client import Config
4
+ from botocore import UNSIGNED
5
+ from dotenv import load_dotenv
6
+ import os
7
+ import sys
8
+ import zipfile
9
+
10
+
11
+ def download_chroma_from_s3(s3_location:str,
12
+ chroma_vs_name:str,
13
+ vectorstore_folder:str,
14
+ vs_save_name:str) -> None:
15
+ """
16
+ Downloads the Chroma DB from an S3 storage to local folder
17
+
18
+ Args
19
+ s3_location (str): The name of S3 bucket
20
+ chroma_vs_name (str): The name of the file to download from S3
21
+ vectorstore_folder (str): The filepath to vectorstore folder in project dir
22
+ vs_save_name (str): The name of the vector store
23
+
24
+ """
25
+ vs_destination = Path()/vectorstore_folder/vs_save_name
26
+ vs_save_path = vs_destination.with_suffix('.zip')
27
+
28
+ try:
29
+ # Initialize an S3 client with unsigned configuration for public access
30
+ s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
31
+ s3.download_file(s3_location, chroma_vs_name, vs_save_path)
32
+ print('Downloaded file from S3')
33
+
34
+ # Extract the zip file
35
+ with zipfile.ZipFile(file=str(vs_save_path), mode='r') as zip_ref:
36
+ zip_ref.extractall(path=vectorstore_folder)
37
+ print("Extracted zip file")
38
+
39
+ except Exception as e:
40
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
41
+
42
+ # Delete the zip file
43
+ vs_save_path.unlink()
44
+ print("Deleting zip file")
45
+
46
+ if __name__ == "__main__":
47
+
48
+ S3_LOCATION = os.getenv("S3_LOCATION")
49
+
50
+ chroma_vs_name = "vectorstores/chroma-zurich-mpnet-1500.zip"
51
+
52
+ project_dir = Path().cwd().parent.parent
53
+ vs_destination = str(project_dir / 'vectorstore')
54
+ assert Path(vs_destination).is_dir(), "Cannot find vectorstore folder"
55
+
56
+ download_chroma_from_s3(s3_location=S3_LOCATION,
57
+ chroma_vs_name=chroma_vs_name,
58
+ vectorstore_folder=vs_destination,
59
+ vs_save_name='chroma-zurich-mpnet-1500')
rag_app/reranking.py CHANGED
@@ -5,11 +5,13 @@ from dotenv import load_dotenv
5
  import os
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
  import requests
 
 
8
 
9
  load_dotenv()
10
 
11
 
12
- def get_reranked_docs(query:str,
13
  path_to_db:str,
14
  embedding_model:str,
15
  hf_api_key:str,
@@ -59,22 +61,71 @@ def get_reranked_docs(query:str,
59
  ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
60
  top_k_results = ranked_results[:num_docs]
61
  return [doc for doc, _, _ in top_k_results]
 
62
 
63
-
64
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
67
- EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
 
 
68
 
69
- path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
 
 
 
 
 
70
 
71
- query = "Ich möchte wissen, ob ich meine geriatrische Haustier-Eidechse versichern kann"
 
 
 
 
72
 
73
- top_5_docs = get_reranked_docs(query=query,
74
- path_to_db=path_to_vector_db,
75
- embedding_model=EMBEDDING_MODEL,
76
- hf_api_key=HUGGINGFACEHUB_API_TOKEN,
77
- num_docs=5)
78
 
79
- for i, doc in enumerate(top_5_docs):
80
- print(f"{i}: {doc}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
  import requests
8
+ from langchain_community.vectorstores import Chroma
9
+
10
 
11
  load_dotenv()
12
 
13
 
14
+ def get_reranked_docs_faiss(query:str,
15
  path_to_db:str,
16
  embedding_model:str,
17
  hf_api_key:str,
 
61
  ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
62
  top_k_results = ranked_results[:num_docs]
63
  return [doc for doc, _, _ in top_k_results]
64
+
65
 
66
+
67
+ def get_reranked_docs_chroma(query:str,
68
+ path_to_db:str,
69
+ embedding_model:str,
70
+ hf_api_key:str,
71
+ reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2",
72
+ num_docs:int=5) -> list:
73
+ """ Re-ranks the similarity search results and returns top-k highest ranked docs
74
+
75
+ Args:
76
+ query (str): The search query
77
+ path_to_db (str): Path to the vectorstore database
78
+ embedding_model (str): Embedding model used in the vector store
79
+ num_docs (int): Number of documents to return
80
+
81
+ Returns: A list of documents with the highest rank
82
+ """
83
+ embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
84
+ model_name=embedding_model)
85
+ # Load the vectorstore database
86
+ db = Chroma(persist_directory=path_to_db, embedding_function=embeddings)
87
 
88
+ # Get k documents based on similarity search
89
+ sim_docs = db.similarity_search(query=query, k=10)
90
+
91
+ passages = [doc.page_content for doc in sim_docs]
92
 
93
+ # Prepare the payload
94
+ payload = {"inputs":
95
+ {"source_sentence": query,
96
+ "sentences": passages}}
97
+
98
+ headers = {"Authorization": f"Bearer {hf_api_key}"}
99
 
100
+ response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
101
+ print(f'{response = }')
102
+ if response.status_code != 200:
103
+ print('Something went wrong with the response')
104
+ return
105
 
106
+ similarity_scores = response.json()
107
+ ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
108
+ top_k_results = ranked_results[:num_docs]
109
+ return [doc for doc, _, _ in top_k_results]
 
110
 
111
+
112
+
113
+ if __name__ == "__main__":
114
+
115
+
116
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
117
+ EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
118
+
119
+ project_dir = Path().cwd().parent
120
+ path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500')
121
+ assert Path(path_to_vector_db).exists(), "Cannot access path_to_vector_db "
122
+
123
+ query = "I'm looking for student insurance"
124
+
125
+ re_ranked_docs = get_reranked_docs_chroma(query=query,
126
+ path_to_db= path_to_vector_db,
127
+ embedding_model=EMBEDDING_MODEL,
128
+ hf_api_key=HUGGINGFACEHUB_API_TOKEN)
129
+
130
+
131
+ print(f"{re_ranked_docs=}")