|
import os |
|
|
|
from langchain.memory import ChatMessageHistory |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain_community.document_compressors import JinaRerank |
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from langchain_groq import ChatGroq |
|
|
|
from core.services.vector_db.qdrent.upload_document import answer_query_from_existing_collection |
|
|
|
os.environ["JINA_API_KEY"] = os.getenv("JINA_API") |
|
|
|
|
|
class AnswerQuery: |
|
def __init__(self, prompt, vector_embedding, sparse_embedding, follow_up_prompt, json_parser): |
|
self.chat_history_store = {} |
|
self.compressor = JinaRerank(model="jina-reranker-v2-base-multilingual") |
|
self.vector_embed = vector_embedding |
|
self.sparse_embed = sparse_embedding |
|
self.prompt = prompt |
|
self.follow_up_prompt = follow_up_prompt |
|
self.json_parser = json_parser |
|
|
|
def format_docs(self, docs: str): |
|
global sources |
|
global temp_context |
|
sources = [] |
|
context = "" |
|
for doc in docs: |
|
context += f"{doc.page_content}\n\n\n" |
|
source = doc.metadata |
|
source = source["source"] |
|
sources.append(source) |
|
if context == "": |
|
context = "No context found" |
|
else: |
|
pass |
|
sources = list(set(sources)) |
|
temp_context = context |
|
return context |
|
|
|
def answer_query(self, query: str, vectorstore: str, llmModel: str = "llama-3.3-70b-versatile"): |
|
global sources |
|
global temp_context |
|
vector_store_name = vectorstore |
|
vector_store = answer_query_from_existing_collection(vector_embed=self.vector_embed, |
|
sparse_embed=self.sparse_embed, |
|
vectorstore=vectorstore) |
|
|
|
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20}) |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=self.compressor, base_retriever=retriever |
|
) |
|
brain_chain = ( |
|
{"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda( |
|
self.format_docs), |
|
"question": RunnableLambda(lambda x: x["question"]), |
|
"chatHistory": RunnableLambda(lambda x: x["chatHistory"])} |
|
| self.prompt |
|
| ChatGroq(model=llmModel, temperature=0.75, max_tokens=512) |
|
| StrOutputParser() |
|
) |
|
message_chain = RunnableWithMessageHistory( |
|
brain_chain, |
|
self.get_session_history, |
|
input_messages_key="question", |
|
history_messages_key="chatHistory" |
|
) |
|
chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain |
|
follow_up_chain = self.follow_up_prompt | ChatGroq(model_name="llama-3.3-70b-versatile", |
|
temperature=0) | self.json_parser |
|
|
|
output = chain.invoke( |
|
{"question": query}, |
|
{"configurable": {"session_id": vector_store_name}} |
|
) |
|
follow_up_questions = follow_up_chain.invoke({"context": temp_context}) |
|
|
|
return output, follow_up_questions, sources |
|
|
|
async def answer_query_stream(self, query: str, vectorstore: str, llmModel): |
|
global sources |
|
global temp_context |
|
|
|
vector_store_name = vectorstore |
|
vector_store = answer_query_from_existing_collection( |
|
vector_embed=self.vector_embed, |
|
sparse_embed=self.sparse_embed, |
|
vectorstore=vectorstore |
|
) |
|
|
|
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20}) |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=self.compressor, |
|
base_retriever=retriever |
|
) |
|
|
|
brain_chain = ( |
|
{ |
|
"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda( |
|
self.format_docs), |
|
"question": RunnableLambda(lambda x: x["question"]), |
|
"chatHistory": RunnableLambda(lambda x: x["chatHistory"]) |
|
} |
|
| self.prompt |
|
| ChatGroq( |
|
model=llmModel, |
|
temperature=0.75, |
|
max_tokens=512, |
|
streaming=True |
|
) |
|
| StrOutputParser() |
|
) |
|
|
|
message_chain = RunnableWithMessageHistory( |
|
brain_chain, |
|
self.get_session_history, |
|
input_messages_key="question", |
|
history_messages_key="chatHistory" |
|
) |
|
|
|
chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain |
|
|
|
async for chunk in chain.astream( |
|
{"question": query}, |
|
{"configurable": {"session_id": vector_store_name}} |
|
): |
|
yield { |
|
"type": "main_response", |
|
"content": chunk |
|
} |
|
|
|
follow_up_chain = self.follow_up_prompt | ChatGroq( |
|
model_name="llama-3.3-70b-versatile", |
|
temperature=0 |
|
) | self.json_parser |
|
|
|
follow_up_questions = await follow_up_chain.ainvoke({"context": temp_context}) |
|
|
|
yield { |
|
"type": "follow_up_questions", |
|
"content": follow_up_questions |
|
} |
|
|
|
yield { |
|
"type": "sources", |
|
"content": sources |
|
} |
|
|
|
def trim_messages(self, chain_input): |
|
for store_name in self.chat_history_store: |
|
messages = self.chat_history_store[store_name].messages |
|
if len(messages) <= 1: |
|
pass |
|
else: |
|
self.chat_history_store[store_name].clear() |
|
for message in messages[-1:]: |
|
self.chat_history_store[store_name].add_message(message) |
|
return True |
|
|
|
def get_session_history(self, session_id: str) -> BaseChatMessageHistory: |
|
if session_id not in self.chat_history_store: |
|
self.chat_history_store[session_id] = ChatMessageHistory() |
|
return self.chat_history_store[session_id] |
|
|