LiKenun's picture
Clean up and restore ability to shut down gracefully
92e41ba
raw
history blame
7.9 kB
from asyncio import create_task
from dependency_injector.resources import AsyncResource
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
from pymongo.operations import SearchIndexModel
from loguru import logger
from pydantic import BaseModel, PrivateAttr
from typing import Any, Dict, Optional, Self
from ctp_slack_bot.core.config import Settings
from ctp_slack_bot.utils import sanitize_mongo_db_uri
class MongoDB(BaseModel):
"""
MongoDB connection manager using Motor for async operations.
"""
settings: Settings
_client: PrivateAttr = PrivateAttr()
_db: PrivateAttr = PrivateAttr()
class Config:
arbitrary_types_allowed = True
def __init__(self: Self, **data: Dict[str, Any]) -> None:
super().__init__(**data)
logger.debug("Created {}", self.__class__.__name__)
def connect(self: Self) -> None:
"""Initialize MongoDB client with settings."""
try:
connection_string = self.settings.MONGODB_URI.get_secret_value()
logger.debug("Connecting to MongoDB using URI: {}", sanitize_mongo_db_uri(connection_string))
# Create client with appropriate settings
self._client = AsyncIOMotorClient(
connection_string,
serverSelectionTimeoutMS=5000,
connectTimeoutMS=10000,
socketTimeoutMS=45000,
maxPoolSize=100,
retryWrites=True,
w="majority"
)
# Set database
db_name = self.settings.MONGODB_NAME
self._db = self._client[db_name]
logger.debug("MongoDB client initialized for database: {}", db_name)
except Exception as e:
logger.error("Failed to initialize MongoDB client: {}", e)
self._client = None
self._db = None
raise
@property
def client(self: Self) -> AsyncIOMotorClient:
"""Get the MongoDB client instance."""
if not hasattr(self, '_client') or self._client is None:
logger.warning("MongoDB client not initialized. Attempting to initialize…")
self.connect()
if not hasattr(self, '_client') or self._client is None:
raise ConnectionError("Failed to initialize MongoDB client.")
return self._client
@property
def db(self: Self) -> Any:
"""Get the MongoDB database instance."""
if not hasattr(self, '_db') or self._db is None:
logger.warning("MongoDB database not initialized. Attempting to initialize client…")
self.connect()
if not hasattr(self, '_db') or self._db is None:
raise ConnectionError("Failed to initialize MongoDB database.")
return self._db
async def ping(self: Self) -> bool:
"""Check if MongoDB connection is alive."""
try:
# Get client to ensure we're connected
client = self.client
# Try a simple ping command
await client.admin.command('ping')
logger.debug("MongoDB connection is active!")
return True
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
logger.error("MongoDB connection failed: {}", e)
return False
except Exception as e:
logger.error("Unexpected error during MongoDB ping: {}", e)
return False
async def get_collection(self: Self, name: str) -> Any:
"""
Get a collection by name with validation.
Creates the collection if it doesn't exist.
"""
# First ensure we can connect at all
if not await self.ping():
logger.error("Cannot get collection '{}' because a MongoDB connection is not available.", name)
raise ConnectionError("MongoDB connection is not available.")
try:
# Get all collection names to check if this one exists
logger.debug("Checking if collection '{}' exists…", name)
collection_names = await self.db.list_collection_names()
if name not in collection_names:
logger.info("Collection '{}' does not exist. Creating it…", name)
# Create the collection
await self.db.create_collection(name)
logger.debug("Successfully created collection: {}", name)
else:
logger.debug("Collection '{}' already exists!", name)
# Get and return the collection
collection = self.db[name]
return collection
except Exception as e:
logger.error("Error accessing collection '{}': {}", name, e)
raise
async def create_indexes(self: Self, collection_name: str) -> None:
"""
Create a vector search index on a collection.
Args:
collection_name: Name of the collection
"""
collection = await self.get_collection(collection_name)
try:
# Create search index model using MongoDB's recommended approach
search_index_model = SearchIndexModel(
definition={
"fields": [
{
"type": "vector",
"path": "embedding",
"numDimensions": self.settings.VECTOR_DIMENSION,
"similarity": "cosine",
"quantization": "scalar"
}
]
},
name=f"{collection_name}_vector_index",
type="vectorSearch"
)
# Create the search index using the motor collection
result = await collection.create_search_index(search_index_model)
logger.info("Vector search index '{}' created for collection {}.", result, collection_name)
except Exception as e:
if "command not found" in str(e).lower():
logger.warning("Vector search not supported by this MongoDB instance. Some functionality may be limited.")
# Create a fallback standard index on embedding field
await collection.create_index("embedding")
logger.info("Created standard index on 'embedding' field as fallback.")
else:
logger.error("Failed to create vector index: {}", e)
raise
async def close(self: Self) -> None:
"""Close MongoDB connection."""
if self._client:
self._client.close()
logger.info("Closed MongoDB connection.")
self._client = None
self._db = None
class MongoDBResource(AsyncResource):
async def init(self: Self, settings: Settings) -> MongoDB:
logger.info("Initializing MongoDB connection for database: {}", settings.MONGODB_NAME)
mongo_db = MongoDB(settings=settings)
mongo_db.connect()
await self._test_connection(mongo_db)
return mongo_db
async def _test_connection(self: Self, mongo_db: MongoDB) -> None:
"""Test MongoDB connection and log the result."""
try:
is_connected = await mongo_db.ping()
if is_connected:
logger.info("MongoDB connection test successful!")
else:
logger.error("MongoDB connection test failed!")
except Exception as e:
logger.error("Error testing MongoDB connection: {}", e)
raise
async def shutdown(self: Self, mongo_db: MongoDB) -> None:
"""Close MongoDB connection on shutdown."""
try:
await mongo_db.close()
except Exception as e:
logger.error("Error closing MongoDB connection: {}", e)