borodache commited on
Commit
a983ce0
·
verified ·
1 Parent(s): 4f1cbcc

Change the retrieval and reranking into two steps search with two different indexes - which is supposed to make the latency much lower (faster)

Browse files
Files changed (4) hide show
  1. rag_agent.py +5 -3
  2. reranker.py +27 -12
  3. retriever.py +10 -14
  4. text_embedder_encoder.py +18 -19
rag_agent.py CHANGED
@@ -5,6 +5,7 @@ import os
5
 
6
  from retriever import Retriever
7
  from reranker import Reranker
 
8
 
9
 
10
  retriever = Retriever()
@@ -27,15 +28,16 @@ class RAGAgent:
27
  self.model_name = model_name
28
  self.max_tokens = max_tokens
29
  self.temperature = temperature
 
30
  self.conversation_summary = ""
31
  self.messages = []
32
 
33
  def get_context(self, query: str) -> List[str]:
34
  # Get initial candidates from retriever
35
- retrieved_docs = self.retriever.search_similar(query)
36
-
37
  # Rerank the candidates
38
- context = self.reranker.rerank(query, retrieved_docs)
39
 
40
  return context
41
 
 
5
 
6
  from retriever import Retriever
7
  from reranker import Reranker
8
+ from text_embedder_encoder import TextEmbedder, encoder_model_name
9
 
10
 
11
  retriever = Retriever()
 
28
  self.model_name = model_name
29
  self.max_tokens = max_tokens
30
  self.temperature = temperature
31
+ self.text_embedder = TextEmbedder()
32
  self.conversation_summary = ""
33
  self.messages = []
34
 
35
  def get_context(self, query: str) -> List[str]:
36
  # Get initial candidates from retriever
37
+ query_vector = self.text_embedder.encode(query)
38
+ retrieved_answers_ids = self.retriever.search_similar(query_vector)
39
  # Rerank the candidates
40
+ context = self.reranker.rerank(query_vector, retrieved_answers_ids)
41
 
42
  return context
43
 
reranker.py CHANGED
@@ -1,22 +1,37 @@
 
1
  from sklearn.metrics.pairwise import cosine_similarity
 
2
 
3
 
4
- from text_embedder_encoder import TextEmbedder
5
 
6
 
7
  class Reranker:
8
- def __init__(self):
9
- self.text_embedder = TextEmbedder()
 
 
 
10
 
11
- def rerank(self, query, retrieved_docs, top_n=5):
12
  # Encode query and documents
13
- query_embedding = self.text_embedder.encode(query)
14
- doc_embeddings = self.text_embedder.encode_many(retrieved_docs)
15
- similarity_scores = cosine_similarity([query_embedding], doc_embeddings)[0]
16
 
17
- similarity_scores_with_idxes = list(zip(similarity_scores, range(len(similarity_scores))))
18
- similarity_scores_with_idxes.sort(reverse=True)
19
- similarity_scores_with_idxes_final = similarity_scores_with_idxes[:top_n]
20
- reranked_docs = [retrieved_docs[idx] for score, idx in similarity_scores_with_idxes_final if score >= 0.7]
 
21
 
22
- return reranked_docs
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone
2
  from sklearn.metrics.pairwise import cosine_similarity
3
+ import os
4
 
5
 
6
+ from text_embedder_encoder import encoder_model_name
7
 
8
 
9
  class Reranker:
10
+ def __init__(self,
11
+ pinecone_api_key=os.environ["pinecone_api_key"],
12
+ answer_index_name=f"hebrew-dentist-answers-{encoder_model_name.replace('/', '-')}".lower()):
13
+ self.pc = Pinecone(api_key=pinecone_api_key)
14
+ self.answer_index_name = answer_index_name
15
 
16
+ def rerank(self, query_vector, retrieved_answers_ids, top_n=5):
17
  # Encode query and documents
18
+ try:
19
+ index = self.pc.Index(self.answer_index_name)
20
+ fetch_response = index.fetch(ids=retrieved_answers_ids)
21
 
22
+ doc_embeddings = []
23
+ answers = []
24
+ for i in range(len(retrieved_answers_ids)):
25
+ doc_embeddings.append(fetch_response['vectors'][retrieved_answers_ids[i]]['values'])
26
+ answers.append(fetch_response['vectors'][retrieved_answers_ids[i]]['metadata']['answer'])
27
 
