File size: 2,923 Bytes
5d4054c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import pickle
from typing import Any

from dotenv import load_dotenv
from haystack.nodes import (  # type: ignore
    AnswerParser,
    EmbeddingRetriever,
    PromptNode,
    PromptTemplate,
)
from haystack.pipelines import Pipeline

from src.document_store.document_store import get_document_store

load_dotenv()

OPENAI_API_KEY = os.environ.get("OPEN_API_KEY")


class RAGPipeline:
    def __init__(
        self,
        embedding_model: str,
        prompt_template: str,
    ):
        self.load_document_store()
        self.embedding_model = embedding_model
        self.prompt_template = prompt_template
        self.retriever_node = self.generate_retriever_node()
        self.prompt_node = self.generate_prompt_node()
        self.update_embeddings()
        self.pipe = self.build_pipeline()

    def run(self, prompt: str, filters: dict) -> Any:
        try:
            result = self.pipe.run(query=prompt, params={"filters": filters})
            return result
        except Exception as e:
            print(e)
            return None

    def build_pipeline(self):
        pipe = Pipeline()
        pipe.add_node(component=self.retriever_node, name="retriever", inputs=["Query"])
        pipe.add_node(
            component=self.prompt_node,
            name="prompt_node",
            inputs=["retriever"],
        )
        return pipe

    def load_document_store(self):
        if os.path.exists(os.path.join("database", "document_store.pkl")):
            with open(
                file=os.path.join("database", "document_store.pkl"), mode="rb"
            ) as f:
                self.document_store = pickle.load(f)
        else:
            self.document_store = get_document_store()

    def generate_retriever_node(self):
        retriever_node = EmbeddingRetriever(
            document_store=self.document_store,
            embedding_model=self.embedding_model,
            top_k=7,
        )
        return retriever_node

    def update_embeddings(self):
        if not os.path.exists(os.path.join("database", "document_store.pkl")):
            self.document_store.update_embeddings(
                self.retriever_node, update_existing_embeddings=True
            )

            with open(
                file=os.path.join("database", "document_store.pkl"), mode="wb"
            ) as f:
                pickle.dump(self.document_store, f)

    def generate_prompt_node(self):
        rag_prompt = PromptTemplate(
            prompt=self.prompt_template,
            output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
        )
        prompt_node = PromptNode(
            model_name_or_path="gpt-4",
            default_prompt_template=rag_prompt,
            api_key="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",
            max_length=3000,
            model_kwargs={"temperature": 0.2, "max_tokens": 4096},
        )
        return prompt_node