Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from sqlalchemy import text, create_engine, inspect | |
from smolagents import tool, CodeAgent, HfApiModel | |
import pandas as pd | |
import tempfile | |
from database import engine, initialize_database | |
# Ensure the database initializes | |
initialize_database() | |
# Function to execute SQL script from uploaded file | |
def execute_sql_script(file_path): | |
""" | |
Executes an uploaded SQL file to initialize the database. | |
Args: | |
file_path (str): Path to the SQL file. | |
Returns: | |
str: Success message or error description. | |
""" | |
try: | |
with engine.connect() as con: | |
with open(file_path, "r") as f: | |
sql_script = f.read() | |
con.execute(text(sql_script)) | |
return "SQL file executed successfully." | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Function to fetch table names dynamically | |
def get_table_names(): | |
""" | |
Returns a list of all tables in the database. | |
Returns: | |
list: List of table names. | |
""" | |
inspector = inspect(engine) | |
return inspector.get_table_names() | |
# Function to fetch table schema dynamically | |
def get_table_schema(table_name): | |
""" | |
Returns a list of column names for a given table. | |
Args: | |
table_name (str): Name of the table. | |
Returns: | |
list: List of column names. | |
""" | |
inspector = inspect(engine) | |
columns = inspector.get_columns(table_name) | |
return [col["name"] for col in columns] | |
# Function to fetch table data dynamically | |
def get_table_data(table_name): | |
""" | |
Retrieves all rows from the specified table as a Pandas DataFrame. | |
Args: | |
table_name (str): Name of the table. | |
Returns: | |
pd.DataFrame: Table data or an error message. | |
""" | |
try: | |
with engine.connect() as con: | |
result = con.execute(text(f"SELECT * FROM {table_name}")) | |
rows = result.fetchall() | |
columns = get_table_schema(table_name) | |
if not rows: | |
return pd.DataFrame(columns=columns) | |
return pd.DataFrame(rows, columns=columns) | |
except Exception as e: | |
return pd.DataFrame({"Error": [str(e)]}) | |
# SQL Execution Tool | |
def sql_engine(query: str) -> str: | |
""" | |
Executes an SQL SELECT query and returns the results. | |
Args: | |
query (str): The SQL query to execute. | |
Returns: | |
str: The query results as a formatted string, or an error message. | |
""" | |
try: | |
with engine.connect() as con: | |
rows = con.execute(text(query)).fetchall() | |
if not rows: | |
return "No results found." | |
return "\n".join([", ".join(map(str, row)) for row in rows]) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Function to generate and execute SQL queries dynamically | |
def query_sql(user_query: str) -> str: | |
""" | |
Processes a user’s natural language query and generates an SQL query dynamically. | |
Args: | |
user_query (str): The question asked by the user. | |
Returns: | |
str: SQL query results or an error message. | |
""" | |
tables = get_table_names() | |
if not tables: | |
return "Error: No tables found. Please upload an SQL file first." | |
schema_info = "Available tables and columns:\n" | |
for table in tables: | |
columns = get_table_schema(table) | |
schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n" | |
schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself." | |
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}") | |
if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")): | |
return "Error: Only SELECT queries are allowed." | |
return sql_engine(generated_sql) | |
# Function to handle query input | |
def handle_query(user_input: str) -> str: | |
""" | |
Handles user input and returns the SQL query result. | |
Args: | |
user_input (str): User's natural language query. | |
Returns: | |
str: The query result or error message. | |
""" | |
return query_sql(user_input) | |
# Function to handle SQL file uploads | |
def handle_file_upload(file): | |
""" | |
Handles file upload, executes SQL, and updates database schema dynamically. | |
Args: | |
file (File): Uploaded SQL file. | |
Returns: | |
tuple: Execution result message and updated table data. | |
""" | |
temp_file_path = tempfile.mkstemp(suffix=".sql")[1] | |
with open(temp_file_path, "wb") as temp_file: | |
temp_file.write(file.read()) | |
result = execute_sql_script(temp_file_path) | |
tables = get_table_names() | |
if tables: | |
table_data = {table: get_table_data(table) for table in tables} | |
else: | |
table_data = {"Error": ["No tables found after execution. Ensure your SQL file creates tables."]} | |
return result, table_data | |
# Initialize CodeAgent for SQL query generation | |
agent = CodeAgent( | |
tools=[sql_engine], | |
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"), | |
) | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## SQL Query Interface") | |
with gr.Row(): | |
user_input = gr.Textbox(label="Ask a question about the data") | |
query_output = gr.Textbox(label="Result") | |
user_input.change(fn=handle_query, inputs=user_input, outputs=query_output) | |
gr.Markdown("## Upload SQL File to Execute") | |
file_upload = gr.File(label="Upload SQL File") | |
upload_output = gr.Textbox(label="Execution Result") | |
# Dynamic table display | |
table_output = gr.Dataframe(label="Database Tables (Dynamic)") | |
file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output]) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |