CSVAgent / app.py
Quazim0t0's picture
Update app.py
28200f6 verified
raw
history blame
5.86 kB
import os
import gradio as gr
from sqlalchemy import text, create_engine, inspect
from smolagents import tool, CodeAgent, HfApiModel
import pandas as pd
import tempfile
from database import engine, initialize_database
# Ensure the database initializes
initialize_database()
# Function to execute SQL script from uploaded file
def execute_sql_script(file_path):
"""
Executes an uploaded SQL file to initialize the database.
Args:
file_path (str): Path to the SQL file.
Returns:
str: Success message or error description.
"""
try:
with engine.connect() as con:
with open(file_path, "r") as f:
sql_script = f.read()
con.execute(text(sql_script))
return "SQL file executed successfully."
except Exception as e:
return f"Error: {str(e)}"
# Function to fetch table names dynamically
def get_table_names():
"""
Returns a list of all tables in the database.
Returns:
list: List of table names.
"""
inspector = inspect(engine)
return inspector.get_table_names()
# Function to fetch table schema dynamically
def get_table_schema(table_name):
"""
Returns a list of column names for a given table.
Args:
table_name (str): Name of the table.
Returns:
list: List of column names.
"""
inspector = inspect(engine)
columns = inspector.get_columns(table_name)
return [col["name"] for col in columns]
# Function to fetch table data dynamically
def get_table_data(table_name):
"""
Retrieves all rows from the specified table as a Pandas DataFrame.
Args:
table_name (str): Name of the table.
Returns:
pd.DataFrame: Table data or an error message.
"""
try:
with engine.connect() as con:
result = con.execute(text(f"SELECT * FROM {table_name}"))
rows = result.fetchall()
columns = get_table_schema(table_name)
if not rows:
return pd.DataFrame(columns=columns)
return pd.DataFrame(rows, columns=columns)
except Exception as e:
return pd.DataFrame({"Error": [str(e)]})
# SQL Execution Tool
@tool
def sql_engine(query: str) -> str:
"""
Executes an SQL SELECT query and returns the results.
Args:
query (str): The SQL query to execute.
Returns:
str: The query results as a formatted string, or an error message.
"""
try:
with engine.connect() as con:
rows = con.execute(text(query)).fetchall()
if not rows:
return "No results found."
return "\n".join([", ".join(map(str, row)) for row in rows])
except Exception as e:
return f"Error: {str(e)}"
# Function to generate and execute SQL queries dynamically
def query_sql(user_query: str) -> str:
"""
Processes a user’s natural language query and generates an SQL query dynamically.
Args:
user_query (str): The question asked by the user.
Returns:
str: SQL query results or an error message.
"""
tables = get_table_names()
if not tables:
return "Error: No tables found. Please upload an SQL file first."
schema_info = "Available tables and columns:\n"
for table in tables:
columns = get_table_schema(table)
schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n"
schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself."
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
return "Error: Only SELECT queries are allowed."
return sql_engine(generated_sql)
# Function to handle query input
def handle_query(user_input: str) -> str:
"""
Handles user input and returns the SQL query result.
Args:
user_input (str): User's natural language query.
Returns:
str: The query result or error message.
"""
return query_sql(user_input)
# Function to handle SQL file uploads
def handle_file_upload(file):
"""
Handles file upload, executes SQL, and updates database schema dynamically.
Args:
file (File): Uploaded SQL file.
Returns:
tuple: Execution result message and updated table data.
"""
temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
with open(temp_file_path, "wb") as temp_file:
temp_file.write(file.read())
result = execute_sql_script(temp_file_path)
tables = get_table_names()
if tables:
table_data = {table: get_table_data(table) for table in tables}
else:
table_data = {"Error": ["No tables found after execution. Ensure your SQL file creates tables."]}
return result, table_data
# Initialize CodeAgent for SQL query generation
agent = CodeAgent(
tools=[sql_engine],
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## SQL Query Interface")
with gr.Row():
user_input = gr.Textbox(label="Ask a question about the data")
query_output = gr.Textbox(label="Result")
user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
gr.Markdown("## Upload SQL File to Execute")
file_upload = gr.File(label="Upload SQL File")
upload_output = gr.Textbox(label="Execution Result")
# Dynamic table display
table_output = gr.Dataframe(label="Database Tables (Dynamic)")
file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)