qatch-demo / utils_get_db_tables_info.py
simone-papicchio's picture
Fix prompts buttons, and NL2SQL bug (#24)
af2b1fd verified
raw
history blame
4.88 kB
import os
import sqlite3
import re
import utilities as us
def utils_extract_db_schema_as_string(
db_id, base_path, model : str | None = None , normalize=False, sql: str | None = None, get_insert_into: bool = False, prompt : str | None = None
):
"""
Extracts the full schema of an SQLite database into a single string.
:param base_path: Base path where the database is located.
:param db_id: Path to the SQLite database file.
:param normalize: Whether to normalize the schema string.
:param sql: Optional SQL query to filter specific tables.
:return: Schema of the database as a single string.
"""
connection = sqlite3.connect(base_path)
cursor = connection.cursor()
# Get the schema entries based on the provided SQL query
schema_entries = _get_schema_entries(cursor, sql, get_insert_into, model, prompt)
# Combine all schema definitions into a single string
schema_string = _combine_schema_entries(schema_entries, normalize)
return schema_string
def _get_schema_entries(cursor, sql=None, get_insert_into=False, model: str | None = None, prompt : str | None = None):
"""
Retrieves schema entries and optionally data entries from the SQLite database.
:param cursor: SQLite cursor object.
:param sql: Optional SQL query to filter specific tables.
:param get_insert_into: Boolean flag to include INSERT INTO statements.
:return: List of schema and optionally data entries.
"""
entries = []
if sql:
# Extract table names from the provided SQL query
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [tbl[0] for tbl in cursor.fetchall() if tbl[0].lower() in sql.lower()]
else:
# Retrieve all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [tbl[0] for tbl in cursor.fetchall()]
for table in tables:
entries_per_table = []
# Retrieve the CREATE TABLE statement for each table
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}' AND sql IS NOT NULL;")
create_table_stmt = cursor.fetchone()
if create_table_stmt:
stmt = create_table_stmt[0].strip()
if not stmt.endswith(';'):
stmt += ';'
entries_per_table.append(stmt)
if get_insert_into:
# Retrieve all data from the table
cursor.execute(f"SELECT * FROM {table};")
rows = cursor.fetchall()
column_names = [description[0] for description in cursor.description]
# Generate INSERT INTO statements for each row
if model==None :
max_len=3
else:
max_len = len(rows)
for row in rows[:max_len]:
values = ', '.join(f"'{str(value)}'" if isinstance(value, str) else str(value) for value in row)
insert_stmt = f"INSERT INTO {table} ({', '.join(column_names)}) VALUES ({values});"
entries_per_table.append(insert_stmt)
if model != None : entries_per_table = us.crop_entries_per_token(entries_per_table, model, prompt)
entries.extend(entries_per_table)
return entries
def _combine_schema_entries(schema_entries, normalize):
"""
Combines schema entries into a single string.
:param schema_entries: List of schema entries.
:param normalize: Whether to normalize the schema string.
:return: Combined schema string.
"""
if not normalize:
return "\n".join(entry for entry in schema_entries)
return "\n".join(
re.sub(
r"\s*\)",
")",
re.sub(
r"\(\s*",
"(",
re.sub(
r"(`\w+`)\s+\(",
r"\1(",
re.sub(
r"^\s*([^\s(]+)",
r"`\1`",
re.sub(
r"\s+",
" ",
entry.replace("CREATE TABLE", "").replace("\t", " "),
).strip(),
),
),
),
)
for entry in schema_entries
)
def create_db_temp(schema_sql: str) -> sqlite3.Connection:
"""
Creates a temporary SQLite database in memory by executing the provided SQL schema.
Args:
schema_sql (str): The SQL code containing CREATE TABLE and INSERT INTO.
Returns:
sqlite3.Connection: Connection object to the temporary database.
"""
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
try:
cursor.executescript(schema_sql)
conn.commit()
except sqlite3.Error as e:
conn.close()
raise
return conn