Spaces:
Running
Running
import streamlit as st | |
import requests | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from st_aggrid import AgGrid | |
from streamlit_echarts import st_echarts | |
# API endpoint for SQL Agent | |
API_URL = "https://1a67-35-234-51-197.ngrok-free.app/query" | |
st.set_page_config(page_title="SQL Agent with Streamlit", page_icon=":bar_chart:", layout="wide") | |
# Sidebar Information | |
with st.sidebar: | |
st.write("# SQL Agent with Streamlit") | |
st.write("""This web app allows you to interact with your data warehouse (DWH) using natural language queries. | |
Simply enter a question, and the app will generate and execute the corresponding SQL query. | |
You can also modify the generated SQL before execution, making it easy to analyze your sales data and view visualizations. | |
Hosted using Streamlit and integrated with Azure SQL Database and OpenAI's GPT-3.5 model, this tool bridges the gap between data analysis and natural language understanding.""") | |
st.write("## About Me") | |
st.write("**Mahmoud Hassanen**") | |
st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**") | |
# Main Page | |
st.title("SQL Agent with Streamlit") | |
st.header("Analyze Sales Data with Natural Language Queries") | |
# User Input | |
question = st.text_input("Enter your question:") | |
# Generate SQL Query | |
if st.button("Generate SQL"): | |
if question: | |
response = requests.post(API_URL, json={"question": question}) | |
if response.status_code == 200: | |
data = response.json() | |
generated_sql = data["sql_query"] | |
st.session_state.generated_sql = generated_sql | |
st.write("### Generated SQL Query:") | |
st.code(generated_sql, language="sql") | |
else: | |
st.error(f"API Error: Status Code {response.status_code}") | |
else: | |
st.warning("Please enter a question.") | |
# Modify and Execute SQL Query | |
if "generated_sql" in st.session_state: | |
modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200) | |
if st.button("Execute Modified Query"): | |
try: | |
response = requests.post(API_URL, json={"sql_query": modified_sql}) | |
if response.status_code == 200: | |
data = response.json() | |
result_df = pd.read_json(data["result"], orient='records') | |
st.write("### Query Results:") | |
grid_response = AgGrid(result_df, height=250, fit_columns_on_grid_load=True) | |
selected_rows = grid_response['selected_rows'] | |
df_selected = pd.DataFrame(selected_rows) if selected_rows else result_df | |
# π Auto-Generate Charts | |
def auto_generate_chart(df): | |
num_cols = df.select_dtypes(include=['number']).columns.tolist() | |
cat_cols = df.select_dtypes(include=['object']).columns.tolist() | |
if "date" in df.columns or "timestamp" in df.columns: | |
options = { | |
"title": {"text": "Time-Series Data"}, | |
"tooltip": {}, | |
"xAxis": {"type": "category", "data": df["date"].tolist()}, | |
"yAxis": {"type": "value"}, | |
"series": [{"name": "Value", "type": "line", "data": df[num_cols[0]].tolist()}], | |
} | |
elif len(cat_cols) > 0 and len(num_cols) == 1: | |
options = { | |
"title": {"text": "Bar Chart"}, | |
"tooltip": {}, | |
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()}, | |
"yAxis": {"type": "value"}, | |
"series": [{"name": num_cols[0], "type": "bar", "data": df[num_cols[0]].tolist()}], | |
} | |
elif len(num_cols) == 2: | |
options = { | |
"title": {"text": "Scatter Plot"}, | |
"tooltip": {}, | |
"xAxis": {"type": "value"}, | |
"yAxis": {"type": "value"}, | |
"series": [{"name": "Data", "type": "scatter", "data": df[num_cols].values.tolist()}], | |
} | |
elif len(cat_cols) > 0 and len(num_cols) > 1: | |
options = { | |
"title": {"text": "Stacked Bar Chart"}, | |
"tooltip": {}, | |
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()}, | |
"yAxis": {"type": "value"}, | |
"series": [{"name": col, "type": "bar", "stack": "stack", "data": df[col].tolist()} for col in num_cols], | |
} | |
elif len(num_cols) == 1: | |
options = { | |
"title": {"text": "Pie Chart"}, | |
"tooltip": {}, | |
"series": [{ | |
"name": num_cols[0], | |
"type": "pie", | |
"radius": "50%", | |
"data": [{"name": cat, "value": val} for cat, val in zip(df[cat_cols[0]], df[num_cols[0]])], | |
}], | |
} | |
else: | |
return None | |
return options | |
chart_options = auto_generate_chart(df_selected) | |
if chart_options: | |
st_echarts(options=chart_options, height="400px") | |
else: | |
st.warning("No suitable chart found for this data.") | |
else: | |
st.error(f"Error executing SQL: {response.json().get('error')}") | |
except Exception as e: | |
st.error(f"Error executing SQL: {e}") | |