Invicto69 commited on
Commit
ddb520c
Β·
verified Β·
1 Parent(s): 3b021a8

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (3) hide show
  1. app.py +13 -6
  2. utils.py +101 -1
  3. var.py +1 -1
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Generator
2
- from utils import validate_api_key, get_info, validate_uri, extract_code_blocks
3
  from langchain_community.utilities import SQLDatabase
4
  from var import system_prompt, markdown_info, query_output, groq_models
5
  import streamlit as st
@@ -10,6 +10,7 @@ st.set_page_config(layout="wide")
10
  # Initialize chat history and selected model
11
  if "messages" not in st.session_state:
12
  st.session_state.messages = []
 
13
 
14
  if "selected_model" not in st.session_state:
15
  st.session_state.selected_model = None
@@ -29,6 +30,7 @@ else:
29
 
30
  if st.session_state.selected_model != model:
31
  st.session_state.messages = []
 
32
  st.session_state.selected_model = model
33
 
34
  uri = st.sidebar.text_input("Enter SQL Database URI")
@@ -37,7 +39,7 @@ if not validate_uri(uri):
37
  st.sidebar.error("Enter valid URI")
38
  else:
39
  st.sidebar.success("URI is valid")
40
- db_info = get_info(uri)
41
  markdown_info = markdown_info.format(**db_info)
42
  with st.expander("SQL Database Info"):
43
  st.markdown(markdown_info)
@@ -53,9 +55,12 @@ if validate_api_key(api_key) and validate_uri(uri):
53
  avatar = {"user": 'πŸ‘¨β€πŸ’»', "assistant": 'πŸ€–', "executor": 'πŸ›’'}
54
 
55
  # Display chat messages from history on app rerun
56
- for message in st.session_state.messages:
57
  with st.chat_message(message["role"], avatar=avatar[message["role"]]):
58
  st.markdown(message["content"])
 
 
 
59
 
60
 
61
  def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
@@ -109,17 +114,19 @@ if validate_api_key(api_key) and validate_uri(uri):
109
  else:
110
  query_output_truncated = query_output.format(result=result)
111
 
 
 
112
  # Append the llm response to session_state.messages
113
  if isinstance(llm_response, str):
114
  st.session_state.messages.append(
115
- {"role": "assistant", "content": llm_response + query_output_truncated})
116
  else:
117
  # Handle the case where llm_response is not a string
118
  combined_response = "\n".join(str(item) for item in llm_response)
119
  st.session_state.messages.append(
120
- {"role": "assistant", "content": combined_response + query_output_truncated})
121
 
122
- st.sidebar.button("Clear Chat History", on_click=lambda: st.session_state.messages.clear())
123
 
124
  else:
125
  st.error("Please enter valid Groq API Key and URI in the sidebar.")
 
1
  from typing import Generator
2
+ from utils import validate_api_key, get_info, validate_uri, extract_code_blocks, get_info_sqlalchemy
3
  from langchain_community.utilities import SQLDatabase
4
  from var import system_prompt, markdown_info, query_output, groq_models
5
  import streamlit as st
 
10
  # Initialize chat history and selected model
11
  if "messages" not in st.session_state:
12
  st.session_state.messages = []
13
+ st.session_state.sql_result = []
14
 
15
  if "selected_model" not in st.session_state:
16
  st.session_state.selected_model = None
 
30
 
31
  if st.session_state.selected_model != model:
32
  st.session_state.messages = []
33
+ st.session_state.sql_result = []
34
  st.session_state.selected_model = model
35
 
36
  uri = st.sidebar.text_input("Enter SQL Database URI")
 
39
  st.sidebar.error("Enter valid URI")
40
  else:
41
  st.sidebar.success("URI is valid")
42
+ db_info = get_info_sqlalchemy(uri)
43
  markdown_info = markdown_info.format(**db_info)
44
  with st.expander("SQL Database Info"):
45
  st.markdown(markdown_info)
 
55
  avatar = {"user": 'πŸ‘¨β€πŸ’»', "assistant": 'πŸ€–', "executor": 'πŸ›’'}
56
 
57
  # Display chat messages from history on app rerun
58
+ for i, message in enumerate(st.session_state.messages):
59
  with st.chat_message(message["role"], avatar=avatar[message["role"]]):
60
  st.markdown(message["content"])
