Spaces:
Sleeping
Sleeping
import logging | |
from typing import Callable, Optional | |
from uuid import UUID | |
import numpy as np | |
from ntr_fileparser import ParsedDocument | |
from ntr_text_fragmentation import EntitiesExtractor, InjectionBuilder | |
from common.configuration import Configuration | |
from components.dbo.chunk_repository import ChunkRepository | |
from components.embedding_extraction import EmbeddingExtractor | |
from components.llm.deepinfra_api import DeepInfraApi | |
from components.search.appendices_chunker import APPENDICES_CHUNKER | |
from components.search.faiss_vector_search import FaissVectorSearch | |
from components.services.llm_config import LLMConfigService | |
logger = logging.getLogger(__name__) | |
class EntityService: | |
""" | |
Сервис для работы с сущностями. | |
Объединяет функциональность chunk_repository, destructurer, injection_builder и faiss_vector_search. | |
""" | |
def __init__( | |
self, | |
vectorizer: EmbeddingExtractor, | |
chunk_repository: ChunkRepository, | |
config: Configuration, | |
llm_api: DeepInfraApi, | |
llm_config_service: LLMConfigService, | |
) -> None: | |
""" | |
Инициализация сервиса. | |
Args: | |
vectorizer: Модель для извлечения эмбеддингов | |
chunk_repository: Репозиторий для работы с чанками | |
config: Конфигурация приложения | |
llm_api: Клиент для взаимодействия с LLM API | |
llm_config_service: Сервис для получения конфигурации LLM | |
""" | |
self.vectorizer = vectorizer | |
self.config = config | |
self.chunk_repository = chunk_repository | |
self.llm_api = llm_api | |
self.llm_config_service = llm_config_service | |
self.faiss_search = None | |
self.current_dataset_id = None | |
self.neighbors_max_distance = config.db_config.entities.neighbors_max_distance | |
self.max_entities_per_message = config.db_config.search.max_entities_per_message | |
self.max_entities_per_dialogue = ( | |
config.db_config.search.max_entities_per_dialogue | |
) | |
self.main_extractor = EntitiesExtractor( | |
strategy_name=config.db_config.entities.strategy_name, | |
strategy_params=config.db_config.entities.strategy_params, | |
process_tables=config.db_config.entities.process_tables, | |
) | |
self.appendices_extractor = EntitiesExtractor( | |
strategy_name=APPENDICES_CHUNKER, | |
strategy_params={ | |
"llm_api": self.llm_api, | |
"llm_config_service": self.llm_config_service, | |
}, | |
process_tables=False, | |
) | |
def _ensure_faiss_initialized(self, dataset_id: int) -> None: | |
""" | |
Проверяет и при необходимости инициализирует или обновляет FAISS индекс. | |
Args: | |
dataset_id: ID датасета для инициализации | |
""" | |
# Если индекс не инициализирован или датасет изменился | |
if self.faiss_search is None or self.current_dataset_id != dataset_id: | |
logger.info(f'Initializing FAISS for dataset {dataset_id}') | |
entities, embeddings = self.chunk_repository.get_searching_entities( | |
dataset_id | |
) | |
if entities: | |
embeddings_dict = { | |
str(entity.id): embedding # Преобразуем UUID в строку для ключа | |
for entity, embedding in zip(entities, embeddings) | |
if embedding is not None | |
} | |
if embeddings_dict: # Проверяем, что есть хотя бы один эмбеддинг | |
self.faiss_search = FaissVectorSearch( | |
self.vectorizer, | |
embeddings_dict, | |
) | |
self.current_dataset_id = dataset_id | |
logger.info( | |
f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings' | |
) | |
else: | |
logger.warning( | |
f'No valid embeddings found for dataset {dataset_id}' | |
) | |
self.faiss_search = None | |
self.current_dataset_id = None | |
else: | |
logger.warning(f'No entities found for dataset {dataset_id}') | |
self.faiss_search = None | |
self.current_dataset_id = None | |
async def process_document( | |
self, | |
document: ParsedDocument, | |
dataset_id: int, | |
progress_callback: Optional[Callable] = None, | |
) -> None: | |
""" | |
Асинхронная обработка документа: разбиение на чанки и сохранение в базу. | |
Args: | |
document: Документ для обработки | |
dataset_id: ID датасета | |
progress_callback: Функция для отслеживания прогресса | |
""" | |
logger.info(f"Processing document {document.name} for dataset {dataset_id}") | |
if 'Приложение' in document.name: | |
entities = await self.appendices_extractor.extract_async(document) | |
else: | |
entities = await self.main_extractor.extract_async(document) | |
# Фильтруем сущности для поиска | |
filtering_entities = [ | |
entity for entity in entities if entity.in_search_text is not None | |
] | |
filtering_texts = [entity.in_search_text for entity in filtering_entities] | |
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback) | |
embeddings_dict = { | |
str(entity.id): embedding | |
for entity, embedding in zip(filtering_entities, embeddings) | |
} | |
# Сохраняем в базу | |
self.chunk_repository.add_entities(entities, dataset_id, embeddings_dict) | |
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}") | |
def build_text( | |
self, | |
entities: list[str], | |
chunk_scores: Optional[list[float]] = None, | |
include_tables: bool = True, | |
max_documents: Optional[int] = None, | |
) -> str: | |
""" | |
Сборка текста из сущностей. | |
Args: | |
entities: Список идентификаторов сущностей | |
chunk_scores: Список весов чанков | |
include_tables: Флаг включения таблиц | |
max_documents: Максимальное количество документов | |
Returns: | |
Собранный текст | |
""" | |
entities = [UUID(entity) for entity in entities] | |
entities = self.chunk_repository.get_entities_by_ids(entities) | |
logger.info(f"Building text for {len(entities)} entities") | |
if chunk_scores is not None: | |
chunk_scores = { | |
entity.id: score for entity, score in zip(entities, chunk_scores) | |
} | |
builder = InjectionBuilder(self.chunk_repository) | |
return builder.build( | |
entities, | |
scores=chunk_scores, | |
include_tables=include_tables, | |
neighbors_max_distance=self.neighbors_max_distance, | |
max_documents=max_documents, | |
) | |
def search_similar_old( | |
self, | |
query: str, | |
dataset_id: int, | |
k: int | None = None, | |
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Поиск похожих сущностей. | |
Args: | |
query: Текст запроса | |
dataset_id: ID датасета | |
k: Максимальное количество возвращаемых результатов (по умолчанию - все). | |
Returns: | |
tuple[np.ndarray, np.ndarray, np.ndarray]: | |
- Вектор запроса | |
- Оценки сходства | |
- Идентификаторы найденных сущностей | |
""" | |
logger.info(f"Searching similar entities for dataset {dataset_id} with k={k}") | |
# Убедимся, что индекс для нужного датасета загружен | |
self._ensure_faiss_initialized(dataset_id) | |
if self.faiss_search is None: | |
logger.warning( | |
f"FAISS search not initialized for dataset {dataset_id}. Returning empty results." | |
) | |
return np.array([]), np.array([]), np.array([]) | |
# Выполняем поиск с использованием параметра k | |
query_vector, scores, ids = self.faiss_search.search_vectors(query, max_entities=k) | |
logger.info(f"Found {len(ids)} similar entities.") | |
return query_vector, scores, ids | |
def search_similar( | |
self, | |
query: str, | |
dataset_id: int, | |
previous_entities: list[list[str]] = None, | |
) -> tuple[list[list[str]], list[str], list[float]]: | |
""" | |
Поиск похожих сущностей. | |
Args: | |
query: Текст запроса | |
dataset_id: ID датасета | |
previous_entities: Список идентификаторов сущностей, которые уже были найдены | |
Returns: | |
tuple[list[list[str]], list[str], list[float]]: | |
- Перефильтрованный список идентификаторов сущностей из прошлых запросов | |
- Список идентификаторов найденных сущностей | |
- Скоры найденных сущностей | |
""" | |
self._ensure_faiss_initialized(dataset_id) | |
if self.faiss_search is None: | |
return previous_entities, [], [] | |
if ( | |
sum(len(entities) for entities in previous_entities) | |
< self.max_entities_per_dialogue - self.max_entities_per_message | |
): | |
_, scores, ids = self.faiss_search.search_vectors( | |
query, self.max_entities_per_message | |
) | |
try: | |
scores = scores.tolist() | |
ids = ids.tolist() | |
except: | |
scores = list(scores) | |
ids = list(ids) | |
return previous_entities, ids, scores | |
if previous_entities: | |
_, scores, ids = self.faiss_search.search_vectors( | |
query, self.max_entities_per_dialogue | |
) | |
scores = scores.tolist() | |
ids = ids.tolist() | |
print(ids) | |
previous_entities_ids = [ | |
[entity for entity in sublist if entity in ids] | |
for sublist in previous_entities | |
] | |
previous_entities_flat = [ | |
entity for sublist in previous_entities_ids for entity in sublist | |
] | |
new_entities = [] | |
new_scores = [] | |
for id_, score in zip(ids, scores): | |
if id_ not in previous_entities_flat: | |
new_entities.append(id_) | |
new_scores.append(score) | |
if len(new_entities) >= self.max_entities_per_message: | |
break | |
return previous_entities, new_entities, new_scores | |
else: | |
_, scores, ids = self.faiss_search.search_vectors( | |
query, self.max_entities_per_dialogue | |
) | |
scores = scores.tolist() | |
ids = ids.tolist() | |
return [], ids, scores | |