File size: 5,459 Bytes
66b6353 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from typing import Generator
from utils import get_all_groq_model, validate_api_key, get_info, validate_uri
import streamlit as st
from groq import Groq
st.set_page_config(layout="wide")
# Initialize chat history and selected model
if "messages" not in st.session_state:
st.session_state.messages = []
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
st.markdown("# SQL Chat")
st.sidebar.title("Settings")
api_key = st.sidebar.text_input("Groq API Key", type="password")
models = []
@st.cache_data
def get_text_models(api_key):
models = get_all_groq_model(api_key=api_key)
vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
models = [model for model in models if model not in vision_audio]
return models
# validating api_key
if not validate_api_key(api_key):
st.sidebar.error("Enter valid API Key")
else:
st.sidebar.success("API Key is valid")
models = get_text_models(api_key)
model = st.sidebar.selectbox("Select Model", models)
if st.session_state.selected_model != model:
st.session_state.messages = []
st.session_state.selected_model = model
uri = st.sidebar.text_input("Enter SQL Database URI")
db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
markdown_info = """
**SQL Dialect**: {sql_dialect}\n
**Tables**: {tables}\n
**Tables Schema**:
```sql
{tables_schema}
```
"""
if not validate_uri(uri):
st.sidebar.error("Enter valid URI")
else:
st.sidebar.success("URI is valid")
db_info = get_info(uri)
markdown_info = markdown_info.format(**db_info)
with st.expander("SQL Database Info"):
st.markdown(markdown_info)
system_prompt = f"""
You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
correctness, efficiency, and security in your SQL queries.\
## SQL Database Info
{markdown_info}
---
## Query Generation Guidelines
1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
6. **Commenting**: Include comments in complex queries to explain logic when needed.
---
## Expected Output Format
The SQL query should be returned as a formatted code block:
```sql
-- Get all completed orders with user details
SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
FROM orders
JOIN users ON orders.user_id = users.id
WHERE orders.status = 'completed'
ORDER BY orders.created_at DESC;
```
If the user's request is ambiguous, ask clarifying questions before generating the query.
"""
if model is not None and validate_uri(uri):
client = Groq(
api_key=api_key,
)
# Display chat messages from history on app rerun
for message in st.session_state.messages:
avatar = 'π€' if message["role"] == "assistant" else 'π¨βπ»'
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
"""Yield chat response content from the Groq API response."""
for chunk in chat_completion:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
if prompt := st.chat_input("Enter your prompt here..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar='π¨βπ»'):
st.markdown(prompt)
# Fetch response from Groq API
try:
chat_completion = client.chat.completions.create(
model=model,
messages=[{
"role": "system",
"content": system_prompt
},
]+
[
{
"role": m["role"],
"content": m["content"]
}
for m in st.session_state.messages
],
max_tokens=3000,
stream=True
)
# Use the generator function with st.write_stream
with st.chat_message("SQL Assistant", avatar="π€"):
chat_responses_generator = generate_chat_responses(chat_completion)
full_response = st.write_stream(chat_responses_generator)
except Exception as e:
st.error(e, icon="π¨")
# Append the full response to session_state.messages
if isinstance(full_response, str):
st.session_state.messages.append(
{"role": "assistant", "content": full_response})
else:
# Handle the case where full_response is not a string
combined_response = "\n".join(str(item) for item in full_response)
st.session_state.messages.append(
{"role": "assistant", "content": combined_response}) |