Spaces:
Build error
Build error
Upload 6 files
Browse files- database_table_descriptions.csv +9 -0
- examples.py +25 -0
- langchain_utils.py +71 -0
- main.py +125 -0
- prompts.py +39 -0
- table_details.py +42 -0
database_table_descriptions.csv
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Table,Description
|
2 |
+
productlines,"Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines."
|
3 |
+
products,"Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table."
|
4 |
+
offices,"Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code."
|
5 |
+
employees,"Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute."
|
6 |
+
customers,"Captures data on customers, including customer number, name, contact details, address, assigned sales rep, and credit limit. Central to managing customer relationships and sales processes."
|
7 |
+
payments,"Records payments made by customers, tracking the customer number, check number, payment date, and amount. Linked to the customers table for financial tracking and account management."
|
8 |
+
orders,"Details each sales order placed by customers, including order number, dates, status, comments, and customer number. Linked to the customers table, tracking sales transactions."
|
9 |
+
orderdetails,"Describes individual line items for each sales order, including order number, product code, quantity, price, and order line number. Links orders to products, detailing the items sold."
|
examples.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain_community.vectorstores import FAISS
|
3 |
+
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
|
4 |
+
from langchain_openai import OpenAIEmbeddings
|
5 |
+
|
6 |
+
# Example queries remain the same
|
7 |
+
examples = [
|
8 |
+
{
|
9 |
+
"input": "Retrieve the user who have placed the highest total value of orders.",
|
10 |
+
"query": "SELECT u.username, SUM(p.price) AS total_order_value FROM users u JOIN orders o ON u.user_id = o.user_id JOIN products p ON o.product_id = p.product_id GROUP BY u.username ORDER BY total_order_value DESC LIMIT 10;"
|
11 |
+
},
|
12 |
+
# ... (rest of the examples)
|
13 |
+
]
|
14 |
+
|
15 |
+
@st.cache_resource
|
16 |
+
def get_example_selector(api_key):
|
17 |
+
embeddings = OpenAIEmbeddings(api_key=api_key)
|
18 |
+
example_selector = SemanticSimilarityExampleSelector.from_examples(
|
19 |
+
examples,
|
20 |
+
embeddings,
|
21 |
+
FAISS,
|
22 |
+
k=2,
|
23 |
+
input_variables=["input"],
|
24 |
+
)
|
25 |
+
return example_selector
|
langchain_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain_community.utilities.sql_database import SQLDatabase
|
3 |
+
from langchain.chains import create_sql_query_chain
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
6 |
+
from langchain.memory import ChatMessageHistory
|
7 |
+
from operator import itemgetter
|
8 |
+
from langchain_core.output_parsers import StrOutputParser
|
9 |
+
from langchain_core.runnables import RunnablePassthrough
|
10 |
+
from table_details import create_table_chain
|
11 |
+
from prompts import create_prompts
|
12 |
+
|
13 |
+
def get_db_uri(credentials):
|
14 |
+
return f"postgresql+psycopg2://{credentials['user']}:{credentials['password']}@{credentials['host']}:{credentials['port']}/{credentials['database']}"
|
15 |
+
|
16 |
+
@st.cache_resource
|
17 |
+
def get_chain(_db_uri, api_key):
|
18 |
+
"""Create the langchain with the provided credentials"""
|
19 |
+
try:
|
20 |
+
db = SQLDatabase.from_uri(_db_uri)
|
21 |
+
llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
|
22 |
+
|
23 |
+
# Get the table chain and prompts
|
24 |
+
table_chain = create_table_chain(api_key)
|
25 |
+
final_prompt, answer_prompt = create_prompts(api_key)
|
26 |
+
|
27 |
+
generate_query = create_sql_query_chain(llm, db, final_prompt)
|
28 |
+
execute_query = QuerySQLDataBaseTool(db=db)
|
29 |
+
rephrase_answer = answer_prompt | llm | StrOutputParser()
|
30 |
+
|
31 |
+
chain = (
|
32 |
+
RunnablePassthrough.assign(table_names_to_use=table_chain) |
|
33 |
+
RunnablePassthrough.assign(query=generate_query).assign(
|
34 |
+
result=itemgetter("query") | execute_query
|
35 |
+
) | rephrase_answer
|
36 |
+
)
|
37 |
+
|
38 |
+
return chain
|
39 |
+
except Exception as e:
|
40 |
+
st.error(f"Error creating chain: {str(e)}")
|
41 |
+
return None
|
42 |
+
|
43 |
+
def create_history(messages):
|
44 |
+
history = ChatMessageHistory()
|
45 |
+
for message in messages:
|
46 |
+
if message["role"] == "user":
|
47 |
+
history.add_user_message(message["content"])
|
48 |
+
else:
|
49 |
+
history.add_ai_message(message["content"])
|
50 |
+
return history
|
51 |
+
|
52 |
+
def invoke_chain(question, messages, db_credentials, api_key):
|
53 |
+
try:
|
54 |
+
db_uri = get_db_uri(db_credentials)
|
55 |
+
chain = get_chain(db_uri, api_key)
|
56 |
+
if chain is None:
|
57 |
+
return "Sorry, I couldn't connect to the database. Please check your credentials."
|
58 |
+
|
59 |
+
history = create_history(messages)
|
60 |
+
response = chain.invoke({
|
61 |
+
"question": question,
|
62 |
+
"top_k": 100,
|
63 |
+
"messages": history.messages
|
64 |
+
})
|
65 |
+
|
66 |
+
history.add_user_message(question)
|
67 |
+
history.add_ai_message(response)
|
68 |
+
return response
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
return f"An error occurred: {str(e)}"
|
main.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain_utils import invoke_chain
|
4 |
+
from langchain_community.callbacks import get_openai_callback
|
5 |
+
|
6 |
+
def init_session_state():
|
7 |
+
if "messages" not in st.session_state:
|
8 |
+
st.session_state.messages = []
|
9 |
+
if "connected" not in st.session_state:
|
10 |
+
st.session_state.connected = False
|
11 |
+
|
12 |
+
def create_sidebar():
|
13 |
+
with st.sidebar:
|
14 |
+
st.title("Postgres Credentials")
|
15 |
+
st.subheader("Enter your Credentials & Connect")
|
16 |
+
|
17 |
+
# Database credentials
|
18 |
+
host = st.text_input("Host", value="localhost")
|
19 |
+
port = st.text_input("Port", value="5432")
|
20 |
+
user = st.text_input("User", value="postgres")
|
21 |
+
password = st.text_input("Password", type="password")
|
22 |
+
database = st.text_input("Database")
|
23 |
+
|
24 |
+
# OpenAI API key
|
25 |
+
api_key = st.text_input("OpenAI API Key", type="password")
|
26 |
+
|
27 |
+
# Connect button
|
28 |
+
if st.button("Connect", use_container_width=True):
|
29 |
+
try:
|
30 |
+
# Store credentials in session state
|
31 |
+
st.session_state.db_credentials = {
|
32 |
+
"host": host,
|
33 |
+
"port": port,
|
34 |
+
"user": user,
|
35 |
+
"password": password,
|
36 |
+
"database": database
|
37 |
+
}
|
38 |
+
st.session_state.api_key = api_key
|
39 |
+
st.session_state.connected = True
|
40 |
+
st.success("Successfully connected!")
|
41 |
+
except Exception as e:
|
42 |
+
st.error(f"Connection failed: {str(e)}")
|
43 |
+
st.session_state.connected = False
|
44 |
+
|
45 |
+
def main():
|
46 |
+
init_session_state()
|
47 |
+
create_sidebar()
|
48 |
+
|
49 |
+
st.title("Chat with Postgres DB")
|
50 |
+
|
51 |
+
if not st.session_state.connected:
|
52 |
+
st.info("Please enter your credentials in the sidebar and connect first.")
|
53 |
+
return
|
54 |
+
|
55 |
+
# Display the welcome message with the database icon using markdown
|
56 |
+
st.markdown("""
|
57 |
+
<div style='display: flex; align-items: center; gap: 10px;'>
|
58 |
+
<span style='font-size: 24px;'>🗄️</span>
|
59 |
+
<span>Hello! I'm a QualityKiosk's SQL assistant. Ask me anything about your database.</span>
|
60 |
+
</div>
|
61 |
+
""", unsafe_allow_html=True)
|
62 |
+
|
63 |
+
# Display chat messages
|
64 |
+
for message in st.session_state.messages:
|
65 |
+
with st.chat_message(message["role"]):
|
66 |
+
st.markdown(message["content"])
|
67 |
+
|
68 |
+
# Chat input
|
69 |
+
if prompt := st.chat_input("Type a message..."):
|
70 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
71 |
+
with st.chat_message("user"):
|
72 |
+
st.markdown(prompt)
|
73 |
+
|
74 |
+
with st.chat_message("assistant"):
|
75 |
+
with st.spinner("Thinking..."):
|
76 |
+
with get_openai_callback() as cb:
|
77 |
+
response = invoke_chain(
|
78 |
+
prompt,
|
79 |
+
st.session_state.messages,
|
80 |
+
st.session_state.db_credentials,
|
81 |
+
st.session_state.api_key
|
82 |
+
)
|
83 |
+
print(f"OpenAI Stats: {cb}")
|
84 |
+
st.markdown(response)
|
85 |
+
|
86 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
# Set dark theme and wide layout
|
90 |
+
st.set_page_config(
|
91 |
+
page_title="Chat with Postgres DB",
|
92 |
+
layout="wide",
|
93 |
+
initial_sidebar_state="expanded",
|
94 |
+
# Optional: Add a custom theme
|
95 |
+
menu_items={
|
96 |
+
'Get Help': 'https://www.qualitykiosk.com',
|
97 |
+
'About': "# Chat with Postgres DB\nA QualityKiosk's SQL Assistant"
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
# Add custom CSS for dark theme and styling
|
102 |
+
st.markdown("""
|
103 |
+
<style>
|
104 |
+
.stApp {
|
105 |
+
background-color: #1E1E1E;
|
106 |
+
color: #FFFFFF;
|
107 |
+
}
|
108 |
+
.stSidebar {
|
109 |
+
background-color: #262626;
|
110 |
+
}
|
111 |
+
.stButton>button {
|
112 |
+
background-color: #0E86D4;
|
113 |
+
color: white;
|
114 |
+
}
|
115 |
+
.stTextInput>div>div>input {
|
116 |
+
background-color: #333333;
|
117 |
+
color: white;
|
118 |
+
}
|
119 |
+
.stMarkdown {
|
120 |
+
color: white;
|
121 |
+
}
|
122 |
+
</style>
|
123 |
+
""", unsafe_allow_html=True)
|
124 |
+
|
125 |
+
main()
|
prompts.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from examples import get_example_selector
|
2 |
+
from langchain_core.prompts import (
|
3 |
+
ChatPromptTemplate,
|
4 |
+
MessagesPlaceholder,
|
5 |
+
FewShotChatMessagePromptTemplate,
|
6 |
+
PromptTemplate
|
7 |
+
)
|
8 |
+
|
9 |
+
example_prompt = ChatPromptTemplate.from_messages([
|
10 |
+
("human", "{input}\nSQLQuery:"),
|
11 |
+
("ai", "{query}"),
|
12 |
+
])
|
13 |
+
|
14 |
+
def create_prompts(api_key):
|
15 |
+
example_selector = get_example_selector(api_key)
|
16 |
+
|
17 |
+
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
18 |
+
example_prompt=example_prompt,
|
19 |
+
example_selector=example_selector,
|
20 |
+
input_variables=["input", "top_k"],
|
21 |
+
)
|
22 |
+
|
23 |
+
final_prompt = ChatPromptTemplate.from_messages([
|
24 |
+
("system", "You are a Postgres SQL expert. Given an input question, create a syntactically correct Postgres SQL query to run. Unless otherwise specified.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
|
25 |
+
few_shot_prompt,
|
26 |
+
MessagesPlaceholder(variable_name="messages"),
|
27 |
+
("human", "{input}"),
|
28 |
+
])
|
29 |
+
|
30 |
+
answer_prompt = PromptTemplate.from_template(
|
31 |
+
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.
|
32 |
+
|
33 |
+
Question: {question}
|
34 |
+
SQL Query: {query}
|
35 |
+
SQL Result: {result}
|
36 |
+
Answer: """
|
37 |
+
)
|
38 |
+
|
39 |
+
return final_prompt, answer_prompt
|
table_details.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
from operator import itemgetter
|
4 |
+
from langchain.chains.openai_tools import create_extraction_chain_pydantic
|
5 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
6 |
+
from langchain_openai import ChatOpenAI
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
class Table(BaseModel):
|
10 |
+
"""Table in SQL database."""
|
11 |
+
name: str = Field(description="Name of table in SQL database.")
|
12 |
+
|
13 |
+
def get_tables(tables: List[Table]) -> List[str]:
|
14 |
+
return [table.name for table in tables]
|
15 |
+
|
16 |
+
@st.cache_data
|
17 |
+
def get_table_details():
|
18 |
+
try:
|
19 |
+
table_description = pd.read_excel("database_table_descriptions.xlsx")
|
20 |
+
table_details = ""
|
21 |
+
for index, row in table_description.iterrows():
|
22 |
+
table_details += f"Table Name:{row['Table']}\nTable Description:{row['Description']}\n\n"
|
23 |
+
return table_details
|
24 |
+
except Exception as e:
|
25 |
+
st.error(f"Error reading table descriptions: {str(e)}")
|
26 |
+
return ""
|
27 |
+
|
28 |
+
def create_table_chain(api_key):
|
29 |
+
table_details = get_table_details()
|
30 |
+
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
|
31 |
+
The tables are:
|
32 |
+
|
33 |
+
{table_details}
|
34 |
+
|
35 |
+
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
|
36 |
+
|
37 |
+
llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
|
38 |
+
return (
|
39 |
+
{"input": itemgetter("question")} |
|
40 |
+
create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) |
|
41 |
+
get_tables
|
42 |
+
)
|