sql / src /streamlit_app.py
Garvitj's picture
Update src/streamlit_app.py
3b1b499 verified
# Fixed and Hugging Face Spaces-Compatible Code
import os
import streamlit as st
import pandas as pd
import subprocess
import json
import plotly.express as px
import re
import io
import requests
from sqlalchemy import create_engine, text, inspect
# --- Get HF Token ---
HF_TOKEN = os.environ.get("HF_TOKEN", "") # Safely get token, fallback if missing
# --- Helper: Call Mistral Model ---
def mistral_call(schema=None, question="no questions were asked", hf_token=HF_TOKEN, model_id="mistralai/Mistral-7B-Instruct-v0.3"):
api_url = f"https://api-inference.huggingface.co/models/{model_id}"
headers = {
"Authorization": f"Bearer {hf_token}",
"Content-Type": "application/json"
}
prompt = f"""You are a helpful assistant that translates natural language questions into SQL using a database schema.
### Schema:
{schema}
### Question:
{question}
"""
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 500,
"do_sample": True,
"temperature": 0.3,
}
}
response = requests.post(api_url, headers=headers, json=payload)
if response.status_code == 200:
try:
generated = response.json()[0]['generated_text']
return generated.split("### Question:")[-1].strip()
except Exception as e:
return f"Error parsing response: {e}"
else:
return f"API call failed: {response.status_code}\n{response.text}"
# --- Visualization Suggestion ---
def extract_json(text):
match = re.search(r"\{.*?\}", text, re.DOTALL)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return None
return None
def get_visualization_suggestion(data):
prompt = f"""
These are the dataset column names: {list(data.columns)}.
Suggest one visualization using the format:
{{"x": "column", "y": "column or list", "chart_type": "bar/line/scatter/pie"}}
"""
response = mistral_call(question=prompt)
return extract_json(response)
# --- Demo Data Generator ---
def generate_demo_data_csv(user_input, num_rows=10):
prompt = f"""
Generate a {num_rows}-row structured dataset in CSV format with quoted column headers and values:
"{user_input}"
"""
response = mistral_call(question=prompt)
csv_data = "\n".join([line.strip() for line in response.splitlines() if line.strip().startswith('"')])
if csv_data:
try:
df = pd.read_csv(io.StringIO(csv_data))
buffer = io.StringIO()
df.to_csv(buffer, index=False)
return "Demo data generated.", buffer
except Exception as e:
return f"CSV error: {e}", None
return "No CSV found.", None
# --- SQL Utilities ---
def extract_sql_code_blocks(text):
return re.findall(r"```sql\s+(.*?)```", text, re.DOTALL | re.IGNORECASE)
def remove_think_tags(text):
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
def classify_sql_task_prompt_engineered(user_input: str) -> str:
prompt = f"""
Classify into:
CREATE_TABLE, INSERT_INTO, SELECT, UPDATE, DELETE, ALTER_TABLE, INSERT_CSV_EXISTING, INSERT_CSV_NEW
Input: {user_input}
Only return the task.
"""
classification = mistral_call(question=prompt)
cleaned = remove_think_tags(classification).strip().upper()
for t in ["CREATE_TABLE", "INSERT_INTO", "SELECT", "UPDATE", "DELETE", "ALTER_TABLE", "INSERT_CSV_EXISTING", "INSERT_CSV_NEW"]:
if t in cleaned:
return t
return "UNKNOWN"
def handle_query(user_input, engine, task_type):
try:
inspector = inspect(engine)
tables = inspector.get_table_names()
prompt = f"Generate {task_type} SQL for: {user_input} using tables: {tables}"
sql_code = mistral_call(question=prompt)
sql_code = extract_sql_code_blocks(sql_code)
return execute_sql(sql_code, engine)
except Exception as e:
return "None", f"Error: {e}"
def execute_sql(sql_code, engine):
try:
if isinstance(sql_code, list):
sql_code = "\n".join(sql_code)
statements = [stmt.strip() for stmt in sql_code.split(';') if stmt.strip()]
with engine.connect() as conn:
for stmt in statements:
conn.execute(text(stmt + ";"))
conn.commit()
return sql_code, "βœ… SQL executed."
except Exception as e:
return "None", f"SQL error: {e}"
def insert_csv_existing(table_name, csv_file, engine):
try:
df = pd.read_csv(csv_file)
df.to_sql(table_name, engine, if_exists='append', index=False)
return f"βœ… CSV inserted into '{table_name}'."
except Exception as e:
return f"CSV insert error: {e}"
def insert_csv_new(table_name, csv_file, engine):
try:
df = pd.read_csv(csv_file)
df.to_sql(table_name, engine, if_exists='replace', index=False)
return f"βœ… CSV inserted into new table '{table_name}'."
except Exception as e:
return f"New CSV insert error: {e}"
# --- Streamlit App ---
st.set_page_config(page_title="AI Dashboard", layout="wide")
st.title("πŸ€– AI-Powered Multi-Feature Dashboard")
st.sidebar.title("Navigation")
option = st.sidebar.radio("Select Feature", ["πŸ“Š Data Visualization", "🧠 SQL Query Generator", "πŸ“„ Demo Data Generator", "🧠 Smart SQL Task Handler"])
if option == "πŸ“Š Data Visualization":
uploaded_file = st.file_uploader("Upload your CSV", type="csv")
if uploaded_file:
try:
content = uploaded_file.getvalue().decode("utf-8")
df = pd.read_csv(io.StringIO(content))
df.columns = df.columns.str.strip().str.replace(" ", "_")
st.write("CSV Preview")
st.dataframe(df.head())
st.write("Shape:", df.shape)
with st.spinner("Getting chart suggestion..."):
suggestion = get_visualization_suggestion(df)
st.write("Model suggestion:")
st.code(suggestion)
if suggestion:
x_col = suggestion.get("x", "").strip()
y_col = suggestion.get("y", [])
y_col = [y_col] if isinstance(y_col, str) else y_col
chart = suggestion.get("chart_type")
if x_col in df.columns and all(y in df.columns for y in y_col):
fig = None
if chart == "bar":
fig = px.bar(df, x=x_col, y=y_col)
elif chart == "line":
fig = px.line(df, x=x_col, y=y_col)
elif chart == "scatter":
fig = px.scatter(df, x=x_col, y=y_col)
elif chart == "pie" and len(y_col) == 1:
fig = px.pie(df, names=x_col, values=y_col[0])
if fig:
st.plotly_chart(fig)
else:
st.error("Unsupported chart type.")
else:
st.error("⚠️ Column suggestion doesn't match your CSV.")
else:
st.error("❌ No valid visualization suggestion returned.")
except Exception as e:
st.error(f"❌ Error reading CSV: {e}")
elif option == "🧠 SQL Query Generator":
user_input = st.text_area("Describe your SQL query in plain English:")
if st.button("Generate SQL"):
st.code(mistral_call(question=user_input))
elif option == "πŸ“„ Demo Data Generator":
user_input = st.text_area("Describe your dataset:")
num_rows = st.number_input("Rows", 1, 1000, 10)
if st.button("Generate Dataset"):
msg, buffer = generate_demo_data_csv(user_input, num_rows)
st.write(msg)
if buffer:
st.download_button("Download CSV", buffer.getvalue(), file_name="generated_data.csv", mime="text/csv")
elif option == "🧠 Smart SQL Task Handler":
st.sidebar.header("DB Settings")
db_type = "SQLite"
db_path = st.sidebar.text_input("SQLite File Path", value="smart_sql.db")
connection_url = f"sqlite:///{db_path}"
try:
engine = create_engine(connection_url)
with engine.connect(): pass
st.sidebar.success("Connected!")
except Exception as e:
st.sidebar.error(f"Connection failed: {e}")
st.stop()
user_input = st.text_area("Enter SQL task (or natural language):")
csv_file = st.file_uploader("Optional CSV Upload")
table_name = st.text_input("Table name (for CSV):")
if st.button("Run SQL Task"):
task = classify_sql_task_prompt_engineered(user_input)
st.markdown(f"**Detected Task:** `{task}`")
if task == "INSERT_CSV_EXISTING" and csv_file and table_name:
st.write(insert_csv_existing(table_name, csv_file, engine))
elif task == "INSERT_CSV_NEW" and csv_file and table_name:
st.write(insert_csv_new(table_name, csv_file, engine))
else:
sql_code, msg = handle_query(user_input, engine, task)
st.code(sql_code)
st.write(msg)