Garvitj commited on
Commit
18a9a60
·
verified ·
1 Parent(s): 2fa8568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -276
app.py CHANGED
@@ -1,277 +1,280 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import pandas as pd
4
- import io
5
- import mysql.connector
6
-
7
- # Hugging Face API Key (Replace with your actual key)
8
- client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY)
9
-
10
- def classify_task(user_input):
11
- """Classifies user task using Mistral."""
12
- prompt = f"""
13
- You are an AI assistant that classifies user requests related to SQL and databases.
14
- Your task is to categorize the input into one of the following options:
15
-
16
- - **generate_sql** If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc.
17
- - **create_table** If the user explicitly wants to **create** a database/table on a local MySQL server.
18
- - **generate_demo_data_db** → If the user wants to insert demo data into a database.
19
- - **generate_demo_data_csv** → If the user wants to generate demo data in CSV format.
20
- - **analyze_data** → If the user asks for insights, trends, or statistical analysis of data.
21
-
22
- **Examples:**
23
- 1. "Give me SQL syntax to create a student table" **generate_sql**
24
- 2. "Create a student table in my database" → **create_table**
25
- 3. "Insert some demo data in my database" → **generate_demo_data_db**
26
- 4. "Generate sample student data in CSV format" → **generate_demo_data_csv**
27
- 5. "Analyze student marks and trends" → **analyze_data**
28
-
29
- **User Input:** {user_input}
30
-
31
- **Output Format:** Return only the category name without any explanations.
32
- """
33
-
34
- response = client.text_generation(prompt, max_new_tokens=20).strip()
35
- return response
36
- def generate_sql_query(user_input):
37
- """Generates SQL queries using Mistral."""
38
- prompt = f"Generate SQL syntax for: {user_input}"
39
- return client.text_generation(prompt, max_new_tokens=200).strip()
40
- def generate_sql_query_for_create(user_input):
41
- """Generates SQL queries using Mistral."""
42
- prompt = f"""
43
- Generate **only** the SQL syntax for: {user_input}
44
-
45
- **Rules:**
46
- - No explanations, no bullet points, no extra text.
47
- - Return **only valid SQL**.
48
-
49
- **Example Input:**
50
- "Create a student table with name, age, and email."
51
-
52
- **Example Output:**
53
- ```sql
54
- CREATE TABLE student (
55
- student_id INT PRIMARY KEY AUTO_INCREMENT,
56
- name VARCHAR(100) NOT NULL,
57
- age INT NOT NULL,
58
- email VARCHAR(255) UNIQUE NOT NULL
59
- );
60
- ```
61
- """
62
-
63
- response = client.text_generation(prompt, max_new_tokens=200).strip()
64
-
65
- # Remove unnecessary text (if any)
66
- if "```sql" in response:
67
- response = response.split("```sql")[1].split("```")[0].strip()
68
-
69
- return response
70
-
71
- import pymysql
72
-
73
- def create_table(user_input, db_user, db_pass, db_host, db_name):
74
- try:
75
- # Validate inputs
76
- if not all([db_user, db_pass, db_host, db_name, user_input]):
77
- return "Please provide all required inputs (database credentials and table structure).", None
78
-
79
- # Generate SQL schema using Mistral
80
-
81
- schema_response = generate_sql_query_for_create(user_input)
82
- print(schema_response)
83
- # Validate schema using sqlparse
84
- parsed_schema = sqlparse.parse(schema_response)
85
- if not parsed_schema:
86
- return "Error: Could not generate a valid table schema.", None
87
-
88
- # Connect to MySQL Server
89
- connection = pymysql.connect(host=db_host, user=db_user, password=db_pass)
90
- cursor = connection.cursor()
91
-
92
- # Create Database if it doesn't exist
93
- cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
94
- connection.commit()
95
- connection.close()
96
-
97
- # Connect to the specified database
98
- connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
99
- cursor = connection.cursor()
100
-
101
- # Execute the generated CREATE TABLE statement
102
- cursor.execute(schema_response)
103
- connection.commit()
104
-
105
- return "Table created successfully.", None
106
-
107
- except pymysql.MySQLError as err:
108
- return f"Error: {err}", None
109
-
110
- finally:
111
- if 'connection' in locals() and connection.open:
112
- cursor.close()
113
- connection.close()
114
-
115
- import mysql.connector
116
- import re
117
- import sqlparse # Install via: pip install sqlparse
118
-
119
- def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10):
120
- """Generates and inserts structured demo data into a database using LLM."""
121
-
122
- if not all([db_user, db_pass, db_name]):
123
- return "Please provide database credentials.", None
124
-
125
- # Generate column definitions using LLM
126
- schema_prompt = f"""
127
- Extract column names and types from the following request:
128
-
129
- "{user_input}"
130
-
131
- **Output Format:**
132
- - The first column should be an "ID" column (INTEGER, PRIMARY KEY).
133
- - Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers).
134
- - Use proper SQL syntax. No explanations.
135
-
136
- Example Output:
137
- ```
138
- CREATE TABLE demo (
139
- ID INT PRIMARY KEY,
140
- Name VARCHAR(100),
141
- Age INT
142
- );
143
- ```
144
- """
145
- schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip()
146
-
147
- # Validate schema using sqlparse
148
- parsed_schema = sqlparse.parse(schema_response)
149
- if not parsed_schema:
150
- return "Error: Could not generate a valid table schema.", None
151
-
152
- # Extract table schema
153
- table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip()
154
-
155
- # Connect to MySQL and create the table dynamically
156
- connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
157
- cursor = connection.cursor()
158
- cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})")
159
-
160
- # Generate demo data using LLM
161
- data_prompt = f"""
162
- Generate {num_rows} rows of structured demo data for this table schema:
163
-
164
- ```
165
- {schema_response}
166
- ```
167
-
168
- **Output Format:**
169
- - Return valid SQL INSERT statements.
170
- - Ensure all values match their respective column types.
171
- - Use double quotes ("") for text values.
172
- - No explanations, just raw SQL.
173
-
174
- Example Output:
175
- ```
176
- INSERT INTO demo VALUES (1, "John Doe", 25);
177
- INSERT INTO demo VALUES (2, "Jane Smith", 30);
178
- ```
179
- """
180
- data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip()
181
-
182
- # Extract SQL INSERT statements using a better regex
183
- insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL)
184
- if not insert_statements:
185
- return "Error: Could not generate valid data.", None
186
-
187
- # Insert data into the database
188
- for statement in insert_statements:
189
- cursor.execute(statement)
190
-
191
- connection.commit()
192
- connection.close()
193
-
194
- return "Demo data inserted into the database successfully.", None
195
- def generate_demo_data_csv(user_input, num_rows=10):
196
- """Generates realistic demo data using the LLM in valid CSV format."""
197
-
198
- prompt = f"""
199
- Generate a structured dataset with {num_rows} rows based on the following request:
200
-
201
- "{user_input}"
202
-
203
- **Output Format:**
204
- - Ensure the response is in **valid CSV format** (comma-separated).
205
- - The **first row** must be column headers.
206
- - Use **double quotes for text values** to avoid formatting issues.
207
- - Do **not** include explanations—just the raw CSV data.
208
-
209
- Example Output:
210
-
211
- "ID","Name","Age","Email"
212
- "1","John Doe","25","[email protected]"
213
- "2","Jane Smith","30","[email protected]"
214
-
215
- """
216
-
217
- # Get LLM response
218
- response = client.text_generation(prompt, max_new_tokens=10000).strip()
219
-
220
- # Ensure we extract only the CSV part (some models may add explanations)
221
- csv_start = response.find('"ID"') # Find where the CSV starts
222
- if csv_start != -1:
223
- response = response[csv_start:] # Remove anything before the CSV
224
-
225
- # Convert to DataFrame
226
- try:
227
- df = pd.read_csv(io.StringIO(response)) # Read as CSV
228
- except Exception as e:
229
- return f"Error: Invalid CSV format. {str(e)}", None
230
-
231
- # Save to a CSV file
232
- file_path = "generated_data.csv"
233
- df.to_csv(file_path, index=False)
234
-
235
- return "Demo data generated as CSV.", file_path # Return file path
236
-
237
- def analyze_data(user_input):
238
- """Analyzes data using Mistral."""
239
- prompt = f"Analyze this data: {user_input}"
240
- return client.text_generation(prompt, max_new_tokens=200).strip()
241
-
242
- def sql_chatbot(user_input, db_user=None, db_pass=None, db_host="localhost", db_name=None, num_rows=10):
243
- task = classify_task(user_input)
244
-
245
- if "generate_sql" in task:
246
- return generate_sql_query(user_input), None
247
-
248
- elif "create_table" in task:
249
- return create_table(user_input, db_user, db_pass, db_host, db_name)
250
-
251
- elif "generate_demo_data_db" in task:
252
- return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows)
253
-
254
- elif "generate_demo_data_csv" in task:
255
- response, file_path = generate_demo_data_csv(user_input, num_rows)
256
- return response, file_path
257
- elif "analyze_data" in task:
258
- return analyze_data(user_input), None
259
-
260
- return f"task:{task} \n I could not understand your request.", None
261
-
262
- iface = gr.Interface(
263
- fn=sql_chatbot,
264
- inputs=[
265
- gr.Textbox(label="User Input"),
266
- gr.Textbox(label="MySQL Username", interactive=True),
267
- gr.Textbox(label="MySQL Password", interactive=True, type="password"),
268
- gr.Textbox(label="MySQL Host", interactive=True, value="localhost"),
269
- gr.Textbox(label="Database Name", interactive=True),
270
- gr.Number(label="Number of Rows", interactive=True, value=10, precision=0)
271
- ],
272
- outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")]
273
- )
274
-
275
- iface.launch()
276
- # print("hi")
 
 
 
