Spaces:
Sleeping
Sleeping
File size: 2,635 Bytes
95b255d 6f92cbb 3a487a7 6f92cbb 6bfb9ac 6f92cbb 6bfb9ac 3b0a41a 6bfb9ac 95b255d 6bfb9ac 95b255d bc41330 6bfb9ac bc41330 6bfb9ac 95b255d 6bfb9ac 3b0a41a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
app = FastAPI()
# Initialize the text generation pipeline
pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
class QueryRequest(BaseModel):
text: str
@app.get("/")
def home():
return {"message": "SQL Generation Server is running"}
@app.post("/generate")
def generate(request: QueryRequest):
try:
text = request.text
prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
output = pipe(prompt, max_new_tokens=100)
generated_text = output[0]['generated_text']
sql_query = generated_text.split("SQL query:")[-1].strip()
# Basic validation to ensure it's a valid SQL query
if not sql_query.lower().startswith(('select', 'show', 'describe', 'insert', 'update', 'delete')):
raise ValueError("Generated text is not a valid SQL query")
# Further validation to ensure no additional text
sql_query = sql_query.split(';')[0].strip()
# Comprehensive list of SQL keywords
allowed_keywords = {
'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
}
# Ensure the query only contains allowed keywords
tokens = sql_query.lower().split()
for token in tokens:
if not any(token.startswith(keyword) for keyword in allowed_keywords):
raise ValueError("Generated text contains invalid SQL syntax")
return {"output": sql_query}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|