muryshev's picture
update
0341212
raw
history blame
12 kB
import logging
from typing import Callable, Optional
from uuid import UUID
import numpy as np
from ntr_fileparser import ParsedDocument
from ntr_text_fragmentation import EntitiesExtractor, InjectionBuilder
from common.configuration import Configuration
from components.dbo.chunk_repository import ChunkRepository
from components.embedding_extraction import EmbeddingExtractor
from components.llm.deepinfra_api import DeepInfraApi
from components.search.appendices_chunker import APPENDICES_CHUNKER
from components.search.faiss_vector_search import FaissVectorSearch
from components.services.llm_config import LLMConfigService
logger = logging.getLogger(__name__)
class EntityService:
"""
Сервис для работы с сущностями.
Объединяет функциональность chunk_repository, destructurer, injection_builder и faiss_vector_search.
"""
def __init__(
self,
vectorizer: EmbeddingExtractor,
chunk_repository: ChunkRepository,
config: Configuration,
llm_api: DeepInfraApi,
llm_config_service: LLMConfigService,
) -> None:
"""
Инициализация сервиса.
Args:
vectorizer: Модель для извлечения эмбеддингов
chunk_repository: Репозиторий для работы с чанками
config: Конфигурация приложения
llm_api: Клиент для взаимодействия с LLM API
llm_config_service: Сервис для получения конфигурации LLM
"""
self.vectorizer = vectorizer
self.config = config
self.chunk_repository = chunk_repository
self.llm_api = llm_api
self.llm_config_service = llm_config_service
self.faiss_search = None
self.current_dataset_id = None
self.neighbors_max_distance = config.db_config.entities.neighbors_max_distance
self.max_entities_per_message = config.db_config.search.max_entities_per_message
self.max_entities_per_dialogue = (
config.db_config.search.max_entities_per_dialogue
)
self.main_extractor = EntitiesExtractor(
strategy_name=config.db_config.entities.strategy_name,
strategy_params=config.db_config.entities.strategy_params,
process_tables=config.db_config.entities.process_tables,
)
self.appendices_extractor = EntitiesExtractor(
strategy_name=APPENDICES_CHUNKER,
strategy_params={
"llm_api": self.llm_api,
"llm_config_service": self.llm_config_service,
},
process_tables=False,
)
def _ensure_faiss_initialized(self, dataset_id: int) -> None:
"""
Проверяет и при необходимости инициализирует или обновляет FAISS индекс.
Args:
dataset_id: ID датасета для инициализации
"""
# Если индекс не инициализирован или датасет изменился
if self.faiss_search is None or self.current_dataset_id != dataset_id:
logger.info(f'Initializing FAISS for dataset {dataset_id}')
entities, embeddings = self.chunk_repository.get_searching_entities(
dataset_id
)
if entities:
embeddings_dict = {
str(entity.id): embedding # Преобразуем UUID в строку для ключа
for entity, embedding in zip(entities, embeddings)
if embedding is not None
}
if embeddings_dict: # Проверяем, что есть хотя бы один эмбеддинг
self.faiss_search = FaissVectorSearch(
self.vectorizer,
embeddings_dict,
)
self.current_dataset_id = dataset_id
logger.info(
f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings'
)
else:
logger.warning(
f'No valid embeddings found for dataset {dataset_id}'
)
self.faiss_search = None
self.current_dataset_id = None
else:
logger.warning(f'No entities found for dataset {dataset_id}')
self.faiss_search = None
self.current_dataset_id = None
async def process_document(
self,
document: ParsedDocument,
dataset_id: int,
progress_callback: Optional[Callable] = None,
) -> None:
"""
Асинхронная обработка документа: разбиение на чанки и сохранение в базу.
Args:
document: Документ для обработки
dataset_id: ID датасета
progress_callback: Функция для отслеживания прогресса
"""
logger.info(f"Processing document {document.name} for dataset {dataset_id}")
if 'Приложение' in document.name:
entities = await self.appendices_extractor.extract_async(document)
else:
entities = await self.main_extractor.extract_async(document)
# Фильтруем сущности для поиска
filtering_entities = [
entity for entity in entities if entity.in_search_text is not None
]
filtering_texts = [entity.in_search_text for entity in filtering_entities]
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback)
embeddings_dict = {
str(entity.id): embedding
for entity, embedding in zip(filtering_entities, embeddings)
}
# Сохраняем в базу
self.chunk_repository.add_entities(entities, dataset_id, embeddings_dict)
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}")
def build_text(
self,
entities: list[str],
chunk_scores: Optional[list[float]] = None,
include_tables: bool = True,
max_documents: Optional[int] = None,
) -> str:
"""
Сборка текста из сущностей.
Args:
entities: Список идентификаторов сущностей
chunk_scores: Список весов чанков
include_tables: Флаг включения таблиц
max_documents: Максимальное количество документов
Returns:
Собранный текст
"""
entities = [UUID(entity) for entity in entities]
entities = self.chunk_repository.get_entities_by_ids(entities)
logger.info(f"Building text for {len(entities)} entities")
if chunk_scores is not None:
chunk_scores = {
entity.id: score for entity, score in zip(entities, chunk_scores)
}
builder = InjectionBuilder(self.chunk_repository)
return builder.build(
entities,
scores=chunk_scores,
include_tables=include_tables,
neighbors_max_distance=self.neighbors_max_distance,
max_documents=max_documents,
)
def search_similar_old(
self,
query: str,
dataset_id: int,
k: int | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Поиск похожих сущностей.
Args:
query: Текст запроса
dataset_id: ID датасета
k: Максимальное количество возвращаемых результатов (по умолчанию - все).
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]:
- Вектор запроса
- Оценки сходства
- Идентификаторы найденных сущностей
"""
logger.info(f"Searching similar entities for dataset {dataset_id} with k={k}")
# Убедимся, что индекс для нужного датасета загружен
self._ensure_faiss_initialized(dataset_id)
if self.faiss_search is None:
logger.warning(
f"FAISS search not initialized for dataset {dataset_id}. Returning empty results."
)
return np.array([]), np.array([]), np.array([])
# Выполняем поиск с использованием параметра k
query_vector, scores, ids = self.faiss_search.search_vectors(query, max_entities=k)
logger.info(f"Found {len(ids)} similar entities.")
return query_vector, scores, ids
def search_similar(
self,
query: str,
dataset_id: int,
previous_entities: list[list[str]] = None,
) -> tuple[list[list[str]], list[str], list[float]]:
"""
Поиск похожих сущностей.
Args:
query: Текст запроса
dataset_id: ID датасета
previous_entities: Список идентификаторов сущностей, которые уже были найдены
Returns:
tuple[list[list[str]], list[str], list[float]]:
- Перефильтрованный список идентификаторов сущностей из прошлых запросов
- Список идентификаторов найденных сущностей
- Скоры найденных сущностей
"""
self._ensure_faiss_initialized(dataset_id)
if self.faiss_search is None:
return previous_entities, [], []
if (
sum(len(entities) for entities in previous_entities)
< self.max_entities_per_dialogue - self.max_entities_per_message
):
_, scores, ids = self.faiss_search.search_vectors(
query, self.max_entities_per_message
)
try:
scores = scores.tolist()
ids = ids.tolist()
except:
scores = list(scores)
ids = list(ids)
return previous_entities, ids, scores
if previous_entities:
_, scores, ids = self.faiss_search.search_vectors(
query, self.max_entities_per_dialogue
)
scores = scores.tolist()
ids = ids.tolist()
print(ids)
previous_entities_ids = [
[entity for entity in sublist if entity in ids]
for sublist in previous_entities
]
previous_entities_flat = [
entity for sublist in previous_entities_ids for entity in sublist
]
new_entities = []
new_scores = []
for id_, score in zip(ids, scores):
if id_ not in previous_entities_flat:
new_entities.append(id_)
new_scores.append(score)
if len(new_entities) >= self.max_entities_per_message:
break
return previous_entities, new_entities, new_scores
else:
_, scores, ids = self.faiss_search.search_vectors(
query, self.max_entities_per_dialogue
)
scores = scores.tolist()
ids = ids.tolist()
return [], ids, scores