File size: 3,387 Bytes
78f9503
 
 
 
 
 
 
 
d2d3ca7
78f9503
d2d3ca7
78f9503
 
 
 
 
 
 
 
 
 
d2d3ca7
78f9503
 
 
 
 
 
d2d3ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78f9503
d2d3ca7
 
 
 
 
 
78f9503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e41e611
d2d3ca7
78f9503
d2d3ca7
 
78f9503
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from langchain import hub
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
import gradio as gr
from dotenv import load_dotenv
from langchain.schema import HumanMessage, AIMessage
from langchain.prompts import ChatPromptTemplate
import os.path
from init_db import db_description
# Load environment variables
load_dotenv()

# Check if database exists
if not os.path.exists('estate.db'):
    raise FileNotFoundError(
        "Database file 'estate.db' not found. Please run 'uv run init_db.py' first."
    )

# Initialize model and database
model = ChatOpenAI(model="gpt-4o-2024-08-06", streaming=True)
db = SQLDatabase.from_uri("sqlite:///estate.db")

# Set up SQL toolkit and tools
toolkit = SQLDatabaseToolkit(db=db, llm=model)
tools = toolkit.get_tools()

sql_distance_query = """
SELECT 
    id, 
    url,
    structure->>'$.lat' AS lat,
    structure->>'$.lng' AS lng,
    (6371000 * acos(
        cos(radians(50.08804)) * cos(radians(CAST(structure->>'$.lat' AS FLOAT))) * 
        cos(radians(CAST(structure->>'$.lng' AS FLOAT)) - radians(14.42076)) + 
        sin(radians(50.08804)) * sin(radians(CAST(structure->>'$.lat' AS FLOAT)))
    )) AS distance
FROM project
WHERE structure->>'$.lat' IS NOT NULL 
AND structure->>'$.lng' IS NOT NULL
ORDER BY distance ASC
LIMIT 5;
"""

# Get the OpenAI tools agent prompt
prompt = ChatPromptTemplate.from_messages([
  ("system", "You are a helpful assistant. You speak Czech. You can answer questions about real estate projects (novostavby) in Czech Republic. You have access to a database of real estate projects (translate output to Czech too). Some info about the structure of the database: " + db_description + "Note that some json values can be null so you must sometimes check with IS NOT NULL. For calculating you can adapt this query: " + sql_distance_query),
  ("placeholder", "{chat_history}"),
  ("human", "{input}"),
  ("placeholder", "{agent_scratchpad}"),
])

# Create the agent with OpenAI tools format
agent = create_openai_tools_agent(
    llm=model,
    tools=tools,
    prompt=prompt
)

# Create agent executor
agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    verbose=True
)

def chat_with_sql(message, history):
    try:
      history_langchain_format = []
      for msg in history:
          if msg['role'] == "user":
              history_langchain_format.append(HumanMessage(content=msg['content']))
          elif msg['role'] == "assistant":
              history_langchain_format.append(AIMessage(content=msg['content']))
      history_langchain_format.append(HumanMessage(content=message))
      response = agent_executor.invoke({"input": message, "history": history_langchain_format})
      for i in range(len(response["output"])):
          yield response["output"][:i+1]
    except Exception as e:
        yield f"Error: {str(e)}"

# Create the Gradio interface
demo = gr.ChatInterface(
    fn=chat_with_sql,
    title="Estate Chat",
    description="Zeptej se me na cokoli o novostavbách v ČR 🇨🇿",
    examples=[
        "Projekt s nejnižším vkladem",
        "Nejlevnejsi byt v Praze",
    ],
    type="messages"
)

if __name__ == "__main__":
    demo.launch(share=False, debug=True)