Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,923 Bytes
5d4054c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import pickle
from typing import Any
from dotenv import load_dotenv
from haystack.nodes import ( # type: ignore
AnswerParser,
EmbeddingRetriever,
PromptNode,
PromptTemplate,
)
from haystack.pipelines import Pipeline
from src.document_store.document_store import get_document_store
load_dotenv()
OPENAI_API_KEY = os.environ.get("OPEN_API_KEY")
class RAGPipeline:
def __init__(
self,
embedding_model: str,
prompt_template: str,
):
self.load_document_store()
self.embedding_model = embedding_model
self.prompt_template = prompt_template
self.retriever_node = self.generate_retriever_node()
self.prompt_node = self.generate_prompt_node()
self.update_embeddings()
self.pipe = self.build_pipeline()
def run(self, prompt: str, filters: dict) -> Any:
try:
result = self.pipe.run(query=prompt, params={"filters": filters})
return result
except Exception as e:
print(e)
return None
def build_pipeline(self):
pipe = Pipeline()
pipe.add_node(component=self.retriever_node, name="retriever", inputs=["Query"])
pipe.add_node(
component=self.prompt_node,
name="prompt_node",
inputs=["retriever"],
)
return pipe
def load_document_store(self):
if os.path.exists(os.path.join("database", "document_store.pkl")):
with open(
file=os.path.join("database", "document_store.pkl"), mode="rb"
) as f:
self.document_store = pickle.load(f)
else:
self.document_store = get_document_store()
def generate_retriever_node(self):
retriever_node = EmbeddingRetriever(
document_store=self.document_store,
embedding_model=self.embedding_model,
top_k=7,
)
return retriever_node
def update_embeddings(self):
if not os.path.exists(os.path.join("database", "document_store.pkl")):
self.document_store.update_embeddings(
self.retriever_node, update_existing_embeddings=True
)
with open(
file=os.path.join("database", "document_store.pkl"), mode="wb"
) as f:
pickle.dump(self.document_store, f)
def generate_prompt_node(self):
rag_prompt = PromptTemplate(
prompt=self.prompt_template,
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
)
prompt_node = PromptNode(
model_name_or_path="gpt-4",
default_prompt_template=rag_prompt,
api_key="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",
max_length=3000,
model_kwargs={"temperature": 0.2, "max_tokens": 4096},
)
return prompt_node
|