puzan789's picture
updated
ad87194
raw
history blame
6.51 kB
import os
from langchain.memory import ChatMessageHistory
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.document_compressors import JinaRerank
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_groq import ChatGroq
from core.services.vector_db.qdrent.upload_document import answer_query_from_existing_collection
os.environ["JINA_API_KEY"] = os.getenv("JINA_API")
class AnswerQuery:
def __init__(self, prompt, vector_embedding, sparse_embedding, follow_up_prompt, json_parser):
self.chat_history_store = {}
self.compressor = JinaRerank(model="jina-reranker-v2-base-multilingual")
self.vector_embed = vector_embedding
self.sparse_embed = sparse_embedding
self.prompt = prompt
self.follow_up_prompt = follow_up_prompt
self.json_parser = json_parser
def format_docs(self, docs: str):
global sources
global temp_context
sources = []
context = ""
for doc in docs:
context += f"{doc.page_content}\n\n\n"
source = doc.metadata
source = source["source"]
sources.append(source)
if context == "":
context = "No context found"
else:
pass
sources = list(set(sources))
temp_context = context
return context
def answer_query(self, query: str, vectorstore: str, llmModel: str = "llama-3.3-70b-versatile"):
global sources
global temp_context
vector_store_name = vectorstore
vector_store = answer_query_from_existing_collection(vector_embed=self.vector_embed,
sparse_embed=self.sparse_embed,
vectorstore=vectorstore)
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20})
compression_retriever = ContextualCompressionRetriever(
base_compressor=self.compressor, base_retriever=retriever
)
brain_chain = (
{"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda(
self.format_docs),
"question": RunnableLambda(lambda x: x["question"]),
"chatHistory": RunnableLambda(lambda x: x["chatHistory"])}
| self.prompt
| ChatGroq(model=llmModel, temperature=0.75, max_tokens=512)
| StrOutputParser()
)
message_chain = RunnableWithMessageHistory(
brain_chain,
self.get_session_history,
input_messages_key="question",
history_messages_key="chatHistory"
)
chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain
follow_up_chain = self.follow_up_prompt | ChatGroq(model_name="llama-3.3-70b-versatile",
temperature=0) | self.json_parser
output = chain.invoke(
{"question": query},
{"configurable": {"session_id": vector_store_name}}
)
follow_up_questions = follow_up_chain.invoke({"context": temp_context})
return output, follow_up_questions, sources
async def answer_query_stream(self, query: str, vectorstore: str, llmModel):
global sources
global temp_context
vector_store_name = vectorstore
vector_store = answer_query_from_existing_collection(
vector_embed=self.vector_embed,
sparse_embed=self.sparse_embed,
vectorstore=vectorstore
)
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20})
compression_retriever = ContextualCompressionRetriever(
base_compressor=self.compressor,
base_retriever=retriever
)
brain_chain = (
{
"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda(
self.format_docs),
"question": RunnableLambda(lambda x: x["question"]),
"chatHistory": RunnableLambda(lambda x: x["chatHistory"])
}
| self.prompt
| ChatGroq(
model=llmModel,
temperature=0.75,
max_tokens=512,
streaming=True
)
| StrOutputParser()
)
message_chain = RunnableWithMessageHistory(
brain_chain,
self.get_session_history,
input_messages_key="question",
history_messages_key="chatHistory"
)
chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain
async for chunk in chain.astream(
{"question": query},
{"configurable": {"session_id": vector_store_name}}
):
yield {
"type": "main_response",
"content": chunk
}
follow_up_chain = self.follow_up_prompt | ChatGroq(
model_name="llama-3.3-70b-versatile",
temperature=0
) | self.json_parser
follow_up_questions = await follow_up_chain.ainvoke({"context": temp_context})
yield {
"type": "follow_up_questions",
"content": follow_up_questions
}
yield {
"type": "sources",
"content": sources
}
def trim_messages(self, chain_input):
for store_name in self.chat_history_store:
messages = self.chat_history_store[store_name].messages
if len(messages) <= 1:
pass
else:
self.chat_history_store[store_name].clear()
for message in messages[-1:]:
self.chat_history_store[store_name].add_message(message)
return True
def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
if session_id not in self.chat_history_store:
self.chat_history_store[session_id] = ChatMessageHistory()
return self.chat_history_store[session_id]