hebrew-dentsit / rag_agent.py
borodache's picture
Change the retrieval and reranking into two steps search with two different indexes - which is supposed to make the latency much lower (faster)
a983ce0 verified
raw
history blame
5.06 kB
from anthropic import Anthropic
from typing import List
import os
from retriever import Retriever
from reranker import Reranker
from text_embedder_encoder import TextEmbedder, encoder_model_name
retriever = Retriever()
reranker = Reranker()
class RAGAgent:
def __init__(
self,
retriever=retriever,
reranker=reranker,
anthropic_api_key: str = os.environ["anthropic_api_key"],
model_name: str = "claude-3-5-sonnet-20241022",
max_tokens: int = 1024,
temperature: float = 0.0,
):
self.retriever = retriever
self.reranker = reranker
self.client = Anthropic(api_key=anthropic_api_key)
self.model_name = model_name
self.max_tokens = max_tokens
self.temperature = temperature
self.text_embedder = TextEmbedder()
self.conversation_summary = ""
self.messages = []
def get_context(self, query: str) -> List[str]:
# Get initial candidates from retriever
query_vector = self.text_embedder.encode(query)
retrieved_answers_ids = self.retriever.search_similar(query_vector)
# Rerank the candidates
context = self.reranker.rerank(query_vector, retrieved_answers_ids)
return context
def generate_prompt(self, context: List[str], conversation_summary: str = "") -> str:
context = "\n".join(context)
summary_context = f"\nืกื™ื›ื•ื ื”ืฉื™ื—ื” ืขื“ ื›ื”:\n{conversation_summary}" if conversation_summary else ""
prompt = f"""
ืืชื” ืจื•ืคื ืฉื™ื ื™ื™ื, ื“ื•ื‘ืจ ืขื‘ืจื™ืช ื‘ืœื‘ื“. ืงื•ืจืื™ื ืœืš 'ืจื•ืคื ื”ืฉื™ื ื™ื™ื ื”ืืœืงื˜ืจื•ื ื™ ื”ืขื‘ืจื™ ื”ืจืืฉื•ืŸ'.{summary_context}
ืขื ื” ืœืžื˜ื•ืคืœ ืขืœ ื”ืฉืืœื” ืฉืœื• ืขืœ ืกืžืš ื”ืงื•ื ื˜ืงืก ื”ื‘ื: {context}.
ื”ื•ืกืฃ ื›ืžื” ืฉื™ื•ืชืจ ืคืจื˜ื™ื, ื•ื“ืื’ ืฉื”ืชื—ื‘ื™ืจ ื™ื”ื™ื” ืชืงื™ืŸ ื•ื™ืคื”.
ืชืขืฆื•ืจ ื›ืฉืืชื” ืžืจื’ื™ืฉ ืฉืžื™ืฆื™ืช ืืช ืขืฆืžืš. ืืœ ืชืžืฆื™ื ื“ื‘ืจื™ื.
ื•ืืœ ืชืขื ื” ื‘ืฉืคื•ืช ืฉื”ืŸ ืœื ืขื‘ืจื™ืช.
"""
return prompt
def update_summary(self, question: str, answer: str) -> str:
"""Update the conversation summary with the new interaction"""
summary_prompt = {
"model": self.model_name,
"max_tokens": 500,
"temperature": 0.0,
"messages": [
{
"role": "user",
"content": f"""ืกื›ื ืืช ื”ืฉื™ื—ื” ื‘ืขื‘ืจื™ืช, ื”ื ื” ืกื™ื›ื•ื ื”ืฉื™ื—ื” ืขื“ ื›ื”:
{self.conversation_summary if self.conversation_summary else "ืื™ืŸ ืฉื™ื—ื” ืงื•ื“ืžืช."}
ืื™ื ื˜ืจืืงืฆื™ื” ื—ื“ืฉื”:
ืฉืืœืช ื”ืžื˜ื•ืคืœ: {question}
ืชืฉื•ื‘ืช ื”ืจื•ืคื: {answer}
ืื ื ืกืคืง ืกื™ื›ื•ื ืžืขื•ื“ื›ืŸ ืฉื›ื•ืœืœ ืืช ื”ืžื™ื“ืข ื”ืจืคื•ืื™ ืžื”ืกื™ื›ื•ื ื”ืงื•ื“ื ื‘ื ื•ืกืฃ ืœื“ื’ืฉ ืขืœ ื”ืื™ื ื˜ืจืงืฆื™ื” ื”ื—ื“ืฉื”. ื”ืกื™ื›ื•ื ืฆืจื™ืš ืœื”ื™ื•ืช ืชืžืฆื™ืชื™ ืขื“ 100 ืžื™ืœื”.
ื•ืชืจ ืขืœ ืžื™ื“ืข ืœื ืจืœื•ื•ื ื˜ื™ ืžื”ืกื™ื›ื•ืžื™ื ื”ืงื•ื“ืžื™ื"""
}
]
}
try:
response = self.client.messages.create(**summary_prompt)
self.conversation_summary = response.content[0].text
return self.conversation_summary
except Exception as e:
print(f"Error updating summary: {e}")
return self.get_basic_summary()
def get_basic_summary(self) -> str:
"""Fallback method for basic summary"""
summary = []
for i in range(0, len(self.messages), 2):
if i + 1 < len(self.messages):
summary.append(f"ืฉืืœืช ื”ืžื˜ื•ืคืœ: {self.messages[i]['content']}")
summary.append(f"ืชืฉื•ื‘ืช ื”ืจื•ืคื ืฉื™ื ื™ื™ื: {self.messages[i + 1]['content']}\n")
return "\n".join(summary)
def get_response(self, question: str) -> str:
# Get relevant context
context = self.get_context(question + self.conversation_summary)
# Generate prompt with context and current conversation summary
prompt = self.generate_prompt(context, self.conversation_summary)
# Get response from Claude
response = self.client.messages.create(
model=self.model_name,
max_tokens=self.max_tokens,
temperature=self.temperature,
messages=[
{"role": "assistant", "content": prompt},
{"role": "user", "content": f"{question}"}
]
)
answer = response.content[0].text
# Store messages for history
self.messages.extend([
{"role": "user", "content": question},
{"role": "assistant", "content": answer}
])
# Update conversation summary
self.update_summary(question, answer)
return answer