from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline import mysql.connector import os from dotenv import load_dotenv # Load environment variables from the .env file load_dotenv() app = FastAPI() # Initialize the text generation pipeline pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2) class QueryRequest(BaseModel): text: str def get_db_connection(): """Create a new database connection.""" try: connection = mysql.connector.connect( host=os.getenv("DB_HOST"), user=os.getenv("DB_USER"), password=os.getenv("DB_PASSWORD"), database=os.getenv("DB_NAME"), raise_on_warnings=True ) return connection except mysql.connector.Error as err: print(f"Error: {err}") return None def get_database_schema(): """Function to retrieve the database schema dynamically.""" schema = {} try: conn = get_db_connection() if conn is None: raise Exception("Failed to connect to the database.") cursor = conn.cursor() # Query to get table names cursor.execute("SHOW TABLES") tables = cursor.fetchall() for table in tables: table_name = table[0] cursor.execute(f"DESCRIBE {table_name}") columns = cursor.fetchall() schema[table_name] = [column[0] for column in columns] cursor.close() conn.close() except mysql.connector.Error as err: print(f"Error: {err}") return {} except Exception as e: print(f"An error occurred: {e}") return {} return schema @app.get("/") def home(): return {"message": "SQL Generation Server is running"} @app.post("/generate") def generate(request: QueryRequest): try: # Log the incoming request text for debugging print(f"Received request with text: {request.text}") text = request.text # Fetch the database schema schema = get_database_schema() schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()]) # Construct the system message system_message = f""" You are a helpful, cheerful database assistant. Use the following dynamically retrieved database schema when creating your answers: {schema_str} [Additional instructions as in your original code] """ prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:" output = pipe(prompt, max_new_tokens=100) generated_text = output[0]['generated_text'] sql_query = generated_text.split("SQL query:")[-1].strip() # Basic validation if not sql_query.lower().startswith(('select', 'show', 'describe')): raise ValueError("Generated text is not a valid SQL query") return {"output": sql_query} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)