277
  # print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import pandas as pd
4
+ import io
5
+ import mysql.connector
6
+ import os
7
+
8
+ HF_API_KEY = os.getenv("HF_API_KEY")
9
+
10
+ # Hugging Face API Key (Replace with your actual key)
11
+ client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3", token=HF_API_KEY)
12
+
13
+ def classify_task(user_input):
14
+ """Classifies user task using Mistral."""
15
+ prompt = f"""
16
+ You are an AI assistant that classifies user requests related to SQL and databases.
17
+ Your task is to categorize the input into one of the following options:
18
+
19
+ - **generate_sql** → If the user asks to generate **SQL syntax**, like SELECT, INSERT, CREATE TABLE, etc.
20
+ - **create_table** → If the user explicitly wants to **create** a database/table on a local MySQL server.
21
+ - **generate_demo_data_db** → If the user wants to insert demo data into a database.
22
+ - **generate_demo_data_csv** → If the user wants to generate demo data in CSV format.
23
+ - **analyze_data** If the user asks for insights, trends, or statistical analysis of data.
24
+
25
+ **Examples:**
26
+ 1. "Give me SQL syntax to create a student table" → **generate_sql**
27
+ 2. "Create a student table in my database" → **create_table**
28
+ 3. "Insert some demo data in my database" → **generate_demo_data_db**
29
+ 4. "Generate sample student data in CSV format" → **generate_demo_data_csv**
30
+ 5. "Analyze student marks and trends" → **analyze_data**
31
+
32
+ **User Input:** {user_input}
33
+
34
+ **Output Format:** Return only the category name without any explanations.
35
+ """
36
+
37
+ response = client.text_generation(prompt, max_new_tokens=20).strip()
38
+ return response
39
+ def generate_sql_query(user_input):
40
+ """Generates SQL queries using Mistral."""
41
+ prompt = f"Generate SQL syntax for: {user_input}"
42
+ return client.text_generation(prompt, max_new_tokens=200).strip()
43
+ def generate_sql_query_for_create(user_input):
44
+ """Generates SQL queries using Mistral."""
45
+ prompt = f"""
46
+ Generate **only** the SQL syntax for: {user_input}
47
+
48
+ **Rules:**
49
+ - No explanations, no bullet points, no extra text.
50
+ - Return **only valid SQL**.
51
+
52
+ **Example Input:**
53
+ "Create a student table with name, age, and email."
54
+
55
+ **Example Output:**
56
+ ```sql
57
+ CREATE TABLE student (
58
+ student_id INT PRIMARY KEY AUTO_INCREMENT,
59
+ name VARCHAR(100) NOT NULL,
60
+ age INT NOT NULL,
61
+ email VARCHAR(255) UNIQUE NOT NULL
62
+ );
63
+ ```
64
+ """
65
+
66
+ response = client.text_generation(prompt, max_new_tokens=200).strip()
67
+
68
+ # Remove unnecessary text (if any)
69
+ if "```sql" in response:
70
+ response = response.split("```sql")[1].split("```")[0].strip()
71
+
72
+ return response
73
+
74
+ import pymysql
75
+
76
+ def create_table(user_input, db_user, db_pass, db_host, db_name):
77
+ try:
78
+ # Validate inputs
79
+ if not all([db_user, db_pass, db_host, db_name, user_input]):
80
+ return "Please provide all required inputs (database credentials and table structure).", None
81
+
82
+ # Generate SQL schema using Mistral
83
+
84
+ schema_response = generate_sql_query_for_create(user_input)
85
+ print(schema_response)
86
+ # Validate schema using sqlparse
87
+ parsed_schema = sqlparse.parse(schema_response)
88
+ if not parsed_schema:
89
+ return "Error: Could not generate a valid table schema.", None
90
+
91
+ # Connect to MySQL Server
92
+ connection = pymysql.connect(host=db_host, user=db_user, password=db_pass)
93
+ cursor = connection.cursor()
94
+
95
+ # Create Database if it doesn't exist
96
+ cursor.execute(f"CREATE DATABASE IF NOT EXISTS {db_name}")
97
+ connection.commit()
98
+ connection.close()
99
+
100
+ # Connect to the specified database
101
+ connection = pymysql.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
102
+ cursor = connection.cursor()
103
+
104
+ # Execute the generated CREATE TABLE statement
105
+ cursor.execute(schema_response)
106
+ connection.commit()
107
+
108
+ return "Table created successfully.", None
109
+
110
+ except pymysql.MySQLError as err:
111
+ return f"Error: {err}", None
112
+
113
+ finally:
114
+ if 'connection' in locals() and connection.open:
115
+ cursor.close()
116
+ connection.close()
117
+
118
+ import mysql.connector
119
+ import re
120
+ import sqlparse # Install via: pip install sqlparse
121
+
122
+ def generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows=10):
123
+ """Generates and inserts structured demo data into a database using LLM."""
124
+
125
+ if not all([db_user, db_pass, db_name]):
126
+ return "Please provide database credentials.", None
127
+
128
+ # Generate column definitions using LLM
129
+ schema_prompt = f"""
130
+ Extract column names and types from the following request:
131
+
132
+ "{user_input}"
133
+
134
+ **Output Format:**
135
+ - The first column should be an "ID" column (INTEGER, PRIMARY KEY).
136
+ - Provide appropriate SQL data types (VARCHAR(100) for text, INT for numbers).
137
+ - Use proper SQL syntax. No explanations.
138
+
139
+ Example Output:
140
+ ```
141
+ CREATE TABLE demo (
142
+ ID INT PRIMARY KEY,
143
+ Name VARCHAR(100),
144
+ Age INT
145
+ );
146
+ ```
147
+ """
148
+ schema_response = client.text_generation(schema_prompt, max_new_tokens=200).strip()
149
+
150
+ # Validate schema using sqlparse
151
+ parsed_schema = sqlparse.parse(schema_response)
152
+ if not parsed_schema:
153
+ return "Error: Could not generate a valid table schema.", None
154
+
155
+ # Extract table schema
156
+ table_schema = schema_response.replace("CREATE TABLE demo (", "").replace(");", "").strip()
157
+
158
+ # Connect to MySQL and create the table dynamically
159
+ connection = mysql.connector.connect(host=db_host, user=db_user, password=db_pass, database=db_name)
160
+ cursor = connection.cursor()
161
+ cursor.execute(f"CREATE TABLE IF NOT EXISTS demo ({table_schema})")
162
+
163
+ # Generate demo data using LLM
164
+ data_prompt = f"""
165
+ Generate {num_rows} rows of structured demo data for this table schema:
166
+
167
+ ```
168
+ {schema_response}
169
+ ```
170
+
171
+ **Output Format:**
172
+ - Return valid SQL INSERT statements.
173
+ - Ensure all values match their respective column types.
174
+ - Use double quotes ("") for text values.
175
+ - No explanations, just raw SQL.
176
+
177
+ Example Output:
178
+ ```
179
+ INSERT INTO demo VALUES (1, "John Doe", 25);
180
+ INSERT INTO demo VALUES (2, "Jane Smith", 30);
181
+ ```
182
+ """
183
+ data_response = client.text_generation(data_prompt, max_new_tokens=1000).strip()
184
+
185
+ # Extract SQL INSERT statements using a better regex
186
+ insert_statements = re.findall(r'INSERT INTO demo VALUES \([^)]+\);', data_response, re.DOTALL)
187
+ if not insert_statements:
188
+ return "Error: Could not generate valid data.", None
189
+
190
+ # Insert data into the database
191
+ for statement in insert_statements:
192
+ cursor.execute(statement)
193
+
194
+ connection.commit()
195
+ connection.close()
196
+
197
+ return "Demo data inserted into the database successfully.", None
198
+ def generate_demo_data_csv(user_input, num_rows=10):
199
+ """Generates realistic demo data using the LLM in valid CSV format."""
200
+
201
+ prompt = f"""
202
+ Generate a structured dataset with {num_rows} rows based on the following request:
203
+
204
+ "{user_input}"
205
+
206
+ **Output Format:**
207
+ - Ensure the response is in **valid CSV format** (comma-separated).
208
+ - The **first row** must be column headers.
209
+ - Use **double quotes for text values** to avoid formatting issues.
210
+ - Do **not** include explanations—just the raw CSV data.
211
+
212
+ Example Output:
213
+
214
+ "ID","Name","Age","Email"
215
+ "1","John Doe","25","[email protected]"
216
+ "2","Jane Smith","30","[email protected]"
217
+
218
+ """
219
+
220
+ # Get LLM response
221
+ response = client.text_generation(prompt, max_new_tokens=10000).strip()
222
+
223
+ # Ensure we extract only the CSV part (some models may add explanations)
224
+ csv_start = response.find('"ID"') # Find where the CSV starts
225
+ if csv_start != -1:
226
+ response = response[csv_start:] # Remove anything before the CSV
227
+
228
+ # Convert to DataFrame
229
+ try:
230
+ df = pd.read_csv(io.StringIO(response)) # Read as CSV
231
+ except Exception as e:
232
+ return f"Error: Invalid CSV format. {str(e)}", None
233
+
234
+ # Save to a CSV file
235
+ file_path = "generated_data.csv"
236
+ df.to_csv(file_path, index=False)
237
+
238
+ return "Demo data generated as CSV.", file_path # Return file path
239
+
240
+ def analyze_data(user_input):
241
+ """Analyzes data using Mistral."""
242
+ prompt = f"Analyze this data: {user_input}"
243
+ return client.text_generation(prompt, max_new_tokens=200).strip()
244
+
245
+ def sql_chatbot(user_input, db_user=None, db_pass=None, db_host="localhost", db_name=None, num_rows=10):
246
+ task = classify_task(user_input)
247
+
248
+ if "generate_sql" in task:
249
+ return generate_sql_query(user_input), None
250
+
251
+ elif "create_table" in task:
252
+ return create_table(user_input, db_user, db_pass, db_host, db_name)
253
+
254
+ elif "generate_demo_data_db" in task:
255
+ return generate_demo_data_db(user_input, db_user, db_pass, db_host, db_name, num_rows)
256
+
257
+ elif "generate_demo_data_csv" in task:
258
+ response, file_path = generate_demo_data_csv(user_input, num_rows)
259
+ return response, file_path
260
+ elif "analyze_data" in task:
261
+ return analyze_data(user_input), None
262
+
263
+ return f"task:{task} \n I could not understand your request.", None
264
+
265
+ iface = gr.Interface(
266
+ fn=sql_chatbot,
267
+ inputs=[
268
+ gr.Textbox(label="User Input"),
269
+ gr.Textbox(label="MySQL Username", interactive=True),
270
+ gr.Textbox(label="MySQL Password", interactive=True, type="password"),
271
+ gr.Textbox(label="MySQL Host", interactive=True, value="localhost"),
272
+ gr.Textbox(label="Database Name", interactive=True),
273
+ gr.Number(label="Number of Rows", interactive=True, value=10, precision=0)
274
+ ],
275
+ outputs=[gr.Textbox(label="Response"), gr.File(label="File Output")]
276
+ )
277
+
278
+ iface.launch()
279
+ # print("hi")
280
  # print(create_table("create a SQL student table","root", "123456", "localhost", "demo"))