generic-chatbot-backend / common /dependencies.py
muryshev's picture
update
86c402d
raw
history blame
3.57 kB
import logging
import os
from logging import Logger
from typing import Annotated
from fastapi import Depends
from ntr_text_fragmentation import InjectionBuilder
from sqlalchemy.orm import Session, sessionmaker
from common.configuration import Configuration
from common.db import session_factory
from components.dbo.chunk_repository import ChunkRepository
from components.embedding_extraction import EmbeddingExtractor
from components.llm.common import LlmParams
from components.llm.deepinfra_api import DeepInfraApi
from components.services.dataset import DatasetService
from components.services.document import DocumentService
from components.services.entity import EntityService
from components.services.llm_config import LLMConfigService
from components.services.llm_prompt import LlmPromptService
def get_config() -> Configuration:
return Configuration(os.environ.get('CONFIG_PATH', 'config_dev.yaml'))
def get_db() -> sessionmaker:
return session_factory
def get_logger() -> Logger:
return logging.getLogger(__name__)
def get_embedding_extractor(
config: Annotated[Configuration, Depends(get_config)],
) -> EmbeddingExtractor:
return EmbeddingExtractor(
config.db_config.faiss.model_embedding_path,
config.db_config.faiss.device,
)
def get_chunk_repository(db: Annotated[Session, Depends(get_db)]) -> ChunkRepository:
return ChunkRepository(db)
def get_injection_builder(
chunk_repository: Annotated[ChunkRepository, Depends(get_chunk_repository)],
) -> InjectionBuilder:
return InjectionBuilder(chunk_repository)
def get_entity_service(
vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
chunk_repository: Annotated[ChunkRepository, Depends(get_chunk_repository)],
config: Annotated[Configuration, Depends(get_config)],
) -> EntityService:
"""Получение сервиса для работы с сущностями через DI."""
return EntityService(vectorizer, chunk_repository, config)
def get_dataset_service(
entity_service: Annotated[EntityService, Depends(get_entity_service)],
config: Annotated[Configuration, Depends(get_config)],
db: Annotated[sessionmaker, Depends(get_db)],
) -> DatasetService:
"""Получение сервиса для работы с датасетами через DI."""
return DatasetService(entity_service, config, db)
def get_document_service(
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
config: Annotated[Configuration, Depends(get_config)],
db: Annotated[sessionmaker, Depends(get_db)],
) -> DocumentService:
return DocumentService(dataset_service, config, db)
def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
return LLMConfigService(db)
def get_llm_service(
config: Annotated[Configuration, Depends(get_config)],
) -> DeepInfraApi:
llm_params = LlmParams(
**{
"url": config.llm_config.base_url,
"model": config.llm_config.model,
"tokenizer": config.llm_config.tokenizer,
"type": "deepinfra",
"default": True,
"predict_params": None, # должны задаваться при каждом запросе
"api_key": os.environ.get(config.llm_config.api_key_env),
"context_length": 128000,
}
)
return DeepInfraApi(params=llm_params)
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
return LlmPromptService(db)