import os import sqlite3 import re def utils_extract_db_schema_as_string( db_id, base_path, normalize=False, sql: str | None = None ): """ Extracts the full schema of an SQLite database into a single string. :param base_path: Base path where the database is located. :param db_id: Path to the SQLite database file. :param normalize: Whether to normalize the schema string. :param sql: Optional SQL query to filter specific tables. :return: Schema of the database as a single string. """ #db_path = os.path.join(base_path, db_id, f"{db_id}.sqlite") # Connect to the SQLite database #if not os.path.exists(db_path): # raise FileNotFoundError(f"Database file not found at: {db_path}") connection = sqlite3.connect(base_path) cursor = connection.cursor() # Get the schema entries based on the provided SQL query schema_entries = _get_schema_entries(cursor, sql) # Combine all schema definitions into a single string schema_string = _combine_schema_entries(schema_entries, normalize) return schema_string def _get_schema_entries(cursor, sql): """ Retrieves schema entries from the SQLite database. :param cursor: SQLite cursor object. :param sql: Optional SQL query to filter specific tables. :return: List of schema entries. """ if sql: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()] if tables: tbl_names = ", ".join(f"'{tbl}'" for tbl in tables) query = f"SELECT sql FROM sqlite_master WHERE type='table' AND name IN ({tbl_names}) AND sql IS NOT NULL;" else: query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;" else: query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;" cursor.execute(query) return cursor.fetchall() def _combine_schema_entries(schema_entries, normalize): """ Combines schema entries into a single string. :param schema_entries: List of schema entries. :param normalize: Whether to normalize the schema string. :return: Combined schema string. """ if not normalize: return "\n".join(entry[0] for entry in schema_entries) return "\n".join( re.sub( r"\s*\)", ")", re.sub( r"\(\s*", "(", re.sub( r"(`\w+`)\s+\(", r"\1(", re.sub( r"^\s*([^\s(]+)", r"`\1`", re.sub( r"\s+", " ", entry[0].replace("CREATE TABLE", "").replace("\t", " "), ).strip(), ), ), ), ) for entry in schema_entries )