Quazim0t0 commited on
Commit
6c1c88d
·
verified ·
1 Parent(s): a573881

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -207,6 +207,7 @@ def query_sql(user_query: str) -> str:
207
  + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
208
  "The table name is '" + table_name + "'.\n"
209
  "If column names contain spaces, they must be quoted.\n"
 
210
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
211
  )
212
 
@@ -215,7 +216,11 @@ def query_sql(user_query: str) -> str:
215
  if not isinstance(generated_sql, str):
216
  return f"{generated_sql}"
217
 
218
- if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
 
 
 
 
219
  return "Error: Only SELECT queries are allowed."
220
 
221
  # Fix table names
@@ -229,6 +234,7 @@ def query_sql(user_query: str) -> str:
229
  if col in generated_sql and f'"{col}"' not in generated_sql and f'`{col}`' not in generated_sql:
230
  generated_sql = generated_sql.replace(col, f'"{col}"')
231
 
 
232
  result = sql_engine(generated_sql)
233
 
234
  try:
 
207
  + "\n\nGenerate a valid SQL SELECT query using ONLY these column names.\n"
208
  "The table name is '" + table_name + "'.\n"
209
  "If column names contain spaces, they must be quoted.\n"
210
+ "You can use aggregate functions like COUNT, AVG, SUM, etc.\n"
211
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
212
  )
213
 
 
216
  if not isinstance(generated_sql, str):
217
  return f"{generated_sql}"
218
 
219
+ # Normalize the SQL for checking
220
+ normalized_sql = generated_sql.strip().lower()
221
+
222
+ # Check if it's a valid SELECT query
223
+ if not (normalized_sql.startswith("select") and "from" in normalized_sql):
224
  return "Error: Only SELECT queries are allowed."
225
 
226
  # Fix table names
 
234
  if col in generated_sql and f'"{col}"' not in generated_sql and f'`{col}`' not in generated_sql:
235
  generated_sql = generated_sql.replace(col, f'"{col}"')
236
 
237
+ # Execute the query
238
  result = sql_engine(generated_sql)
239
 
240
  try: