Mhassanen commited on
Commit
113110f
·
verified ·
1 Parent(s): 2925a40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -17
app.py CHANGED
@@ -3,36 +3,36 @@ import requests
3
  import pandas as pd
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
 
 
6
 
7
- API_URL = "https://5415-34-127-95-74.ngrok-free.app/query"
 
8
 
9
- st.set_page_config(
10
- page_title="SQL Agent with Streamlit",
11
- page_icon=":bar_chart:",
12
- layout="wide"
13
- )
14
 
 
15
  with st.sidebar:
16
  st.write("## About Me")
17
  st.write("**Mahmoud Hassanen**")
18
  st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**")
19
 
20
- # Main content
21
  st.title("SQL Agent with Streamlit")
22
  st.header("Analyze Sales Data with Natural Language Queries")
23
 
24
- # Input for the question
25
  question = st.text_input("Enter your question:")
26
 
 
27
  if st.button("Generate SQL"):
28
  if question:
29
- # API to generate SQL
30
  response = requests.post(API_URL, json={"question": question})
31
 
32
  if response.status_code == 200:
33
  data = response.json()
34
  generated_sql = data["sql_query"]
35
- st.session_state.generated_sql = generated_sql # Store the generated SQL in session state
36
  st.write("### Generated SQL Query:")
37
  st.code(generated_sql, language="sql")
38
  else:
@@ -40,27 +40,90 @@ if st.button("Generate SQL"):
40
  else:
41
  st.warning("Please enter a question.")
42
 
43
- # Allow the user to modify the SQL query
44
  if "generated_sql" in st.session_state:
45
  modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200)
46
-
47
  if st.button("Execute Modified Query"):
48
  try:
49
- # Execute the modified SQL query
50
- response = requests.post(API_URL, json={"sql_query": modified_sql}) # Send the modified SQL to the API
51
  if response.status_code == 200:
52
  data = response.json()
53
  result_df = pd.read_json(data["result"], orient='records')
 
54
  st.write("### Query Results:")
55
- st.dataframe(result_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Visualize the data (if applicable)
 
 
 
 
 
58
  if 'region' in result_df.columns and 'total_sales' in result_df.columns:
59
  st.write("### Total Sales by Region")
60
  fig, ax = plt.subplots()
61
  sns.barplot(x='region', y='total_sales', data=result_df, ax=ax)
62
  st.pyplot(fig)
 
63
  else:
64
  st.error(f"Error executing SQL: {response.json().get('error')}")
65
  except Exception as e:
66
- st.error(f"Error executing SQL: {e}")
 
3
  import pandas as pd
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
+ from st_aggrid import AgGrid
7
+ from streamlit_echarts import st_echarts
8
 
9
+ # API endpoint for SQL Agent
10
+ API_URL = "https://b6ae-35-234-51-197.ngrok-free.app/query"
11
 
12
+ st.set_page_config(page_title="SQL Agent with Streamlit", page_icon=":bar_chart:", layout="wide")
 
 
 
 
13
 
14
+ # Sidebar Information
15
  with st.sidebar:
16
  st.write("## About Me")
17
  st.write("**Mahmoud Hassanen**")
18
  st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**")
19
 
20
+ # Main Page
21
  st.title("SQL Agent with Streamlit")
22
  st.header("Analyze Sales Data with Natural Language Queries")
23
 
24
+ # User Input
25
  question = st.text_input("Enter your question:")
26
 
27
+ # Generate SQL Query
28
  if st.button("Generate SQL"):
29
  if question:
 
30
  response = requests.post(API_URL, json={"question": question})
31
 
32
  if response.status_code == 200:
33
  data = response.json()
34
  generated_sql = data["sql_query"]
35
+ st.session_state.generated_sql = generated_sql
36
  st.write("### Generated SQL Query:")
37
  st.code(generated_sql, language="sql")
38
  else:
 
40
  else:
41
  st.warning("Please enter a question.")
42
 
43
+ # Modify and Execute SQL Query
44
  if "generated_sql" in st.session_state:
45
  modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200)
