File size: 2,976 Bytes
696b8fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import sqlite3
import re


def utils_extract_db_schema_as_string(
    db_id, base_path, normalize=False, sql: str | None = None
):
    """
    Extracts the full schema of an SQLite database into a single string.

    :param base_path: Base path where the database is located.
    :param db_id: Path to the SQLite database file.
    :param normalize: Whether to normalize the schema string.
    :param sql: Optional SQL query to filter specific tables.
    :return: Schema of the database as a single string.
    """
    #db_path = os.path.join(base_path, db_id, f"{db_id}.sqlite")

    # Connect to the SQLite database

    #if not os.path.exists(db_path):
    #    raise FileNotFoundError(f"Database file not found at: {db_path}")

    connection = sqlite3.connect(base_path)
    cursor = connection.cursor()

    # Get the schema entries based on the provided SQL query
    schema_entries = _get_schema_entries(cursor, sql)

    # Combine all schema definitions into a single string
    schema_string = _combine_schema_entries(schema_entries, normalize)

    return schema_string


def _get_schema_entries(cursor, sql):
    """
    Retrieves schema entries from the SQLite database.

    :param cursor: SQLite cursor object.
    :param sql: Optional SQL query to filter specific tables.
    :return: List of schema entries.
    """
    if sql:
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
        if tables:
            tbl_names = ", ".join(f"'{tbl}'" for tbl in tables)
            query = f"SELECT sql FROM sqlite_master WHERE type='table' AND name IN ({tbl_names}) AND sql IS NOT NULL;"
        else:
            query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
    else:
        query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"

    cursor.execute(query)
    return cursor.fetchall()


def _combine_schema_entries(schema_entries, normalize):
    """
    Combines schema entries into a single string.

    :param schema_entries: List of schema entries.
    :param normalize: Whether to normalize the schema string.
    :return: Combined schema string.
    """
    if not normalize:
        return "\n".join(entry[0] for entry in schema_entries)

    return "\n".join(
        re.sub(
            r"\s*\)",
            ")",
            re.sub(
                r"\(\s*",
                "(",
                re.sub(
                    r"(`\w+`)\s+\(",
                    r"\1(",
                    re.sub(
                        r"^\s*([^\s(]+)",
                        r"`\1`",
                        re.sub(
                            r"\s+",
                            " ",
                            entry[0].replace("CREATE TABLE", "").replace("\t", " "),
                        ).strip(),
                    ),
                ),
            ),
        )
        for entry in schema_entries
    )