Spaces:
Sleeping
Sleeping
from sqlmodel import SQLModel, create_engine, Session, select | |
from rag_app.database.schema import Sources | |
from rag_app.utils.logger import get_console_logger | |
import os | |
from dotenv import load_dotenv | |
import uuid | |
from datetime import datetime | |
class DataBaseHandler(): | |
""" | |
A class for managing the database. | |
Attributes: | |
sqlite_file_name (str): The SQLite file name for the database. | |
logger (Logger): The logger for logging database operations. | |
engine (Engine): The SQLAlchemy engine for the database. | |
Methods: | |
create_all_tables: Create all tables in the database. | |
read_one: Read a single entry from the database by its hash_id. | |
add_one: Add a single entry to the database. | |
update_one: Update a single entry in the database by its hash_id. | |
delete_one: Delete a single entry from the database by its id. | |
add_many: Add multiple entries to the database. | |
delete_many: Delete multiple entries from the database by their ids. | |
read_all: Read all entries from the database, optionally filtered by a query. | |
delete_all: Delete all entries from the database. | |
""" | |
def __init__( | |
self, | |
sqlite_file_name = os.getenv('SOURCES_CACHE'), | |
logger = get_console_logger("db_handler"), | |
# *args, | |
# **kwargs, | |
): | |
self.sqlite_file_name = sqlite_file_name | |
self.logger = logger | |
sqlite_url = f"sqlite:///{self.sqlite_file_name}" | |
self.engine = create_engine(sqlite_url, echo=False) | |
self.session_id = str(uuid.uuid4()) | |
self.session_date_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
def create_all_tables(self) -> None: | |
SQLModel.metadata.create_all(self.engine) | |
def create_new_session(self) -> None: | |
"""creates a new session_id and date time | |
""" | |
self.session_id = str(uuid.uuid4()) | |
self.session_date_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
def read_one(self,hash_id: dict): | |
""" | |
Read a single entry from the database by its hash_id. | |
Args: | |
hash_id (dict): Dictionary containing the hash_id to search for. | |
Returns: | |
Sources: The matching entry from the database, or None if no match is found. | |
""" | |
with Session(self.engine) as session: | |
statement = select(Sources).where(Sources.hash_id == hash_id) | |
sources = session.exec(statement).first() | |
return sources | |
def add_one(self,data: dict): | |
""" | |
Add a single entry to the database. | |
Args: | |
data (dict): Dictionary containing the data for the new entry. | |
Returns: | |
Sources: The added entry, or None if the entry already exists. | |
""" | |
with Session(self.engine) as session: | |
if session.exec( | |
select(Sources).where(Sources.hash_id == data.get("hash_id")) | |
).first(): | |
self.logger.warning(f"Item with hash_id {data.get('hash_id')} already exists") | |
return None # or raise an exception, or handle as needed | |
sources = Sources(**data) | |
session.add(sources) | |
session.commit() | |
session.refresh(sources) | |
self.logger.info(f"Item with hash_id {data.get('hash_id')} added to the database") | |
return sources | |
def update_one(self,hash_id: dict, data: dict): | |
""" | |
Update a single entry in the database by its hash_id. | |
Args: | |
hash_id (dict): Dictionary containing the hash_id to search for. | |
data (dict): Dictionary containing the updated data for the entry. | |
Returns: | |
Sources: The updated entry, or None if no match is found. | |
""" | |
with Session(self.engine) as session: | |
# Check if the item with the given hash_id exists | |
sources = session.exec( | |
select(Sources).where(Sources.hash_id == hash_id) | |
).first() | |
if not sources: | |
self.logger.warning(f"No item with hash_id {hash_id} found for update") | |
return None # or raise an exception, or handle as needed | |
for key, value in data.items(): | |
setattr(sources, key, value) | |
session.commit() | |
self.logger.info(f"Item with hash_id {hash_id} updated in the database") | |
return sources | |
def delete_one(self,id: int): | |
""" | |
Delete a single entry from the database by its id. | |
Args: | |
id (int): The id of the entry to delete. | |
Returns: | |
None | |
""" | |
with Session(self.engine) as session: | |
# Check if the item with the given hash_id exists | |
sources = session.exec( | |
select(Sources).where(Sources.hash_id == id) | |
).first() | |
if not sources: | |
self.logger.warning(f"No item with hash_id {id} found for deletion") | |
return None # or raise an exception, or handle as needed | |
session.delete(sources) | |
session.commit() | |
self.logger.info(f"Item with hash_id {id} deleted from the database") | |
def add_many(self,data: list): | |
""" | |
Add multiple entries to the database. | |
Args: | |
data (list): List of dictionaries, each containing the data for a new entry. | |
Returns: | |
None | |
""" | |
with Session(self.engine) as session: | |
for info in data: | |
# Reuse add_one function for each item | |
result = self.add_one(info) | |
if result is None: | |
self.logger.warning( | |
f"Item with hash_id {info.get('hash_id')} could not be added" | |
) | |
else: | |
self.logger.info( | |
f"Item with hash_id {info.get('hash_id')} added to the database" | |
) | |
session.commit() # Commit at the end of the loop | |
def delete_many(self,ids: list): | |
""" | |
Delete multiple entries from the database by their ids. | |
Args: | |
ids (list): List of ids of the entries to delete. | |
Returns: | |
None | |
""" | |
with Session(self.engine) as session: | |
for id in ids: | |
# Reuse delete_one function for each item | |
result = self.delete_one(id) | |
if result is None: | |
self.logger.warning(f"No item with hash_id {id} found for deletion") | |
else: | |
self.logger.info(f"Item with hash_id {id} deleted from the database") | |
session.commit() # Commit at the end of the loop | |
def read_all(self,query: dict = None): | |
""" | |
Read all entries from the database, optionally filtered by a query. | |
Args: | |
query (dict, optional): Dictionary containing the query parameters. Defaults to None. | |
Returns: | |
list: List of matching entries from the database. | |
""" | |
with Session(self.engine) as session: | |
statement = select(Sources) | |
if query: | |
statement = statement.where( | |
*[getattr(Sources, key) == value for key, value in query.items()] | |
) | |
sources = session.exec(statement).all() | |
return sources | |
def delete_all(self,): | |
""" | |
Delete all entries from the database. | |
Returns: | |
None | |
""" | |
with Session(self.engine) as session: | |
session.exec(Sources).delete() | |
session.commit() | |
self.logger.info("All items deleted from the database") |