Spaces:
Sleeping
Sleeping
File size: 4,803 Bytes
ced6b34 b07ca6f ced6b34 b07ca6f ced6b34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import json
from abc import abstractmethod, ABC
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_huggingface import HuggingFaceEmbeddings
from langchain import hub
from langchain.agents import create_react_agent
from langchain.schema import SystemMessage
from langchain.schema import SystemMessage, HumanMessage
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langgraph.prebuilt import create_react_agent
class Database(ABC):
@abstractmethod
def create_agent(self, llm):
raise NotImplementedError
class Session:
def __init__(self, llm: BaseChatModel, datasources=None):
self.llm = llm
self.datasources = datasources
self._datasources = []
self._dataagents = []
if self.datasources is not None:
for datasource in self.datasources:
self.add_datasource(datasource)
def add_datasource(self, database: Database):
agent = database.create_agent(self.llm)
self._datasources.append(database)
self._dataagents.append(agent)
def get_relevant_source(self, message, datasource):
if datasource is not None:
return self._datasources[datasource], self._dataagents[datasource]
return self._datasources[0], self._dataagents[0]
def invoke(self, message, datasource=None):
db, agent = self.get_relevant_source(message, datasource)
processed_message = db.process_message(message)
response = agent.invoke(processed_message)
processed_response = db.postprocess(response)
return processed_response, response
def stream(self, message, stream_mode=None):
db, agent = self.get_relevant_source(message)
return agent.stream(
{"messages": [("user", message)]},
stream_mode=stream_mode,
)
class SQLDatabase(Database):
def __init__(self, db):
self.db = db
def create_agent(self, llm):
toolkit = SQLDatabaseToolkit(db=self.db, llm=llm)
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="SQLite", top_k=5)
agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)
return agent
def process_message(self, message):
return {"messages": [("user", message)]}
def postprocess(self, response):
return response['messages'][-1].content
@classmethod
def from_uri(cls, database_uri, engine_args=None, **kwargs):
db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs)
return cls(db)
class DocumentDatabase(Database):
def __init__(
self,
path: str,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
top_k: int = 3,
model_kwargs = None,
encode_kwargs = None,
):
self.path = path
self.model_name = model_name
self.top_k = top_k
self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs
self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs
embeddings = HuggingFaceEmbeddings(
model_name=self.model_name,
model_kwargs=self.model_kwargs,
encode_kwargs=self.encode_kwargs,
show_progress=False,
)
self.vector_store = InMemoryVectorStore(embeddings)
with open(path, 'rb') as f:
self.vector_store.store = json.load(f)
def create_agent(self, llm):
# Step 1: Retrieve relevant documents from the vector store
retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k)))
# Step 2: Format the retrieved docs into a prompt
def format_prompt(inputs):
message, docs = inputs
docs_in_promp = '\n\n'.join(doc.page_content for doc in docs)
prompt = [
SystemMessage(
"You are an assistant for question-answering tasks. " +
"Use the following pieces of retrieved context to answer " +
"the question. If you don't know the answer, say that you " +
"don't know. Use three sentences maximum and keep the " +
"answer concise." +
"\n\n" +
docs_in_promp
),
HumanMessage(message)
]
return prompt
format_prompt_node = RunnableLambda(format_prompt)
# Step 3: Invoke LLM with the formatted prompt
invoke_llm = llm
# Step 4: Chain everything together
agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm
return agent_pipeline
def process_message(self, message):
return message
def postprocess(self, response):
return response.content
|