46
+
47
  if st.button("Execute Modified Query"):
48
  try:
49
+ response = requests.post(API_URL, json={"sql_query": modified_sql})
 
50
  if response.status_code == 200:
51
  data = response.json()
52
  result_df = pd.read_json(data["result"], orient='records')
53
+
54
  st.write("### Query Results:")
55
+ grid_response = AgGrid(result_df, height=250, fit_columns_on_grid_load=True)
56
+ selected_rows = grid_response['selected_rows']
57
+ df_selected = pd.DataFrame(selected_rows) if selected_rows else result_df
58
+
59
+ # 🚀 Auto-Generate Charts
60
+ def auto_generate_chart(df):
61
+ num_cols = df.select_dtypes(include=['number']).columns.tolist()
62
+ cat_cols = df.select_dtypes(include=['object']).columns.tolist()
63
+
64
+ if "date" in df.columns or "timestamp" in df.columns:
65
+ options = {
66
+ "title": {"text": "Time-Series Data"},
67
+ "tooltip": {},
68
+ "xAxis": {"type": "category", "data": df["date"].tolist()},
69
+ "yAxis": {"type": "value"},
70
+ "series": [{"name": "Value", "type": "line", "data": df[num_cols[0]].tolist()}],
71
+ }
72
+ elif len(cat_cols) > 0 and len(num_cols) == 1:
73
+ options = {
74
+ "title": {"text": "Bar Chart"},
75
+ "tooltip": {},
76
+ "xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
77
+ "yAxis": {"type": "value"},
78
+ "series": [{"name": num_cols[0], "type": "bar", "data": df[num_cols[0]].tolist()}],
79
+ }
80
+ elif len(num_cols) == 2:
81
+ options = {
82
+ "title": {"text": "Scatter Plot"},
83
+ "tooltip": {},
84
+ "xAxis": {"type": "value"},
85
+ "yAxis": {"type": "value"},
86
+ "series": [{"name": "Data", "type": "scatter", "data": df[num_cols].values.tolist()}],
87
+ }
88
+ elif len(cat_cols) > 0 and len(num_cols) > 1:
89
+ options = {
90
+ "title": {"text": "Stacked Bar Chart"},
91
+ "tooltip": {},
92
+ "xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()},
93
+ "yAxis": {"type": "value"},
94
+ "series": [{"name": col, "type": "bar", "stack": "stack", "data": df[col].tolist()} for col in num_cols],
95
+ }
96
+ elif len(num_cols) == 1:
97
+ options = {
98
+ "title": {"text": "Pie Chart"},
99
+ "tooltip": {},
100
+ "series": [{
101
+ "name": num_cols[0],
102
+ "type": "pie",
103
+ "radius": "50%",
104
+ "data": [{"name": cat, "value": val} for cat, val in zip(df[cat_cols[0]], df[num_cols[0]])],
105
+ }],
106
+ }
107
+ else:
108
+ return None
109
+
110
+ return options
111
+
112
+ chart_options = auto_generate_chart(df_selected)
113
 
114
+ if chart_options:
115
+ st_echarts(options=chart_options, height="400px")
116
+ else:
117
+ st.warning("No suitable chart found for this data.")
118
+
119
+ # Static Seaborn Visualization (Optional)
120
  if 'region' in result_df.columns and 'total_sales' in result_df.columns:
121
  st.write("### Total Sales by Region")
122
  fig, ax = plt.subplots()
123
  sns.barplot(x='region', y='total_sales', data=result_df, ax=ax)
124
  st.pyplot(fig)
125
+
126
  else:
127
  st.error(f"Error executing SQL: {response.json().get('error')}")
128
  except Exception as e:
129
+ st.error(f"Error executing SQL: {e}")