SQL-LLM-Agent / app.py
Mhassanen's picture
Update app.py
113110f verified
raw
history blame
5.81 kB
import streamlit as st
import requests
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from st_aggrid import AgGrid
from streamlit_echarts import st_echarts
# API endpoint for SQL Agent
API_URL = "https://b6ae-35-234-51-197.ngrok-free.app/query"
st.set_page_config(page_title="SQL Agent with Streamlit", page_icon=":bar_chart:", layout="wide")
# Sidebar Information
with st.sidebar:
st.write("## About Me")
st.write("**Mahmoud Hassanen**")
st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**")
# Main Page
st.title("SQL Agent with Streamlit")
st.header("Analyze Sales Data with Natural Language Queries")
# User Input
question = st.text_input("Enter your question:")
# Generate SQL Query
if st.button("Generate SQL"):
if question:
response = requests.post(API_URL, json={"question": question})
if response.status_code == 200:
data = response.json()
generated_sql = data["sql_query"]
st.session_state.generated_sql = generated_sql
st.write("### Generated SQL Query:")
st.code(generated_sql, language="sql")
else:
st.error(f"API Error: Status Code {response.status_code}")
else:
st.warning("Please enter a question.")
# Modify and Execute SQL Query
if "generated_sql" in st.session_state:
modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200)
if st.button("Execute Modified Query"):
try:
response = requests.post(API_URL, json={"sql_query": modified_sql})
if response.status_code == 200:
data = response.json()
result_df = pd.read_json(data["result"], orient='records')
st.write("### Query Results:")
grid_response = AgGrid(result_df, height=250, fit_columns_on_grid_load=True)
selected_rows = grid_response['selected_rows']
df_selected = pd.DataFrame(selected_rows) if selected_rows else result_df
# πŸš€ Auto-Generate Charts
def auto_generate_chart(df):
num_cols = df.select_dtypes(include=['number']).columns.tolist()
cat_cols = df.select_dtypes(include=['object']).columns.tolist()
if "date" in df.columns or "timestamp" in df.columns:
options = {
"title": {"text": "Time-Series Data"},
"tooltip": {},
"xAxis": {"type": "category", "data": df["date"].tolist()},
"yAxis": {"type": "value"},
"series": [{"name": "Value", "type": "line", "data": df[num_cols[0]].tolist()}],
}
elif len(cat_cols) > 0 and len(num_cols) == 1:
options = {
"title": {"text": "Bar Chart"},
"tooltip": {},
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
"yAxis": {"type": "value"},
"series": [{"name": num_cols[0], "type": "bar", "data": df[num_cols[0]].tolist()}],
}
elif len(num_cols) == 2:
options = {
"title": {"text": "Scatter Plot"},
"tooltip": {},
"xAxis": {"type": "value"},
"yAxis": {"type": "value"},
"series": [{"name": "Data", "type": "scatter", "data": df[num_cols].values.tolist()}],
}
elif len(cat_cols) > 0 and len(num_cols) > 1:
options = {
"title": {"text": "Stacked Bar Chart"},
"tooltip": {},
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
"yAxis": {"type": "value"},
"series": [{"name": col, "type": "bar", "stack": "stack", "data": df[col].tolist()} for col in num_cols],
}
elif len(num_cols) == 1:
options = {
"title": {"text": "Pie Chart"},
"tooltip": {},
"series": [{
"name": num_cols[0],
"type": "pie",
"radius": "50%",
"data": [{"name": cat, "value": val} for cat, val in zip(df[cat_cols[0]], df[num_cols[0]])],
}],
}
else:
return None
return options
chart_options = auto_generate_chart(df_selected)
if chart_options:
st_echarts(options=chart_options, height="400px")
else:
st.warning("No suitable chart found for this data.")
# Static Seaborn Visualization (Optional)
if 'region' in result_df.columns and 'total_sales' in result_df.columns:
st.write("### Total Sales by Region")
fig, ax = plt.subplots()
sns.barplot(x='region', y='total_sales', data=result_df, ax=ax)
st.pyplot(fig)
else:
st.error(f"Error executing SQL: {response.json().get('error')}")
except Exception as e:
st.error(f"Error executing SQL: {e}")