|
from transformers import pipeline |
|
import re |
|
from typing import Dict |
|
|
|
class NLPToSQL2: |
|
def __init__(self): |
|
self.model = pipeline( |
|
"text2text-generation", |
|
model="mrm8488/t5-base-finetuned-wikiSQL", |
|
tokenizer="t5-base" |
|
) |
|
|
|
def query_to_sql(self, user_query): |
|
prompt = (f"Generate a valid SQL query in the correct format based on the following schema:\n" |
|
f"Table1: Employees\n" |
|
f"Columns: ID, Name, Department, Salary\n" |
|
f"Table2: Departments\n" |
|
f"Columns: Name, Manager\n" |
|
f"Natural Language: {user_query}" |
|
f"SQL query:" |
|
) |
|
|
|
result = self.model(prompt, max_length=200) |
|
sql = result[0]['generated_text'] |
|
|
|
return sql |
|
|
|
class NLPToSQL: |
|
def __init__(self): |
|
self.query_patterns: Dict[str, str] = { |
|
r"show\s+(?:me\s+)?all\s+employees?\s+in\s+(?:the\s+)?(\w+)\s+department": |
|
"SELECT * FROM Employees WHERE LOWER(Department) = LOWER('{}')", |
|
|
|
r"who\s+is\s+(?:the\s+)?manager\s+of\s+(?:the\s+)?(\w+)\s+department": |
|
"SELECT Manager FROM Departments WHERE LOWER(Name) = LOWER('{}')", |
|
|
|
r"list\s+(?:all\s+)?employees?\s+hired\s+after\s+(\d{4}-\d{2}-\d{2})": |
|
"SELECT * FROM Employees WHERE Hire_Date > '{}'", |
|
|
|
r"what\s+is\s+(?:the\s+)?total\s+salary\s+(?:expense\s+)?for\s+(?:the\s+)?(\w+)\s+department": |
|
"SELECT SUM(Salary) as Total_Salary FROM Employees WHERE LOWER(Department) = LOWER('{}')", |
|
|
|
r"show\s+(?:me\s+)?(?:the\s+)?salary\s+of\s+(\w+)": |
|
"SELECT Salary FROM Employees WHERE LOWER(Name) = LOWER('{}')", |
|
|
|
r"list\s+(?:all\s+)?employees?\s+with\s+salary\s+(?:greater|more)\s+than\s+(\d+)": |
|
"SELECT * FROM Employees WHERE Salary > {}", |
|
|
|
r"(?:show|list)\s+(?:me\s+)?all\s+departments": |
|
"SELECT * FROM Departments", |
|
|
|
r"(?:show|list)\s+(?:me\s+)?all\s+employees": |
|
"SELECT * FROM Employees" |
|
} |
|
|
|
def query_to_sql(self, user_query: str) -> str: |
|
normalized_query = " ".join(user_query.lower().split()) |
|
|
|
for pattern, sql_template in self.query_patterns.items(): |
|
match = re.search(pattern, normalized_query, re.IGNORECASE) |
|
if match: |
|
if match.groups(): |
|
return sql_template.format(*match.groups()) |
|
return sql_template |
|
|
|
return self._generate_fallback_query(normalized_query) |
|
|
|
def _generate_fallback_query(self, query: str) -> str: |
|
if any(word in query for word in ['department', 'manager']): |
|
return "SELECT * FROM Departments" |
|
return "SELECT * FROM Employees" |
|
|
|
def sanitize_sql(self, sql: str) -> str: |
|
sql = re.sub(r'[;"]', '', sql) |
|
sql = sql.replace("'", "''") |
|
if not sql.strip().endswith(';'): |
|
sql = f"{sql};" |
|
|
|
return sql |