File size: 5,750 Bytes
eebea6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import re
import gradio as gr
from dotenv import load_dotenv
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import sqlite3

# Load environment variables from .env file
load_dotenv()

# Set up the database connection
db_path = os.path.join(os.path.dirname(__file__), "chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

# Function to get table info
def get_table_info(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Get all table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    table_info = {}
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name})")
        columns = cursor.fetchall()
        column_names = [column[1] for column in columns]
        table_info[table_name] = column_names
    
    conn.close()
    return table_info

# Get table info
table_info = get_table_info(db_path)

# Format table info for display
def format_table_info(table_info):
    info_str = f"Total number of tables: {len(table_info)}\n\n"
    info_str += "Tables and their columns:\n\n"
    for table, columns in table_info.items():
        info_str += f"{table}:\n"
        for column in columns:
            info_str += f"  - {column}\n"
        info_str += "\n"
    return info_str

# Initialize the language model
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

class Table(BaseModel):
    """Table in SQL database."""
    name: str = Field(description="Name of table in SQL database.")

# Create the table selection prompt
table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \

The tables are:



{table_names}



Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

table_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "{input}"),
])

llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = table_prompt | llm_with_tools | output_parser

# Function to get table names from the output
def get_table_names(output: List[Table]) -> List[str]:
    return [table.name for table in output]

# Create the SQL query chain
query_chain = create_sql_query_chain(llm, db)

# Combine table selection and query generation
full_chain = (
    RunnablePassthrough.assign(
        table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]}))
    )
    | query_chain
)

# Function to strip markdown formatting from SQL query
def strip_markdown(text):
    # Remove code block formatting
    text = re.sub(r'```sql\s*|\s*```', '', text)
    # Remove any leading/trailing whitespace
    return text.strip()

# Function to execute SQL query
def execute_query(query: str) -> str:
    try:
        # Strip markdown formatting before executing
        clean_query = strip_markdown(query)
        result = db.run(clean_query)
        return str(result)
    except Exception as e:
        return f"Error executing query: {str(e)}"

# Create the answer generation prompt
answer_prompt = ChatPromptTemplate.from_messages([
    ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    If there was an error in executing the SQL query, please explain the error and suggest a correction.

    Do not include any SQL code formatting or markdown in your response.

    

    Here is the database schema for reference:

    {table_info}"""),
    ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:")
])

# Assemble the final chain
chain = (
    RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x))
    .assign(result=lambda x: execute_query(x["query"]))
    | answer_prompt
    | llm
    | StrOutputParser()
)

# Function to process user input and generate response
def process_input(message, history, table_info_str):
    response = chain.invoke({"question": message, "table_info": table_info_str})
    return response

# Formatted table info
formatted_table_info = format_table_info(table_info)

# Create Gradio interface
iface = gr.ChatInterface(
    fn=process_input,
    title="SQL Q&A Chatbot for Chinook Database",
    description="Ask questions about the Chinook music store database and get answers!",
    examples=[
        ["Who are the top 5 artists with the most albums in the database?"],
        ["What is the total sales amount for each country?"],
        ["Which employee has made the highest total sales, and what is the amount?"],
        ["What are the top 10 longest tracks in the database, and who are their artists?"],
        ["How many customers are there in each country, and what is the total sales for each?"]
    ],
    additional_inputs=[
        gr.Textbox(
            label="Database Schema",
            value=formatted_table_info,
            lines=10,
            max_lines=20,
            interactive=False
        )
    ],
    theme="soft"
)

# Launch the interface
if __name__ == "__main__":
    iface.launch()