File size: 5,459 Bytes
66b6353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import Generator
from utils import get_all_groq_model, validate_api_key, get_info, validate_uri
import streamlit as st
from groq import Groq

st.set_page_config(layout="wide")

# Initialize chat history and selected model
if "messages" not in st.session_state:
    st.session_state.messages = []

if "selected_model" not in st.session_state:
    st.session_state.selected_model = None

st.markdown("# SQL Chat")

st.sidebar.title("Settings")
api_key = st.sidebar.text_input("Groq API Key", type="password")

models = []

@st.cache_data
def get_text_models(api_key):
    models = get_all_groq_model(api_key=api_key)
    vision_audio = [model for model in models if 'vision' in model or 'whisper' in model]
    models = [model for model in models if model not in vision_audio]
    return models

# validating api_key
if not validate_api_key(api_key):
    st.sidebar.error("Enter valid API Key")
else:
    st.sidebar.success("API Key is valid")
    models = get_text_models(api_key)

model = st.sidebar.selectbox("Select Model", models)

if st.session_state.selected_model != model:
    st.session_state.messages = []
    st.session_state.selected_model = model


uri = st.sidebar.text_input("Enter SQL Database URI")
db_info = {'sql_dialect': '', 'tables': '', 'tables_schema': ''}
markdown_info = """
**SQL Dialect**: {sql_dialect}\n
**Tables**: {tables}\n
**Tables Schema**:
```sql
{tables_schema}
```
"""

if not validate_uri(uri):
    st.sidebar.error("Enter valid URI")
else:
    st.sidebar.success("URI is valid")
    db_info = get_info(uri)
    markdown_info = markdown_info.format(**db_info)
    with st.expander("SQL Database Info"):
        st.markdown(markdown_info)

system_prompt = f"""
You are an AI assistant specialized in generating optimized SQL queries based on user instructions. \
You have access to the database schema provided in a structured Markdown format. Use this schema to ensure \
correctness, efficiency, and security in your SQL queries.\

## SQL Database Info
{markdown_info}

---

## Query Generation Guidelines
1. **Ensure Query Validity**: Use only the tables and columns defined in the schema.
2. **Optimize Performance**: Prefer indexed columns for filtering, avoid `SELECT *` where specific columns suffice.
3. **Security Best Practices**: Always use parameterized queries or placeholders instead of direct user inputs.
4. **Context Awareness**: Understand the intent behind the query and generate the most relevant SQL statement.
5. **Formatting**: Return queries in a clean, well-structured format with appropriate indentation.
6. **Commenting**: Include comments in complex queries to explain logic when needed.

---

## Expected Output Format

The SQL query should be returned as a formatted code block:

```sql
-- Get all completed orders with user details
SELECT orders.id, users.name, users.email, orders.amount, orders.created_at
FROM orders
JOIN users ON orders.user_id = users.id
WHERE orders.status = 'completed'
ORDER BY orders.created_at DESC;
```

If the user's request is ambiguous, ask clarifying questions before generating the query.
"""

if model is not None and validate_uri(uri):
    client = Groq(
        api_key=api_key,
    )

    # Display chat messages from history on app rerun
    for message in st.session_state.messages:
        avatar = 'πŸ€–' if message["role"] == "assistant" else 'πŸ‘¨β€πŸ’»'
        with st.chat_message(message["role"], avatar=avatar):
            st.markdown(message["content"])


    def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
        """Yield chat response content from the Groq API response."""
        for chunk in chat_completion:
            if chunk.choices[0].delta.content:
                yield chunk.choices[0].delta.content


    if prompt := st.chat_input("Enter your prompt here..."):
        st.session_state.messages.append({"role": "user", "content": prompt})

        with st.chat_message("user", avatar='πŸ‘¨β€πŸ’»'):
            st.markdown(prompt)

        # Fetch response from Groq API
        try:
            chat_completion = client.chat.completions.create(
                model=model,
                messages=[{
                        "role": "system",
                        "content": system_prompt
                    },
                ]+
                [
                    {
                        "role": m["role"],
                        "content": m["content"]
                    }
                    for m in st.session_state.messages
                ],
                max_tokens=3000,
                stream=True
            )

            # Use the generator function with st.write_stream
            with st.chat_message("SQL Assistant", avatar="πŸ€–"):
                chat_responses_generator = generate_chat_responses(chat_completion)
                full_response = st.write_stream(chat_responses_generator)
        except Exception as e:
            st.error(e, icon="🚨")

        # Append the full response to session_state.messages
        if isinstance(full_response, str):
            st.session_state.messages.append(
                {"role": "assistant", "content": full_response})
        else:
            # Handle the case where full_response is not a string
            combined_response = "\n".join(str(item) for item in full_response)
            st.session_state.messages.append(
                {"role": "assistant", "content": combined_response})