|
|
|
|
|
import os |
|
import streamlit as st |
|
import pandas as pd |
|
import subprocess |
|
import json |
|
import plotly.express as px |
|
import re |
|
import io |
|
import requests |
|
from sqlalchemy import create_engine, text, inspect |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", "") |
|
|
|
|
|
def mistral_call(schema=None, question="no questions were asked", hf_token=HF_TOKEN, model_id="mistralai/Mistral-7B-Instruct-v0.3"): |
|
api_url = f"https://api-inference.huggingface.co/models/{model_id}" |
|
headers = { |
|
"Authorization": f"Bearer {hf_token}", |
|
"Content-Type": "application/json" |
|
} |
|
prompt = f"""You are a helpful assistant that translates natural language questions into SQL using a database schema. |
|
### Schema: |
|
{schema} |
|
### Question: |
|
{question} |
|
""" |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": 500, |
|
"do_sample": True, |
|
"temperature": 0.3, |
|
} |
|
} |
|
response = requests.post(api_url, headers=headers, json=payload) |
|
if response.status_code == 200: |
|
try: |
|
generated = response.json()[0]['generated_text'] |
|
return generated.split("### Question:")[-1].strip() |
|
except Exception as e: |
|
return f"Error parsing response: {e}" |
|
else: |
|
return f"API call failed: {response.status_code}\n{response.text}" |
|
|
|
|
|
def extract_json(text): |
|
match = re.search(r"\{.*?\}", text, re.DOTALL) |
|
if match: |
|
try: |
|
return json.loads(match.group(0)) |
|
except json.JSONDecodeError: |
|
return None |
|
return None |
|
|
|
def get_visualization_suggestion(data): |
|
prompt = f""" |
|
These are the dataset column names: {list(data.columns)}. |
|
Suggest one visualization using the format: |
|
{{"x": "column", "y": "column or list", "chart_type": "bar/line/scatter/pie"}} |
|
""" |
|
response = mistral_call(question=prompt) |
|
return extract_json(response) |
|
|
|
|
|
def generate_demo_data_csv(user_input, num_rows=10): |
|
prompt = f""" |
|
Generate a {num_rows}-row structured dataset in CSV format with quoted column headers and values: |
|
"{user_input}" |
|
""" |
|
response = mistral_call(question=prompt) |
|
csv_data = "\n".join([line.strip() for line in response.splitlines() if line.strip().startswith('"')]) |
|
if csv_data: |
|
try: |
|
df = pd.read_csv(io.StringIO(csv_data)) |
|
buffer = io.StringIO() |
|
df.to_csv(buffer, index=False) |
|
return "Demo data generated.", buffer |
|
except Exception as e: |
|
return f"CSV error: {e}", None |
|
return "No CSV found.", None |
|
|
|
|
|
def extract_sql_code_blocks(text): |
|
return re.findall(r"```sql\s+(.*?)```", text, re.DOTALL | re.IGNORECASE) |
|
|
|
def remove_think_tags(text): |
|
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE) |
|
|
|
def classify_sql_task_prompt_engineered(user_input: str) -> str: |
|
prompt = f""" |
|
Classify into: |
|
CREATE_TABLE, INSERT_INTO, SELECT, UPDATE, DELETE, ALTER_TABLE, INSERT_CSV_EXISTING, INSERT_CSV_NEW |
|
Input: {user_input} |
|
Only return the task. |
|
""" |
|
classification = mistral_call(question=prompt) |
|
cleaned = remove_think_tags(classification).strip().upper() |
|
for t in ["CREATE_TABLE", "INSERT_INTO", "SELECT", "UPDATE", "DELETE", "ALTER_TABLE", "INSERT_CSV_EXISTING", "INSERT_CSV_NEW"]: |
|
if t in cleaned: |
|
return t |
|
return "UNKNOWN" |
|
|
|
def handle_query(user_input, engine, task_type): |
|
try: |
|
inspector = inspect(engine) |
|
tables = inspector.get_table_names() |
|
prompt = f"Generate {task_type} SQL for: {user_input} using tables: {tables}" |
|
sql_code = mistral_call(question=prompt) |
|
sql_code = extract_sql_code_blocks(sql_code) |
|
return execute_sql(sql_code, engine) |
|
except Exception as e: |
|
return "None", f"Error: {e}" |
|
|
|
def execute_sql(sql_code, engine): |
|
try: |
|
if isinstance(sql_code, list): |
|
sql_code = "\n".join(sql_code) |
|
statements = [stmt.strip() for stmt in sql_code.split(';') if stmt.strip()] |
|
with engine.connect() as conn: |
|
for stmt in statements: |
|
conn.execute(text(stmt + ";")) |
|
conn.commit() |
|
return sql_code, "β
SQL executed." |
|
except Exception as e: |
|
return "None", f"SQL error: {e}" |
|
|
|
def insert_csv_existing(table_name, csv_file, engine): |
|
try: |
|
df = pd.read_csv(csv_file) |
|
df.to_sql(table_name, engine, if_exists='append', index=False) |
|
return f"β
CSV inserted into '{table_name}'." |
|
except Exception as e: |
|
return f"CSV insert error: {e}" |
|
|
|
def insert_csv_new(table_name, csv_file, engine): |
|
try: |
|
df = pd.read_csv(csv_file) |
|
df.to_sql(table_name, engine, if_exists='replace', index=False) |
|
return f"β
CSV inserted into new table '{table_name}'." |
|
except Exception as e: |
|
return f"New CSV insert error: {e}" |
|
|
|
|
|
st.set_page_config(page_title="AI Dashboard", layout="wide") |
|
st.title("π€ AI-Powered Multi-Feature Dashboard") |
|
|
|
st.sidebar.title("Navigation") |
|
option = st.sidebar.radio("Select Feature", ["π Data Visualization", "π§ SQL Query Generator", "π Demo Data Generator", "π§ Smart SQL Task Handler"]) |
|
|
|
if option == "π Data Visualization": |
|
uploaded_file = st.file_uploader("Upload your CSV", type="csv") |
|
if uploaded_file: |
|
try: |
|
content = uploaded_file.getvalue().decode("utf-8") |
|
df = pd.read_csv(io.StringIO(content)) |
|
df.columns = df.columns.str.strip().str.replace(" ", "_") |
|
st.write("CSV Preview") |
|
st.dataframe(df.head()) |
|
st.write("Shape:", df.shape) |
|
|
|
with st.spinner("Getting chart suggestion..."): |
|
suggestion = get_visualization_suggestion(df) |
|
|
|
st.write("Model suggestion:") |
|
st.code(suggestion) |
|
|
|
if suggestion: |
|
x_col = suggestion.get("x", "").strip() |
|
y_col = suggestion.get("y", []) |
|
y_col = [y_col] if isinstance(y_col, str) else y_col |
|
chart = suggestion.get("chart_type") |
|
if x_col in df.columns and all(y in df.columns for y in y_col): |
|
fig = None |
|
if chart == "bar": |
|
fig = px.bar(df, x=x_col, y=y_col) |
|
elif chart == "line": |
|
fig = px.line(df, x=x_col, y=y_col) |
|
elif chart == "scatter": |
|
fig = px.scatter(df, x=x_col, y=y_col) |
|
elif chart == "pie" and len(y_col) == 1: |
|
fig = px.pie(df, names=x_col, values=y_col[0]) |
|
if fig: |
|
st.plotly_chart(fig) |
|
else: |
|
st.error("Unsupported chart type.") |
|
else: |
|
st.error("β οΈ Column suggestion doesn't match your CSV.") |
|
else: |
|
st.error("β No valid visualization suggestion returned.") |
|
except Exception as e: |
|
st.error(f"β Error reading CSV: {e}") |
|
|
|
elif option == "π§ SQL Query Generator": |
|
user_input = st.text_area("Describe your SQL query in plain English:") |
|
if st.button("Generate SQL"): |
|
st.code(mistral_call(question=user_input)) |
|
|
|
elif option == "π Demo Data Generator": |
|
user_input = st.text_area("Describe your dataset:") |
|
num_rows = st.number_input("Rows", 1, 1000, 10) |
|
if st.button("Generate Dataset"): |
|
msg, buffer = generate_demo_data_csv(user_input, num_rows) |
|
st.write(msg) |
|
if buffer: |
|
st.download_button("Download CSV", buffer.getvalue(), file_name="generated_data.csv", mime="text/csv") |
|
|
|
elif option == "π§ Smart SQL Task Handler": |
|
st.sidebar.header("DB Settings") |
|
db_type = "SQLite" |
|
db_path = st.sidebar.text_input("SQLite File Path", value="smart_sql.db") |
|
connection_url = f"sqlite:///{db_path}" |
|
try: |
|
engine = create_engine(connection_url) |
|
with engine.connect(): pass |
|
st.sidebar.success("Connected!") |
|
except Exception as e: |
|
st.sidebar.error(f"Connection failed: {e}") |
|
st.stop() |
|
|
|
user_input = st.text_area("Enter SQL task (or natural language):") |
|
csv_file = st.file_uploader("Optional CSV Upload") |
|
table_name = st.text_input("Table name (for CSV):") |
|
if st.button("Run SQL Task"): |
|
task = classify_sql_task_prompt_engineered(user_input) |
|
st.markdown(f"**Detected Task:** `{task}`") |
|
if task == "INSERT_CSV_EXISTING" and csv_file and table_name: |
|
st.write(insert_csv_existing(table_name, csv_file, engine)) |
|
elif task == "INSERT_CSV_NEW" and csv_file and table_name: |
|
st.write(insert_csv_new(table_name, csv_file, engine)) |
|
else: |
|
sql_code, msg = handle_query(user_input, engine, task) |
|
st.code(sql_code) |
|
st.write(msg) |
|
|