Spaces:
Running
Running
Change the retrieval and reranking into two steps search with two different indexes - which is supposed to make the latency much lower (faster)
Browse files- rag_agent.py +5 -3
- reranker.py +27 -12
- retriever.py +10 -14
- text_embedder_encoder.py +18 -19
rag_agent.py
CHANGED
@@ -5,6 +5,7 @@ import os
|
|
5 |
|
6 |
from retriever import Retriever
|
7 |
from reranker import Reranker
|
|
|
8 |
|
9 |
|
10 |
retriever = Retriever()
|
@@ -27,15 +28,16 @@ class RAGAgent:
|
|
27 |
self.model_name = model_name
|
28 |
self.max_tokens = max_tokens
|
29 |
self.temperature = temperature
|
|
|
30 |
self.conversation_summary = ""
|
31 |
self.messages = []
|
32 |
|
33 |
def get_context(self, query: str) -> List[str]:
|
34 |
# Get initial candidates from retriever
|
35 |
-
|
36 |
-
|
37 |
# Rerank the candidates
|
38 |
-
context = self.reranker.rerank(
|
39 |
|
40 |
return context
|
41 |
|
|
|
5 |
|
6 |
from retriever import Retriever
|
7 |
from reranker import Reranker
|
8 |
+
from text_embedder_encoder import TextEmbedder, encoder_model_name
|
9 |
|
10 |
|
11 |
retriever = Retriever()
|
|
|
28 |
self.model_name = model_name
|
29 |
self.max_tokens = max_tokens
|
30 |
self.temperature = temperature
|
31 |
+
self.text_embedder = TextEmbedder()
|
32 |
self.conversation_summary = ""
|
33 |
self.messages = []
|
34 |
|
35 |
def get_context(self, query: str) -> List[str]:
|
36 |
# Get initial candidates from retriever
|
37 |
+
query_vector = self.text_embedder.encode(query)
|
38 |
+
retrieved_answers_ids = self.retriever.search_similar(query_vector)
|
39 |
# Rerank the candidates
|
40 |
+
context = self.reranker.rerank(query_vector, retrieved_answers_ids)
|
41 |
|
42 |
return context
|
43 |
|
reranker.py
CHANGED
@@ -1,22 +1,37 @@
|
|
|
|
1 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
2 |
|
3 |
|
4 |
-
from text_embedder_encoder import
|
5 |
|
6 |
|
7 |
class Reranker:
|
8 |
-
def __init__(self
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
-
def rerank(self,
|
12 |
# Encode query and documents
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pinecone import Pinecone
|
2 |
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
+
import os
|
4 |
|
5 |
|
6 |
+
from text_embedder_encoder import encoder_model_name
|
7 |
|
8 |
|
9 |
class Reranker:
|
10 |
+
def __init__(self,
|
11 |
+
pinecone_api_key=os.environ["pinecone_api_key"],
|
12 |
+
answer_index_name=f"hebrew-dentist-answers-{encoder_model_name.replace('/', '-')}".lower()):
|
13 |
+
self.pc = Pinecone(api_key=pinecone_api_key)
|
14 |
+
self.answer_index_name = answer_index_name
|
15 |
|
16 |
+
def rerank(self, query_vector, retrieved_answers_ids, top_n=5):
|
17 |
# Encode query and documents
|
18 |
+
try:
|
19 |
+
index = self.pc.Index(self.answer_index_name)
|
20 |
+
fetch_response = index.fetch(ids=retrieved_answers_ids)
|
21 |
|
22 |
+
doc_embeddings = []
|
23 |
+
answers = []
|
24 |
+
for i in range(len(retrieved_answers_ids)):
|
25 |
+
doc_embeddings.append(fetch_response['vectors'][retrieved_answers_ids[i]]['values'])
|
26 |
+
answers.append(fetch_response['vectors'][retrieved_answers_ids[i]]['metadata']['answer'])
|
27 |
|
28 |
+
similarity_scores = cosine_similarity([query_vector], doc_embeddings)[0]
|
29 |
+
similarity_scores_with_idxes = list(zip(similarity_scores, range(len(similarity_scores))))
|
30 |
+
similarity_scores_with_idxes.sort(reverse=True)
|
31 |
+
similarity_scores_with_idxes_final = similarity_scores_with_idxes[:top_n]
|
32 |
+
reranked_answers = [answers[idx] for score, idx in similarity_scores_with_idxes_final if score >= 0.7]
|
33 |
+
|
34 |
+
return reranked_answers
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error performing rerank: {e}")
|
37 |
+
return []
|
retriever.py
CHANGED
@@ -1,29 +1,25 @@
|
|
1 |
from pinecone import Pinecone
|
2 |
import os
|
3 |
|
4 |
-
from text_embedder_encoder import
|
5 |
|
6 |
|
7 |
class Retriever:
|
8 |
def __init__(self,
|
9 |
pinecone_api_key=os.environ["pinecone_api_key"],
|
10 |
-
|
11 |
# Initialize Pinecone connection
|
12 |
self.pc = Pinecone(api_key=pinecone_api_key)
|
13 |
-
self.
|
14 |
-
self.text_embedder = TextEmbedder()
|
15 |
-
self.vector_dim = 768
|
16 |
|
17 |
-
def search_similar(self,
|
18 |
"""
|
19 |
Search for similar content using vector similarity in Pinecone
|
20 |
"""
|
21 |
try:
|
22 |
-
# Generate embedding for query
|
23 |
-
query_vector = self.text_embedder.encode(query_text)
|
24 |
|
25 |
# Get Pinecone index
|
26 |
-
index = self.pc.Index(self.
|
27 |
|
28 |
# Execute search
|
29 |
results = index.query(
|
@@ -32,12 +28,12 @@ class Retriever:
|
|
32 |
include_metadata=True,
|
33 |
)
|
34 |
|
35 |
-
|
36 |
for match in results['matches']:
|
37 |
-
|
38 |
-
|
39 |
|
40 |
-
return
|
41 |
except Exception as e:
|
42 |
-
print(f"Error performing
|
43 |
return []
|
|
|
1 |
from pinecone import Pinecone
|
2 |
import os
|
3 |
|
4 |
+
from text_embedder_encoder import encoder_model_name
|
5 |
|
6 |
|
7 |
class Retriever:
|
8 |
def __init__(self,
|
9 |
pinecone_api_key=os.environ["pinecone_api_key"],
|
10 |
+
question_index_name=f"hebrew-dentist-questions-{encoder_model_name.replace('/', '-')}".lower()):
|
11 |
# Initialize Pinecone connection
|
12 |
self.pc = Pinecone(api_key=pinecone_api_key)
|
13 |
+
self.question_index_name = question_index_name
|
|
|
|
|
14 |
|
15 |
+
def search_similar(self, query_vector, top_k=50):
|
16 |
"""
|
17 |
Search for similar content using vector similarity in Pinecone
|
18 |
"""
|
19 |
try:
|
|
|
|
|
20 |
|
21 |
# Get Pinecone index
|
22 |
+
index = self.pc.Index(self.question_index_name)
|
23 |
|
24 |
# Execute search
|
25 |
results = index.query(
|
|
|
28 |
include_metadata=True,
|
29 |
)
|
30 |
|
31 |
+
answers_records_ids = []
|
32 |
for match in results['matches']:
|
33 |
+
answers_records_ids.append(
|
34 |
+
':'.join(match['id'].split(':')[:-1]) + ":" + str(int(match['metadata']['answer_id'])))
|
35 |
|
36 |
+
return answers_records_ids
|
37 |
except Exception as e:
|
38 |
+
print(f"Error performing retriever: {e}")
|
39 |
return []
|
text_embedder_encoder.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
-
from typing import List
|
5 |
|
6 |
|
7 |
encoder_model_name = 'MPA/sambert'
|
@@ -36,21 +35,21 @@ class TextEmbedder:
|
|
36 |
|
37 |
return embeddings
|
38 |
|
39 |
-
def encode_many(self, texts: List[str]) -> np.ndarray:
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
from sentence_transformers import SentenceTransformer
|
|
|
4 |
|
5 |
|
6 |
encoder_model_name = 'MPA/sambert'
|
|
|
35 |
|
36 |
return embeddings
|
37 |
|
38 |
+
# def encode_many(self, texts: List[str]) -> np.ndarray:
|
39 |
+
# """
|
40 |
+
# Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
|
41 |
+
#
|
42 |
+
# Args:
|
43 |
+
# text (str): Hebrew text to encode
|
44 |
+
# model_name (str): Name of the model to use
|
45 |
+
# # max_seq_length (int): Maximum sequence length for the model
|
46 |
+
# strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
|
47 |
+
#
|
48 |
+
# Returns:
|
49 |
+
# numpy.ndarray: Text embedding
|
50 |
+
# """
|
51 |
+
# # Get embeddings for the text
|
52 |
+
# embeddings = self.model.encode(texts)
|
53 |
+
# embeddings = [[float(x) for x in embedding] for embedding in embeddings]
|
54 |
+
#
|
55 |
+
# return embeddings
|