File size: 4,046 Bytes
6a0ec6a
 
5a55ea7
6a0ec6a
1767e22
5a55ea7
 
 
 
 
 
 
 
 
 
 
 
 
1767e22
5a55ea7
 
 
 
20e319d
5a55ea7
 
 
 
 
 
 
 
20e319d
 
5a55ea7
20e319d
 
5a55ea7
20e319d
5a55ea7
 
 
 
20e319d
5a55ea7
20e319d
5a55ea7
042246b
 
 
 
 
 
 
 
 
 
 
5a55ea7
7306c07
5a55ea7
 
 
 
 
 
 
61d9b40
5a55ea7
2443195
5a55ea7
61d9b40
5a55ea7
f8c651a
61d9b40
5a55ea7
1f7ee11
5a55ea7
215368b
5a55ea7
edb7e14
5a55ea7
 
 
 
 
 
 
 
edb7e14
5a55ea7
 
 
 
1f7ee11
5a55ea7
 
 
1f7ee11
1df3c5d
1f7ee11
 
 
5a55ea7
1767e22
5a55ea7
1767e22
 
5a55ea7
 
1767e22
 
 
5a55ea7
 
 
 
 
 
 
 
6a0ec6a
 
0380e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import gradio as gr
from sqlalchemy import text, create_engine, inspect
from smolagents import tool, CodeAgent, HfApiModel
import pandas as pd
import tempfile
from database import engine

# Function to execute SQL script from uploaded file
def execute_sql_script(file_path):
    try:
        with engine.connect() as con:
            with open(file_path, "r") as f:
                sql_script = f.read()
                con.execute(text(sql_script))
        return "SQL file executed successfully."
    except Exception as e:
        return f"Error: {str(e)}"

# Function to fetch table names dynamically
def get_table_names():
    inspector = inspect(engine)
    return inspector.get_table_names()

# Function to fetch table schema dynamically
def get_table_schema(table_name):
    inspector = inspect(engine)
    columns = inspector.get_columns(table_name)
    return [col["name"] for col in columns]

# Function to fetch table data dynamically
def get_table_data(table_name):
    try:
        with engine.connect() as con:
            result = con.execute(text(f"SELECT * FROM {table_name}"))
            rows = result.fetchall()

        columns = get_table_schema(table_name)

        if not rows:
            return pd.DataFrame(columns=columns)

        return pd.DataFrame(rows, columns=columns)
    except Exception as e:
        return pd.DataFrame({"Error": [str(e)]})

# SQL Execution Tool
@tool
def sql_engine(query: str) -> str:
    try:
        with engine.connect() as con:
            rows = con.execute(text(query)).fetchall()
        if not rows:
            return "No results found."
        return "\n".join([", ".join(map(str, row)) for row in rows])
    except Exception as e:
        return f"Error: {str(e)}"

# Function to generate and execute SQL queries dynamically
def query_sql(user_query: str) -> str:
    # Get schema details dynamically
    tables = get_table_names()
    schema_info = "Available tables and columns:\n"
    
    for table in tables:
        columns = get_table_schema(table)
        schema_info += f"Table '{table}' has columns: {', '.join(columns)}.\n"

    schema_info += "Generate a valid SQL SELECT query using ONLY these column names. DO NOT return anything other than the SQL query itself."

    generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")

    if not isinstance(generated_sql, str) or not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
        return "Error: Only SELECT queries are allowed."

    return sql_engine(generated_sql)

# Function to handle query input
def handle_query(user_input: str) -> str:
    return query_sql(user_input)

# Function to handle SQL file uploads
def handle_file_upload(file):
    temp_file_path = tempfile.mkstemp(suffix=".sql")[1]
    with open(temp_file_path, "wb") as temp_file:
        temp_file.write(file.read())
    
    result = execute_sql_script(temp_file_path)
    tables = get_table_names()

    if tables:
        table_data = {table: get_table_data(table) for table in tables}
    else:
        table_data = {}

    return result, table_data

# Initialize CodeAgent for SQL query generation
agent = CodeAgent(
    tools=[sql_engine],
    model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## SQL Query Interface")

    with gr.Row():
        user_input = gr.Textbox(label="Ask a question about the data")
        query_output = gr.Textbox(label="Result")

    user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)

    gr.Markdown("## Upload SQL File to Execute")
    file_upload = gr.File(label="Upload SQL File")
    upload_output = gr.Textbox(label="Execution Result")

    # Dynamic table display
    table_output = gr.Dataframe(label="Database Tables (Dynamic)")

    file_upload.change(fn=handle_file_upload, inputs=file_upload, outputs=[upload_output, table_output])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)