Spaces:
Sleeping
Sleeping
from langchain import hub | |
from langchain.agents import AgentExecutor, create_openai_tools_agent | |
from langchain_openai import ChatOpenAI | |
from langchain_community.utilities import SQLDatabase | |
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit | |
import gradio as gr | |
from dotenv import load_dotenv | |
from langchain.schema import HumanMessage, AIMessage | |
from langchain.prompts import ChatPromptTemplate | |
import os.path | |
from init_db import db_description | |
# Load environment variables | |
load_dotenv() | |
# Check if database exists | |
if not os.path.exists('estate.db'): | |
raise FileNotFoundError( | |
"Database file 'estate.db' not found. Please run 'uv run init_db.py' first." | |
) | |
# Initialize model and database | |
model = ChatOpenAI(model="gpt-4o-2024-08-06", streaming=True) | |
db = SQLDatabase.from_uri("sqlite:///estate.db") | |
# Set up SQL toolkit and tools | |
toolkit = SQLDatabaseToolkit(db=db, llm=model) | |
tools = toolkit.get_tools() | |
sql_distance_query = """ | |
SELECT | |
id, | |
url, | |
structure->>'$.lat' AS lat, | |
structure->>'$.lng' AS lng, | |
(6371000 * acos( | |
cos(radians(50.08804)) * cos(radians(CAST(structure->>'$.lat' AS FLOAT))) * | |
cos(radians(CAST(structure->>'$.lng' AS FLOAT)) - radians(14.42076)) + | |
sin(radians(50.08804)) * sin(radians(CAST(structure->>'$.lat' AS FLOAT))) | |
)) AS distance | |
FROM project | |
WHERE structure->>'$.lat' IS NOT NULL | |
AND structure->>'$.lng' IS NOT NULL | |
ORDER BY distance ASC | |
LIMIT 5; | |
""" | |
# Get the OpenAI tools agent prompt | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful assistant. You speak Czech. You can answer questions about real estate projects (novostavby) in Czech Republic. You have access to a database of real estate projects (translate output to Czech too). Some info about the structure of the database: " + db_description + "Note that some json values can be null so you must sometimes check with IS NOT NULL. For calculating you can adapt this query: " + sql_distance_query), | |
("placeholder", "{chat_history}"), | |
("human", "{input}"), | |
("placeholder", "{agent_scratchpad}"), | |
]) | |
# Create the agent with OpenAI tools format | |
agent = create_openai_tools_agent( | |
llm=model, | |
tools=tools, | |
prompt=prompt | |
) | |
# Create agent executor | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
verbose=True | |
) | |
def chat_with_sql(message, history): | |
try: | |
history_langchain_format = [] | |
for msg in history: | |
if msg['role'] == "user": | |
history_langchain_format.append(HumanMessage(content=msg['content'])) | |
elif msg['role'] == "assistant": | |
history_langchain_format.append(AIMessage(content=msg['content'])) | |
history_langchain_format.append(HumanMessage(content=message)) | |
response = agent_executor.invoke({"input": message, "history": history_langchain_format}) | |
for i in range(len(response["output"])): | |
yield response["output"][:i+1] | |
except Exception as e: | |
yield f"Error: {str(e)}" | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=chat_with_sql, | |
title="Estate Chat", | |
description="Zeptej se me na cokoli o novostavbách v ČR 🇨🇿", | |
examples=[ | |
"Projekt s nejnižším vkladem", | |
"Nejlevnejsi byt v Praze", | |
], | |
type="messages" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False, debug=True) |