import streamlit as st import requests import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from st_aggrid import AgGrid from streamlit_echarts import st_echarts # API endpoint for SQL Agent API_URL = "https://b6ae-35-234-51-197.ngrok-free.app/query" st.set_page_config(page_title="SQL Agent with Streamlit", page_icon=":bar_chart:", layout="wide") # Sidebar Information with st.sidebar: st.write("## About Me") st.write("**Mahmoud Hassanen**") st.write("**[LinkedIn Profile](https://www.linkedin.com/in/mahmoudhassanen99/)**") # Main Page st.title("SQL Agent with Streamlit") st.header("Analyze Sales Data with Natural Language Queries") # User Input question = st.text_input("Enter your question:") # Generate SQL Query if st.button("Generate SQL"): if question: response = requests.post(API_URL, json={"question": question}) if response.status_code == 200: data = response.json() generated_sql = data["sql_query"] st.session_state.generated_sql = generated_sql st.write("### Generated SQL Query:") st.code(generated_sql, language="sql") else: st.error(f"API Error: Status Code {response.status_code}") else: st.warning("Please enter a question.") # Modify and Execute SQL Query if "generated_sql" in st.session_state: modified_sql = st.text_area("Modify the SQL query (if needed):", st.session_state.generated_sql, height=200) if st.button("Execute Modified Query"): try: response = requests.post(API_URL, json={"sql_query": modified_sql}) if response.status_code == 200: data = response.json() result_df = pd.read_json(data["result"], orient='records') st.write("### Query Results:") grid_response = AgGrid(result_df, height=250, fit_columns_on_grid_load=True) selected_rows = grid_response['selected_rows'] df_selected = pd.DataFrame(selected_rows) if selected_rows else result_df # 🚀 Auto-Generate Charts def auto_generate_chart(df): num_cols = df.select_dtypes(include=['number']).columns.tolist() cat_cols = df.select_dtypes(include=['object']).columns.tolist() if "date" in df.columns or "timestamp" in df.columns: options = { "title": {"text": "Time-Series Data"}, "tooltip": {}, "xAxis": {"type": "category", "data": df["date"].tolist()}, "yAxis": {"type": "value"}, "series": [{"name": "Value", "type": "line", "data": df[num_cols[0]].tolist()}], } elif len(cat_cols) > 0 and len(num_cols) == 1: options = { "title": {"text": "Bar Chart"}, "tooltip": {}, "xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()}, "yAxis": {"type": "value"}, "series": [{"name": num_cols[0], "type": "bar", "data": df[num_cols[0]].tolist()}], } elif len(num_cols) == 2: options = { "title": {"text": "Scatter Plot"}, "tooltip": {}, "xAxis": {"type": "value"}, "yAxis": {"type": "value"}, "series": [{"name": "Data", "type": "scatter", "data": df[num_cols].values.tolist()}], } elif len(cat_cols) > 0 and len(num_cols) > 1: options = { "title": {"text": "Stacked Bar Chart"}, "tooltip": {}, "xAxis": {"type": "category", "data": df[cat_cols[0]].tolist()}, "yAxis": {"type": "value"}, "series": [{"name": col, "type": "bar", "stack": "stack", "data": df[col].tolist()} for col in num_cols], } elif len(num_cols) == 1: options = { "title": {"text": "Pie Chart"}, "tooltip": {}, "series": [{ "name": num_cols[0], "type": "pie", "radius": "50%", "data": [{"name": cat, "value": val} for cat, val in zip(df[cat_cols[0]], df[num_cols[0]])], }], } else: return None return options chart_options = auto_generate_chart(df_selected) if chart_options: st_echarts(options=chart_options, height="400px") else: st.warning("No suitable chart found for this data.") # Static Seaborn Visualization (Optional) if 'region' in result_df.columns and 'total_sales' in result_df.columns: st.write("### Total Sales by Region") fig, ax = plt.subplots() sns.barplot(x='region', y='total_sales', data=result_df, ax=ax) st.pyplot(fig) else: st.error(f"Error executing SQL: {response.json().get('error')}") except Exception as e: st.error(f"Error executing SQL: {e}")