Spaces:
Runtime error
Runtime error
from asyncio import create_task | |
from dependency_injector.resources import AsyncResource | |
from motor.motor_asyncio import AsyncIOMotorClient | |
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError | |
from pymongo.operations import SearchIndexModel | |
from loguru import logger | |
from pydantic import BaseModel, PrivateAttr | |
from typing import Any, Dict, Optional, Self | |
from ctp_slack_bot.core.config import Settings | |
from ctp_slack_bot.utils import sanitize_mongo_db_uri | |
class MongoDB(BaseModel): | |
""" | |
MongoDB connection manager using Motor for async operations. | |
""" | |
settings: Settings | |
_client: PrivateAttr = PrivateAttr() | |
_db: PrivateAttr = PrivateAttr() | |
class Config: | |
arbitrary_types_allowed = True | |
def __init__(self: Self, **data: Dict[str, Any]) -> None: | |
super().__init__(**data) | |
logger.debug("Created {}", self.__class__.__name__) | |
def connect(self: Self) -> None: | |
"""Initialize MongoDB client with settings.""" | |
try: | |
connection_string = self.settings.MONGODB_URI.get_secret_value() | |
logger.debug("Connecting to MongoDB using URI: {}", sanitize_mongo_db_uri(connection_string)) | |
# Create client with appropriate settings | |
self._client = AsyncIOMotorClient( | |
connection_string, | |
serverSelectionTimeoutMS=5000, | |
connectTimeoutMS=10000, | |
socketTimeoutMS=45000, | |
maxPoolSize=100, | |
retryWrites=True, | |
w="majority" | |
) | |
# Set database | |
db_name = self.settings.MONGODB_NAME | |
self._db = self._client[db_name] | |
logger.debug("MongoDB client initialized for database: {}", db_name) | |
except Exception as e: | |
logger.error("Failed to initialize MongoDB client: {}", e) | |
self._client = None | |
self._db = None | |
raise | |
def client(self: Self) -> AsyncIOMotorClient: | |
"""Get the MongoDB client instance.""" | |
if not hasattr(self, '_client') or self._client is None: | |
logger.warning("MongoDB client not initialized. Attempting to initialize…") | |
self.connect() | |
if not hasattr(self, '_client') or self._client is None: | |
raise ConnectionError("Failed to initialize MongoDB client.") | |
return self._client | |
def db(self: Self) -> Any: | |
"""Get the MongoDB database instance.""" | |
if not hasattr(self, '_db') or self._db is None: | |
logger.warning("MongoDB database not initialized. Attempting to initialize client…") | |
self.connect() | |
if not hasattr(self, '_db') or self._db is None: | |
raise ConnectionError("Failed to initialize MongoDB database.") | |
return self._db | |
async def ping(self: Self) -> bool: | |
"""Check if MongoDB connection is alive.""" | |
try: | |
# Get client to ensure we're connected | |
client = self.client | |
# Try a simple ping command | |
await client.admin.command('ping') | |
logger.debug("MongoDB connection is active!") | |
return True | |
except (ConnectionFailure, ServerSelectionTimeoutError) as e: | |
logger.error("MongoDB connection failed: {}", e) | |
return False | |
except Exception as e: | |
logger.error("Unexpected error during MongoDB ping: {}", e) | |
return False | |
async def get_collection(self: Self, name: str) -> Any: | |
""" | |
Get a collection by name with validation. | |
Creates the collection if it doesn't exist. | |
""" | |
# First ensure we can connect at all | |
if not await self.ping(): | |
logger.error("Cannot get collection '{}' because a MongoDB connection is not available.", name) | |
raise ConnectionError("MongoDB connection is not available.") | |
try: | |
# Get all collection names to check if this one exists | |
logger.debug("Checking if collection '{}' exists…", name) | |
collection_names = await self.db.list_collection_names() | |
if name not in collection_names: | |
logger.info("Collection '{}' does not exist. Creating it…", name) | |
# Create the collection | |
await self.db.create_collection(name) | |
logger.debug("Successfully created collection: {}", name) | |
else: | |
logger.debug("Collection '{}' already exists!", name) | |
# Get and return the collection | |
collection = self.db[name] | |
return collection | |
except Exception as e: | |
logger.error("Error accessing collection '{}': {}", name, e) | |
raise | |
async def create_indexes(self: Self, collection_name: str) -> None: | |
""" | |
Create a vector search index on a collection. | |
Args: | |
collection_name: Name of the collection | |
""" | |
collection = await self.get_collection(collection_name) | |
try: | |
# Create search index model using MongoDB's recommended approach | |
search_index_model = SearchIndexModel( | |
definition={ | |
"fields": [ | |
{ | |
"type": "vector", | |
"path": "embedding", | |
"numDimensions": self.settings.VECTOR_DIMENSION, | |
"similarity": "cosine", | |
"quantization": "scalar" | |
} | |
] | |
}, | |
name=f"{collection_name}_vector_index", | |
type="vectorSearch" | |
) | |
# Create the search index using the motor collection | |
result = await collection.create_search_index(search_index_model) | |
logger.info("Vector search index '{}' created for collection {}.", result, collection_name) | |
except Exception as e: | |
if "command not found" in str(e).lower(): | |
logger.warning("Vector search not supported by this MongoDB instance. Some functionality may be limited.") | |
# Create a fallback standard index on embedding field | |
await collection.create_index("embedding") | |
logger.info("Created standard index on 'embedding' field as fallback.") | |
else: | |
logger.error("Failed to create vector index: {}", e) | |
raise | |
async def close(self: Self) -> None: | |
"""Close MongoDB connection.""" | |
if self._client: | |
self._client.close() | |
logger.info("Closed MongoDB connection.") | |
self._client = None | |
self._db = None | |
class MongoDBResource(AsyncResource): | |
async def init(self: Self, settings: Settings) -> MongoDB: | |
logger.info("Initializing MongoDB connection for database: {}", settings.MONGODB_NAME) | |
mongo_db = MongoDB(settings=settings) | |
mongo_db.connect() | |
await self._test_connection(mongo_db) | |
return mongo_db | |
async def _test_connection(self: Self, mongo_db: MongoDB) -> None: | |
"""Test MongoDB connection and log the result.""" | |
try: | |
is_connected = await mongo_db.ping() | |
if is_connected: | |
logger.info("MongoDB connection test successful!") | |
else: | |
logger.error("MongoDB connection test failed!") | |
except Exception as e: | |
logger.error("Error testing MongoDB connection: {}", e) | |
raise | |
async def shutdown(self: Self, mongo_db: MongoDB) -> None: | |
"""Close MongoDB connection on shutdown.""" | |
try: | |
await mongo_db.close() | |
except Exception as e: | |
logger.error("Error closing MongoDB connection: {}", e) | |