isayahc's picture
did some refactoring added documentation
fdb6484 unverified
from sqlmodel import SQLModel, create_engine, Session, select
from rag_app.database.schema import Sources
from rag_app.utils.logger import get_console_logger
import os
from dotenv import load_dotenv
import uuid
from datetime import datetime
class DataBaseHandler():
"""
A class for managing the database.
Attributes:
sqlite_file_name (str): The SQLite file name for the database.
logger (Logger): The logger for logging database operations.
engine (Engine): The SQLAlchemy engine for the database.
Methods:
create_all_tables: Create all tables in the database.
read_one: Read a single entry from the database by its hash_id.
add_one: Add a single entry to the database.
update_one: Update a single entry in the database by its hash_id.
delete_one: Delete a single entry from the database by its id.
add_many: Add multiple entries to the database.
delete_many: Delete multiple entries from the database by their ids.
read_all: Read all entries from the database, optionally filtered by a query.
delete_all: Delete all entries from the database.
"""
def __init__(
self,
sqlite_file_name = os.getenv('SOURCES_CACHE'),
logger = get_console_logger("db_handler"),
# *args,
# **kwargs,
):
self.sqlite_file_name = sqlite_file_name
self.logger = logger
sqlite_url = f"sqlite:///{self.sqlite_file_name}"
self.engine = create_engine(sqlite_url, echo=False)
self.session_id = str(uuid.uuid4())
self.session_date_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def create_all_tables(self) -> None:
SQLModel.metadata.create_all(self.engine)
def create_new_session(self) -> None:
"""creates a new session_id and date time
"""
self.session_id = str(uuid.uuid4())
self.session_date_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def read_one(self,hash_id: dict):
"""
Read a single entry from the database by its hash_id.
Args:
hash_id (dict): Dictionary containing the hash_id to search for.
Returns:
Sources: The matching entry from the database, or None if no match is found.
"""
with Session(self.engine) as session:
statement = select(Sources).where(Sources.hash_id == hash_id)
sources = session.exec(statement).first()
return sources
def add_one(self,data: dict):
"""
Add a single entry to the database.
Args:
data (dict): Dictionary containing the data for the new entry.
Returns:
Sources: The added entry, or None if the entry already exists.
"""
with Session(self.engine) as session:
if session.exec(
select(Sources).where(Sources.hash_id == data.get("hash_id"))
).first():
self.logger.warning(f"Item with hash_id {data.get('hash_id')} already exists")
return None # or raise an exception, or handle as needed
sources = Sources(**data)
session.add(sources)
session.commit()
session.refresh(sources)
self.logger.info(f"Item with hash_id {data.get('hash_id')} added to the database")
return sources
def update_one(self,hash_id: dict, data: dict):
"""
Update a single entry in the database by its hash_id.
Args:
hash_id (dict): Dictionary containing the hash_id to search for.
data (dict): Dictionary containing the updated data for the entry.
Returns:
Sources: The updated entry, or None if no match is found.
"""
with Session(self.engine) as session:
# Check if the item with the given hash_id exists
sources = session.exec(
select(Sources).where(Sources.hash_id == hash_id)
).first()
if not sources:
self.logger.warning(f"No item with hash_id {hash_id} found for update")
return None # or raise an exception, or handle as needed
for key, value in data.items():
setattr(sources, key, value)
session.commit()
self.logger.info(f"Item with hash_id {hash_id} updated in the database")
return sources
def delete_one(self,id: int):
"""
Delete a single entry from the database by its id.
Args:
id (int): The id of the entry to delete.
Returns:
None
"""
with Session(self.engine) as session:
# Check if the item with the given hash_id exists
sources = session.exec(
select(Sources).where(Sources.hash_id == id)
).first()
if not sources:
self.logger.warning(f"No item with hash_id {id} found for deletion")
return None # or raise an exception, or handle as needed
session.delete(sources)
session.commit()
self.logger.info(f"Item with hash_id {id} deleted from the database")
def add_many(self,data: list):
"""
Add multiple entries to the database.
Args:
data (list): List of dictionaries, each containing the data for a new entry.
Returns:
None
"""
with Session(self.engine) as session:
for info in data:
# Reuse add_one function for each item
result = self.add_one(info)
if result is None:
self.logger.warning(
f"Item with hash_id {info.get('hash_id')} could not be added"
)
else:
self.logger.info(
f"Item with hash_id {info.get('hash_id')} added to the database"
)
session.commit() # Commit at the end of the loop
def delete_many(self,ids: list):
"""
Delete multiple entries from the database by their ids.
Args:
ids (list): List of ids of the entries to delete.
Returns:
None
"""
with Session(self.engine) as session:
for id in ids:
# Reuse delete_one function for each item
result = self.delete_one(id)
if result is None:
self.logger.warning(f"No item with hash_id {id} found for deletion")
else:
self.logger.info(f"Item with hash_id {id} deleted from the database")
session.commit() # Commit at the end of the loop
def read_all(self,query: dict = None):
"""
Read all entries from the database, optionally filtered by a query.
Args:
query (dict, optional): Dictionary containing the query parameters. Defaults to None.
Returns:
list: List of matching entries from the database.
"""
with Session(self.engine) as session:
statement = select(Sources)
if query:
statement = statement.where(
*[getattr(Sources, key) == value for key, value in query.items()]
)
sources = session.exec(statement).all()
return sources
def delete_all(self,):
"""
Delete all entries from the database.
Returns:
None
"""
with Session(self.engine) as session:
session.exec(Sources).delete()
session.commit()
self.logger.info("All items deleted from the database")