Garvitj commited on
Commit
4d86911
Β·
verified Β·
1 Parent(s): 5ee5842

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +216 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,223 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
6
  """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # Fixed and Hugging Face Spaces-Compatible Code
2
+
3
+ import os
4
  import streamlit as st
5
+ import pandas as pd
6
+ import subprocess
7
+ import json
8
+ import plotly.express as px
9
+ import re
10
+ import io
11
+ import requests
12
+ from sqlalchemy import create_engine, text, inspect
13
+
14
+ # --- Get HF Token ---
15
+ HF_TOKEN = os.environ["HF_TOKEN"] # will raise KeyError if not set
16
+
17
+ # --- Helper: Call Mistral Model ---
18
+ def mistral_call(schema=None, question="no questions were asked", hf_token=HF_TOKEN, model_id="mistralai/Mistral-7B-Instruct-v0.3"):
19
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
20
+ headers = {
21
+ "Authorization": f"Bearer {hf_token}",
22
+ "Content-Type": "application/json"
23
+ }
24
+ prompt = f"""You are a helpful assistant that translates natural language questions into SQL using a database schema.
25
+ ### Schema:
26
+ {schema}
27
+ ### Question:
28
+ {question}
29
+ """
30
+ payload = {
31
+ "inputs": prompt,
32
+ "parameters": {
33
+ "max_new_tokens": 500,
34
+ "do_sample": True,
35
+ "temperature": 0.3,
36
+ }
37
+ }
38
+ response = requests.post(api_url, headers=headers, json=payload)
39
+ if response.status_code == 200:
40
+ try:
41
+ generated = response.json()[0]['generated_text']
42
+ return generated.split("### Question:")[-1].strip()
43
+ except Exception as e:
44
+ return f"Error parsing response: {e}"
45
+ else:
46
+ return f"API call failed: {response.status_code}\n{response.text}"
47
+
48
+ # --- Visualization Suggestion ---
49
+ def extract_json(text):
50
+ match = re.search(r"\{.*?\}", text, re.DOTALL)
51
+ if match:
52
+ try:
53
+ return json.loads(match.group(0))
54
+ except json.JSONDecodeError:
55
+ return None
56
+ return None
57
 
58
+ def get_visualization_suggestion(data):
59
+ prompt = f"""
60
+ These are the dataset column names: {list(data.columns)}.
61
+ Suggest one visualization using the format:
62
+ {{"x": "column", "y": "column or list", "chart_type": "bar/line/scatter/pie"}}
63
  """
64
+ response = mistral_call(question=prompt)
65
+ return extract_json(response)
66
+
67
+ # --- Demo Data Generator ---
68
+ def generate_demo_data_csv(user_input, num_rows=10):
69
+ prompt = f"""
70
+ Generate a {num_rows}-row structured dataset in CSV format with quoted column headers and values:
71
+ "{user_input}"
72
+ """
73
+ response = mistral_call(question=prompt)
74
+ csv_data = "\n".join([line.strip() for line in response.splitlines() if line.strip().startswith('"')])
75
+ if csv_data:
76
+ try:
77
+ df = pd.read_csv(io.StringIO(csv_data))
78
+ buffer = io.StringIO()
79
+ df.to_csv(buffer, index=False)
80
+ return "Demo data generated.", buffer
81
+ except Exception as e:
82
+ return f"CSV error: {e}", None
83
+ return "No CSV found.", None
84
+
85
+ # --- SQL Utilities ---
86
+ def extract_sql_code_blocks(text):
87
+ return re.findall(r"```sql\s+(.*?)```", text, re.DOTALL | re.IGNORECASE)
88
 
89
+ def remove_think_tags(text):
90
+ return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
 
91
 
92
+ def classify_sql_task_prompt_engineered(user_input: str) -> str:
93
+ prompt = f"""
94
+ Classify into:
95
+ CREATE_TABLE, INSERT_INTO, SELECT, UPDATE, DELETE, ALTER_TABLE, INSERT_CSV_EXISTING, INSERT_CSV_NEW
96
+ Input: {user_input}
97
+ Only return the task.
98
  """
99
+ classification = mistral_call(question=prompt)
100
+ cleaned = remove_think_tags(classification).strip().upper()
101
+ for t in ["CREATE_TABLE", "INSERT_INTO", "SELECT", "UPDATE", "DELETE", "ALTER_TABLE", "INSERT_CSV_EXISTING", "INSERT_CSV_NEW"]:
102
+ if t in cleaned:
103
+ return t
104
+ return "UNKNOWN"
105
+
106
+ def handle_query(user_input, engine, task_type):
107
+ try:
108
+ inspector = inspect(engine)
109
+ tables = inspector.get_table_names()
110
+ prompt = f"Generate {task_type} SQL for: {user_input} using tables: {tables}"
111
+ sql_code = mistral_call(question=prompt)
112
+ sql_code = extract_sql_code_blocks(sql_code)
113
+ return execute_sql(sql_code, engine)
114
+ except Exception as e:
115
+ return "None", f"Error: {e}"
116
+
117
+ def execute_sql(sql_code, engine):
118
+ try:
119
+ if isinstance(sql_code, list):
120
+ sql_code = "\n".join(sql_code)
121
+ statements = [stmt.strip() for stmt in sql_code.split(';') if stmt.strip()]
122
+ with engine.connect() as conn:
123
+ for stmt in statements:
124
+ conn.execute(text(stmt + ";"))
125
+ conn.commit()
126
+ return sql_code, "βœ… SQL executed."
127
+ except Exception as e:
128
+ return "None", f"SQL error: {e}"
129
+
130
+ def insert_csv_existing(table_name, csv_file, engine):
131
+ try:
132
+ df = pd.read_csv(csv_file)
133
+ df.to_sql(table_name, engine, if_exists='append', index=False)
134
+ return f"βœ… CSV inserted into '{table_name}'."
135
+ except Exception as e:
136
+ return f"CSV insert error: {e}"
137
+
138
+ def insert_csv_new(table_name, csv_file, engine):
139
+ try:
140
+ df = pd.read_csv(csv_file)
141
+ df.to_sql(table_name, engine, if_exists='replace', index=False)
142
+ return f"βœ… CSV inserted into new table '{table_name}'."
143
+ except Exception as e:
144
+ return f"New CSV insert error: {e}"
145
+
146
+ # --- Streamlit App ---
147
+ st.set_page_config(page_title="AI Dashboard", layout="wide")
148
+ st.title("πŸ€– AI-Powered Multi-Feature Dashboard")
149
+
150
+ st.sidebar.title("Navigation")
151
+ option = st.sidebar.radio("Select Feature", ["πŸ“Š Data Visualization", "🧠 SQL Query Generator", "πŸ“„ Demo Data Generator", "🧠 Smart SQL Task Handler"])
152
+
153
+ if option == "πŸ“Š Data Visualization":
154
+ uploaded_file = st.file_uploader("Upload your CSV", type="csv")
155
+ if uploaded_file:
156
+ df = pd.read_csv(uploaded_file)
157
+ df.columns = df.columns.str.strip()
158
+ st.dataframe(df.head())
159
+ with st.spinner("Getting chart suggestion..."):
160
+ suggestion = get_visualization_suggestion(df)
161
+ if suggestion:
162
+ x_col = suggestion.get("x", "").strip()
163
+ y_col = suggestion.get("y", [])
164
+ y_col = [y_col] if isinstance(y_col, str) else y_col
165
+ chart = suggestion.get("chart_type")
166
+ if x_col in df.columns and all(y in df.columns for y in y_col):
167
+ fig = None
168
+ if chart == "bar":
169
+ fig = px.bar(df, x=x_col, y=y_col)
170
+ elif chart == "line":
171
+ fig = px.line(df, x=x_col, y=y_col)
172
+ elif chart == "scatter":
173
+ fig = px.scatter(df, x=x_col, y=y_col)
174
+ elif chart == "pie" and len(y_col) == 1:
175
+ fig = px.pie(df, names=x_col, values=y_col[0])
176
+ if fig:
177
+ st.plotly_chart(fig)
178
+ else:
179
+ st.error("Unsupported chart type.")
180
+ else:
181
+ st.error("Invalid column suggestion from model.")
182
+
183
+ elif option == "🧠 SQL Query Generator":
184
+ user_input = st.text_area("Describe your SQL query in plain English:")
185
+ if st.button("Generate SQL"):
186
+ st.code(mistral_call(question=user_input))
187
+
188
+ elif option == "πŸ“„ Demo Data Generator":
189
+ user_input = st.text_area("Describe your dataset:")
190
+ num_rows = st.number_input("Rows", 1, 1000, 10)
191
+ if st.button("Generate Dataset"):
192
+ msg, buffer = generate_demo_data_csv(user_input, num_rows)
193
+ st.write(msg)
194
+ if buffer:
195
+ st.download_button("Download CSV", buffer.getvalue(), file_name="generated_data.csv", mime="text/csv")
196
+
197
+ elif option == "🧠 Smart SQL Task Handler":
198
+ st.sidebar.header("DB Settings")
199
+ db_type = "SQLite"
200
+ db_path = st.sidebar.text_input("SQLite File Path", value="smart_sql.db")
201
+ connection_url = f"sqlite:///{db_path}"
202
+ try:
203
+ engine = create_engine(connection_url)
204
+ with engine.connect(): pass
205
+ st.sidebar.success("Connected!")
206
+ except Exception as e:
207
+ st.sidebar.error(f"Connection failed: {e}")
208
+ st.stop()
209
 
210
+ user_input = st.text_area("Enter SQL task (or natural language):")
211
+ csv_file = st.file_uploader("Optional CSV Upload")
212
+ table_name = st.text_input("Table name (for CSV):")
213
+ if st.button("Run SQL Task"):
214
+ task = classify_sql_task_prompt_engineered(user_input)
215
+ st.markdown(f"**Detected Task:** `{task}`")
216
+ if task == "INSERT_CSV_EXISTING" and csv_file and table_name:
217
+ st.write(insert_csv_existing(table_name, csv_file, engine))
218
+ elif task == "INSERT_CSV_NEW" and csv_file and table_name:
219
+ st.write(insert_csv_new(table_name, csv_file, engine))
220
+ else:
221
+ sql_code, msg = handle_query(user_input, engine, task)
222
+ st.code(sql_code)
223
+ st.write(msg)