Quazim0t0 commited on
Commit
eff3c87
·
verified ·
1 Parent(s): cf110d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -112
app.py CHANGED
@@ -1,20 +1,48 @@
1
  import os
2
  import gradio as gr
3
- from sqlalchemy import text
4
- from smolagents import tool, CodeAgent, HfApiModel
5
- import spaces
6
  import pandas as pd
7
- from database import (
8
- engine,
9
- create_dynamic_table,
10
- clear_database,
11
- insert_rows_into_table,
12
- get_table_schema
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_data_table():
16
  """
17
- Fetches all data from the current table and returns it as a Pandas DataFrame.
18
  """
19
  try:
20
  # Get list of tables
@@ -28,18 +56,10 @@ def get_data_table():
28
 
29
  # Use the first table found
30
  table_name = tables[0][0]
31
-
32
- with engine.connect() as con:
33
- result = con.execute(text(f"SELECT * FROM {table_name}"))
34
- rows = result.fetchall()
35
-
36
- if not rows:
37
- return pd.DataFrame()
38
-
39
- columns = result.keys()
40
- df = pd.DataFrame(rows, columns=columns)
41
- return df
42
-
43
  except Exception as e:
44
  return pd.DataFrame({"Error": [str(e)]})
45
 
@@ -160,7 +180,7 @@ def process_uploaded_file(file):
160
  def sql_engine(query: str) -> str:
161
  """
162
  Executes an SQL query and returns formatted results.
163
-
164
  Args:
165
  query: The SQL query string to execute on the database. Must be a valid SELECT query.
166
 
@@ -182,48 +202,10 @@ def sql_engine(query: str) -> str:
182
  except Exception as e:
183
  return f"Error: {str(e)}"
184
 
185
- agent = CodeAgent(
186
- tools=[sql_engine],
187
- model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
188
- )
189
-
190
- def query_sql(user_query: str) -> str:
191
  """
192
- Converts natural language input to an SQL query using CodeAgent.
193
  """
194
- table_name, column_names, column_info = get_table_info()
195
-
196
- if not table_name:
197
- return "Error: No data table exists. Please upload a file first."
198
-
199
- schema_info = (
200
- f"The database has a table named '{table_name}' with the following columns:\n"
201
- + "\n".join([
202
- f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
203
- for col, info in column_info.items()
204
- ])
205
- + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
206
- "The table name is '" + table_name + "'.\n"
207
- "If column names contain spaces, they must be quoted.\n"
208
- "You can use aggregate functions like COUNT, AVG, SUM, etc.\n"
209
- "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
210
- )
211
-
212
- # Get SQL from the agent
213
- generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
214
-
215
- if not isinstance(generated_sql, str):
216
- return "Error: Invalid query generated"
217
-
218
- # Clean up the SQL
219
- if generated_sql.isnumeric(): # If the agent returned just a number
220
- return generated_sql
221
-
222
- # Extract just the SQL query if there's additional text
223
- sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()]
224
- if sql_lines:
225
- generated_sql = sql_lines[0]
226
-
227
  # Remove any trailing semicolons
228
  generated_sql = generated_sql.strip().rstrip(';')
229
 
@@ -255,12 +237,103 @@ def query_sql(user_query: str) -> str:
255
  return generated_sql
256
  return f"Error executing query: {str(e)}"
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Create the Gradio interface
259
  with gr.Blocks() as demo:
260
  with gr.Group() as upload_group:
261
  gr.Markdown("""
262
  # CSVAgent
263
-
264
  Upload your data file to begin.
265
 
266
  ### Supported File Types:
@@ -277,10 +350,7 @@ with gr.Blocks() as demo:
277
  https://tableconvert.com/sql-to-csv
278
  - Will work on the handling of SQL files soon.
279
 
280
-
281
  ### Try it out! Upload a CSV file and then ask a question about the data!
282
- - There is issues with the UI displaying the answer correctly, some questions such as "How many Customers are located in Korea?"
283
- The right answer will appear in the logs, but throws an error on the "Results" section.
284
  """)
285
 
286
  file_input = gr.File(
@@ -295,6 +365,9 @@ with gr.Blocks() as demo:
295
  with gr.Column(scale=1):
296
  user_input = gr.Textbox(label="Ask a question about the data")
297
  query_output = gr.Textbox(label="Result")
 
 
 
298
 
299
  with gr.Column(scale=2):
300
  gr.Markdown("### Current Data")
@@ -307,48 +380,6 @@ with gr.Blocks() as demo:
307
  schema_display = gr.Markdown(value="Loading schema...")
308
  refresh_btn = gr.Button("Refresh Data")
309
 
310
- def handle_upload(file_obj):
311
- if file_obj is None:
312
- return (
313
- "Please upload a file.",
314
- None,
315
- "No schema available",
316
- gr.update(visible=True),
317
- gr.update(visible=False)
318
- )
319
-
320
- success, message = process_uploaded_file(file_obj)
321
- if success:
322
- df = get_data_table()
323
- _, _, column_info = get_table_info()
324
- schema = "\n".join([
325
- f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
326
- for col, info in column_info.items()
327
- ])
328
- return (
329
- message,
330
- df,
331
- f"### Current Schema:\n```\n{schema}\n```",
332
- gr.update(visible=False),
333
- gr.update(visible=True)
334
- )
335
- return (
336
- message,
337
- None,
338
- "No schema available",
339
- gr.update(visible=True),
340
- gr.update(visible=False)
341
- )
342
-
343
- def refresh_data():
344
- df = get_data_table()
345
- _, _, column_info = get_table_info()
346
- schema = "\n".join([
347
- f"- {col} ({info['type']}){' primary key' if info['is_primary'] else ''}"
348
- for col, info in column_info.items()
349
- ])
350
- return df, f"### Current Schema:\n```\n{schema}\n```"
351
-
352
  # Event handlers
353
  file_input.upload(
354
  fn=handle_upload,
@@ -364,8 +395,15 @@ with gr.Blocks() as demo:
364
 
365
  user_input.change(
366
  fn=query_sql,
367
- inputs=user_input,
368
- outputs=query_output
 
 
 
 
 
 
 
369
  )
370
 
371
  refresh_btn.click(
 
1
  import os
2
  import gradio as gr
 
 
 
3
  import pandas as pd
4
+ from sqlalchemy import create_engine, text
5
+ from code_agent import CodeAgent
6
+ from hf_api_model import HfApiModel
7
+
8
+ # Initialize SQLite database engine
9
+ engine = create_engine('sqlite:///data.db')
10
+
11
+ def clear_database():
12
+ """
13
+ Clear all tables from the database.
14
+ """
15
+ with engine.connect() as con:
16
+ # Get all table names
17
+ tables = con.execute(text(
18
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
19
+ )).fetchall()
20
+
21
+ # Drop each table
22
+ for table in tables:
23
+ con.execute(text(f"DROP TABLE IF EXISTS {table[0]}"))
24
+
25
+ def create_dynamic_table(df):
26
+ """
27
+ Create a table dynamically based on DataFrame structure.
28
+ """
29
+ df.to_sql('data_table', engine, index=False, if_exists='replace')
30
+ return 'data_table'
31
+
32
+ def insert_rows_into_table(records, table_name):
33
+ """
34
+ Insert records into the specified table.
35
+ """
36
+ with engine.begin() as conn:
37
+ for record in records:
38
+ conn.execute(
39
+ text(f"INSERT INTO {table_name} ({', '.join(record.keys())}) VALUES ({', '.join(['?' for _ in record])})")
40
+ .bindparams(*record.values())
41
+ )
42
 
43
  def get_data_table():
44
  """
45
+ Get the current data table as a DataFrame.
46
  """
47
  try:
48
  # Get list of tables
 
56
 
57
  # Use the first table found
58
  table_name = tables[0][0]
59
+
60
+ # Read the table into a DataFrame
61
+ return pd.read_sql_table(table_name, engine)
62
+
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
  return pd.DataFrame({"Error": [str(e)]})
65
 
 
180
  def sql_engine(query: str) -> str:
181
  """
182
  Executes an SQL query and returns formatted results.
183
+
184
  Args:
185
  query: The SQL query string to execute on the database. Must be a valid SELECT query.
186
 
 
202
  except Exception as e:
203
  return f"Error: {str(e)}"
204
 
205
+ def process_sql_result(generated_sql, table_name, column_names):
 
 
 
 
 
206
  """
207
+ Process and execute the generated SQL query.
208
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Remove any trailing semicolons
210
  generated_sql = generated_sql.strip().rstrip(';')
211
 
 
237
  return generated_sql
238
  return f"Error executing query: {str(e)}"
239
 
240
+ def query_sql(user_query: str, show_full: bool) -> tuple:
241
+ """
242
+ Converts natural language input to an SQL query using CodeAgent.
243
+ Returns both short and full responses based on switch state.
244
+ """
245
+ table_name, column_names, column_info = get_table_info()
246
+
247
+ if not table_name:
248
+ return "Error: No data table exists. Please upload a file first.", ""
249
+
250
+ schema_info = (
251
+ f"The database has a table named '{table_name}' with the following columns:\n"
252
+ + "\n".join([
253
+ f"- {col} ({info['type']}{' primary key' if info['is_primary'] else ''})"
254
+ for col, info in column_info.items()
255
+ ])
256
+ + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
257
+ "The table name is '" + table_name + "'.\n"
258
+ "If column names contain spaces, they must be quoted.\n"
259
+ "You can use aggregate functions like COUNT, AVG, SUM, etc.\n"
260
+ "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
261
+ )
262
+
263
+ # Get full response from the agent
264
+ full_response = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
265
+
266
+ # Process the short response as before
267
+ if not isinstance(full_response, str):
268
+ return "Error: Invalid query generated", ""
269
+
270
+ # Extract and process SQL for short response
271
+ generated_sql = full_response
272
+ if generated_sql.isnumeric():
273
+ short_response = generated_sql
274
+ else:
275
+ sql_lines = [line for line in generated_sql.split('\n') if 'select' in line.lower()]
276
+ if sql_lines:
277
+ generated_sql = sql_lines[0]
278
+
279
+ # Process the SQL query and get the short result
280
+ short_response = process_sql_result(generated_sql, table_name, column_names)
281
+
282
+ return short_response, full_response
283
+
284
+ def handle_upload(file_obj):
285
+ if file_obj is None:
286
+ return (
287
+ "Please upload a file.",
288
+ None,
289
+ "No schema available",
290
+ gr.update(visible=True),
291
+ gr.update(visible=False)
292
+ )
293
+
294
+ success, message = process_uploaded_file(file_obj)
295
+ if success:
296
+ df = get_data_table()
297
+ _, _, column_info = get_table_info()
298
+ schema = "\n".join([
299
+ f"- {col} ({info['type']}){'primary key' if info['is_primary'] else ''}"
300
+ for col, info in column_info.items()
301
+ ])
302
+ return (
303
+ message,
304
+ df,
305
+ f"### Current Schema:\n```\n{schema}\n```",
306
+ gr.update(visible=False),
307
+ gr.update(visible=True)
308
+ )
309
+ return (
310
+ message,
311
+ None,
312
+ "No schema available",
313
+ gr.update(visible=True),
314
+ gr.update(visible=False)
315
+ )
316
+
317
+ def refresh_data():
318
+ df = get_data_table()
319
+ _, _, column_info = get_table_info()
320
+ schema = "\n".join([
321
+ f"- {col} ({info['type']}){'primary key' if info['is_primary'] else ''}"
322
+ for col, info in column_info.items()
323
+ ])
324
+ return df, f"### Current Schema:\n```\n{schema}\n```"
325
+
326
+ # Initialize the CodeAgent
327
+ agent = CodeAgent(
328
+ tools=[sql_engine],
329
+ model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
330
+ )
331
+
332
  # Create the Gradio interface
333
  with gr.Blocks() as demo:
334
  with gr.Group() as upload_group:
335
  gr.Markdown("""
336
  # CSVAgent
 
337
  Upload your data file to begin.
338
 
339
  ### Supported File Types:
 
350
  https://tableconvert.com/sql-to-csv
351
  - Will work on the handling of SQL files soon.
352
 
 
353
  ### Try it out! Upload a CSV file and then ask a question about the data!
 
 
354
  """)
355
 
356
  file_input = gr.File(
 
365
  with gr.Column(scale=1):
366
  user_input = gr.Textbox(label="Ask a question about the data")
367
  query_output = gr.Textbox(label="Result")
368
+ # Add the switch and secondary result box
369
+ full_response_switch = gr.Switch(label="Show Full Response", value=False)
370
+ full_response_output = gr.Textbox(label="Full Response", visible=False)
371
 
372
  with gr.Column(scale=2):
373
  gr.Markdown("### Current Data")
 
380
  schema_display = gr.Markdown(value="Loading schema...")
381
  refresh_btn = gr.Button("Refresh Data")
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  # Event handlers
384
  file_input.upload(
385
  fn=handle_upload,
 
395
 
396
  user_input.change(
397
  fn=query_sql,
398
+ inputs=[user_input, full_response_switch],
399
+ outputs=[query_output, full_response_output]
400
+ )
401
+
402
+ # Add switch change event to control visibility of full response
403
+ full_response_switch.change(
404
+ fn=lambda x: gr.update(visible=x),
405
+ inputs=full_response_switch,
406
+ outputs=full_response_output
407
  )
408
 
409
  refresh_btn.click(