Tai Truong
fix readme
d202ada
from langchain.agents import AgentExecutor
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.utilities import SQLDatabase
from langflow.base.agents.agent import LCAgentComponent
from langflow.inputs import HandleInput, MessageTextInput
class SQLAgentComponent(LCAgentComponent):
display_name = "SQLAgent"
description = "Construct an SQL agent from an LLM and tools."
name = "SQLAgent"
icon = "LangChain"
inputs = [
*LCAgentComponent._base_inputs,
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
MessageTextInput(name="database_uri", display_name="Database URI", required=True),
HandleInput(
name="extra_tools",
display_name="Extra Tools",
input_types=["Tool", "BaseTool"],
is_list=True,
advanced=True,
),
]
def build_agent(self) -> AgentExecutor:
db = SQLDatabase.from_uri(self.database_uri)
toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
agent_args = self.get_agent_kwargs()
agent_args["max_iterations"] = agent_args["agent_executor_kwargs"]["max_iterations"]
del agent_args["agent_executor_kwargs"]["max_iterations"]
return create_sql_agent(llm=self.llm, toolkit=toolkit, extra_tools=self.extra_tools or [], **agent_args)