|
import gradio as gr
|
|
from huggingface_hub import InferenceClient
|
|
import pandas as pd
|
|
import io
|
|
import mysql.connector
|
|
|
|
|
|
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY)
|
|
|
|
def classify_task(user_input):
|
|
"""Classifies user task using Mistral."""
|
|
prompt = f"""
|
|
You are an AI assistant that classifies user requests related to SQL and databases.
|
|
Your task is to categorize the input into one of the following options:
|
|
|
|
- **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc.
|
|
- **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server.
|
|
- **generate_demo_data_db** → If the user wants to insert demo data into a database.
|
|
- **generate_demo_data_csv** → If the user wants to generate demo data in CSV format.
|
|
- **analyze_data** → If the user asks for insights, trends, or statistical analysis of data.
|
|
|
|
**Examples:**
|
|
1. "Give me SQL syntax to create a student table" → **generate_sql**
|
|
2. "Create a student table in my database" → **create_table**
|
|
3. "Insert some demo data in my database" → **generate_demo_data_db**
|
|
4. "Generate sample student data in CSV format" → **generate_demo_data_csv**
|
|
5. "Analyze student marks and trends" → **analyze_data**
|
|
|
|
**User Input:** {user_input}
|
|
|
|
**Output Format:** Return only the category name without any explanations.
|
|
"""
|
|
|
|
response = client.text_generation(prompt, max_new_tokens=20).strip()
|
|
return response
|
|
def generate_sql_query(user_input):
|
|
"""Generates SQL queries using Mistral."""
|
|
prompt = f"Generate SQL syntax for: {user_input}"
|
|
return client.text_generation(prompt, max_new_tokens=200).strip()
|
|
def generate_sql_query_for_create(user_input):
|
|
"""Generates SQL queries using Mistral."""
|
|
prompt = f"""
|
|
Generate **only** the SQL syntax for: {user_input}
|
|
|
|
**Rules:**
|
|
- No explanations, no bullet points, no extra text.
|
|
- Return **only valid SQL**.
|
|
|
|
**Example Input:**
|
|
"Create a student table with name, age, and email."
|
|
|
|
**Example Output:**
|
|
```sql
|
|
CREATE TABLE student (
|
|
student_id INT PRIMARY KEY AUTO_INCREMENT,
|
|
name VARCHAR(100) NOT NULL,
|
|
age INT NOT NULL,
|
|
email VARCHAR(255) UNIQUE NOT NULL
|
|
);
|
|
```
|
|
"""
|
|
|
|
response = client.text_generation(prompt, max_new_tokens=200).strip()
|
|
|
|
|
|
if "```sql" in response:
|
|
response = response.split("```sql")[1].split("```")[0].strip()
|
|
|
|
return response
|
|
|
|
import pymysql
|
|
|
|
def create_table(user_input, db_user, db_pass, db_host, db_name):
|
|
try:
|
|
|
|
if not all([db_user, db_pass, db_host, db_name, user_input]):
|
|
return "Please provide all required inputs (database credentials and table structure).", None
|
|
|
|
|
|
|
|
schema_response = generate_sql_query_for_create(user_input)
|
|
print(schema_response)
|
|
|
|
parsed_schema = sqlparse.parse(schema_response)
|
|
if not parsed_schema:
|
|
return "Error: Could not generate a valid table schema.", None
|
|
|
|
|
|
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass)
|
|
cursor = connection.cursor()
|
|
|
|
|
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
|
|
connection.commit()
|
|
connection.close()
|
|
|
|
|
|
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
|
|
cursor = connection.cursor()
|
|
|
|
|
|
cursor.execute(schema_response)
|
|
connection.commit()
|
|
|
|
return "Table created successfully.", None
|
|
|
|
except pymysql.MySQLError as err:
|
|
return f"Error: {err}", None
|
|
|
|
finally:
|
|
if 'connection' in locals() and connection.open:
|
|
cursor.close()
|
|
connection.close()
|
|
|
|
import mysql.connector
|
|
import re
|
|
import sqlparse
|
|
|
|
def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10):
|
|
"""Generates and inserts structured demo data into a database using LLM."""
|
|
|
|
if not all([db_user, db_pass, db_name]):
|
|
return "Please provide database credentials.", None
|
|
|
|
|
|
schema_prompt = f"""
|
|
Extract column names and types from the following request:
|
|
|
|
"{user_input}"
|
|
|
|
**Output Format:**
|
|
- The first column should be an "ID" column (INTEGER, PRIMARY KEY).
|
|
- Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers).
|
|
- Use proper SQL syntax. No explanations.
|
|
|
|
Example Output:
|
|
```
|
|
CREATE TABLE demo (
|
|
ID INT PRIMARY KEY,
|
|
Name VARCHAR(100),
|
|
Age INT
|
|
);
|
|
```
|
|
"""
|
|
schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip()
|
|
|
|
|
|
parsed_schema = sqlparse.parse(schema_response)
|
|
if not parsed_schema:
|
|
return "Error: Could not generate a valid table schema.", None
|
|
|
|
|
|
table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip()
|
|
|
|
|
|
connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
|
|
cursor = connection.cursor()
|
|
cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})")
|
|
|
|
|
|
data_prompt = f"""
|
|
Generate {num_rows} rows of structured demo data for this table schema:
|
|
|
|
```
|
|
{schema_response}
|
|
```
|
|
|
|
**Output Format:**
|
|
- Return valid SQL INSERT statements.
|
|
- Ensure all values match their respective column types.
|
|
- Use double quotes ("") for text values.
|
|
- No explanations, just raw SQL.
|
|
|
|
Example Output:
|
|
```
|
|
INSERT INTO demo VALUES (1, "John Doe", 25);
|
|
INSERT INTO demo VALUES (2, "Jane Smith", 30);
|
|
```
|
|
"""
|
|
data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip()
|
|
|
|
|
|
insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL)
|
|
if not insert_statements:
|
|
return "Error: Could not generate valid data.", None
|
|
|
|
|
|
for statement in insert_statements:
|
|
cursor.execute(statement)
|
|
|
|
connection.commit()
|
|
connection.close()
|
|
|
|
return "Demo data inserted into the database successfully.", None
|
|
def generate_demo_data_csv(user_input, num_rows=10):
|
|
"""Generates realistic demo data using the LLM in valid CSV format."""
|
|
|
|
prompt = f"""
|
|
Generate a structured dataset with {num_rows} rows based on the following request:
|
|
|
|
"{user_input}"
|
|
|
|
**Output Format:**
|
|
- Ensure the response is in **valid CSV format** (comma-separated).
|
|
- The **first row** must be column headers.
|
|
- Use **double quotes for text values** to avoid formatting issues.
|
|
- Do **not** include explanations—just the raw CSV data.
|
|
|
|
Example Output:
|
|
|
|
"ID","Name","Age","Email"
|
|
"1","John Doe","25","[email protected]"
|
|
"2","Jane Smith","30","[email protected]"
|
|
|
|
"""
|
|
|
|
|
|
response = client.text_generation(prompt, max_new_tokens=10000).strip()
|
|
|
|
|
|
csv_start = response.find('"ID"')
|
|
if csv_start != -1:
|
|
response = response[csv_start:]
|
|
|
|
|
|
try:
|
|
df = pd.read_csv(io.StringIO(response))
|
|
except Exception as e:
|
|
return f"Error: Invalid CSV format. {str(e)}", None
|
|
|
|
|
|
file_path = "generated_data.csv"
|
|
df.to_csv(file_path, index=False)
|
|
|
|
return "Demo data generated as CSV.", file_path
|
|
|
|
def analyze_data(user_input):
|
|
"""Analyzes data using Mistral."""
|
|
prompt = f"Analyze this data: {user_input}"
|
|
return client.text_generation(prompt, max_new_tokens=200).strip()
|
|
|
|
def sql_chatbot(user_input, db_user=None, db_pass=None, db_host="localhost", db_name=None, num_rows=10):
|
|
task = classify_task(user_input)
|
|
|
|
if "generate_sql" in task:
|
|
return generate_sql_query(user_input), None
|
|
|
|
elif "create_table" in task:
|
|
return create_table(user_input, db_user, db_pass, db_host, db_name)
|
|
|
|
elif "generate_demo_data_db" in task:
|
|
return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows)
|
|
|
|
elif "generate_demo_data_csv" in task:
|
|
response, file_path = generate_demo_data_csv(user_input, num_rows)
|
|
return response, file_path
|
|
elif "analyze_data" in task:
|
|
return analyze_data(user_input), None
|
|
|
|
return f"task:{task} \n I could not understand your request.", None
|
|
|
|
iface = gr.Interface(
|
|
fn=sql_chatbot,
|
|
inputs=[
|
|
gr.Textbox(label="User Input"),
|
|
gr.Textbox(label="MySQL Username", interactive=True),
|
|
gr.Textbox(label="MySQL Password", interactive=True, type="password"),
|
|
gr.Textbox(label="MySQL Host", interactive=True, value="localhost"),
|
|
gr.Textbox(label="Database Name", interactive=True),
|
|
gr.Number(label="Number of Rows", interactive=True, value=10, precision=0)
|
|
],
|
|
outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")]
|
|
)
|
|
|
|
iface.launch()
|
|
|
|
|