File size: 3,151 Bytes
2d31646
ddf9ea0
 
2d31646
ddf9ea0
2d31646
 
 
 
 
 
 
 
ddf9ea0
 
 
 
 
 
 
 
 
2d31646
ddf9ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from transformers import pipeline
import re
from typing import Dict

class NLPToSQL2:
    def __init__(self):
        self.model = pipeline(
            "text2text-generation",
            model="mrm8488/t5-base-finetuned-wikiSQL",
            tokenizer="t5-base"
        )

    def query_to_sql(self, user_query):
        prompt = (f"Generate a valid SQL query in the correct format based on the following schema:\n"
            f"Table1: Employees\n"
            f"Columns: ID, Name, Department, Salary\n"
            f"Table2: Departments\n"
            f"Columns: Name, Manager\n"
            f"Natural Language: {user_query}"
            f"SQL query:"
                  )
        
        result = self.model(prompt, max_length=200)
        sql = result[0]['generated_text']
                
        return sql
    
class NLPToSQL:
    def __init__(self):
        self.query_patterns: Dict[str, str] = {
            r"show\s+(?:me\s+)?all\s+employees?\s+in\s+(?:the\s+)?(\w+)\s+department": 
                "SELECT * FROM Employees WHERE LOWER(Department) = LOWER('{}')",
            
            r"who\s+is\s+(?:the\s+)?manager\s+of\s+(?:the\s+)?(\w+)\s+department": 
                "SELECT Manager FROM Departments WHERE LOWER(Name) = LOWER('{}')",
            
            r"list\s+(?:all\s+)?employees?\s+hired\s+after\s+(\d{4}-\d{2}-\d{2})":
                "SELECT * FROM Employees WHERE Hire_Date > '{}'",
            
            r"what\s+is\s+(?:the\s+)?total\s+salary\s+(?:expense\s+)?for\s+(?:the\s+)?(\w+)\s+department":
                "SELECT SUM(Salary) as Total_Salary FROM Employees WHERE LOWER(Department) = LOWER('{}')",
            
            r"show\s+(?:me\s+)?(?:the\s+)?salary\s+of\s+(\w+)":
                "SELECT Salary FROM Employees WHERE LOWER(Name) = LOWER('{}')",
            
            r"list\s+(?:all\s+)?employees?\s+with\s+salary\s+(?:greater|more)\s+than\s+(\d+)":
                "SELECT * FROM Employees WHERE Salary > {}",
            
            r"(?:show|list)\s+(?:me\s+)?all\s+departments":
                "SELECT * FROM Departments",
            
            r"(?:show|list)\s+(?:me\s+)?all\s+employees":
                "SELECT * FROM Employees"
        }

    def query_to_sql(self, user_query: str) -> str:
        normalized_query = " ".join(user_query.lower().split())
        
        for pattern, sql_template in self.query_patterns.items():
            match = re.search(pattern, normalized_query, re.IGNORECASE)
            if match:
                if match.groups():
                    return sql_template.format(*match.groups())
                return sql_template
        
        return self._generate_fallback_query(normalized_query)
    
    def _generate_fallback_query(self, query: str) -> str:
        if any(word in query for word in ['department', 'manager']):
            return "SELECT * FROM Departments"
        return "SELECT * FROM Employees"
    
    def sanitize_sql(self, sql: str) -> str:
        sql = re.sub(r'[;"]', '', sql)
        sql = sql.replace("'", "''")
        if not sql.strip().endswith(';'):
            sql = f"{sql};"
        
        return sql