muryshev's picture
update
fd78d64
raw
history blame
20.1 kB
import asyncio
import logging
from typing import Callable, Optional
from uuid import UUID
import numpy as np
from ntr_fileparser import ParsedDocument
from ntr_text_fragmentation import (EntitiesExtractor, EntityRepository,
InjectionBuilder, InMemoryEntityRepository, LinkerEntity)
from common.configuration import Configuration
from components.dbo.chunk_repository import ChunkRepository
from components.embedding_extraction import EmbeddingExtractor
from components.llm.deepinfra_api import DeepInfraApi
from components.search.appendices_chunker import APPENDICES_CHUNKER
from components.search.faiss_vector_search import FaissVectorSearch
from components.services.llm_config import LLMConfigService
logger = logging.getLogger(__name__)
class EntityService:
"""
Сервис для работы с сущностями.
Объединяет функциональность chunk_repository, destructurer, injection_builder и faiss_vector_search.
"""
def __init__(
self,
vectorizer: EmbeddingExtractor,
chunk_repository: ChunkRepository,
config: Configuration,
llm_api: DeepInfraApi,
llm_config_service: LLMConfigService,
) -> None:
"""
Инициализация сервиса.
Args:
vectorizer: Модель для извлечения эмбеддингов
chunk_repository: Репозиторий для работы с чанками
config: Конфигурация приложения
llm_api: Клиент для взаимодействия с LLM API
llm_config_service: Сервис для получения конфигурации LLM
"""
self.vectorizer = vectorizer
self.config = config
self.chunk_repository = chunk_repository
self.llm_api = llm_api
self.llm_config_service = llm_config_service
self.faiss_search = None
self.current_dataset_id = None
self.neighbors_max_distance = config.db_config.entities.neighbors_max_distance
self.max_entities_per_message = config.db_config.search.max_entities_per_message
self.max_entities_per_dialogue = (
config.db_config.search.max_entities_per_dialogue
)
self.main_extractor = EntitiesExtractor(
strategy_name=config.db_config.entities.strategy_name,
strategy_params=config.db_config.entities.strategy_params,
process_tables=config.db_config.entities.process_tables,
)
self.appendices_extractor = EntitiesExtractor(
strategy_name=APPENDICES_CHUNKER,
strategy_params={
"llm_api": self.llm_api,
"llm_config_service": self.llm_config_service,
},
process_tables=False,
)
self._in_memory_cache: InMemoryEntityRepository = None
self._cached_dataset_id: int | None = None
def invalidate_cache(self) -> None:
"""Инвалидирует (удаляет) текущий кеш в памяти."""
if self._in_memory_cache:
self._in_memory_cache = None
self._cached_dataset_id = None
else:
logger.info("In-memory кеш уже пуст. Ничего не делаем.")
def build_cache(self, dataset_id: int) -> None:
"""Строит кеш для указанного датасета."""
all_entities = self.chunk_repository.get_all_entities_for_dataset(dataset_id)
in_memory_repo = InMemoryEntityRepository(entities=all_entities)
self._in_memory_cache = in_memory_repo
self._cached_dataset_id = dataset_id
async def build_or_rebuild_cache_async(self, dataset_id: int) -> None:
"""
Строит или перестраивает кеш для указанного датасета, удаляя предыдущий кеш.
"""
all_entities = await self.chunk_repository.get_all_entities_for_dataset_async(dataset_id)
if not all_entities:
logger.warning(f"No entities found for dataset {dataset_id}. Cache not built.")
self._in_memory_cache = None
self._cached_dataset_id = None
return
logger.info(f"Building new in-memory cache for dataset {dataset_id}")
in_memory_repo = InMemoryEntityRepository(entities=all_entities)
self._in_memory_cache = in_memory_repo
self._cached_dataset_id = dataset_id
logger.info(f"Cached {len(all_entities)} entities for dataset {dataset_id}")
def _get_repository_for_dataset(self, dataset_id: int) -> EntityRepository:
"""
Возвращает кешированный репозиторий, если он существует и соответствует
запрошенному dataset_id, иначе возвращает основной репозиторий ChunkRepository.
"""
# Проверяем совпадение ID с закешированным
if self._cached_dataset_id == dataset_id and self._in_memory_cache is not None:
return self._in_memory_cache
else:
# Логируем причину промаха кеша для диагностики
if not self._in_memory_cache:
logger.warning(f"Cache miss for dataset {dataset_id}: Cache is empty. Using ChunkRepository (DB).")
elif self._cached_dataset_id != dataset_id:
logger.warning(f"Cache miss for dataset {dataset_id}: Cache contains data for dataset {self._cached_dataset_id}. Using ChunkRepository (DB).")
else: # На случай непредвиденной ситуации
logger.warning(f"Cache miss for dataset {dataset_id}: Unknown reason. Using ChunkRepository (DB).")
return self.chunk_repository
def _ensure_faiss_initialized(self, dataset_id: int) -> None:
"""
Проверяет и при необходимости инициализирует или обновляет FAISS индекс.
Args:
dataset_id: ID датасета для инициализации
"""
# Переинициализируем FAISS, только если ID датасета изменился
if self.faiss_search is None or self.current_dataset_id != dataset_id:
logger.info(f'Initializing FAISS for dataset {dataset_id}')
entities, embeddings = self.chunk_repository.get_searching_entities(
dataset_id
)
if entities:
embeddings_dict = {
str(entity.id): embedding # Преобразуем UUID в строку для ключа
for entity, embedding in zip(entities, embeddings)
if embedding is not None
}
if embeddings_dict: # Проверяем, что есть хотя бы один эмбеддинг
self.faiss_search = FaissVectorSearch(
self.vectorizer,
embeddings_dict,
)
self.current_dataset_id = dataset_id
logger.info(
f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings'
)
else:
logger.warning(
f'No valid embeddings found for dataset {dataset_id}'
)
self.faiss_search = None
self.current_dataset_id = None
else:
logger.warning(f'No entities found for dataset {dataset_id}')
self.faiss_search = None
self.current_dataset_id = None
async def process_document(
self,
document: ParsedDocument,
dataset_id: int,
progress_callback: Optional[Callable] = None,
) -> None:
"""
Асинхронная обработка документа: разбиение на чанки и сохранение в базу.
Args:
document: Документ для обработки
dataset_id: ID датасета
progress_callback: Функция для отслеживания прогресса
"""
logger.info(f"Processing document {document.name} for dataset {dataset_id}")
# Определяем экстрактор в зависимости от имени документа
if 'Приложение' in document.name:
entities = await self.appendices_extractor.extract_async(document)
else:
entities = await self.main_extractor.extract_async(document)
# Фильтруем сущности для поиска
filtering_entities = [
entity for entity in entities if entity.in_search_text is not None
]
filtering_texts = [entity.in_search_text for entity in filtering_entities]
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback)
# Собираем словарь эмбеддингов только для найденных сущностей
embeddings_dict = {}
if embeddings is not None:
embeddings_dict = {
str(entity.id): embedding
for entity, embedding in zip(filtering_entities, embeddings)
if embedding is not None
}
else:
logger.warning(f"Vectorizer returned None for document {document.name}")
# Сохраняем в базу
await self.chunk_repository.add_entities_async(entities, dataset_id, embeddings_dict)
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}")
async def add_entities_batch_async(
self,
dataset_id: int,
entities: list[LinkerEntity],
embeddings: dict[str, np.ndarray],
):
"""Асинхронно добавляет батч сущностей и их эмбеддингов в БД."""
if not entities:
logger.info("add_entities_batch_async called with empty entities list. Nothing to add.")
return
logger.info(f"Starting batch insertion of {len(entities)} entities for dataset {dataset_id}...")
try:
await asyncio.to_thread(
self.chunk_repository.add_entities,
entities,
dataset_id,
embeddings
)
logger.info(f"Batch insertion of {len(entities)} entities finished for dataset {dataset_id}.")
except Exception as e:
logger.error(
f"Error during batch insertion for dataset {dataset_id}: {e}",
exc_info=True,
)
raise e
async def prepare_document_data_async(
self,
document: ParsedDocument,
progress_callback: Optional[Callable] = None,
) -> tuple[list[LinkerEntity], dict[str, np.ndarray]]:
"""Асинхронно извлекает сущности и векторы для документа.
Не сохраняет данные в репозиторий, а возвращает их для последующей
батчевой обработки.
Args:
document: Документ для обработки.
progress_callback: Функция для отслеживания прогресса векторизации.
Returns:
Кортеж: (список извлеченных LinkerEntity, словарь эмбеддингов {id_str: embedding}).
"""
logger.debug(f"Preparing data for document {document.name}")
# 1. Извлечение сущностей
if 'Приложение' in document.name:
entities = await self.appendices_extractor.extract_async(document)
else:
entities = await self.main_extractor.extract_async(document)
# 2. Векторизация (если нужно)
filtering_entities = [
entity for entity in entities if entity.in_search_text is not None
]
filtering_texts = [entity.in_search_text for entity in filtering_entities]
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback)
embeddings_dict = {}
if embeddings is not None:
embeddings_dict = {
str(entity.id): embedding
for entity, embedding in zip(filtering_entities, embeddings)
if embedding is not None
}
else:
logger.warning(f"Vectorizer returned None for document {document.name}")
logger.debug(f"Prepared data for document {document.name}: {len(entities)} entities, {len(embeddings_dict)} embeddings.")
return entities, embeddings_dict
async def build_text_async(
self,
entities: list[str],
dataset_id: int,
chunk_scores: Optional[list[float]] = None,
include_tables: bool = True,
max_documents: Optional[int] = None,
) -> str:
"""
Асинхронная сборка текста из сущностей с использованием кешированного или основного репозитория.
Args:
entities: Список идентификаторов сущностей (строки UUID)
dataset_id: ID датасета для получения репозитория (кешированного или БД)
chunk_scores: Список весов чанков (соответствует порядку entities)
include_tables: Флаг включения таблиц
max_documents: Максимальное количество документов
Returns:
Собранный текст
"""
if not entities:
logger.warning("build_text called with empty entities list.")
return ""
try:
entity_ids = [UUID(entity) for entity in entities]
except ValueError as e:
logger.error(f"Invalid UUID format found in entities list: {e}")
raise ValueError(f"Invalid UUID format in entities list: {entities}") from e
repository = self._get_repository_for_dataset(dataset_id)
# Передаем репозиторий (кеш или БД) в InjectionBuilder
builder = InjectionBuilder(repository=repository)
# Создаем словарь score_map UUID -> score, если chunk_scores предоставлены
scores_map: dict[UUID, float] | None = None
if chunk_scores is not None:
if len(entity_ids) == len(chunk_scores):
scores_map = {eid: score for eid, score in zip(entity_ids, chunk_scores)}
else:
logger.warning(f"Length mismatch between entities ({len(entity_ids)}) and chunk_scores ({len(chunk_scores)}). Scores ignored.")
logger.info(f"Building text for {len(entity_ids)} entities from dataset {dataset_id} using {repository.__class__.__name__}")
# Вызываем асинхронный метод сборщика
return await builder.build_async(
entities=entity_ids, # Передаем список UUID
scores=scores_map, # Передаем словарь UUID -> score
include_tables=include_tables,
neighbors_max_distance=self.neighbors_max_distance,
max_documents=max_documents,
)
def search_similar_old(
self,
query: str,
dataset_id: int,
k: int | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Поиск похожих сущностей.
Args:
query: Текст запроса
dataset_id: ID датасета
k: Максимальное количество возвращаемых результатов (по умолчанию - все).
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]:
- Вектор запроса
- Оценки сходства
- Идентификаторы найденных сущностей
"""
logger.info(f"Searching similar entities for dataset {dataset_id} with k={k}")
self._ensure_faiss_initialized(dataset_id)
if self.faiss_search is None:
logger.warning(
f"FAISS search not initialized for dataset {dataset_id}. Returning empty results."
)
return np.array([]), np.array([]), np.array([])
# Выполняем поиск с использованием параметра k
query_vector, scores, ids = self.faiss_search.search_vectors(query, max_entities=k)
logger.info(f"Found {len(ids)} similar entities.")
return query_vector, scores, ids
def search_similar(
self,
query: str,
dataset_id: int,
previous_entities: list[list[str]] = None,
) -> tuple[list[list[str]], list[str], list[float]]:
"""
Поиск похожих сущностей.
Args:
query: Текст запроса
dataset_id: ID датасета
previous_entities: Список идентификаторов сущностей, которые уже были найдены
Returns:
tuple[list[list[str]], list[str], list[float]]:
- Перефильтрованный список идентификаторов сущностей из прошлых запросов
- Список идентификаторов найденных сущностей (строки UUID)
- Скоры найденных сущностей
"""
self._ensure_faiss_initialized(dataset_id)
if self.faiss_search is None:
return previous_entities, [], []
if (
sum(len(entities) for entities in previous_entities)
< self.max_entities_per_dialogue - self.max_entities_per_message
):
_, scores, ids = self.faiss_search.search_vectors(
query, self.max_entities_per_message
)
try:
scores = scores.tolist()
ids = ids.tolist()
except:
scores = list(scores)
ids = list(ids)
return previous_entities, ids, scores
if previous_entities:
_, scores, ids = self.faiss_search.search_vectors(
query, self.max_entities_per_dialogue
)
scores = scores.tolist()
ids = ids.tolist()
print(ids)
previous_entities_ids = [
[entity for entity in sublist if entity in ids]
for sublist in previous_entities
]
previous_entities_flat = [
entity for sublist in previous_entities_ids for entity in sublist
]
new_entities = []
new_scores = []
for id_, score in zip(ids, scores):
if id_ not in previous_entities_flat:
new_entities.append(id_)
new_scores.append(score)
if len(new_entities) >= self.max_entities_per_message:
break
return previous_entities, new_entities, new_scores
else:
_, scores, ids = self.faiss_search.search_vectors(
query, self.max_entities_per_dialogue
)
scores = scores.tolist()
ids = ids.tolist()
return [], ids, scores