Invicto69 commited on
Commit
66b6353
Β·
verified Β·
1 Parent(s): e3f67b5

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (3) hide show
  1. app.py +159 -0
  2. requirements.txt +5 -0
  3. utils.py +48 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator
2
+ from utils import get_all_groq_model, validate_api_key, get_info, validate_uri
3
+ import streamlit as st
4
+ from groq import Groq
5
+
6
+ st.set_page_config(layout="wide")
7
+
8
+ # Initialize chat history and selected model
9
+ if "messages" not in st.session_state:
10
+ st.session_state.messages = []
11
+
12
+ if "selected_model" not in st.session_state:
13
+ st.session_state.selected_model = None
14
+
15
+ st.markdown("# SQL Chat")
16
+
17
+ st.sidebar.title("Settings")
18
+ api_key = st.sidebar.text_input("Groq API Key", type="password")
19
+
20
+ models = []
21
+
22
+ @st.cache_data
23
+ def get_text_models(api_key):
24
+ models = get_all_groq_model(api_key=api_key)
25
+ vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
26
+ models = [model for model in models if model not in vision_audio]
27
+ return models
28
+
29
+ # validating api_key
30
+ if not validate_api_key(api_key):
31
+ st.sidebar.error("Enter valid API Key")
32
+ else:
33
+ st.sidebar.success("API Key is valid")
34
+ models = get_text_models(api_key)
35
+
36
+ model = st.sidebar.selectbox("Select Model", models)
37
+
38
+ if st.session_state.selected_model != model:
39
+ st.session_state.messages = []
40
+ st.session_state.selected_model = model
41
+
42
+
43
+ uri = st.sidebar.text_input("Enter SQL Database URI")
44
+ db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
45
+ markdown_info = """
46
+ **SQL Dialect**: {sql_dialect}\n
47
+ **Tables**: {tables}\n
48
+ **Tables Schema**:
49
+ ```sql
50
+ {tables_schema}
51
+ ```
52
+ """
53
+
54
+ if not validate_uri(uri):
55
+ st.sidebar.error("Enter valid URI")
56
+ else:
57
+ st.sidebar.success("URI is valid")
58
+ db_info = get_info(uri)
59
+ markdown_info = markdown_info.format(**db_info)
60
+ with st.expander("SQL Database Info"):
61
+ st.markdown(markdown_info)
62
+
63
+ system_prompt = f"""
64
+ You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
65
+ You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
66
+ correctness, efficiency, and security in your SQL queries.\
67
+
68
+ ## SQL Database Info
69
+ {markdown_info}
70
+
71
+ ---
72
+
73
+ ## Query Generation Guidelines
74
+ 1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
75
+ 2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
76
+ 3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
77
+ 4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
78
+ 5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
79
+ 6. **Commenting**: Include comments in complex queries to explain logic when needed.
80
+
81
+ ---
82
+
83
+ ## Expected Output Format
84
+
85
+ The SQL query should be returned as a formatted code block:
86
+
87
+ ```sql
88
+ -- Get all completed orders with user details
89
+ SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
90
+ FROM orders
91
+ JOIN users ON orders.user_id = users.id
92
+ WHERE orders.status = 'completed'
93
+ ORDER BY orders.created_at DESC;
94
+ ```
95
+
96
+ If the user's request is ambiguous, ask clarifying questions before generating the query.
97
+ """
98
+
99
+ if model is not None and validate_uri(uri):
100
+ client = Groq(
101
+ api_key=api_key,
102
+ )
103
+
104
+ # Display chat messages from history on app rerun
105
+ for message in st.session_state.messages:
106
+ avatar = 'πŸ€–' if message["role"] == "assistant" else 'πŸ‘¨β€πŸ’»'
107
+ with st.chat_message(message["role"], avatar=avatar):
108
+ st.markdown(message["content"])
109
+
110
+
111
+ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
112
+ """Yield chat response content from the Groq API response."""
113
+ for chunk in chat_completion:
114
+ if chunk.choices[0].delta.content:
115
+ yield chunk.choices[0].delta.content
116
+
117
+
118
+ if prompt := st.chat_input("Enter your prompt here..."):
119
+ st.session_state.messages.append({"role": "user", "content": prompt})
120
+
121
+ with st.chat_message("user", avatar='πŸ‘¨β€πŸ’»'):
122
+ st.markdown(prompt)
123
+
124
+ # Fetch response from Groq API
125
+ try:
126
+ chat_completion = client.chat.completions.create(
127
+ model=model,
128
+ messages=[{
129
+ "role": "system",
130
+ "content": system_prompt
131
+ },
132
+ ]+
133
+ [
134
+ {
135
+ "role": m["role"],
136
+ "content": m["content"]
137
+ }
138
+ for m in st.session_state.messages
139
+ ],
140
+ max_tokens=3000,
141
+ stream=True
142
+ )
143
+
144
+ # Use the generator function with st.write_stream
145
+ with st.chat_message("SQL Assistant", avatar="πŸ€–"):
146
+ chat_responses_generator = generate_chat_responses(chat_completion)
147
+ full_response = st.write_stream(chat_responses_generator)
148
+ except Exception as e:
149
+ st.error(e, icon="🚨")
150
+
151
+ # Append the full response to session_state.messages
152
+ if isinstance(full_response, str):
153
+ st.session_state.messages.append(
154
+ {"role": "assistant", "content": full_response})
155
+ else:
156
+ # Handle the case where full_response is not a string
157
+ combined_response = "\n".join(str(item) for item in full_response)
158
+ st.session_state.messages.append(
159
+ {"role": "assistant", "content": combined_response})
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ groq
2
+ langchain
3
+ langchain[groq]
4
+ streamlit
5
+ langchain_community
utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from langchain_community.utilities import SQLDatabase
3
+ from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool
4
+
5
+ def get_all_groq_model(api_key:str=None) -> list:
6
+ if api_key is None:
7
+ raise ValueError("API key is required")
8
+ url = "https://api.groq.com/openai/v1/models"
9
+
10
+ headers = {
11
+ "Authorization": f"Bearer {api_key}",
12
+ "Content-Type": "application/json"
13
+ }
14
+
15
+ response = requests.get(url, headers=headers)
16
+
17
+ data = response.json()['data']
18
+ model_ids = [model['id'] for model in data]
19
+
20
+ return model_ids
21
+
22
+ def validate_api_key(api_key:str) -> bool:
23
+ if len(api_key) == 0:
24
+ return False
25
+ try:
26
+ get_all_groq_model(api_key=api_key)
27
+ return True
28
+ except Exception as e:
29
+ return False
30
+
31
+ def validate_uri(uri:str) -> bool:
32
+ try:
33
+ SQLDatabase.from_uri(uri)
34
+ return True
35
+ except Exception as e:
36
+ return False
37
+
38
+ def get_info(uri:str) -> dict[str, str] | None:
39
+ db = SQLDatabase.from_uri(uri)
40
+ dialect = db.dialect
41
+ # List all the tables accessible to the user.
42
+ access_tables = ListSQLDatabaseTool(db=db).invoke("")
43
+ # List the table schemas of all the accessible tables.
44
+ tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables)
45
+ return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas}
46
+
47
+ if __name__ == "__main__":
48
+ print(get_all_groq_model())