File size: 6,505 Bytes
ad87194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os

from langchain.memory import ChatMessageHistory
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.document_compressors import JinaRerank
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_groq import ChatGroq

from core.services.vector_db.qdrent.upload_document import answer_query_from_existing_collection

os.environ["JINA_API_KEY"] = os.getenv("JINA_API")


class AnswerQuery:
    def __init__(self, prompt, vector_embedding, sparse_embedding, follow_up_prompt, json_parser):
        self.chat_history_store = {}
        self.compressor = JinaRerank(model="jina-reranker-v2-base-multilingual")
        self.vector_embed = vector_embedding
        self.sparse_embed = sparse_embedding
        self.prompt = prompt
        self.follow_up_prompt = follow_up_prompt
        self.json_parser = json_parser

    def format_docs(self, docs: str):
        global sources
        global temp_context
        sources = []
        context = ""
        for doc in docs:
            context += f"{doc.page_content}\n\n\n"
            source = doc.metadata
            source = source["source"]
            sources.append(source)
        if context == "":
            context = "No context found"
        else:
            pass
        sources = list(set(sources))
        temp_context = context
        return context

    def answer_query(self, query: str, vectorstore: str, llmModel: str = "llama-3.3-70b-versatile"):
        global sources
        global temp_context
        vector_store_name = vectorstore
        vector_store = answer_query_from_existing_collection(vector_embed=self.vector_embed,
                                                             sparse_embed=self.sparse_embed,
                                                             vectorstore=vectorstore)

        retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20})
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=self.compressor, base_retriever=retriever
        )
        brain_chain = (
                {"context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda(
                    self.format_docs),
                 "question": RunnableLambda(lambda x: x["question"]),
                 "chatHistory": RunnableLambda(lambda x: x["chatHistory"])}
                | self.prompt
                | ChatGroq(model=llmModel, temperature=0.75, max_tokens=512)
                | StrOutputParser()
        )
        message_chain = RunnableWithMessageHistory(
            brain_chain,
            self.get_session_history,
            input_messages_key="question",
            history_messages_key="chatHistory"
        )
        chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain
        follow_up_chain = self.follow_up_prompt | ChatGroq(model_name="llama-3.3-70b-versatile",
                                                           temperature=0) | self.json_parser

        output = chain.invoke(
            {"question": query},
            {"configurable": {"session_id": vector_store_name}}
        )
        follow_up_questions = follow_up_chain.invoke({"context": temp_context})

        return output, follow_up_questions, sources

    async def answer_query_stream(self, query: str, vectorstore: str, llmModel):
        global sources
        global temp_context

        vector_store_name = vectorstore
        vector_store = answer_query_from_existing_collection(
            vector_embed=self.vector_embed,
            sparse_embed=self.sparse_embed,
            vectorstore=vectorstore
        )

        retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10, "fetch_k": 20})
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=self.compressor,
            base_retriever=retriever
        )

        brain_chain = (
                {
                    "context": RunnableLambda(lambda x: x["question"]) | compression_retriever | RunnableLambda(
                        self.format_docs),
                    "question": RunnableLambda(lambda x: x["question"]),
                    "chatHistory": RunnableLambda(lambda x: x["chatHistory"])
                }
                | self.prompt
                | ChatGroq(
            model=llmModel,
            temperature=0.75,
            max_tokens=512,
            streaming=True
        )
                | StrOutputParser()
        )

        message_chain = RunnableWithMessageHistory(
            brain_chain,
            self.get_session_history,
            input_messages_key="question",
            history_messages_key="chatHistory"
        )

        chain = RunnablePassthrough.assign(messages_trimmed=self.trim_messages) | message_chain

        async for chunk in chain.astream(
                {"question": query},
                {"configurable": {"session_id": vector_store_name}}
        ):
            yield {
                "type": "main_response",
                "content": chunk
            }

        follow_up_chain = self.follow_up_prompt | ChatGroq(
            model_name="llama-3.3-70b-versatile",
            temperature=0
        ) | self.json_parser

        follow_up_questions = await follow_up_chain.ainvoke({"context": temp_context})

        yield {
            "type": "follow_up_questions",
            "content": follow_up_questions
        }

        yield {
            "type": "sources",
            "content": sources
        }

    def trim_messages(self, chain_input):
        for store_name in self.chat_history_store:
            messages = self.chat_history_store[store_name].messages
            if len(messages) <= 1:
                pass
            else:
                self.chat_history_store[store_name].clear()
                for message in messages[-1:]:
                    self.chat_history_store[store_name].add_message(message)
        return True

    def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        if session_id not in self.chat_history_store:
            self.chat_history_store[session_id] = ChatMessageHistory()
        return self.chat_history_store[session_id]