sango07 commited on
Commit
16601c8
·
verified ·
1 Parent(s): 5dc1c7b

Upload 6 files

Browse files
Files changed (6) hide show
  1. database_table_descriptions.csv +9 -0
  2. examples.py +25 -0
  3. langchain_utils.py +71 -0
  4. main.py +125 -0
  5. prompts.py +39 -0
  6. table_details.py +42 -0
database_table_descriptions.csv ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Table,Description
2
+ productlines,"Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines."
3
+ products,"Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table."
4
+ offices,"Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code."
5
+ employees,"Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute."
6
+ customers,"Captures data on customers, including customer number, name, contact details, address, assigned sales rep, and credit limit. Central to managing customer relationships and sales processes."
7
+ payments,"Records payments made by customers, tracking the customer number, check number, payment date, and amount. Linked to the customers table for financial tracking and account management."
8
+ orders,"Details each sales order placed by customers, including order number, dates, status, comments, and customer number. Linked to the customers table, tracking sales transactions."
9
+ orderdetails,"Describes individual line items for each sales order, including order number, product code, quantity, price, and order line number. Links orders to products, detailing the items sold."
examples.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_community.vectorstores import FAISS
3
+ from langchain_core.example_selectors import SemanticSimilarityExampleSelector
4
+ from langchain_openai import OpenAIEmbeddings
5
+
6
+ # Example queries remain the same
7
+ examples = [
8
+ {
9
+ "input": "Retrieve the user who have placed the highest total value of orders.",
10
+ "query": "SELECT u.username, SUM(p.price) AS total_order_value FROM users u JOIN orders o ON u.user_id = o.user_id JOIN products p ON o.product_id = p.product_id GROUP BY u.username ORDER BY total_order_value DESC LIMIT 10;"
11
+ },
12
+ # ... (rest of the examples)
13
+ ]
14
+
15
+ @st.cache_resource
16
+ def get_example_selector(api_key):
17
+ embeddings = OpenAIEmbeddings(api_key=api_key)
18
+ example_selector = SemanticSimilarityExampleSelector.from_examples(
19
+ examples,
20
+ embeddings,
21
+ FAISS,
22
+ k=2,
23
+ input_variables=["input"],
24
+ )
25
+ return example_selector
langchain_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_community.utilities.sql_database import SQLDatabase
3
+ from langchain.chains import create_sql_query_chain
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
6
+ from langchain.memory import ChatMessageHistory
7
+ from operator import itemgetter
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.runnables import RunnablePassthrough
10
+ from table_details import create_table_chain
11
+ from prompts import create_prompts
12
+
13
+ def get_db_uri(credentials):
14
+ return f"postgresql+psycopg2://{credentials['user']}:{credentials['password']}@{credentials['host']}:{credentials['port']}/{credentials['database']}"
15
+
16
+ @st.cache_resource
17
+ def get_chain(_db_uri, api_key):
18
+ """Create the langchain with the provided credentials"""
19
+ try:
20
+ db = SQLDatabase.from_uri(_db_uri)
21
+ llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
22
+
23
+ # Get the table chain and prompts
24
+ table_chain = create_table_chain(api_key)
25
+ final_prompt, answer_prompt = create_prompts(api_key)
26
+
27
+ generate_query = create_sql_query_chain(llm, db, final_prompt)
28
+ execute_query = QuerySQLDataBaseTool(db=db)
29
+ rephrase_answer = answer_prompt | llm | StrOutputParser()
30
+
31
+ chain = (
32
+ RunnablePassthrough.assign(table_names_to_use=table_chain) |
33
+ RunnablePassthrough.assign(query=generate_query).assign(
34
+ result=itemgetter("query") | execute_query
35
+ ) | rephrase_answer
36
+ )
37
+
38
+ return chain
39
+ except Exception as e:
40
+ st.error(f"Error creating chain: {str(e)}")
41
+ return None
42
+
43
+ def create_history(messages):
44
+ history = ChatMessageHistory()
45
+ for message in messages:
46
+ if message["role"] == "user":
47
+ history.add_user_message(message["content"])
48
+ else:
49
+ history.add_ai_message(message["content"])
50
+ return history
51
+
52
+ def invoke_chain(question, messages, db_credentials, api_key):
53
+ try:
54
+ db_uri = get_db_uri(db_credentials)
55
+ chain = get_chain(db_uri, api_key)
56
+ if chain is None:
57
+ return "Sorry, I couldn't connect to the database. Please check your credentials."
58
+
59
+ history = create_history(messages)
60
+ response = chain.invoke({
61
+ "question": question,
62
+ "top_k": 100,
63
+ "messages": history.messages
64
+ })
65
+
66
+ history.add_user_message(question)
67
+ history.add_ai_message(response)
68
+ return response
69
+
70
+ except Exception as e:
71
+ return f"An error occurred: {str(e)}"
main.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain_utils import invoke_chain
4
+ from langchain_community.callbacks import get_openai_callback
5
+
6
+ def init_session_state():
7
+ if "messages" not in st.session_state:
8
+ st.session_state.messages = []
9
+ if "connected" not in st.session_state:
10
+ st.session_state.connected = False
11
+
12
+ def create_sidebar():
13
+ with st.sidebar:
14
+ st.title("Postgres Credentials")
15
+ st.subheader("Enter your Credentials & Connect")
16
+
17
+ # Database credentials
18
+ host = st.text_input("Host", value="localhost")
19
+ port = st.text_input("Port", value="5432")
20
+ user = st.text_input("User", value="postgres")
21
+ password = st.text_input("Password", type="password")
22
+ database = st.text_input("Database")
23
+
24
+ # OpenAI API key
25
+ api_key = st.text_input("OpenAI API Key", type="password")
26
+
27
+ # Connect button
28
+ if st.button("Connect", use_container_width=True):
29
+ try:
30
+ # Store credentials in session state
31
+ st.session_state.db_credentials = {
32
+ "host": host,
33
+ "port": port,
34
+ "user": user,
35
+ "password": password,
36
+ "database": database
37
+ }
38
+ st.session_state.api_key = api_key
39
+ st.session_state.connected = True
40
+ st.success("Successfully connected!")
41
+ except Exception as e:
42
+ st.error(f"Connection failed: {str(e)}")
43
+ st.session_state.connected = False
44
+
45
+ def main():
46
+ init_session_state()
47
+ create_sidebar()
48
+
49
+ st.title("Chat with Postgres DB")
50
+
51
+ if not st.session_state.connected:
52
+ st.info("Please enter your credentials in the sidebar and connect first.")
53
+ return
54
+
55
+ # Display the welcome message with the database icon using markdown
56
+ st.markdown("""
57
+ <div style='display: flex; align-items: center; gap: 10px;'>
58
+ <span style='font-size: 24px;'>🗄️</span>
59
+ <span>Hello! I'm a QualityKiosk's SQL assistant. Ask me anything about your database.</span>
60
+ </div>
61
+ """, unsafe_allow_html=True)
62
+
63
+ # Display chat messages
64
+ for message in st.session_state.messages:
65
+ with st.chat_message(message["role"]):
66
+ st.markdown(message["content"])
67
+
68
+ # Chat input
69
+ if prompt := st.chat_input("Type a message..."):
70
+ st.session_state.messages.append({"role": "user", "content": prompt})
71
+ with st.chat_message("user"):
72
+ st.markdown(prompt)
73
+
74
+ with st.chat_message("assistant"):
75
+ with st.spinner("Thinking..."):
76
+ with get_openai_callback() as cb:
77
+ response = invoke_chain(
78
+ prompt,
79
+ st.session_state.messages,
80
+ st.session_state.db_credentials,
81
+ st.session_state.api_key
82
+ )
83
+ print(f"OpenAI Stats: {cb}")
84
+ st.markdown(response)
85
+
86
+ st.session_state.messages.append({"role": "assistant", "content": response})
87
+
88
+ if __name__ == "__main__":
89
+ # Set dark theme and wide layout
90
+ st.set_page_config(
91
+ page_title="Chat with Postgres DB",
92
+ layout="wide",
93
+ initial_sidebar_state="expanded",
94
+ # Optional: Add a custom theme
95
+ menu_items={
96
+ 'Get Help': 'https://www.qualitykiosk.com',
97
+ 'About': "# Chat with Postgres DB\nA QualityKiosk's SQL Assistant"
98
+ }
99
+ )
100
+
101
+ # Add custom CSS for dark theme and styling
102
+ st.markdown("""
103
+ <style>
104
+ .stApp {
105
+ background-color: #1E1E1E;
106
+ color: #FFFFFF;
107
+ }
108
+ .stSidebar {
109
+ background-color: #262626;
110
+ }
111
+ .stButton>button {
112
+ background-color: #0E86D4;
113
+ color: white;
114
+ }
115
+ .stTextInput>div>div>input {
116
+ background-color: #333333;
117
+ color: white;
118
+ }
119
+ .stMarkdown {
120
+ color: white;
121
+ }
122
+ </style>
123
+ """, unsafe_allow_html=True)
124
+
125
+ main()
prompts.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from examples import get_example_selector
2
+ from langchain_core.prompts import (
3
+ ChatPromptTemplate,
4
+ MessagesPlaceholder,
5
+ FewShotChatMessagePromptTemplate,
6
+ PromptTemplate
7
+ )
8
+
9
+ example_prompt = ChatPromptTemplate.from_messages([
10
+ ("human", "{input}\nSQLQuery:"),
11
+ ("ai", "{query}"),
12
+ ])
13
+
14
+ def create_prompts(api_key):
15
+ example_selector = get_example_selector(api_key)
16
+
17
+ few_shot_prompt = FewShotChatMessagePromptTemplate(
18
+ example_prompt=example_prompt,
19
+ example_selector=example_selector,
20
+ input_variables=["input", "top_k"],
21
+ )
22
+
23
+ final_prompt = ChatPromptTemplate.from_messages([
24
+ ("system", "You are a Postgres SQL expert. Given an input question, create a syntactically correct Postgres SQL query to run. Unless otherwise specified.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
25
+ few_shot_prompt,
26
+ MessagesPlaceholder(variable_name="messages"),
27
+ ("human", "{input}"),
28
+ ])
29
+
30
+ answer_prompt = PromptTemplate.from_template(
31
+ """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
32
+
33
+ Question: {question}
34
+ SQL Query: {query}
35
+ SQL Result: {result}
36
+ Answer: """
37
+ )
38
+
39
+ return final_prompt, answer_prompt
table_details.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ from operator import itemgetter
4
+ from langchain.chains.openai_tools import create_extraction_chain_pydantic
5
+ from langchain_core.pydantic_v1 import BaseModel, Field
6
+ from langchain_openai import ChatOpenAI
7
+ from typing import List
8
+
9
+ class Table(BaseModel):
10
+ """Table in SQL database."""
11
+ name: str = Field(description="Name of table in SQL database.")
12
+
13
+ def get_tables(tables: List[Table]) -> List[str]:
14
+ return [table.name for table in tables]
15
+
16
+ @st.cache_data
17
+ def get_table_details():
18
+ try:
19
+ table_description = pd.read_excel("database_table_descriptions.xlsx")
20
+ table_details = ""
21
+ for index, row in table_description.iterrows():
22
+ table_details += f"Table Name:{row['Table']}\nTable Description:{row['Description']}\n\n"
23
+ return table_details
24
+ except Exception as e:
25
+ st.error(f"Error reading table descriptions: {str(e)}")
26
+ return ""
27
+
28
+ def create_table_chain(api_key):
29
+ table_details = get_table_details()
30
+ table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
31
+ The tables are:
32
+
33
+ {table_details}
34
+
35
+ Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
36
+
37
+ llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key)
38
+ return (
39
+ {"input": itemgetter("question")} |
40
+ create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) |
41
+ get_tables
42
+ )