61
+ if (i+1)%2 == 0:
62
+ with st.chat_message("SQL Executor", avatar=avatar["executor"]):
63
+ st.markdown(st.session_state.sql_result[i//2])
64
 
65
 
66
  def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
 
114
  else:
115
  query_output_truncated = query_output.format(result=result)
116
 
117
+ st.session_state.sql_result.append(query_output_truncated)
118
+
119
  # Append the llm response to session_state.messages
120
  if isinstance(llm_response, str):
121
  st.session_state.messages.append(
122
+ {"role": "assistant", "content": llm_response})
123
  else:
124
  # Handle the case where llm_response is not a string
125
  combined_response = "\n".join(str(item) for item in llm_response)
126
  st.session_state.messages.append(
127
+ {"role": "assistant", "content": combined_response})
128
 
129
+ st.sidebar.button("Clear Chat History", on_click=lambda: st.session_state.messages.clear() and st.session_state.sql_result.clear())
130
 
131
  else:
132
  st.error("Please enter valid Groq API Key and URI in the sidebar.")
utils.py CHANGED
@@ -1,9 +1,21 @@
1
  import requests
2
  from langchain_community.utilities import SQLDatabase
3
  from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
 
 
 
 
 
 
 
 
 
 
 
4
  import re
5
 
6
  def get_all_groq_model(api_key:str=None) -> list:
 
7
  if api_key is None:
8
  raise ValueError("API key is required")
9
  url = "https://api.groq.com/openai/v1/models"
@@ -21,6 +33,7 @@ def get_all_groq_model(api_key:str=None) -> list:
21
  return model_ids
22
 
23
  def validate_api_key(api_key:str) -> bool:
 
24
  if len(api_key) == 0:
25
  return False
26
  try:
@@ -30,6 +43,7 @@ def validate_api_key(api_key:str) -> bool:
30
  return False
31
 
32
  def validate_uri(uri:str) -> bool:
 
33
  try:
34
  SQLDatabase.from_uri(uri)
35
  return True
@@ -37,6 +51,7 @@ def validate_uri(uri:str) -> bool:
37
  return False
38
 
39
  def get_info(uri:str) -> dict[str, str] | None:
 
40
  db = SQLDatabase.from_uri(uri)
41
  dialect = db.dialect
42
  # List all the tables accessible to the user.
@@ -45,10 +60,95 @@ def get_info(uri:str) -> dict[str, str] | None:
45
  tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
46
  return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def extract_code_blocks(text):
49
  pattern = r"```(?:\w+)?\n(.*?)\n```"
50
  matches = re.findall(pattern, text, re.DOTALL)
51
  return matches
52
 
53
  if __name__ == "__main__":
54
- pass
 
 
 
 
 
 
1
  import requests
2
  from langchain_community.utilities import SQLDatabase
3
  from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
4
+ from sqlalchemy import (
5
+ create_engine,
6
+ MetaData,
7
+ inspect,
8
+ Table,
9
+ select,
10
+ distinct
11
+ )
12
+ from sqlalchemy.schema import CreateTable
13
+ from sqlalchemy.exc import ProgrammingError
14
+ from sqlalchemy.engine import Engine
15
  import re
16
 
17
  def get_all_groq_model(api_key:str=None) -> list:
18
+ """Uses Groq API to fetch all the available models."""
19
  if api_key is None:
20
  raise ValueError("API key is required")
21
  url = "https://api.groq.com/openai/v1/models"
 
33
  return model_ids
34
 
35
  def validate_api_key(api_key:str) -> bool:
36
+ """Validates the Groq API key using the get_all_groq_model function."""
37
  if len(api_key) == 0:
38
  return False
39
  try:
 
43
  return False
44
 
45
  def validate_uri(uri:str) -> bool:
46
+ """Validates the SQL Database URI using the SQLDatabase.from_uri function."""
47
  try:
48
  SQLDatabase.from_uri(uri)
49
  return True
 
51
  return False
52
 
53
  def get_info(uri:str) -> dict[str, str] | None:
54
+ """Gets the dialect name, accessible tables and table schemas using the SQLDatabase toolkit"""
55
  db = SQLDatabase.from_uri(uri)
56
  dialect = db.dialect
57
  # List all the tables accessible to the user.
 
60
  tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
61
  return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
62
 
63
+ def get_sample_rows(engine:Engine, table:Table, row_count: int = 3) -> str:
64
+ """Gets the sample rows of a table using the SQLAlchemy engine"""
65
+ # build the select command
66
+ command = select(table).limit(row_count)
67
+
68
+ # save the columns in string format
69
+ columns_str = "\t".join([col.name for col in table.columns])
70
+
71
+ try:
72
+ # get the sample rows
73
+ with engine.connect() as connection:
74
+ sample_rows_result = connection.execute(command) # type: ignore
75
+ # shorten values in the sample rows
76
+ sample_rows = list(
77
+ map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
78
+ )
79
+
80
+ # save the sample rows in string format
81
+ sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
82
+
83
+ # in some dialects when there are no rows in the table a
84
+ # 'ProgrammingError' is returned
85
+ except ProgrammingError:
86
+ sample_rows_str = ""
87
+
88
+ return (
89
+ f"{row_count} rows from {table.name} table:\n"
90
+ f"{columns_str}\n"
91
+ f"{sample_rows_str}"
92
+ )
93
+
94
+ def get_unique_values(engine:Engine, table:Table) -> str:
95
+ """Gets the unique values of each column in a table using the SQLAlchemy engine"""
96
+ unique_values = {}
97
+ for column in table.c:
98
+ command = select(distinct(column))
99
+
100
+ try:
101
+ # get the sample rows
102
+ with engine.connect() as connection:
103
+ result = connection.execute(command) # type: ignore
104
+ # shorten values in the sample rows
105
+ unique_values[column.name] = [str(u) for u in result]
106
+
107
+ # save the sample rows in string format
108
+ # sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
109
+ # in some dialects when there are no rows in the table a
110
+ # 'ProgrammingError' is returned
111
+ except ProgrammingError:
112
+ sample_rows_str = ""
113
+
114
+ output_str = f"Unique values of each column in {table.name}: \n"
115
+ for column, values in unique_values.items():
116
+ output_str += f"{column} has {len(values)} unique values: {" ".join(values[:20])}"
117
+ if len(values) > 20:
118
+ output_str += ", ...."
119
+ output_str += "\n"
120
+
121
+ return output_str
122
+
123
+ def get_info_sqlalchemy(uri:str) -> dict[str, str] | None:
124
+ """Gets the dialect name, accessible tables and table schemas using the SQLAlchemy engine"""
125
+ engine = create_engine(uri)
126
+ # Get dialect name using inspector
127
+ inspector = inspect(engine)
128
+ dialect = inspector.dialect.name
129
+ # Metadata for tables and columns
130
+ m = MetaData()
131
+ m.reflect(engine)
132
+
133
+ tables = {}
134
+ for table in m.tables.values():
135
+ tables[table.name] = str(CreateTable(table).compile(engine)).rstrip()
136
+ tables[table.name] += "\n\n/*"
137
+ tables[table.name] += "\n" + get_sample_rows(engine, table)+"\n"
138
+ tables[table.name] += "\n" + get_unique_values(engine, table)+"\n"
139
+ tables[table.name] += "*/"
140
+
141
+ return {'sql_dialect': dialect, 'tables': ", ".join(tables.keys()), 'tables_schema': "\n\n".join(tables.values())}
142
+
143
  def extract_code_blocks(text):
144
  pattern = r"```(?:\w+)?\n(.*?)\n```"
145
  matches = re.findall(pattern, text, re.DOTALL)
146
  return matches
147
 
148
  if __name__ == "__main__":
149
+ from dotenv import load_dotenv
150
+ import os
151
+ load_dotenv()
152
+
153
+ uri = os.getenv("POSTGRES_URI")
154
+ print(get_info_sqlalchemy(uri))
var.py CHANGED
@@ -30,7 +30,7 @@ correctness, efficiency, and security in your SQL queries.\
30
  4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
31
  5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
32
  6. **Commenting**: Include comments in complex queries to explain logic when needed.
33
- 7. **Result**: Don't return the result of the query, just the SQL query.
34
  8. **Optimal**: Try to generate query which is optimal and not brute force.
35
  9. **Single query**: Generate a best single SQL query for the user input.'
36
  10. **Comment**: Include comments in the query to explain the logic behind it.
 
30
  4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
31
  5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
32
  6. **Commenting**: Include comments in complex queries to explain logic when needed.
33
+ 7. **Result**: Don't return the result of the query, return only the SQL query.
34
  8. **Optimal**: Try to generate query which is optimal and not brute force.
35
  9. **Single query**: Generate a best single SQL query for the user input.'
36
  10. **Comment**: Include comments in the query to explain the logic behind it.