28
+ similarity_scores = cosine_similarity([query_vector], doc_embeddings)[0]
29
+ similarity_scores_with_idxes = list(zip(similarity_scores, range(len(similarity_scores))))
30
+ similarity_scores_with_idxes.sort(reverse=True)
31
+ similarity_scores_with_idxes_final = similarity_scores_with_idxes[:top_n]
32
+ reranked_answers = [answers[idx] for score, idx in similarity_scores_with_idxes_final if score >= 0.7]
33
+
34
+ return reranked_answers
35
+ except Exception as e:
36
+ print(f"Error performing rerank: {e}")
37
+ return []
retriever.py CHANGED
@@ -1,29 +1,25 @@
1
  from pinecone import Pinecone
2
  import os
3
 
4
- from text_embedder_encoder import TextEmbedder, encoder_model_name
5
 
6
 
7
  class Retriever:
8
  def __init__(self,
9
  pinecone_api_key=os.environ["pinecone_api_key"],
10
- index_name=f"hebrew-dentist-qa-{encoder_model_name.replace('/', '-')}".lower()):
11
  # Initialize Pinecone connection
12
  self.pc = Pinecone(api_key=pinecone_api_key)
13
- self.index_name = index_name
14
- self.text_embedder = TextEmbedder()
15
- self.vector_dim = 768
16
 
17
- def search_similar(self, query_text, top_k=50):
18
  """
19
  Search for similar content using vector similarity in Pinecone
20
  """
21
  try:
22
- # Generate embedding for query
23
- query_vector = self.text_embedder.encode(query_text)
24
 
25
  # Get Pinecone index
26
- index = self.pc.Index(self.index_name)
27
 
28
  # Execute search
29
  results = index.query(
@@ -32,12 +28,12 @@ class Retriever:
32
  include_metadata=True,
33
  )
34
 
35
- answers = []
36
  for match in results['matches']:
37
- answer = match['metadata']['answer']
38
- answers.append(answer)
39
 
40
- return answers
41
  except Exception as e:
42
- print(f"Error performing similarity search: {e}")
43
  return []
 
1
  from pinecone import Pinecone
2
  import os
3
 
4
+ from text_embedder_encoder import encoder_model_name
5
 
6
 
7
  class Retriever:
8
  def __init__(self,
9
  pinecone_api_key=os.environ["pinecone_api_key"],
10
+ question_index_name=f"hebrew-dentist-questions-{encoder_model_name.replace('/', '-')}".lower()):
11
  # Initialize Pinecone connection
12
  self.pc = Pinecone(api_key=pinecone_api_key)
13
+ self.question_index_name = question_index_name
 
 
14
 
15
+ def search_similar(self, query_vector, top_k=50):
16
  """
17
  Search for similar content using vector similarity in Pinecone
18
  """
19
  try:
 
 
20
 
21
  # Get Pinecone index
22
+ index = self.pc.Index(self.question_index_name)
23
 
24
  # Execute search
25
  results = index.query(
 
28
  include_metadata=True,
29
  )
30
 
31
+ answers_records_ids = []
32
  for match in results['matches']:
33
+ answers_records_ids.append(
34
+ ':'.join(match['id'].split(':')[:-1]) + ":" + str(int(match['metadata']['answer_id'])))
35
 
36
+ return answers_records_ids
37
  except Exception as e:
38
+ print(f"Error performing retriever: {e}")
39
  return []
text_embedder_encoder.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
- from typing import List
5
 
6
 
7
  encoder_model_name = 'MPA/sambert'
@@ -36,21 +35,21 @@ class TextEmbedder:
36
 
37
  return embeddings
38
 
39
- def encode_many(self, texts: List[str]) -> np.ndarray:
40
- """
41
- Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
42
-
43
- Args:
44
- text (str): Hebrew text to encode
45
- model_name (str): Name of the model to use
46
- # max_seq_length (int): Maximum sequence length for the model
47
- strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
48
-
49
- Returns:
50
- numpy.ndarray: Text embedding
51
- """
52
- # Get embeddings for the text
53
- embeddings = self.model.encode(texts)
54
- embeddings = [[float(x) for x in embedding] for embedding in embeddings]
55
-
56
- return embeddings
 
1
  import torch
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
 
4
 
5
 
6
  encoder_model_name = 'MPA/sambert'
 
35
 
36
  return embeddings
37
 
38
+ # def encode_many(self, texts: List[str]) -> np.ndarray:
39
+ # """
40
+ # Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
41
+ #
42
+ # Args:
43
+ # text (str): Hebrew text to encode
44
+ # model_name (str): Name of the model to use
45
+ # # max_seq_length (int): Maximum sequence length for the model
46
+ # strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
47
+ #
48
+ # Returns:
49
+ # numpy.ndarray: Text embedding
50
+ # """
51
+ # # Get embeddings for the text
52
+ # embeddings = self.model.encode(texts)
53
+ # embeddings = [[float(x) for x in embedding] for embedding in embeddings]
54
+ #
55
+ # return embeddings