Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import mysql.connector | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables from the .env file | |
load_dotenv() | |
app = FastAPI() | |
# Initialize the text generation pipeline | |
pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2) | |
class QueryRequest(BaseModel): | |
text: str | |
def get_db_connection(): | |
"""Create a new database connection.""" | |
try: | |
connection = mysql.connector.connect( | |
host=os.getenv("DB_HOST"), | |
user=os.getenv("DB_USER"), | |
password=os.getenv("DB_PASSWORD"), | |
database=os.getenv("DB_NAME"), | |
raise_on_warnings=True | |
) | |
return connection | |
except mysql.connector.Error as err: | |
print(f"Error: {err}") | |
return None | |
def get_database_schema(): | |
"""Function to retrieve the database schema dynamically.""" | |
schema = {} | |
try: | |
conn = get_db_connection() | |
if conn is None: | |
raise Exception("Failed to connect to the database.") | |
cursor = conn.cursor() | |
# Query to get table names | |
cursor.execute("SHOW TABLES") | |
tables = cursor.fetchall() | |
for table in tables: | |
table_name = table[0] | |
cursor.execute(f"DESCRIBE {table_name}") | |
columns = cursor.fetchall() | |
schema[table_name] = [column[0] for column in columns] | |
cursor.close() | |
conn.close() | |
except mysql.connector.Error as err: | |
print(f"Error: {err}") | |
return {} | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return {} | |
return schema | |
def home(): | |
return {"message": "SQL Generation Server is running"} | |
def generate(request: QueryRequest): | |
try: | |
# Log the incoming request text for debugging | |
print(f"Received request with text: {request.text}") | |
text = request.text | |
# Fetch the database schema | |
schema = get_database_schema() | |
schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()]) | |
# Construct the system message | |
system_message = f""" | |
You are a helpful, cheerful database assistant. | |
Use the following dynamically retrieved database schema when creating your answers: | |
{schema_str} | |
[Additional instructions as in your original code] | |
""" | |
prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:" | |
output = pipe(prompt, max_new_tokens=100) | |
generated_text = output[0]['generated_text'] | |
sql_query = generated_text.split("SQL query:")[-1].strip() | |
# Basic validation | |
if not sql_query.lower().startswith(('select', 'show', 'describe')): | |
raise ValueError("Generated text is not a valid SQL query") | |
return {"output": sql_query} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |