SQLchat / app.py
Invicto69's picture
Synced repo using 'sync_with_huggingface' Github Action
66b6353 verified
raw
history blame
5.46 kB
from typing import Generator
from utils import get_all_groq_model, validate_api_key, get_info, validate_uri
import streamlit as st
from groq import Groq
st.set_page_config(layout="wide")
# Initialize chat history and selected model
if "messages" not in st.session_state:
st.session_state.messages = []
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
st.markdown("# SQL Chat")
st.sidebar.title("Settings")
api_key = st.sidebar.text_input("Groq API Key", type="password")
models = []
@st.cache_data
def get_text_models(api_key):
models = get_all_groq_model(api_key=api_key)
vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
models = [model for model in models if model not in vision_audio]
return models
# validating api_key
if not validate_api_key(api_key):
st.sidebar.error("Enter valid API Key")
else:
st.sidebar.success("API Key is valid")
models = get_text_models(api_key)
model = st.sidebar.selectbox("Select Model", models)
if st.session_state.selected_model != model:
st.session_state.messages = []
st.session_state.selected_model = model
uri = st.sidebar.text_input("Enter SQL Database URI")
db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
markdown_info = """
**SQL Dialect**: {sql_dialect}\n
**Tables**: {tables}\n
**Tables Schema**:
```sql
{tables_schema}
```
"""
if not validate_uri(uri):
st.sidebar.error("Enter valid URI")
else:
st.sidebar.success("URI is valid")
db_info = get_info(uri)
markdown_info = markdown_info.format(**db_info)
with st.expander("SQL Database Info"):
st.markdown(markdown_info)
system_prompt = f"""
You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
correctness, efficiency, and security in your SQL queries.\
## SQL Database Info
{markdown_info}
---
## Query Generation Guidelines
1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
6. **Commenting**: Include comments in complex queries to explain logic when needed.
---
## Expected Output Format
The SQL query should be returned as a formatted code block:
```sql
-- Get all completed orders with user details
SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
FROM orders
JOIN users ON orders.user_id = users.id
WHERE orders.status = 'completed'
ORDER BY orders.created_at DESC;
```
If the user's request is ambiguous, ask clarifying questions before generating the query.
"""
if model is not None and validate_uri(uri):
client = Groq(
api_key=api_key,
)
# Display chat messages from history on app rerun
for message in st.session_state.messages:
avatar = 'πŸ€–' if message["role"] == "assistant" else 'πŸ‘¨β€πŸ’»'
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
"""Yield chat response content from the Groq API response."""
for chunk in chat_completion:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
if prompt := st.chat_input("Enter your prompt here..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar='πŸ‘¨β€πŸ’»'):
st.markdown(prompt)
# Fetch response from Groq API
try:
chat_completion = client.chat.completions.create(
model=model,
messages=[{
"role": "system",
"content": system_prompt
},
]+
[
{
"role": m["role"],
"content": m["content"]
}
for m in st.session_state.messages
],
max_tokens=3000,
stream=True
)
# Use the generator function with st.write_stream
with st.chat_message("SQL Assistant", avatar="πŸ€–"):
chat_responses_generator = generate_chat_responses(chat_completion)
full_response = st.write_stream(chat_responses_generator)
except Exception as e:
st.error(e, icon="🚨")
# Append the full response to session_state.messages
if isinstance(full_response, str):
st.session_state.messages.append(
{"role": "assistant", "content": full_response})
else:
# Handle the case where full_response is not a string
combined_response = "\n".join(str(item) for item in full_response)
st.session_state.messages.append(
{"role": "assistant", "content": combined_response})