File size: 5,063 Bytes
fb0495b
 
d0eb2f2
fb0495b
 
 
 
a983ce0
fb0495b
 
 
 
 
 
 
 
 
 
 
c9580eb
12a8453
fb0495b
 
 
 
 
 
12a8453
fb0495b
 
a983ce0
12a8453
 
fb0495b
 
 
a983ce0
 
fb0495b
a983ce0
fb0495b
 
 
12a8453
fb0495b
12a8453
 
fb0495b
12a8453
 
 
 
 
fb0495b
 
 
12a8453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb0495b
 
12a8453
fb0495b
12a8453
 
fb0495b
 
 
12a8453
fb0495b
 
 
12a8453
 
 
fb0495b
 
12a8453
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from anthropic import Anthropic
from typing import List
import os


from retriever import Retriever
from reranker import Reranker
from text_embedder_encoder import TextEmbedder, encoder_model_name


retriever = Retriever()
reranker = Reranker()


class RAGAgent:
    def __init__(

            self,

            retriever=retriever,

            reranker=reranker,

            anthropic_api_key: str = os.environ["anthropic_api_key"],

            model_name: str = "claude-3-5-sonnet-20241022",

            max_tokens: int = 1024,

            temperature: float = 0.0,

    ):
        self.retriever = retriever
        self.reranker = reranker
        self.client = Anthropic(api_key=anthropic_api_key)
        self.model_name = model_name
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.text_embedder = TextEmbedder()
        self.conversation_summary = ""
        self.messages = []

    def get_context(self, query: str) -> List[str]:
        # Get initial candidates from retriever
        query_vector = self.text_embedder.encode(query)
        retrieved_answers_ids = self.retriever.search_similar(query_vector)
        # Rerank the candidates
        context = self.reranker.rerank(query_vector, retrieved_answers_ids)

        return context

    def generate_prompt(self, context: List[str], conversation_summary: str = "") -> str:
        context = "\n".join(context)
        summary_context = f"\nืกื™ื›ื•ื ื”ืฉื™ื—ื” ืขื“ ื›ื”:\n{conversation_summary}" if conversation_summary else ""

        prompt = f"""

                    ืืชื” ืจื•ืคื ืฉื™ื ื™ื™ื, ื“ื•ื‘ืจ ืขื‘ืจื™ืช ื‘ืœื‘ื“. ืงื•ืจืื™ื ืœืš 'ืจื•ืคื ื”ืฉื™ื ื™ื™ื ื”ืืœืงื˜ืจื•ื ื™ ื”ืขื‘ืจื™ ื”ืจืืฉื•ืŸ'.{summary_context}

                    ืขื ื” ืœืžื˜ื•ืคืœ ืขืœ ื”ืฉืืœื” ืฉืœื• ืขืœ ืกืžืš ื”ืงื•ื ื˜ืงืก ื”ื‘ื: {context}. 

                    ื”ื•ืกืฃ ื›ืžื” ืฉื™ื•ืชืจ ืคืจื˜ื™ื, ื•ื“ืื’ ืฉื”ืชื—ื‘ื™ืจ ื™ื”ื™ื” ืชืงื™ืŸ ื•ื™ืคื”. 

                    ืชืขืฆื•ืจ ื›ืฉืืชื” ืžืจื’ื™ืฉ ืฉืžื™ืฆื™ืช ืืช ืขืฆืžืš. ืืœ ืชืžืฆื™ื ื“ื‘ืจื™ื. 

                    ื•ืืœ ืชืขื ื” ื‘ืฉืคื•ืช ืฉื”ืŸ ืœื ืขื‘ืจื™ืช.

                 """
        return prompt

    def update_summary(self, question: str, answer: str) -> str:
        """Update the conversation summary with the new interaction"""
        summary_prompt = {
            "model": self.model_name,
            "max_tokens": 500,
            "temperature": 0.0,
            "messages": [
                {
                    "role": "user",
                    "content": f"""ืกื›ื ืืช ื”ืฉื™ื—ื” ื‘ืขื‘ืจื™ืช, ื”ื ื” ืกื™ื›ื•ื ื”ืฉื™ื—ื” ืขื“ ื›ื”:

{self.conversation_summary if self.conversation_summary else "ืื™ืŸ ืฉื™ื—ื” ืงื•ื“ืžืช."}



ืื™ื ื˜ืจืืงืฆื™ื” ื—ื“ืฉื”:

ืฉืืœืช ื”ืžื˜ื•ืคืœ: {question}

ืชืฉื•ื‘ืช ื”ืจื•ืคื: {answer}



ืื ื ืกืคืง ืกื™ื›ื•ื ืžืขื•ื“ื›ืŸ ืฉื›ื•ืœืœ ืืช ื”ืžื™ื“ืข ื”ืจืคื•ืื™ ืžื”ืกื™ื›ื•ื ื”ืงื•ื“ื ื‘ื ื•ืกืฃ ืœื“ื’ืฉ ืขืœ ื”ืื™ื ื˜ืจืงืฆื™ื” ื”ื—ื“ืฉื”. ื”ืกื™ื›ื•ื ืฆืจื™ืš ืœื”ื™ื•ืช ืชืžืฆื™ืชื™ ืขื“ 100 ืžื™ืœื”.

ื•ืชืจ ืขืœ ืžื™ื“ืข ืœื ืจืœื•ื•ื ื˜ื™ ืžื”ืกื™ื›ื•ืžื™ื ื”ืงื•ื“ืžื™ื"""
                }
            ]
        }

        try:
            response = self.client.messages.create(**summary_prompt)
            self.conversation_summary = response.content[0].text
            return self.conversation_summary
        except Exception as e:
            print(f"Error updating summary: {e}")
            return self.get_basic_summary()

    def get_basic_summary(self) -> str:
        """Fallback method for basic summary"""
        summary = []
        for i in range(0, len(self.messages), 2):
            if i + 1 < len(self.messages):
                summary.append(f"ืฉืืœืช ื”ืžื˜ื•ืคืœ: {self.messages[i]['content']}")
                summary.append(f"ืชืฉื•ื‘ืช ื”ืจื•ืคื ืฉื™ื ื™ื™ื: {self.messages[i + 1]['content']}\n")
        return "\n".join(summary)

    def get_response(self, question: str) -> str:
        # Get relevant context
        context = self.get_context(question + self.conversation_summary)

        # Generate prompt with context and current conversation summary
        prompt = self.generate_prompt(context, self.conversation_summary)

        # Get response from Claude
        response = self.client.messages.create(
            model=self.model_name,
            max_tokens=self.max_tokens,
            temperature=self.temperature,
            messages=[
                {"role": "assistant", "content": prompt},
                {"role": "user", "content": f"{question}"}
            ]
        )

        answer = response.content[0].text

        # Store messages for history
        self.messages.extend([
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer}
        ])

        # Update conversation summary
        self.update_summary(question, answer)

        return answer