import gradio as gr from huggingface_hub import InferenceClient import pandas as pd import io import mysql.connector # Hugging Face API Key (Replace with your actual key) client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY) def classify_task(user_input): """Classifies user task using Mistral.""" prompt = f""" You are an AI assistant that classifies user requests related to SQL and databases. Your task is to categorize the input into one of the following options: - **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc. - **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server. - **generate_demo_data_db** → If the user wants to insert demo data into a database. - **generate_demo_data_csv** → If the user wants to generate demo data in CSV format. - **analyze_data** → If the user asks for insights, trends, or statistical analysis of data. **Examples:** 1. "Give me SQL syntax to create a student table" → **generate_sql** 2. "Create a student table in my database" → **create_table** 3. "Insert some demo data in my database" → **generate_demo_data_db** 4. "Generate sample student data in CSV format" → **generate_demo_data_csv** 5. "Analyze student marks and trends" → **analyze_data** **User Input:** {user_input} **Output Format:** Return only the category name without any explanations. """ response = client.text_generation(prompt, max_new_tokens=20).strip() return response def generate_sql_query(user_input): """Generates SQL queries using Mistral.""" prompt = f"Generate SQL syntax for: {user_input}" return client.text_generation(prompt, max_new_tokens=200).strip() def generate_sql_query_for_create(user_input): """Generates SQL queries using Mistral.""" prompt = f""" Generate **only** the SQL syntax for: {user_input} **Rules:** - No explanations, no bullet points, no extra text. - Return **only valid SQL**. **Example Input:** "Create a student table with name, age, and email." **Example Output:** ```sql CREATE TABLE student ( student_id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(100) NOT NULL, age INT NOT NULL, email VARCHAR(255) UNIQUE NOT NULL ); ``` """ response = client.text_generation(prompt, max_new_tokens=200).strip() # Remove unnecessary text (if any) if "```sql" in response: response = response.split("```sql")[1].split("```")[0].strip() return response import pymysql def create_table(user_input, db_user, db_pass, db_host, db_name): try: # Validate inputs if not all([db_user, db_pass, db_host, db_name, user_input]): return "Please provide all required inputs (database credentials and table structure).", None # Generate SQL schema using Mistral schema_response = generate_sql_query_for_create(user_input) print(schema_response) # Validate schema using sqlparse parsed_schema = sqlparse.parse(schema_response) if not parsed_schema: return "Error: Could not generate a valid table schema.", None # Connect to MySQL Server connection = pymysql.connect(host=db_host, user=db_user, password=db_pass) cursor = connection.cursor() # Create Database if it doesn't exist cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}") connection.commit() connection.close() # Connect to the specified database connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name) cursor = connection.cursor() # Execute the generated CREATE TABLE statement cursor.execute(schema_response) connection.commit() return "Table created successfully.", None except pymysql.MySQLError as err: return f"Error: {err}", None finally: if 'connection' in locals() and connection.open: cursor.close() connection.close() import mysql.connector import re import sqlparse # Install via: pip install sqlparse def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10): """Generates and inserts structured demo data into a database using LLM.""" if not all([db_user, db_pass, db_name]): return "Please provide database credentials.", None # Generate column definitions using LLM schema_prompt = f""" Extract column names and types from the following request: "{user_input}" **Output Format:** - The first column should be an "ID" column (INTEGER, PRIMARY KEY). - Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers). - Use proper SQL syntax. No explanations. Example Output: ``` CREATE TABLE demo ( ID INT PRIMARY KEY, Name VARCHAR(100), Age INT ); ``` """ schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip() # Validate schema using sqlparse parsed_schema = sqlparse.parse(schema_response) if not parsed_schema: return "Error: Could not generate a valid table schema.", None # Extract table schema table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip() # Connect to MySQL and create the table dynamically connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name) cursor = connection.cursor() cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})") # Generate demo data using LLM data_prompt = f""" Generate {num_rows} rows of structured demo data for this table schema: ``` {schema_response} ``` **Output Format:** - Return valid SQL INSERT statements. - Ensure all values match their respective column types. - Use double quotes ("") for text values. - No explanations, just raw SQL. Example Output: ``` INSERT INTO demo VALUES (1, "John Doe", 25); INSERT INTO demo VALUES (2, "Jane Smith", 30); ``` """ data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip() # Extract SQL INSERT statements using a better regex insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL) if not insert_statements: return "Error: Could not generate valid data.", None # Insert data into the database for statement in insert_statements: cursor.execute(statement) connection.commit() connection.close() return "Demo data inserted into the database successfully.", None def generate_demo_data_csv(user_input, num_rows=10): """Generates realistic demo data using the LLM in valid CSV format.""" prompt = f""" Generate a structured dataset with {num_rows} rows based on the following request: "{user_input}" **Output Format:** - Ensure the response is in **valid CSV format** (comma-separated). - The **first row** must be column headers. - Use **double quotes for text values** to avoid formatting issues. - Do **not** include explanations—just the raw CSV data. Example Output: "ID","Name","Age","Email" "1","John Doe","25","john.doe@example.com" "2","Jane Smith","30","jane.smith@example.com" """ # Get LLM response response = client.text_generation(prompt, max_new_tokens=10000).strip() # Ensure we extract only the CSV part (some models may add explanations) csv_start = response.find('"ID"') # Find where the CSV starts if csv_start != -1: response = response[csv_start:] # Remove anything before the CSV # Convert to DataFrame try: df = pd.read_csv(io.StringIO(response)) # Read as CSV except Exception as e: return f"Error: Invalid CSV format. {str(e)}", None # Save to a CSV file file_path = "generated_data.csv" df.to_csv(file_path, index=False) return "Demo data generated as CSV.", file_path # Return file path def analyze_data(user_input): """Analyzes data using Mistral.""" prompt = f"Analyze this data: {user_input}" return client.text_generation(prompt, max_new_tokens=200).strip() def sql_chatbot(user_input, db_user=None, db_pass=None, db_host="localhost", db_name=None, num_rows=10): task = classify_task(user_input) if "generate_sql" in task: return generate_sql_query(user_input), None elif "create_table" in task: return create_table(user_input, db_user, db_pass, db_host, db_name) elif "generate_demo_data_db" in task: return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows) elif "generate_demo_data_csv" in task: response, file_path = generate_demo_data_csv(user_input, num_rows) return response, file_path elif "analyze_data" in task: return analyze_data(user_input), None return f"task:{task} \n I could not understand your request.", None iface = gr.Interface( fn=sql_chatbot, inputs=[ gr.Textbox(label="User Input"), gr.Textbox(label="MySQL Username", interactive=True), gr.Textbox(label="MySQL Password", interactive=True, type="password"), gr.Textbox(label="MySQL Host", interactive=True, value="localhost"), gr.Textbox(label="Database Name", interactive=True), gr.Number(label="Number of Rows", interactive=True, value=10, precision=0) ], outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")] ) iface.launch() # print("hi") # print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))