Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| import mysql.connector | |
| import os | |
| from dotenv import load_dotenv | |
| # Load environment variables from the .env file | |
| load_dotenv() | |
| app = FastAPI() | |
| # Initialize the text generation pipeline | |
| pipe = pipeline("text-generation", model="defog/llama-3-sqlcoder-8b", pad_token_id=2) | |
| class QueryRequest(BaseModel): | |
| text: str | |
| def get_db_connection(): | |
| """Create a new database connection.""" | |
| try: | |
| connection = mysql.connector.connect( | |
| host=os.getenv("DB_HOST"), | |
| user=os.getenv("DB_USER"), | |
| password=os.getenv("DB_PASSWORD"), | |
| database=os.getenv("DB_NAME"), | |
| raise_on_warnings=True | |
| ) | |
| return connection | |
| except mysql.connector.Error as err: | |
| print(f"Error: {err}") | |
| return None | |
| def get_database_schema(): | |
| """Function to retrieve the database schema dynamically.""" | |
| schema = {} | |
| try: | |
| conn = get_db_connection() | |
| if conn is None: | |
| raise Exception("Failed to connect to the database.") | |
| cursor = conn.cursor() | |
| # Query to get table names | |
| cursor.execute("SHOW TABLES") | |
| tables = cursor.fetchall() | |
| for table in tables: | |
| table_name = table[0] | |
| cursor.execute(f"DESCRIBE {table_name}") | |
| columns = cursor.fetchall() | |
| schema[table_name] = [column[0] for column in columns] | |
| cursor.close() | |
| conn.close() | |
| except mysql.connector.Error as err: | |
| print(f"Error: {err}") | |
| return {} | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return {} | |
| return schema | |
| def home(): | |
| return {"message": "SQL Generation Server is running"} | |
| def generate(request: QueryRequest): | |
| try: | |
| # Log the incoming request text for debugging | |
| print(f"Received request with text: {request.text}") | |
| text = request.text | |
| # Fetch the database schema | |
| schema = get_database_schema() | |
| schema_str = "\n".join([f"{table}: {', '.join(columns)}" for table, columns in schema.items()]) | |
| # Construct the system message | |
| system_message = f""" | |
| You are a helpful, cheerful database assistant. | |
| Use the following dynamically retrieved database schema when creating your answers: | |
| {schema_str} | |
| [Additional instructions as in your original code] | |
| """ | |
| prompt = f"{system_message}\n\nUser request:\n\n{text}\n\nSQL query:" | |
| output = pipe(prompt, max_new_tokens=100) | |
| generated_text = output[0]['generated_text'] | |
| sql_query = generated_text.split("SQL query:")[-1].strip() | |
| # Basic validation | |
| if not sql_query.lower().startswith(('select', 'show', 'describe')): | |
| raise ValueError("Generated text is not a valid SQL query") | |
| return {"output": sql_query} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |