SQL-LLM-Agent / app.py
Mhassanen's picture
Update app.py
c43a503 verified
import streamlit as st
import requests
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from st_aggrid import AgGrid
from streamlit_echarts import st_echarts
# API endpoint for SQL Agent
API_URL = "https://1a67-35-234-51-197.ngrok-free.app/query"
st.set_page_config(page_title="SQL Agent with Streamlit", page_icon=":bar_chart:", layout="wide")
# Sidebar Information
with st.sidebar:
st.write("# SQL Agent with Streamlit")
st.write("""This web app allows you to interact with your data warehouse (DWH) using natural language queries.
Simply enter a question, and the app will generate and execute the corresponding SQL query.
You can also modify the generated SQL before execution, making it easy to analyze your sales data and view visualizations.
Hosted using Streamlit and integrated with Azure SQL Database and OpenAI's GPT-3.5 model, this tool bridges the gap between data analysis and natural language understanding.""")
st.write("## About Me")
st.write("**Mahmoud Hassanen**")
st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**")
# Main Page
st.title("SQL Agent with Streamlit")
st.header("Analyze Sales Data with Natural Language Queries")
# User Input
question = st.text_input("Enter your question:")
# Generate SQL Query
if st.button("Generate SQL"):
if question:
response = requests.post(API_URL, json={"question": question})
if response.status_code == 200:
data = response.json()
generated_sql = data["sql_query"]
st.session_state.generated_sql = generated_sql
st.write("### Generated SQL Query:")
st.code(generated_sql, language="sql")
else:
st.error(f"API Error: Status Code {response.status_code}")
else:
st.warning("Please enter a question.")
# Modify and Execute SQL Query
if "generated_sql" in st.session_state:
modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200)
if st.button("Execute Modified Query"):
try:
response = requests.post(API_URL, json={"sql_query": modified_sql})
if response.status_code == 200:
data = response.json()
result_df = pd.read_json(data["result"], orient='records')
st.write("### Query Results:")
grid_response = AgGrid(result_df, height=250, fit_columns_on_grid_load=True)
selected_rows = grid_response['selected_rows']
df_selected = pd.DataFrame(selected_rows) if selected_rows else result_df
# πŸš€ Auto-Generate Charts
def auto_generate_chart(df):
num_cols = df.select_dtypes(include=['number']).columns.tolist()
cat_cols = df.select_dtypes(include=['object']).columns.tolist()
if "date" in df.columns or "timestamp" in df.columns:
options = {
"title": {"text": "Time-Series Data"},
"tooltip": {},
"xAxis": {"type": "category", "data": df["date"].tolist()},
"yAxis": {"type": "value"},
"series": [{"name": "Value", "type": "line", "data": df[num_cols[0]].tolist()}],
}
elif len(cat_cols) > 0 and len(num_cols) == 1:
options = {
"title": {"text": "Bar Chart"},
"tooltip": {},
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
"yAxis": {"type": "value"},
"series": [{"name": num_cols[0], "type": "bar", "data": df[num_cols[0]].tolist()}],
}
elif len(num_cols) == 2:
options = {
"title": {"text": "Scatter Plot"},
"tooltip": {},
"xAxis": {"type": "value"},
"yAxis": {"type": "value"},
"series": [{"name": "Data", "type": "scatter", "data": df[num_cols].values.tolist()}],
}
elif len(cat_cols) > 0 and len(num_cols) > 1:
options = {
"title": {"text": "Stacked Bar Chart"},
"tooltip": {},
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
"yAxis": {"type": "value"},
"series": [{"name": col, "type": "bar", "stack": "stack", "data": df[col].tolist()} for col in num_cols],
}
elif len(num_cols) == 1:
options = {
"title": {"text": "Pie Chart"},
"tooltip": {},
"series": [{
"name": num_cols[0],
"type": "pie",
"radius": "50%",
"data": [{"name": cat, "value": val} for cat, val in zip(df[cat_cols[0]], df[num_cols[0]])],
}],
}
else:
return None
return options
chart_options = auto_generate_chart(df_selected)
if chart_options:
st_echarts(options=chart_options, height="400px")
else:
st.warning("No suitable chart found for this data.")
else:
st.error(f"Error executing SQL: {response.json().get('error')}")
except Exception as e:
st.error(f"Error executing SQL: {e}")