qatch-demo / utils_get_db_tables_info.py
simone-papicchio's picture
More stable version. Link all acc, but still miss prediction (#6)
aff05a7 verified
raw
history blame
2.98 kB
import os
import sqlite3
import re
def utils_extract_db_schema_as_string(
db_id, base_path, normalize=False, sql: str | None = None
):
"""
Extracts the full schema of an SQLite database into a single string.
:param base_path: Base path where the database is located.
:param db_id: Path to the SQLite database file.
:param normalize: Whether to normalize the schema string.
:param sql: Optional SQL query to filter specific tables.
:return: Schema of the database as a single string.
"""
#db_path = os.path.join(base_path, db_id, f"{db_id}.sqlite")
# Connect to the SQLite database
#if not os.path.exists(db_path):
# raise FileNotFoundError(f"Database file not found at: {db_path}")
connection = sqlite3.connect(base_path)
cursor = connection.cursor()
# Get the schema entries based on the provided SQL query
schema_entries = _get_schema_entries(cursor, sql)
# Combine all schema definitions into a single string
schema_string = _combine_schema_entries(schema_entries, normalize)
return schema_string
def _get_schema_entries(cursor, sql):
"""
Retrieves schema entries from the SQLite database.
:param cursor: SQLite cursor object.
:param sql: Optional SQL query to filter specific tables.
:return: List of schema entries.
"""
if sql:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
if tables:
tbl_names = ", ".join(f"'{tbl}'" for tbl in tables)
query = f"SELECT sql FROM sqlite_master WHERE type='table' AND name IN ({tbl_names}) AND sql IS NOT NULL;"
else:
query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
else:
query = "SELECT sql FROM sqlite_master WHERE sql IS NOT NULL;"
cursor.execute(query)
return cursor.fetchall()
def _combine_schema_entries(schema_entries, normalize):
"""
Combines schema entries into a single string.
:param schema_entries: List of schema entries.
:param normalize: Whether to normalize the schema string.
:return: Combined schema string.
"""
if not normalize:
return "\n".join(entry[0] for entry in schema_entries)
return "\n".join(
re.sub(
r"\s*\)",
")",
re.sub(
r"\(\s*",
"(",
re.sub(
r"(`\w+`)\s+\(",
r"\1(",
re.sub(
r"^\s*([^\s(]+)",
r"`\1`",
re.sub(
r"\s+",
" ",
entry[0].replace("CREATE TABLE", "").replace("\t", " "),
).strip(),
),
),
),
)
for entry in schema_entries
)