File size: 3,645 Bytes
0d9dd80
 
 
95b255d
6f92cbb
 
0d9dd80
 
 
 
6f92cbb
 
 
0d9dd80
 
 
 
 
 
6f92cbb
6bfb9ac
 
 
6f92cbb
 
6bfb9ac
3b0a41a
 
0d9dd80
95b255d
6bfb9ac
0d9dd80
 
6bfb9ac
 
 
 
 
95b255d
bc41330
 
6bfb9ac
 
bc41330
 
 
 
 
 
 
0d9dd80
bc41330
 
 
 
0d9dd80
bc41330
0d9dd80
6bfb9ac
0d9dd80
 
 
95b255d
0d9dd80
 
 
 
 
 
 
 
 
6bfb9ac
 
 
0d9dd80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import logging
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import pipeline

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

app = FastAPI()

# Initialize the text generation pipeline
try:
    pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error(f"Failed to load the model: {str(e)}")
    raise

class QueryRequest(BaseModel):
    text: str

@app.get("/")
def home():
    return {"message": "SQL Generation Server is running"}

@app.post("/generate")
async def generate(request: QueryRequest):
    try:
        text = request.text
        logger.info(f"Received request: {text}")
        
        prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
        output = pipe(prompt, max_new_tokens=100)
        
        generated_text = output[0]['generated_text']
        sql_query = generated_text.split("SQL query:")[-1].strip()
        
        # Basic validation to ensure it's a valid SQL query
        if not sql_query.lower().startswith(('select', 'show', 'describe', 'insert', 'update', 'delete')):
            raise ValueError("Generated text is not a valid SQL query")
        
        # Further validation to ensure no additional text
        sql_query = sql_query.split(';')[0].strip()

        # Comprehensive list of SQL keywords
        allowed_keywords = {
            'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
        }
        
        # Ensure the query only contains allowed keywords
        tokens = sql_query.lower().split()
        for token in tokens:
            if not any(token.startswith(keyword) for keyword in allowed_keywords):
                raise ValueError(f"Generated text contains invalid SQL syntax: {token}")

        logger.info(f"Generated SQL query: {sql_query}")
        return {"output": sql_query}
    except ValueError as ve:
        logger.warning(f"Validation error: {str(ve)}")
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        logger.error(f"Error in generate endpoint: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail="An error occurred while generating the SQL query")

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    return JSONResponse(
        status_code=exc.status_code,
        content={"message": exc.detail},
    )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)