File size: 2,185 Bytes
78f9503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from langchain import hub
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
import gradio as gr
from dotenv import load_dotenv
from langchain.schema import HumanMessage, AIMessage
import os.path

# Load environment variables
load_dotenv()

# Check if database exists
if not os.path.exists('estate.db'):
    raise FileNotFoundError(
        "Database file 'estate.db' not found. Please run 'uv run init_db.py' first."
    )

# Initialize model and database
model = ChatOpenAI(model="gpt-4", streaming=True)
db = SQLDatabase.from_uri("sqlite:///estate.db")

# Set up SQL toolkit and tools
toolkit = SQLDatabaseToolkit(db=db, llm=model)
tools = toolkit.get_tools()

# Get the OpenAI tools agent prompt
prompt = hub.pull("hwchase17/openai-tools-agent")

# Create the agent with OpenAI tools format
agent = create_openai_tools_agent(
    llm=model,
    tools=tools,
    prompt=prompt
)

# Create agent executor
agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    verbose=True
)

def chat_with_sql(message, history):
    try:
      history_langchain_format = []
      for msg in history:
          if msg['role'] == "user":
              history_langchain_format.append(HumanMessage(content=msg['content']))
          elif msg['role'] == "assistant":
              history_langchain_format.append(AIMessage(content=msg['content']))
      history_langchain_format.append(HumanMessage(content=message))
      response = agent_executor.invoke({"input": message, "history": history_langchain_format})
      for i in range(len(response["output"])):
          yield response["output"][:i+1]
    except Exception as e:
        yield f"Error: {str(e)}"

# Create the Gradio interface
demo = gr.ChatInterface(
    fn=chat_with_sql,
    title="SQL Chat Assistant",
    description="Ask questions about your SQLite database!",
    examples=[
        "What is project with the lowest deposit?",
    ],
    type="messages"
)

if __name__ == "__main__":
    demo.launch(share=False, debug=True)