Devashish-Nagpal's picture
Updated the approach and shifted the flask app to gradio app for deployment on HuggingFace
ddf9ea0
raw
history blame
3.15 kB
from transformers import pipeline
import re
from typing import Dict
class NLPToSQL2:
def __init__(self):
self.model = pipeline(
"text2text-generation",
model="mrm8488/t5-base-finetuned-wikiSQL",
tokenizer="t5-base"
)
def query_to_sql(self, user_query):
prompt = (f"Generate a valid SQL query in the correct format based on the following schema:\n"
f"Table1: Employees\n"
f"Columns: ID, Name, Department, Salary\n"
f"Table2: Departments\n"
f"Columns: Name, Manager\n"
f"Natural Language: {user_query}"
f"SQL query:"
)
result = self.model(prompt, max_length=200)
sql = result[0]['generated_text']
return sql
class NLPToSQL:
def __init__(self):
self.query_patterns: Dict[str, str] = {
r"show\s+(?:me\s+)?all\s+employees?\s+in\s+(?:the\s+)?(\w+)\s+department":
"SELECT * FROM Employees WHERE LOWER(Department) = LOWER('{}')",
r"who\s+is\s+(?:the\s+)?manager\s+of\s+(?:the\s+)?(\w+)\s+department":
"SELECT Manager FROM Departments WHERE LOWER(Name) = LOWER('{}')",
r"list\s+(?:all\s+)?employees?\s+hired\s+after\s+(\d{4}-\d{2}-\d{2})":
"SELECT * FROM Employees WHERE Hire_Date > '{}'",
r"what\s+is\s+(?:the\s+)?total\s+salary\s+(?:expense\s+)?for\s+(?:the\s+)?(\w+)\s+department":
"SELECT SUM(Salary) as Total_Salary FROM Employees WHERE LOWER(Department) = LOWER('{}')",
r"show\s+(?:me\s+)?(?:the\s+)?salary\s+of\s+(\w+)":
"SELECT Salary FROM Employees WHERE LOWER(Name) = LOWER('{}')",
r"list\s+(?:all\s+)?employees?\s+with\s+salary\s+(?:greater|more)\s+than\s+(\d+)":
"SELECT * FROM Employees WHERE Salary > {}",
r"(?:show|list)\s+(?:me\s+)?all\s+departments":
"SELECT * FROM Departments",
r"(?:show|list)\s+(?:me\s+)?all\s+employees":
"SELECT * FROM Employees"
}
def query_to_sql(self, user_query: str) -> str:
normalized_query = " ".join(user_query.lower().split())
for pattern, sql_template in self.query_patterns.items():
match = re.search(pattern, normalized_query, re.IGNORECASE)
if match:
if match.groups():
return sql_template.format(*match.groups())
return sql_template
return self._generate_fallback_query(normalized_query)
def _generate_fallback_query(self, query: str) -> str:
if any(word in query for word in ['department', 'manager']):
return "SELECT * FROM Departments"
return "SELECT * FROM Employees"
def sanitize_sql(self, sql: str) -> str:
sql = re.sub(r'[;"]', '', sql)
sql = sql.replace("'", "''")
if not sql.strip().endswith(';'):
sql = f"{sql};"
return sql