File size: 3,344 Bytes
5473610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00295a8
 
 
5473610
 
 
 
 
 
 
 
 
 
 
 
 
00295a8
5473610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00295a8
5473610
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import mysql.connector
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

app = FastAPI()

# Initialize the text generation pipeline
pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2)

class QueryRequest(BaseModel):
    text: str

def get_db_connection():
    """Create a new database connection."""
    try:
        connection = mysql.connector.connect(
            host=os.getenv("DB_HOST"),
            user=os.getenv("DB_USER"),
            password=os.getenv("DB_PASSWORD"),
            database=os.getenv("DB_NAME"),
            raise_on_warnings=True
        )
        return connection
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return None

def get_database_schema():
    """Function to retrieve the database schema dynamically."""
    schema = {}
    try:
        conn = get_db_connection()
        if conn is None:
            raise Exception("Failed to connect to the database.")
        
        cursor = conn.cursor()

        # Query to get table names
        cursor.execute("SHOW TABLES")
        tables = cursor.fetchall()

        for table in tables:
            table_name = table[0]
            cursor.execute(f"DESCRIBE {table_name}")
            columns = cursor.fetchall()
            schema[table_name] = [column[0] for column in columns]

        cursor.close()
        conn.close()
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return {}
    except Exception as e:
        print(f"An error occurred: {e}")
        return {}

    return schema

@app.get("/")
def home():
    return {"message": "SQL Generation Server is running"}

@app.post("/generate")
def generate(request: QueryRequest):
    try:
        # Log the incoming request text for debugging
        print(f"Received request with text: {request.text}")
        
        text = request.text
        
        # Fetch the database schema
        schema = get_database_schema()
        schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()])
        
        # Construct the system message
        system_message = f"""

        You are a helpful, cheerful database assistant. 

        Use the following dynamically retrieved database schema when creating your answers:



        {schema_str}



        [Additional instructions as in your original code]

        """
        
        prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
        output = pipe(prompt, max_new_tokens=100)
        
        generated_text = output[0]['generated_text']
        sql_query = generated_text.split("SQL query:")[-1].strip()
        
        # Basic validation
        if not sql_query.lower().startswith(('select', 'show', 'describe')):
            raise ValueError("Generated text is not a valid SQL query")
        
        return {"output": sql_query}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)