from __future__ import annotations from contextlib import asynccontextmanager, contextmanager from typing import TYPE_CHECKING from loguru import logger from langflow.services.schema import ServiceType if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator from sqlmodel import Session from sqlmodel.ext.asyncio.session import AsyncSession from langflow.services.cache.service import AsyncBaseCacheService, CacheService from langflow.services.chat.service import ChatService from langflow.services.database.service import DatabaseService from langflow.services.session.service import SessionService from langflow.services.settings.service import SettingsService from langflow.services.socket.service import SocketIOService from langflow.services.state.service import StateService from langflow.services.storage.service import StorageService from langflow.services.store.service import StoreService from langflow.services.task.service import TaskService from langflow.services.telemetry.service import TelemetryService from langflow.services.tracing.service import TracingService from langflow.services.variable.service import VariableService def get_service(service_type: ServiceType, default=None): """Retrieves the service instance for the given service type. Args: service_type (ServiceType): The type of service to retrieve. default (ServiceFactory, optional): The default ServiceFactory to use if the service is not found. Defaults to None. Returns: Any: The service instance. """ from langflow.services.manager import service_manager if not service_manager.factories: # ! This is a workaround to ensure that the service manager is initialized # ! Not optimal, but it works for now service_manager.register_factories() return service_manager.get(service_type, default) def get_telemetry_service() -> TelemetryService: """Retrieves the TelemetryService instance from the service manager. Returns: TelemetryService: The TelemetryService instance. """ from langflow.services.telemetry.factory import TelemetryServiceFactory return get_service(ServiceType.TELEMETRY_SERVICE, TelemetryServiceFactory()) def get_tracing_service() -> TracingService: """Retrieves the TracingService instance from the service manager. Returns: TracingService: The TracingService instance. """ from langflow.services.tracing.factory import TracingServiceFactory return get_service(ServiceType.TRACING_SERVICE, TracingServiceFactory()) def get_state_service() -> StateService: """Retrieves the StateService instance from the service manager. Returns: The StateService instance. """ from langflow.services.state.factory import StateServiceFactory return get_service(ServiceType.STATE_SERVICE, StateServiceFactory()) def get_socket_service() -> SocketIOService: """Get the SocketIOService instance from the service manager. Returns: SocketIOService: The SocketIOService instance. """ return get_service(ServiceType.SOCKETIO_SERVICE) # type: ignore[attr-defined] def get_storage_service() -> StorageService: """Retrieves the storage service instance. Returns: The storage service instance. """ from langflow.services.storage.factory import StorageServiceFactory return get_service(ServiceType.STORAGE_SERVICE, default=StorageServiceFactory()) def get_variable_service() -> VariableService: """Retrieves the VariableService instance from the service manager. Returns: The VariableService instance. """ from langflow.services.variable.factory import VariableServiceFactory return get_service(ServiceType.VARIABLE_SERVICE, VariableServiceFactory()) def get_settings_service() -> SettingsService: """Retrieves the SettingsService instance. If the service is not yet initialized, it will be initialized before returning. Returns: The SettingsService instance. Raises: ValueError: If the service cannot be retrieved or initialized. """ from langflow.services.settings.factory import SettingsServiceFactory return get_service(ServiceType.SETTINGS_SERVICE, SettingsServiceFactory()) def get_db_service() -> DatabaseService: """Retrieves the DatabaseService instance from the service manager. Returns: The DatabaseService instance. """ from langflow.services.database.factory import DatabaseServiceFactory return get_service(ServiceType.DATABASE_SERVICE, DatabaseServiceFactory()) async def get_session() -> AsyncGenerator[AsyncSession, None]: """Retrieves an async session from the database service. Yields: AsyncSession: An async session object. """ async with get_db_service().with_async_session() as session: yield session @contextmanager def session_scope() -> Generator[Session, None, None]: """Context manager for managing a session scope. This context manager is used to manage a session scope for database operations. It ensures that the session is properly committed if no exceptions occur, and rolled back if an exception is raised. Yields: Session: The session object. Raises: Exception: If an error occurs during the session scope. """ db_service = get_db_service() with db_service.with_session() as session: try: yield session session.commit() except Exception: logger.exception("An error occurred during the session scope.") session.rollback() raise @asynccontextmanager async def async_session_scope() -> AsyncGenerator[AsyncSession, None]: """Context manager for managing an async session scope. This context manager is used to manage an async session scope for database operations. It ensures that the session is properly committed if no exceptions occur, and rolled back if an exception is raised. Yields: AsyncSession: The async session object. Raises: Exception: If an error occurs during the session scope. """ db_service = get_db_service() async with db_service.with_async_session() as session: try: yield session await session.commit() except Exception: logger.exception("An error occurred during the session scope.") await session.rollback() raise def get_cache_service() -> CacheService | AsyncBaseCacheService: """Retrieves the cache service from the service manager. Returns: The cache service instance. """ from langflow.services.cache.factory import CacheServiceFactory return get_service(ServiceType.CACHE_SERVICE, CacheServiceFactory()) def get_shared_component_cache_service() -> CacheService: """Retrieves the cache service from the service manager. Returns: The cache service instance. """ from langflow.services.shared_component_cache.factory import SharedComponentCacheServiceFactory return get_service(ServiceType.SHARED_COMPONENT_CACHE_SERVICE, SharedComponentCacheServiceFactory()) def get_session_service() -> SessionService: """Retrieves the session service from the service manager. Returns: The session service instance. """ from langflow.services.session.factory import SessionServiceFactory return get_service(ServiceType.SESSION_SERVICE, SessionServiceFactory()) def get_task_service() -> TaskService: """Retrieves the TaskService instance from the service manager. Returns: The TaskService instance. """ from langflow.services.task.factory import TaskServiceFactory return get_service(ServiceType.TASK_SERVICE, TaskServiceFactory()) def get_chat_service() -> ChatService: """Get the chat service instance. Returns: ChatService: The chat service instance. """ return get_service(ServiceType.CHAT_SERVICE) def get_store_service() -> StoreService: """Retrieves the StoreService instance from the service manager. Returns: StoreService: The StoreService instance. """ return get_service(ServiceType.STORE_SERVICE)