Spaces:
Runtime error
Runtime error
from dependency_injector.resources import AsyncResource | |
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase, AsyncIOMotorCollection | |
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError | |
from loguru import logger | |
from pydantic import ConfigDict | |
from typing import Any, Self | |
from ctp_slack_bot.core import HealthReportingApplicationComponentBase, Settings | |
from ctp_slack_bot.utils import sanitize_mongo_db_uri | |
class MongoDB(HealthReportingApplicationComponentBase): | |
""" | |
MongoDB connection manager using Motor for async operations. | |
""" | |
model_config = ConfigDict(frozen=True) | |
settings: Settings | |
_client: AsyncIOMotorClient | |
_db: AsyncIOMotorDatabase | |
def connect(self: Self) -> None: | |
"""Initialize MongoDB client with settings.""" | |
try: | |
connection_string = self.settings.mongodb_uri.get_secret_value() | |
logger.debug("Connecting to MongoDB using URI: {}", sanitize_mongo_db_uri(connection_string)) | |
# Create client with appropriate settings. | |
self._client = AsyncIOMotorClient( | |
connection_string, | |
serverSelectionTimeoutMS=5000, | |
connectTimeoutMS=10000, | |
socketTimeoutMS=45000, | |
maxPoolSize=100, | |
retryWrites=True, | |
w="majority" | |
) | |
# Get the database name. | |
db_name = self.settings.mongodb_name | |
self._db = self._client[db_name] | |
logger.debug("MongoDB client initialized for database: {}", db_name) | |
except Exception as e: | |
logger.error("Failed to initialize MongoDB client: {}", e) | |
self._client = None | |
self._db = None | |
raise e | |
async def ping(self: Self) -> bool: | |
"""Check if MongoDB connection is alive.""" | |
try: | |
await self._client.admin.command("ping") | |
logger.debug("MongoDB connection is active!") | |
return True | |
except (ConnectionFailure, ServerSelectionTimeoutError) as e: | |
logger.error("MongoDB connection failed: {}", e) | |
except Exception as e: | |
logger.error("Unexpected error during MongoDB ping: {}", e) | |
return False | |
async def get_collection(self: Self, name: str) -> AsyncIOMotorCollection: | |
""" | |
Get a collection by name or creates it if it doesn’t exist. | |
""" | |
try: | |
if name not in await self._db.list_collection_names(): | |
collection = await self._db.create_collection(name) | |
logger.debug("Created previously nonexistent collection, {}.", name) | |
else: | |
collection = self._db[name] | |
logger.debug("Retrieved collection, {}.", name) | |
return collection | |
except Exception as e: | |
logger.error("Error accessing collection, {}: {}", name, e) | |
raise e | |
def close(self: Self) -> None: | |
"""Close the MongoDB connection.""" | |
if self._client: | |
self._client.close() | |
logger.info("Closed MongoDB connection.") | |
self._client = None | |
self._db = None | |
def name(self: Self) -> str: | |
return "mongo_db" | |
async def is_healthy(self: Self) -> bool: | |
return await self.ping() | |
class MongoDBResource(AsyncResource): | |
async def init(self: Self, settings: Settings) -> MongoDB: | |
logger.info("Initializing MongoDB connection for database: {}", settings.mongodb_name) | |
mongo_db = MongoDB(settings=settings) | |
mongo_db.connect() | |
await self._test_connection(mongo_db) | |
return mongo_db | |
async def _test_connection(self: Self, mongo_db: MongoDB) -> None: | |
"""Test MongoDB connection and log the result.""" | |
if await mongo_db.ping(): | |
logger.info("MongoDB connection test successful!") | |
else: | |
logger.error("MongoDB connection test failed!") | |
async def shutdown(self: Self, mongo_db: MongoDB) -> None: | |
"""Close MongoDB connection on shutdown.""" | |
mongo_db.close() | |