sql / app.py
Garvitj's picture
Update app.py
5903789 verified
raw
history blame
9.66 kB
import gradio as gr
from huggingface_hub import InferenceClient
import pandas as pd
import io
import mysql.connector
import os
HF_API_KEY = os.getenv("hf_key")
# Hugging Face API Key (Replace with your actual key)
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY)
def classify_task(user_input):
"""Classifies user task using Mistral."""
prompt = f"""
You are an AI assistant that classifies user requests related to SQL and databases.
Your task is to categorize the input into one of the following options:
- **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc.
- **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server.
- **generate_demo_data_db** → If the user wants to insert demo data into a database.
- **generate_demo_data_csv** → If the user wants to generate demo data in CSV format.
- **analyze_data** → If the user asks for insights, trends, or statistical analysis of data.
**Examples:**
1. "Give me SQL syntax to create a student table" → **generate_sql**
2. "Create a student table in my database" → **create_table**
3. "Insert some demo data in my database" → **generate_demo_data_db**
4. "Generate sample student data in CSV format" → **generate_demo_data_csv**
5. "Analyze student marks and trends" → **analyze_data**
**User Input:** {user_input}
**Output Format:** Return only the category name without any explanations.
"""
response = client.text_generation(prompt, max_new_tokens=20).strip()
return response
def generate_sql_query(user_input):
"""Generates SQL queries using Mistral."""
prompt = f"Generate SQL syntax for: {user_input}"
return client.text_generation(prompt, max_new_tokens=200).strip()
def generate_sql_query_for_create(user_input):
"""Generates SQL queries using Mistral."""
prompt = f"""
Generate **only** the SQL syntax for: {user_input}
**Rules:**
- No explanations, no bullet points, no extra text.
- Return **only valid SQL**.
**Example Input:**
"Create a student table with name, age, and email."
**Example Output:**
```sql
CREATE TABLE student (
student_id INT PRIMARY KEY AUTO_INCREMENT,
name VARCHAR(100) NOT NULL,
age INT NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL
);
```
"""
response = client.text_generation(prompt, max_new_tokens=200).strip()
# Remove unnecessary text (if any)
if "```sql" in response:
response = response.split("```sql")[1].split("```")[0].strip()
return response
import pymysql
def create_table(user_input, db_user, db_pass, db_host, db_name):
try:
# Validate inputs
if not all([db_user, db_pass, db_host, db_name, user_input]):
return "Please provide all required inputs (database credentials and table structure).", None
# Generate SQL schema using Mistral
schema_response = generate_sql_query_for_create(user_input)
print(schema_response)
# Validate schema using sqlparse
parsed_schema = sqlparse.parse(schema_response)
if not parsed_schema:
return "Error: Could not generate a valid table schema.", None
# Connect to MySQL Server
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass)
cursor = connection.cursor()
# Create Database if it doesn't exist
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
connection.commit()
connection.close()
# Connect to the specified database
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
cursor = connection.cursor()
# Execute the generated CREATE TABLE statement
cursor.execute(schema_response)
connection.commit()
return "Table created successfully.", None
except pymysql.MySQLError as err:
return f"Error: {err}", None
finally:
if 'connection' in locals() and connection.open:
cursor.close()
connection.close()
import mysql.connector
import re
import sqlparse # Install via: pip install sqlparse
def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10):
"""Generates and inserts structured demo data into a database using LLM."""
if not all([db_user, db_pass, db_name]):
return "Please provide database credentials.", None
# Generate column definitions using LLM
schema_prompt = f"""
Extract column names and types from the following request:
"{user_input}"
**Output Format:**
- The first column should be an "ID" column (INTEGER, PRIMARY KEY).
- Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers).
- Use proper SQL syntax. No explanations.
Example Output:
```
CREATE TABLE demo (
ID INT PRIMARY KEY,
Name VARCHAR(100),
Age INT
);
```
"""
schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip()
# Validate schema using sqlparse
parsed_schema = sqlparse.parse(schema_response)
if not parsed_schema:
return "Error: Could not generate a valid table schema.", None
# Extract table schema
table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip()
# Connect to MySQL and create the table dynamically
connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
cursor = connection.cursor()
cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})")
# Generate demo data using LLM
data_prompt = f"""
Generate {num_rows} rows of structured demo data for this table schema:
```
{schema_response}
```
**Output Format:**
- Return valid SQL INSERT statements.
- Ensure all values match their respective column types.
- Use double quotes ("") for text values.
- No explanations, just raw SQL.
Example Output:
```
INSERT INTO demo VALUES (1, "John Doe", 25);
INSERT INTO demo VALUES (2, "Jane Smith", 30);
```
"""
data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip()
# Extract SQL INSERT statements using a better regex
insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL)
if not insert_statements:
return "Error: Could not generate valid data.", None
# Insert data into the database
for statement in insert_statements:
cursor.execute(statement)
connection.commit()
connection.close()
return "Demo data inserted into the database successfully.", None
def generate_demo_data_csv(user_input, num_rows=10):
"""Generates realistic demo data using the LLM in valid CSV format."""
prompt = f"""
Generate a structured dataset with {num_rows} rows based on the following request:
"{user_input}"
**Output Format:**
- Ensure the response is in **valid CSV format** (comma-separated).
- The **first row** must be column headers.
- Use **double quotes for text values** to avoid formatting issues.
- Do **not** include explanations—just the raw CSV data.
Example Output:
"ID","Name","Age","Email"
"1","John Doe","25","[email protected]"
"2","Jane Smith","30","[email protected]"
"""
# Get LLM response
response = client.text_generation(prompt, max_new_tokens=10000).strip()
# Ensure we extract only the CSV part (some models may add explanations)
csv_start = response.find('"ID"') # Find where the CSV starts
if csv_start != -1:
response = response[csv_start:] # Remove anything before the CSV
# Convert to DataFrame
try:
df = pd.read_csv(io.StringIO(response)) # Read as CSV
except Exception as e:
return f"Error: Invalid CSV format. {str(e)}", None
# Save to a CSV file
file_path = "generated_data.csv"
df.to_csv(file_path, index=False)
return "Demo data generated as CSV.", file_path # Return file path
def analyze_data(user_input):
"""Analyzes data using Mistral."""
prompt = f"Analyze this data: {user_input}"
return client.text_generation(prompt, max_new_tokens=200).strip()
def sql_chatbot(user_input, num_rows=10):
task = classify_task(user_input)
if "generate_sql" in task:
return generate_sql_query(user_input), None
elif "create_table" in task:
return create_table(user_input, db_user, db_pass, db_host, db_name)
elif "generate_demo_data_db" in task:
return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows)
elif "generate_demo_data_csv" in task:
response, file_path = generate_demo_data_csv(user_input, num_rows)
return response, file_path
elif "analyze_data" in task:
return analyze_data(user_input), None
return f"task:{task} \n I could not understand your request.", None
iface = gr.Interface(
fn=sql_chatbot,
inputs=[
gr.Textbox(label="User Input"),
gr.Number(label="Number of Rows", interactive=True, value=10, precision=0)
],
outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")]
)
iface.launch()
# print("hi")
# print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))