|
from typing import Generator |
|
from utils import get_all_groq_model, validate_api_key, get_info, validate_uri |
|
import streamlit as st |
|
from groq import Groq |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "selected_model" not in st.session_state: |
|
st.session_state.selected_model = None |
|
|
|
st.markdown("# SQL Chat") |
|
|
|
st.sidebar.title("Settings") |
|
api_key = st.sidebar.text_input("Groq API Key", type="password") |
|
|
|
models = [] |
|
|
|
@st.cache_data |
|
def get_text_models(api_key): |
|
models = get_all_groq_model(api_key=api_key) |
|
vision_audio = [model for model in models if 'vision' in model or 'whisper' in model] |
|
models = [model for model in models if model not in vision_audio] |
|
return models |
|
|
|
|
|
if not validate_api_key(api_key): |
|
st.sidebar.error("Enter valid API Key") |
|
else: |
|
st.sidebar.success("API Key is valid") |
|
models = get_text_models(api_key) |
|
|
|
model = st.sidebar.selectbox("Select Model", models) |
|
|
|
if st.session_state.selected_model != model: |
|
st.session_state.messages = [] |
|
st.session_state.selected_model = model |
|
|
|
|
|
uri = st.sidebar.text_input("Enter SQL Database URI") |
|
db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''} |
|
markdown_info = """ |
|
**SQL Dialect**: {sql_dialect}\n |
|
**Tables**: {tables}\n |
|
**Tables Schema**: |
|
```sql |
|
{tables_schema} |
|
``` |
|
""" |
|
|
|
if not validate_uri(uri): |
|
st.sidebar.error("Enter valid URI") |
|
else: |
|
st.sidebar.success("URI is valid") |
|
db_info = get_info(uri) |
|
markdown_info = markdown_info.format(**db_info) |
|
with st.expander("SQL Database Info"): |
|
st.markdown(markdown_info) |
|
|
|
system_prompt = f""" |
|
You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \ |
|
You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \ |
|
correctness, efficiency, and security in your SQL queries.\ |
|
|
|
## SQL Database Info |
|
{markdown_info} |
|
|
|
--- |
|
|
|
## Query Generation Guidelines |
|
1. **Ensure Query Validity**: Use only the tables and columns defined in the schema. |
|
2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice. |
|
3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs. |
|
4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement. |
|
5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation. |
|
6. **Commenting**: Include comments in complex queries to explain logic when needed. |
|
|
|
--- |
|
|
|
## Expected Output Format |
|
|
|
The SQL query should be returned as a formatted code block: |
|
|
|
```sql |
|
-- Get all completed orders with user details |
|
SELECT orders.id, users.name, users.email, orders.amount, orders.created_at |
|
FROM orders |
|
JOIN users ON orders.user_id = users.id |
|
WHERE orders.status = 'completed' |
|
ORDER BY orders.created_at DESC; |
|
``` |
|
|
|
If the user's request is ambiguous, ask clarifying questions before generating the query. |
|
""" |
|
|
|
if model is not None and validate_uri(uri): |
|
client = Groq( |
|
api_key=api_key, |
|
) |
|
|
|
|
|
for message in st.session_state.messages: |
|
avatar = 'π€' if message["role"] == "assistant" else 'π¨βπ»' |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(message["content"]) |
|
|
|
|
|
def generate_chat_responses(chat_completion) -> Generator[str, None, None]: |
|
"""Yield chat response content from the Groq API response.""" |
|
for chunk in chat_completion: |
|
if chunk.choices[0].delta.content: |
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
if prompt := st.chat_input("Enter your prompt here..."): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
with st.chat_message("user", avatar='π¨βπ»'): |
|
st.markdown(prompt) |
|
|
|
|
|
try: |
|
chat_completion = client.chat.completions.create( |
|
model=model, |
|
messages=[{ |
|
"role": "system", |
|
"content": system_prompt |
|
}, |
|
]+ |
|
[ |
|
{ |
|
"role": m["role"], |
|
"content": m["content"] |
|
} |
|
for m in st.session_state.messages |
|
], |
|
max_tokens=3000, |
|
stream=True |
|
) |
|
|
|
|
|
with st.chat_message("SQL Assistant", avatar="π€"): |
|
chat_responses_generator = generate_chat_responses(chat_completion) |
|
full_response = st.write_stream(chat_responses_generator) |
|
except Exception as e: |
|
st.error(e, icon="π¨") |
|
|
|
|
|
if isinstance(full_response, str): |
|
st.session_state.messages.append( |
|
{"role": "assistant", "content": full_response}) |
|
else: |
|
|
|
combined_response = "\n".join(str(item) for item in full_response) |
|
st.session_state.messages.append( |
|
{"role": "assistant", "content": combined_response}) |