Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,36 +3,36 @@ import requests
|
|
3 |
import pandas as pd
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
|
|
|
|
6 |
|
7 |
-
|
|
|
8 |
|
9 |
-
st.set_page_config(
|
10 |
-
page_title="SQL Agent with Streamlit",
|
11 |
-
page_icon=":bar_chart:",
|
12 |
-
layout="wide"
|
13 |
-
)
|
14 |
|
|
|
15 |
with st.sidebar:
|
16 |
st.write("## About Me")
|
17 |
st.write("**Mahmoud Hassanen**")
|
18 |
st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**")
|
19 |
|
20 |
-
# Main
|
21 |
st.title("SQL Agent with Streamlit")
|
22 |
st.header("Analyze Sales Data with Natural Language Queries")
|
23 |
|
24 |
-
# Input
|
25 |
question = st.text_input("Enter your question:")
|
26 |
|
|
|
27 |
if st.button("Generate SQL"):
|
28 |
if question:
|
29 |
-
# API to generate SQL
|
30 |
response = requests.post(API_URL, json={"question": question})
|
31 |
|
32 |
if response.status_code == 200:
|
33 |
data = response.json()
|
34 |
generated_sql = data["sql_query"]
|
35 |
-
st.session_state.generated_sql = generated_sql
|
36 |
st.write("### Generated SQL Query:")
|
37 |
st.code(generated_sql, language="sql")
|
38 |
else:
|
@@ -40,27 +40,90 @@ if st.button("Generate SQL"):
|
|
40 |
else:
|
41 |
st.warning("Please enter a question.")
|
42 |
|
43 |
-
#
|
44 |
if "generated_sql" in st.session_state:
|
45 |
modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200)
|
46 |
-
|
47 |
if st.button("Execute Modified Query"):
|
48 |
try:
|
49 |
-
|
50 |
-
response = requests.post(API_URL, json={"sql_query": modified_sql}) # Send the modified SQL to the API
|
51 |
if response.status_code == 200:
|
52 |
data = response.json()
|
53 |
result_df = pd.read_json(data["result"], orient='records')
|
|
|
54 |
st.write("### Query Results:")
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
if 'region' in result_df.columns and 'total_sales' in result_df.columns:
|
59 |
st.write("### Total Sales by Region")
|
60 |
fig, ax = plt.subplots()
|
61 |
sns.barplot(x='region', y='total_sales', data=result_df, ax=ax)
|
62 |
st.pyplot(fig)
|
|
|
63 |
else:
|
64 |
st.error(f"Error executing SQL: {response.json().get('error')}")
|
65 |
except Exception as e:
|
66 |
-
st.error(f"Error executing SQL: {e}")
|
|
|
3 |
import pandas as pd
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
6 |
+
from st_aggrid import AgGrid
|
7 |
+
from streamlit_echarts import st_echarts
|
8 |
|
9 |
+
# API endpoint for SQL Agent
|
10 |
+
API_URL = "https://b6ae-35-234-51-197.ngrok-free.app/query"
|
11 |
|
12 |
+
st.set_page_config(page_title="SQL Agent with Streamlit", page_icon=":bar_chart:", layout="wide")
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
# Sidebar Information
|
15 |
with st.sidebar:
|
16 |
st.write("## About Me")
|
17 |
st.write("**Mahmoud Hassanen**")
|
18 |
st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**")
|
19 |
|
20 |
+
# Main Page
|
21 |
st.title("SQL Agent with Streamlit")
|
22 |
st.header("Analyze Sales Data with Natural Language Queries")
|
23 |
|
24 |
+
# User Input
|
25 |
question = st.text_input("Enter your question:")
|
26 |
|
27 |
+
# Generate SQL Query
|
28 |
if st.button("Generate SQL"):
|
29 |
if question:
|
|
|
30 |
response = requests.post(API_URL, json={"question": question})
|
31 |
|
32 |
if response.status_code == 200:
|
33 |
data = response.json()
|
34 |
generated_sql = data["sql_query"]
|
35 |
+
st.session_state.generated_sql = generated_sql
|
36 |
st.write("### Generated SQL Query:")
|
37 |
st.code(generated_sql, language="sql")
|
38 |
else:
|
|
|
40 |
else:
|
41 |
st.warning("Please enter a question.")
|
42 |
|
43 |
+
# Modify and Execute SQL Query
|
44 |
if "generated_sql" in st.session_state:
|
45 |
modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200)
|
46 |
+
|
47 |
if st.button("Execute Modified Query"):
|
48 |
try:
|
49 |
+
response = requests.post(API_URL, json={"sql_query": modified_sql})
|
|
|
50 |
if response.status_code == 200:
|
51 |
data = response.json()
|
52 |
result_df = pd.read_json(data["result"], orient='records')
|
53 |
+
|
54 |
st.write("### Query Results:")
|
55 |
+
grid_response = AgGrid(result_df, height=250, fit_columns_on_grid_load=True)
|
56 |
+
selected_rows = grid_response['selected_rows']
|
57 |
+
df_selected = pd.DataFrame(selected_rows) if selected_rows else result_df
|
58 |
+
|
59 |
+
# 🚀 Auto-Generate Charts
|
60 |
+
def auto_generate_chart(df):
|
61 |
+
num_cols = df.select_dtypes(include=['number']).columns.tolist()
|
62 |
+
cat_cols = df.select_dtypes(include=['object']).columns.tolist()
|
63 |
+
|
64 |
+
if "date" in df.columns or "timestamp" in df.columns:
|
65 |
+
options = {
|
66 |
+
"title": {"text": "Time-Series Data"},
|
67 |
+
"tooltip": {},
|
68 |
+
"xAxis": {"type": "category", "data": df["date"].tolist()},
|
69 |
+
"yAxis": {"type": "value"},
|
70 |
+
"series": [{"name": "Value", "type": "line", "data": df[num_cols[0]].tolist()}],
|
71 |
+
}
|
72 |
+
elif len(cat_cols) > 0 and len(num_cols) == 1:
|
73 |
+
options = {
|
74 |
+
"title": {"text": "Bar Chart"},
|
75 |
+
"tooltip": {},
|
76 |
+
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
|
77 |
+
"yAxis": {"type": "value"},
|
78 |
+
"series": [{"name": num_cols[0], "type": "bar", "data": df[num_cols[0]].tolist()}],
|
79 |
+
}
|
80 |
+
elif len(num_cols) == 2:
|
81 |
+
options = {
|
82 |
+
"title": {"text": "Scatter Plot"},
|
83 |
+
"tooltip": {},
|
84 |
+
"xAxis": {"type": "value"},
|
85 |
+
"yAxis": {"type": "value"},
|
86 |
+
"series": [{"name": "Data", "type": "scatter", "data": df[num_cols].values.tolist()}],
|
87 |
+
}
|
88 |
+
elif len(cat_cols) > 0 and len(num_cols) > 1:
|
89 |
+
options = {
|
90 |
+
"title": {"text": "Stacked Bar Chart"},
|
91 |
+
"tooltip": {},
|
92 |
+
"xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
|
93 |
+
"yAxis": {"type": "value"},
|
94 |
+
"series": [{"name": col, "type": "bar", "stack": "stack", "data": df[col].tolist()} for col in num_cols],
|
95 |
+
}
|
96 |
+
elif len(num_cols) == 1:
|
97 |
+
options = {
|
98 |
+
"title": {"text": "Pie Chart"},
|
99 |
+
"tooltip": {},
|
100 |
+
"series": [{
|
101 |
+
"name": num_cols[0],
|
102 |
+
"type": "pie",
|
103 |
+
"radius": "50%",
|
104 |
+
"data": [{"name": cat, "value": val} for cat, val in zip(df[cat_cols[0]], df[num_cols[0]])],
|
105 |
+
}],
|
106 |
+
}
|
107 |
+
else:
|
108 |
+
return None
|
109 |
+
|
110 |
+
return options
|
111 |
+
|
112 |
+
chart_options = auto_generate_chart(df_selected)
|
113 |
|
114 |
+
if chart_options:
|
115 |
+
st_echarts(options=chart_options, height="400px")
|
116 |
+
else:
|
117 |
+
st.warning("No suitable chart found for this data.")
|
118 |
+
|
119 |
+
# Static Seaborn Visualization (Optional)
|
120 |
if 'region' in result_df.columns and 'total_sales' in result_df.columns:
|
121 |
st.write("### Total Sales by Region")
|
122 |
fig, ax = plt.subplots()
|
123 |
sns.barplot(x='region', y='total_sales', data=result_df, ax=ax)
|
124 |
st.pyplot(fig)
|
125 |
+
|
126 |
else:
|
127 |
st.error(f"Error executing SQL: {response.json().get('error')}")
|
128 |
except Exception as e:
|
129 |
+
st.error(f"Error executing SQL: {e}")
|