LiKenun's picture
Refactor #6
f0fe0fd
from asyncio import iscoroutine, isfuture
from dependency_injector import providers
from dependency_injector.containers import DeclarativeContainer
from dependency_injector.providers import Callable, Configuration, Dict, List, Resource, Singleton
from importlib import import_module
from itertools import chain
from openai import AsyncOpenAI
from pkgutil import iter_modules
from types import ModuleType
from typing import Any, Iterator, Sequence
from ctp_slack_bot.controllers import ControllerBase, ControllerRegistry
from ctp_slack_bot.core import Settings
from ctp_slack_bot.db.mongo_db import MongoDBResource
from ctp_slack_bot.db.repositories.mongo_db_vectorized_chunk_repository import MongoVectorizedChunkRepositoryResource
from ctp_slack_bot.mime_type_handlers import MimeTypeHandlerRegistry
from ctp_slack_bot.services.answer_retrieval_service import AnswerRetrievalService
from ctp_slack_bot.services.application_health_service import ApplicationHealthService
from ctp_slack_bot.services.content_ingestion_service import ContentIngestionServiceResource
from ctp_slack_bot.services.context_retrieval_service import ContextRetrievalService
from ctp_slack_bot.services.embeddings_model_service import EmbeddingsModelService
from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
from ctp_slack_bot.services.google_drive_service import GoogleDriveService
from ctp_slack_bot.services.http_client_service import HTTPClientServiceResource
from ctp_slack_bot.services.http_server_service import HTTPServerResource
from ctp_slack_bot.services.language_model_service import LanguageModelService
from ctp_slack_bot.services.question_dispatch_service import QuestionDispatchServiceResource
from ctp_slack_bot.services.slack_service import SlackServiceResource
from ctp_slack_bot.services.task_service import TaskServiceResource
from ctp_slack_bot.services.vectorization_service import VectorizationService
async def _await_or_return(value):
if iscoroutine(value) or isfuture(value):
return await value
return value
class Container(DeclarativeContainer): # TODO: audit for potential async-related bugs.
async def __get_http_controller_providers(container) -> Sequence[ControllerBase]:
return [controller_class(**{dependency_name: await _await_or_return(container.providers[dependency_name]())
for dependency_name
in controller_class.model_fields.keys() & container.providers.keys()})
for controller_class in ControllerRegistry.get_registry()]
def __iter_mime_type_handler_providers() -> Iterator[tuple[str, Singleton]]:
handler_provider_map = {}
for mime_type, handler in MimeTypeHandlerRegistry.get_registry().items():
if handler in handler_provider_map:
provider = handler_provider_map[handler]
else:
provider = Singleton(handler)
handler_provider_map[handler] = provider
yield (mime_type, provider)
__self__ = providers.Self()
settings = Singleton(Settings)
event_brokerage_service = Singleton(EventBrokerageService)
http_client = Resource (HTTPClientServiceResource)
mongo_db = Resource (MongoDBResource,
settings=settings)
vectorized_chunk_repository = Resource (MongoVectorizedChunkRepositoryResource,
settings=settings,
mongo_db=mongo_db)
open_ai_client = Singleton(lambda settings: AsyncOpenAI(api_key=settings.openai_api_key.get_secret_value()),
settings=settings)
embeddings_model_service = Singleton(EmbeddingsModelService,
settings=settings,
open_ai_client=open_ai_client)
vectorization_service = Singleton(VectorizationService,
settings=settings,
embeddings_model_service=embeddings_model_service)
content_ingestion_service = Resource (ContentIngestionServiceResource,
settings=settings,
event_brokerage_service=event_brokerage_service,
vectorized_chunk_repository=vectorized_chunk_repository,
vectorization_service=vectorization_service)
context_retrieval_service = Singleton(ContextRetrievalService,
settings=settings,
vectorization_service=vectorization_service,
vectorized_chunk_repository=vectorized_chunk_repository)
language_model_service = Singleton(LanguageModelService,
settings=settings,
open_ai_client=open_ai_client)
answer_retrieval_service = Singleton(AnswerRetrievalService,
settings=settings,
event_brokerage_service=event_brokerage_service,
language_model_service=language_model_service)
question_dispatch_service = Resource (QuestionDispatchServiceResource,
settings=settings,
event_brokerage_service=event_brokerage_service,
context_retrieval_service=context_retrieval_service,
answer_retrieval_service=answer_retrieval_service)
slack_service = Resource (SlackServiceResource,
settings=settings,
event_brokerage_service=event_brokerage_service,
http_client=http_client)
mime_type_handlers = Dict ({mime_type: handler_provider
for mime_type, handler_provider
in __iter_mime_type_handler_providers()})
google_drive_service = Singleton(GoogleDriveService,
settings=settings)
# file_monitor_service = Singleton(FileMonitorService,
# settings=settings,
# google_drive_service=google_drive_service,
# mime_type_handler_factory=mime_type_handler_factory)
application_health_service = Singleton(ApplicationHealthService,
services=List(mongo_db, slack_service))
task_service = Resource (TaskServiceResource,
settings=settings)
http_server = Resource (HTTPServerResource,
settings=settings,
controllers=Callable(__get_http_controller_providers, __self__))