Spaces:
Sleeping
Sleeping
File size: 3,344 Bytes
5473610 00295a8 5473610 00295a8 5473610 00295a8 5473610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import mysql.connector
import os
from dotenv import load_dotenv
# Load environment variables from the .env file
load_dotenv()
app = FastAPI()
# Initialize the text generation pipeline
pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2)
class QueryRequest(BaseModel):
text: str
def get_db_connection():
"""Create a new database connection."""
try:
connection = mysql.connector.connect(
host=os.getenv("DB_HOST"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
database=os.getenv("DB_NAME"),
raise_on_warnings=True
)
return connection
except mysql.connector.Error as err:
print(f"Error: {err}")
return None
def get_database_schema():
"""Function to retrieve the database schema dynamically."""
schema = {}
try:
conn = get_db_connection()
if conn is None:
raise Exception("Failed to connect to the database.")
cursor = conn.cursor()
# Query to get table names
cursor.execute("SHOW TABLES")
tables = cursor.fetchall()
for table in tables:
table_name = table[0]
cursor.execute(f"DESCRIBE {table_name}")
columns = cursor.fetchall()
schema[table_name] = [column[0] for column in columns]
cursor.close()
conn.close()
except mysql.connector.Error as err:
print(f"Error: {err}")
return {}
except Exception as e:
print(f"An error occurred: {e}")
return {}
return schema
@app.get("/")
def home():
return {"message": "SQL Generation Server is running"}
@app.post("/generate")
def generate(request: QueryRequest):
try:
# Log the incoming request text for debugging
print(f"Received request with text: {request.text}")
text = request.text
# Fetch the database schema
schema = get_database_schema()
schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()])
# Construct the system message
system_message = f"""
You are a helpful, cheerful database assistant.
Use the following dynamically retrieved database schema when creating your answers:
{schema_str}
[Additional instructions as in your original code]
"""
prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:"
output = pipe(prompt, max_new_tokens=100)
generated_text = output[0]['generated_text']
sql_query = generated_text.split("SQL query:")[-1].strip()
# Basic validation
if not sql_query.lower().startswith(('select', 'show', 'describe')):
raise ValueError("Generated text is not a valid SQL query")
return {"output": sql_query}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|