import os from langchain.memory import ChatMessageHistory from langchain.retrievers import ContextualCompressionRetriever from langchain_community.document_compressors import JinaRerank from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_groq import ChatGroq from core.services.vector_db.qdrent.upload_document import answer_query_from_existing_collection os.environ["JINA_API_KEY"] = os.getenv("JINA_API") class AnswerQuery: def __init__(self, prompt, vector_embedding, sparse_embedding, follow_up_prompt, json_parser): self.chat_history_store = {} self.compressor = JinaRerank(model="jina-reranker-v2-base-multilingual") self.vector_embed = vector_embedding self.sparse_embed = sparse_embedding self.prompt = prompt self.follow_up_prompt = follow_up_prompt self.json_parser = json_parser def format_docs(self, docs: str): global sources global temp_context sources = [] context = "" for doc in docs: context += f"{doc.page_content}\n\n\n" source = doc.metadata source = source["source"] sources.append(source) if context == "": context = "No context found" else: pass sources = list(set(sources)) temp_context = context return context def answer_query(self, query: str, vectorstore: str, llmModel: str = "llama-3.3-70b-versatile"): global sources global temp_context vector_store_name = vectorstore vector_store = answer_query_from_existing_collection(vector_embed=self.vector_embed, sparse_embed=self.sparse_embed, vectorstore=vectorstore) retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20}) compression_retriever = ContextualCompressionRetriever( base_compressor=self.compressor, base_retriever=retriever ) brain_chain = ( {"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda( self.format_docs), "question": RunnableLambda(lambda x: x["question"]), "chatHistory": RunnableLambda(lambda x: x["chatHistory"])} | self.prompt | ChatGroq(model=llmModel, temperature=0.75, max_tokens=512) | StrOutputParser() ) message_chain = RunnableWithMessageHistory( brain_chain, self.get_session_history, input_messages_key="question", history_messages_key="chatHistory" ) chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain follow_up_chain = self.follow_up_prompt | ChatGroq(model_name="llama-3.3-70b-versatile", temperature=0) | self.json_parser output = chain.invoke( {"question": query}, {"configurable": {"session_id": vector_store_name}} ) follow_up_questions = follow_up_chain.invoke({"context": temp_context}) return output, follow_up_questions, sources async def answer_query_stream(self, query: str, vectorstore: str, llmModel): global sources global temp_context vector_store_name = vectorstore vector_store = answer_query_from_existing_collection( vector_embed=self.vector_embed, sparse_embed=self.sparse_embed, vectorstore=vectorstore ) retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20}) compression_retriever = ContextualCompressionRetriever( base_compressor=self.compressor, base_retriever=retriever ) brain_chain = ( { "context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda( self.format_docs), "question": RunnableLambda(lambda x: x["question"]), "chatHistory": RunnableLambda(lambda x: x["chatHistory"]) } | self.prompt | ChatGroq( model=llmModel, temperature=0.75, max_tokens=512, streaming=True ) | StrOutputParser() ) message_chain = RunnableWithMessageHistory( brain_chain, self.get_session_history, input_messages_key="question", history_messages_key="chatHistory" ) chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain async for chunk in chain.astream( {"question": query}, {"configurable": {"session_id": vector_store_name}} ): yield { "type": "main_response", "content": chunk } follow_up_chain = self.follow_up_prompt | ChatGroq( model_name="llama-3.3-70b-versatile", temperature=0 ) | self.json_parser follow_up_questions = await follow_up_chain.ainvoke({"context": temp_context}) yield { "type": "follow_up_questions", "content": follow_up_questions } yield { "type": "sources", "content": sources } def trim_messages(self, chain_input): for store_name in self.chat_history_store: messages = self.chat_history_store[store_name].messages if len(messages) <= 1: pass else: self.chat_history_store[store_name].clear() for message in messages[-1:]: self.chat_history_store[store_name].add_message(message) return True def get_session_history(self, session_id: str) -> BaseChatMessageHistory: if session_id not in self.chat_history_store: self.chat_history_store[session_id] = ChatMessageHistory() return self.chat_history_store[session_id]