|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
import pandas as pd |
|
import io |
|
import mysql.connector |
|
import os |
|
|
|
HF_API_KEY = os.getenv("HF_API_KEY") |
|
|
|
|
|
client = InferenceClient(model="mistralai/Mistral-7B-v0.1", token=HF_API_KEY) |
|
|
|
def classify_task(user_input): |
|
"""Classifies user task using Mistral.""" |
|
prompt = f""" |
|
You are an AI assistant that classifies user requests related to SQL and databases. |
|
Your task is to categorize the input into one of the following options: |
|
|
|
- **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc. |
|
- **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server. |
|
- **generate_demo_data_db** → If the user wants to insert demo data into a database. |
|
- **generate_demo_data_csv** → If the user wants to generate demo data in CSV format. |
|
- **analyze_data** → If the user asks for insights, trends, or statistical analysis of data. |
|
|
|
**Examples:** |
|
1. "Give me SQL syntax to create a student table" → **generate_sql** |
|
2. "Create a student table in my database" → **create_table** |
|
3. "Insert some demo data in my database" → **generate_demo_data_db** |
|
4. "Generate sample student data in CSV format" → **generate_demo_data_csv** |
|
5. "Analyze student marks and trends" → **analyze_data** |
|
|
|
**User Input:** {user_input} |
|
|
|
**Output Format:** Return only the category name without any explanations. |
|
""" |
|
|
|
response = client.text_generation(prompt, max_new_tokens=20).strip() |
|
return response |
|
def generate_sql_query(user_input): |
|
"""Generates SQL queries using Mistral.""" |
|
prompt = f"Generate SQL syntax for: {user_input}" |
|
return client.text_generation(prompt, max_new_tokens=200).strip() |
|
def generate_sql_query_for_create(user_input): |
|
"""Generates SQL queries using Mistral.""" |
|
prompt = f""" |
|
Generate **only** the SQL syntax for: {user_input} |
|
|
|
**Rules:** |
|
- No explanations, no bullet points, no extra text. |
|
- Return **only valid SQL**. |
|
|
|
**Example Input:** |
|
"Create a student table with name, age, and email." |
|
|
|
**Example Output:** |
|
```sql |
|
CREATE TABLE student ( |
|
student_id INT PRIMARY KEY AUTO_INCREMENT, |
|
name VARCHAR(100) NOT NULL, |
|
age INT NOT NULL, |
|
email VARCHAR(255) UNIQUE NOT NULL |
|
); |
|
``` |
|
""" |
|
|
|
response = client.text_generation(prompt, max_new_tokens=200).strip() |
|
|
|
|
|
if "```sql" in response: |
|
response = response.split("```sql")[1].split("```")[0].strip() |
|
|
|
return response |
|
|
|
import pymysql |
|
|
|
def create_table(user_input, db_user, db_pass, db_host, db_name): |
|
try: |
|
|
|
if not all([db_user, db_pass, db_host, db_name, user_input]): |
|
return "Please provide all required inputs (database credentials and table structure).", None |
|
|
|
|
|
|
|
schema_response = generate_sql_query_for_create(user_input) |
|
print(schema_response) |
|
|
|
parsed_schema = sqlparse.parse(schema_response) |
|
if not parsed_schema: |
|
return "Error: Could not generate a valid table schema.", None |
|
|
|
|
|
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass) |
|
cursor = connection.cursor() |
|
|
|
|
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}") |
|
connection.commit() |
|
connection.close() |
|
|
|
|
|
connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name) |
|
cursor = connection.cursor() |
|
|
|
|
|
cursor.execute(schema_response) |
|
connection.commit() |
|
|
|
return "Table created successfully.", None |
|
|
|
except pymysql.MySQLError as err: |
|
return f"Error: {err}", None |
|
|
|
finally: |
|
if 'connection' in locals() and connection.open: |
|
cursor.close() |
|
connection.close() |
|
|
|
import mysql.connector |
|
import re |
|
import sqlparse |
|
|
|
def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10): |
|
"""Generates and inserts structured demo data into a database using LLM.""" |
|
|
|
if not all([db_user, db_pass, db_name]): |
|
return "Please provide database credentials.", None |
|
|
|
|
|
schema_prompt = f""" |
|
Extract column names and types from the following request: |
|
|
|
"{user_input}" |
|
|
|
**Output Format:** |
|
- The first column should be an "ID" column (INTEGER, PRIMARY KEY). |
|
- Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers). |
|
- Use proper SQL syntax. No explanations. |
|
|
|
Example Output: |
|
``` |
|
CREATE TABLE demo ( |
|
ID INT PRIMARY KEY, |
|
Name VARCHAR(100), |
|
Age INT |
|
); |
|
``` |
|
""" |
|
schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip() |
|
|
|
|
|
parsed_schema = sqlparse.parse(schema_response) |
|
if not parsed_schema: |
|
return "Error: Could not generate a valid table schema.", None |
|
|
|
|
|
table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip() |
|
|
|
|
|
connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name) |
|
cursor = connection.cursor() |
|
cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})") |
|
|
|
|
|
data_prompt = f""" |
|
Generate {num_rows} rows of structured demo data for this table schema: |
|
|
|
``` |
|
{schema_response} |
|
``` |
|
|
|
**Output Format:** |
|
- Return valid SQL INSERT statements. |
|
- Ensure all values match their respective column types. |
|
- Use double quotes ("") for text values. |
|
- No explanations, just raw SQL. |
|
|
|
Example Output: |
|
``` |
|
INSERT INTO demo VALUES (1, "John Doe", 25); |
|
INSERT INTO demo VALUES (2, "Jane Smith", 30); |
|
``` |
|
""" |
|
data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip() |
|
|
|
|
|
insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL) |
|
if not insert_statements: |
|
return "Error: Could not generate valid data.", None |
|
|
|
|
|
for statement in insert_statements: |
|
cursor.execute(statement) |
|
|
|
connection.commit() |
|
connection.close() |
|
|
|
return "Demo data inserted into the database successfully.", None |
|
def generate_demo_data_csv(user_input, num_rows=10): |
|
"""Generates realistic demo data using the LLM in valid CSV format.""" |
|
|
|
prompt = f""" |
|
Generate a structured dataset with {num_rows} rows based on the following request: |
|
|
|
"{user_input}" |
|
|
|
**Output Format:** |
|
- Ensure the response is in **valid CSV format** (comma-separated). |
|
- The **first row** must be column headers. |
|
- Use **double quotes for text values** to avoid formatting issues. |
|
- Do **not** include explanations—just the raw CSV data. |
|
|
|
Example Output: |
|
|
|
"ID","Name","Age","Email" |
|
"1","John Doe","25","[email protected]" |
|
"2","Jane Smith","30","[email protected]" |
|
|
|
""" |
|
|
|
|
|
response = client.text_generation(prompt, max_new_tokens=10000).strip() |
|
|
|
|
|
csv_start = response.find('"ID"') |
|
if csv_start != -1: |
|
response = response[csv_start:] |
|
|
|
|
|
try: |
|
df = pd.read_csv(io.StringIO(response)) |
|
except Exception as e: |
|
return f"Error: Invalid CSV format. {str(e)}", None |
|
|
|
|
|
file_path = "generated_data.csv" |
|
df.to_csv(file_path, index=False) |
|
|
|
return "Demo data generated as CSV.", file_path |
|
|
|
def analyze_data(user_input): |
|
"""Analyzes data using Mistral.""" |
|
prompt = f"Analyze this data: {user_input}" |
|
return client.text_generation(prompt, max_new_tokens=200).strip() |
|
|
|
def sql_chatbot(user_input, num_rows=10): |
|
task = classify_task(user_input) |
|
|
|
if "generate_sql" in task: |
|
return generate_sql_query(user_input), None |
|
|
|
elif "create_table" in task: |
|
return create_table(user_input, db_user, db_pass, db_host, db_name) |
|
|
|
elif "generate_demo_data_db" in task: |
|
return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows) |
|
|
|
elif "generate_demo_data_csv" in task: |
|
response, file_path = generate_demo_data_csv(user_input, num_rows) |
|
return response, file_path |
|
elif "analyze_data" in task: |
|
return analyze_data(user_input), None |
|
|
|
return f"task:{task} \n I could not understand your request.", None |
|
|
|
iface = gr.Interface( |
|
fn=sql_chatbot, |
|
inputs=[ |
|
gr.Textbox(label="User Input"), |
|
gr.Number(label="Number of Rows", interactive=True, value=10, precision=0) |
|
], |
|
outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")] |
|
) |
|
|
|
iface.launch() |
|
|
|
|