File size: 4,803 Bytes
ced6b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07ca6f
ced6b34
 
b07ca6f
 
 
 
 
 
 
ced6b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
from abc import abstractmethod, ABC

from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel

from langchain_huggingface import HuggingFaceEmbeddings

from langchain import hub
from langchain.agents import create_react_agent
from langchain.schema import SystemMessage
from langchain.schema import SystemMessage, HumanMessage
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough

from langgraph.prebuilt import create_react_agent

class Database(ABC):
	@abstractmethod
	def create_agent(self, llm):
		raise NotImplementedError

class Session:
	def __init__(self, llm: BaseChatModel, datasources=None):
		self.llm = llm
		self.datasources = datasources
		self._datasources = []
		self._dataagents = []

		if self.datasources is not None:
			for datasource in self.datasources:
				self.add_datasource(datasource)

	def add_datasource(self, database: Database):
		agent = database.create_agent(self.llm)
		self._datasources.append(database)
		self._dataagents.append(agent)

	def get_relevant_source(self, message, datasource):
		if datasource is not None:
			return self._datasources[datasource], self._dataagents[datasource]
		return self._datasources[0], self._dataagents[0]

	def invoke(self, message, datasource=None):
		db, agent = self.get_relevant_source(message, datasource)
		processed_message = db.process_message(message)
		response = agent.invoke(processed_message)
		processed_response = db.postprocess(response)
		return processed_response, response

	def stream(self, message, stream_mode=None):
		db, agent = self.get_relevant_source(message)
		return agent.stream(
			{"messages": [("user", message)]},
			stream_mode=stream_mode,
		)

class SQLDatabase(Database):
	def __init__(self, db):
		self.db = db

	def create_agent(self, llm):
		toolkit = SQLDatabaseToolkit(db=self.db, llm=llm)
		prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
		system_message = prompt_template.format(dialect="SQLite", top_k=5)
		agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)
		return agent

	def process_message(self, message):
		return {"messages": [("user", message)]}

	def postprocess(self, response):
		return response['messages'][-1].content

	@classmethod
	def from_uri(cls, database_uri, engine_args=None, **kwargs):
		db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs)
		return cls(db)


class DocumentDatabase(Database):
	def __init__(
			self,
			path: str,
			model_name: str = "sentence-transformers/all-mpnet-base-v2",
			top_k: int = 3,
			model_kwargs = None,
			encode_kwargs = None,
		):
		self.path = path
		self.model_name = model_name
		self.top_k = top_k
		self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs
		self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs

		embeddings = HuggingFaceEmbeddings(
			model_name=self.model_name,
			model_kwargs=self.model_kwargs,
			encode_kwargs=self.encode_kwargs,
			show_progress=False,
		)
		self.vector_store = InMemoryVectorStore(embeddings)
		with open(path, 'rb') as f:
			self.vector_store.store = json.load(f)

	def create_agent(self, llm):
		# Step 1: Retrieve relevant documents from the vector store
		retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k)))

		# Step 2: Format the retrieved docs into a prompt
		def format_prompt(inputs):
			message, docs = inputs
			docs_in_promp = '\n\n'.join(doc.page_content for doc in docs)
			prompt = [
				SystemMessage(
					"You are an assistant for question-answering tasks. " + 
					"Use the following pieces of retrieved context to answer " + 
					"the question. If you don't know the answer, say that you " + 
					"don't know. Use three sentences maximum and keep the " + 
					"answer concise." + 
					"\n\n" + 
					docs_in_promp
				),
				HumanMessage(message)
			]
			return prompt

		format_prompt_node = RunnableLambda(format_prompt)

		# Step 3: Invoke LLM with the formatted prompt
		invoke_llm = llm

		# Step 4: Chain everything together
		agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm

		return agent_pipeline

	def process_message(self, message):
		return message

	def postprocess(self, response):
		return response.content