Quazim0t0 commited on
Commit
9002697
·
verified ·
1 Parent(s): be857eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -8
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import gradio as gr
3
- from sqlalchemy import text, inspect, create_engine
4
  from smolagents import tool, CodeAgent, HfApiModel
5
  import pandas as pd
6
  import tempfile
@@ -9,6 +9,33 @@ from database import engine, initialize_database
9
  # Ensure the database initializes (won't crash if empty)
10
  initialize_database()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Function to execute an uploaded SQL script
13
  def execute_sql_script(file_path):
14
  """
@@ -105,11 +132,48 @@ def handle_file_upload(file):
105
 
106
  return result, table_data
107
 
108
- # Initialize CodeAgent for SQL query generation
109
- agent = CodeAgent(
110
- tools=[sql_engine],
111
- model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
112
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Gradio UI
115
  with gr.Blocks() as demo:
@@ -132,5 +196,3 @@ with gr.Blocks() as demo:
132
 
133
  if __name__ == "__main__":
134
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
135
-
136
-
 
1
  import os
2
  import gradio as gr
3
+ from sqlalchemy import text, inspect
4
  from smolagents import tool, CodeAgent, HfApiModel
5
  import pandas as pd
6
  import tempfile
 
9
  # Ensure the database initializes (won't crash if empty)
10
  initialize_database()
11
 
12
+ # SQL Execution Tool (FIXED - Defined BEFORE Use)
13
+ @tool
14
+ def sql_engine(query: str) -> str:
15
+ """
16
+ Executes an SQL SELECT query and returns the results.
17
+
18
+ Args:
19
+ query (str): The SQL query string to execute. Only SELECT queries are allowed.
20
+
21
+ Returns:
22
+ str: A formatted string containing the query results, or an error message if the query fails.
23
+ """
24
+ try:
25
+ with engine.connect() as con:
26
+ rows = con.execute(text(query)).fetchall()
27
+ if not rows:
28
+ return "No results found."
29
+ return "\n".join([", ".join(map(str, row)) for row in rows])
30
+ except Exception as e:
31
+ return f"Error: {str(e)}"
32
+
33
+ # Initialize CodeAgent for SQL query generation (Moved Below `sql_engine`)
34
+ agent = CodeAgent(
35
+ tools=[sql_engine],
36
+ model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
37
+ )
38
+
39
  # Function to execute an uploaded SQL script
40
  def execute_sql_script(file_path):
41
  """
 
132
 
133
  return result, table_data
134
 
135
+ # Function to handle natural language to SQL conversion
136
+ def query_sql(user_query: str) -> str:
137
+ """
138
+ Converts a user's natural language query into an SQL query.
139
+
140
+ Args:
141
+ user_query (str): The question asked by the user.
142
+
143
+ Returns:
144
+ str: The results of the executed SQL query.
145
+ """
146
+ tables = get_table_names()
147
+ if not tables:
148
+ return "Error: No tables found. Please upload an SQL file first."
149
+
150
+ schema_info = "Available tables and columns:\n"
151
+
152
+ for table in tables:
153
+ columns = get_table_schema(table)
154
+ schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n"
155
+
156
+ schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself."
157
+
158
+ generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
159
+
160
+ if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
161
+ return "Error: Only SELECT queries are allowed."
162
+
163
+ return sql_engine(generated_sql)
164
+
165
+ # Function to handle query input
166
+ def handle_query(user_input: str) -> str:
167
+ """
168
+ Handles user input and returns the SQL query result.
169
+
170
+ Args:
171
+ user_input (str): User's natural language query.
172
+
173
+ Returns:
174
+ str: The query result or error message.
175
+ """
176
+ return query_sql(user_input)
177
 
178
  # Gradio UI
179
  with gr.Blocks() as demo:
 
196
 
197
  if __name__ == "__main__":
198
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)