diff --git a/common/configuration.py b/common/configuration.py index 3d189e0716773406064b761aecf3a9cf9b79cb32..1a1de94db3f22f1a0bc9f1d19392846f3a31d17f 100644 --- a/common/configuration.py +++ b/common/configuration.py @@ -1,221 +1,48 @@ """This module includes classes to define configurations.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from pyaml_env import parse_config -from pydantic import BaseModel -class Query(BaseModel): - query: str - query_abbreviation: str - abbreviations_replaced: Optional[List] = None - userName: Optional[str] = None - - -class SemanticChunk(BaseModel): - index_answer: int - doc_name: str - title: str - text_answer: str - # doc_number: str # TODO Потом поменять название переменной на doc_id везде с чем это будет связанно - other_info: List - start_index_paragraph: int - - -class FilterChunks(BaseModel): - id: str - filename: str - title: str - chunks: List[SemanticChunk] - - -class BusinessProcess(BaseModel): - production_activities_section: Optional[str] - processes_name: Optional[str] - level_process: Optional[str] - - -class Lead(BaseModel): - person: Optional[str] - leads: Optional[str] - - -class Subordinate(BaseModel): - person_name: Optional[str] - position: Optional[str] - - -class OrganizationalStructure(BaseModel): - position: Optional[str] = None - leads: Optional[List[Lead]] = None - subordinates: Optional[Subordinate] = None - - -class RocksNN(BaseModel): - division: Optional[str] - company_name: Optional[str] - - -class RocksNNSearch(BaseModel): - division: Optional[str] - company_name: Optional[List] - - -class SegmentationSearch(BaseModel): - segmentation_model: Optional[str] - company_name: Optional[List] - - -class Group(BaseModel): - group_name: Optional[str] - position_in_group: Optional[str] - block: Optional[str] - - -class GroupComposition(BaseModel): - person_name: Optional[str] - position_in_group: Optional[str] - - -class SearchGroupComposition(BaseModel): - group_name: Optional[str] - group_composition: Optional[List[GroupComposition]] - - -class PeopleChunks(BaseModel): - business_processes: Optional[List[BusinessProcess]] = None - organizatinal_structure: Optional[List[OrganizationalStructure]] = None - business_curator: Optional[List[RocksNN]] = None - groups: Optional[List[Group]] = None - person_name: str - - -class SummaryChunks(BaseModel): - doc_chunks: Optional[List[FilterChunks]] = None - people_search: Optional[List[PeopleChunks]] = None - groups_search: Optional[SearchGroupComposition] = None - rocks_nn_search: Optional[RocksNNSearch] = None - segmentation_search: Optional[SegmentationSearch] = None - query_type: str = '[3]' - - -class ElasticConfiguration: - def __init__(self, config_data): - self.es_host = str(config_data['es_host']) - self.es_port = int(config_data['es_port']) - self.use_elastic = bool(config_data['use_elastic']) - self.people_path = str(config_data['people_path']) - - -class FaissDataConfiguration: +class EntitiesExtractorConfiguration: def __init__(self, config_data): - self.model_embedding_path = str(config_data['model_embedding_path']) - self.device = str(config_data['device']) - self.path_to_metadata = str(config_data['path_to_metadata']) - - -class ChunksElasticSearchConfiguration: - def __init__(self, config_data): - self.use_chunks_search = bool(config_data['use_chunks_search']) - self.index_name = str(config_data['index_name']) - self.k_neighbors = int(config_data['k_neighbors']) - - -class PeopleSearchConfiguration: - def __init__(self, config_data): - self.use_people_search = bool(config_data['use_people_search']) - self.index_name = str(config_data['index_name']) - self.k_neighbors = int(config_data['k_neighbors']) - - -class VectorSearchConfiguration: - def __init__(self, config_data): - self.use_vector_search = bool(config_data['use_vector_search']) - self.k_neighbors = int(config_data['k_neighbors']) - - -class GroupsSearchConfiguration: - def __init__(self, config_data): - self.use_groups_search = bool(config_data['use_groups_search']) - self.index_name = str(config_data['index_name']) - self.k_neighbors = int(config_data['k_neighbors']) - - -class RocksNNSearchConfiguration: - def __init__(self, config_data): - self.use_rocks_nn_search = bool(config_data['use_rocks_nn_search']) - self.index_name = str(config_data['index_name']) - self.k_neighbors = int(config_data['k_neighbors']) - - -class AbbreviationSearchConfiguration: - def __init__(self, config_data): - self.use_abbreviation_search = bool(config_data['use_abbreviation_search']) - self.index_name = str(config_data['index_name']) - self.k_neighbors = int(config_data['k_neighbors']) - - -class SegmentationSearchConfiguration: - def __init__(self, config_data): - self.use_segmentation_search = bool(config_data['use_segmentation_search']) - self.index_name = str(config_data['index_name']) - self.k_neighbors = int(config_data['k_neighbors']) + self.strategy_name = str(config_data['strategy_name']) + self.strategy_params: dict = config_data['strategy_params'] + self.process_tables = bool(config_data['process_tables']) + self.neighbors_max_distance = int(config_data['neighbors_max_distance']) class SearchConfiguration: def __init__(self, config_data): - self.vector_search = VectorSearchConfiguration(config_data['vector_search']) - self.people_elastic_search = PeopleSearchConfiguration( - config_data['people_elastic_search'] - ) - self.chunks_elastic_search = ChunksElasticSearchConfiguration( - config_data['chunks_elastic_search'] - ) - self.groups_elastic_search = GroupsSearchConfiguration( - config_data['groups_elastic_search'] - ) - self.rocks_nn_elastic_search = RocksNNSearchConfiguration( - config_data['rocks_nn_elastic_search'] - ) - self.segmentation_elastic_search = SegmentationSearchConfiguration( - config_data['segmentation_elastic_search'] - ) - self.stop_index_names = list(config_data['stop_index_names']) - self.abbreviation_search = AbbreviationSearchConfiguration( - config_data['abbreviation_search'] - ) + self.use_vector_search = bool(config_data['use_vector_search']) + self.vectorizer_path = str(config_data['vectorizer_path']) + self.device = str(config_data['device']) + self.max_entities_per_message = int(config_data['max_entities_per_message']) + self.max_entities_per_dialogue = int(config_data['max_entities_per_dialogue']) self.use_qe = bool(config_data['use_qe']) class FilesConfiguration: def __init__(self, config_data): self.empty_start = bool(config_data['empty_start']) - self.regulations_path = str(config_data['regulations_path']) - self.default_regulations_path = str(config_data['default_regulations_path']) self.documents_path = str(config_data['documents_path']) -class RankingConfiguration: - def __init__(self, config_data): - self.use_ranging = bool(config_data['use_ranging']) - self.alpha = float(config_data['alpha']) - self.beta = float(config_data['beta']) - self.k_neighbors = int(config_data['k_neighbors']) - - class DataBaseConfiguration: def __init__(self, config_data): - self.elastic = ElasticConfiguration(config_data['elastic']) - self.faiss = FaissDataConfiguration(config_data['faiss']) + self.entities = EntitiesExtractorConfiguration(config_data['entities']) self.search = SearchConfiguration(config_data['search']) self.files = FilesConfiguration(config_data['files']) - self.ranker = RankingConfiguration(config_data['ranging']) class LLMConfiguration: def __init__(self, config_data): - self.base_url = str(config_data['base_url']) if config_data['base_url'] not in ("", "null", "None") else None + self.base_url = ( + str(config_data['base_url']) + if config_data['base_url'] not in ("", "null", "None") + else None + ) self.api_key_env = ( str(config_data['api_key_env']) if config_data['api_key_env'] not in ("", "null", "None") @@ -235,6 +62,7 @@ class CommonConfiguration: def __init__(self, config_data): self.log_file_path = str(config_data['log_file_path']) self.log_sql_path = str(config_data['log_sql_path']) + self.log_level = str(config_data['log_level']) class Configuration: diff --git a/common/dependencies.py b/common/dependencies.py index f9b6d4a5eb16c37806402637416f18a939abc45e..a72e264466389d7ae330dc9d05ec98609036346e 100644 --- a/common/dependencies.py +++ b/common/dependencies.py @@ -37,12 +37,13 @@ def get_embedding_extractor( config: Annotated[Configuration, Depends(get_config)], ) -> EmbeddingExtractor: return EmbeddingExtractor( - config.db_config.faiss.model_embedding_path, - config.db_config.faiss.device, + config.db_config.search.vectorizer_path, + config.db_config.search.device, ) -def get_chunk_repository(db: Annotated[Session, Depends(get_db)]) -> ChunkRepository: +def get_chunk_repository(db: Annotated[sessionmaker, Depends(get_db)]) -> ChunkRepository: + """Получение репозитория чанков через DI.""" return ChunkRepository(db) diff --git a/components/dbo/chunk_repository.py b/components/dbo/chunk_repository.py index d0f748997ea45f525f31625970fb98f0b88d8f90..8a0a3c5705a534f19db0581c86d8494589dada0c 100644 --- a/components/dbo/chunk_repository.py +++ b/components/dbo/chunk_repository.py @@ -1,249 +1,184 @@ +import logging from uuid import UUID import numpy as np from ntr_text_fragmentation import LinkerEntity -from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository -from sqlalchemy import and_, select -from sqlalchemy.orm import Session +from ntr_text_fragmentation.integrations.sqlalchemy import \ + SQLAlchemyEntityRepository +from sqlalchemy import func, select +from sqlalchemy.orm import Session, sessionmaker from components.dbo.models.entity import EntityModel +logger = logging.getLogger(__name__) + class ChunkRepository(SQLAlchemyEntityRepository): - def __init__(self, db: Session): - super().__init__(db) + """ + Репозиторий для работы с сущностями (чанками, документами, связями), + хранящимися в базе данных с использованием SQL Alchemy. + Наследуется от SQLAlchemyEntityRepository, предоставляя конкретную реализацию + для модели EntityModel. + """ + + def __init__(self, db_session_factory: sessionmaker[Session]): + """ + Инициализация репозитория. + + Args: + db_session_factory: Фабрика сессий SQLAlchemy. + """ + super().__init__(db_session_factory) + @property def _entity_model_class(self): + """Возвращает класс модели SQLAlchemy.""" return EntityModel - def _map_db_entity_to_linker_entity(self, db_entity: EntityModel): + def _map_db_entity_to_linker_entity(self, db_entity: EntityModel) -> LinkerEntity: """ - Преобразует сущность из базы данных в LinkerEntity. - + Преобразует объект EntityModel из базы данных в объект LinkerEntity + или его соответствующий подкласс. + Args: - db_entity: Сущность из базы данных - + db_entity: Сущность EntityModel из базы данных. + Returns: - LinkerEntity + Объект LinkerEntity или его подкласс. """ - # Преобразуем строковые ID в UUID - entity = LinkerEntity( - id=UUID(db_entity.uuid), # Преобразуем строку в UUID + # Создаем базовый LinkerEntity со всеми данными из БД + # Преобразуем строковые UUID обратно в объекты UUID + base_data = LinkerEntity( + id=UUID(db_entity.uuid), name=db_entity.name, text=db_entity.text, - type=db_entity.entity_type, in_search_text=db_entity.in_search_text, - metadata=db_entity.metadata_json, - source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID - target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID + metadata=db_entity.metadata_json or {}, + source_id=UUID(db_entity.source_id) if db_entity.source_id else None, + target_id=UUID(db_entity.target_id) if db_entity.target_id else None, number_in_relation=db_entity.number_in_relation, + type=db_entity.entity_type, + groupper=db_entity.entity_type, ) - return LinkerEntity.deserialize(entity) + + # Используем LinkerEntity._deserialize для получения объекта нужного типа + # на основе поля 'type', взятого из db_entity.entity_type + try: + deserialized_entity = base_data.deserialize() + return deserialized_entity + except Exception as e: + logger.error( + f"Error deserializing entity {base_data.id} of type {base_data.type}: {e}" + ) + return base_data def add_entities( self, entities: list[LinkerEntity], dataset_id: int, - embeddings: dict[str, np.ndarray], + embeddings: dict[str, np.ndarray] | None = None, ): """ - Добавляет сущности в базу данных. - + Добавляет список сущностей LinkerEntity в базу данных. + Args: - entities: Список сущностей для добавления - dataset_id: ID датасета - embeddings: Словарь эмбеддингов {entity_id: embedding} + entities: Список сущностей LinkerEntity для добавления. + dataset_id: ID датасета, к которому принадлежат сущности. + embeddings: Словарь эмбеддингов {entity_id_str: embedding}, где entity_id_str - строка UUID. """ + embeddings = embeddings or {} with self.db() as session: + db_entities_to_add = [] for entity in entities: # Преобразуем UUID в строку для хранения в базе - entity_id = str(entity.id) - - if entity_id in embeddings: - embedding = embeddings[entity_id] - else: - embedding = None - - session.add( - EntityModel( - uuid=str(entity.id), # UUID в строку - name=entity.name, - text=entity.text, - entity_type=entity.type, - in_search_text=entity.in_search_text, - metadata_json=entity.metadata, - source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку - target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку - number_in_relation=entity.number_in_relation, - chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index - dataset_id=dataset_id, - embedding=embedding, - ) + entity_id_str = str(entity.id) + embedding = embeddings.get(entity_id_str) + + db_entity = EntityModel( + uuid=entity_id_str, + name=entity.name, + text=entity.text, + entity_type=entity.type, + in_search_text=entity.in_search_text, + metadata_json=( + entity.metadata if isinstance(entity.metadata, dict) else {} + ), + source_id=str(entity.source_id) if entity.source_id else None, + target_id=str(entity.target_id) if entity.target_id else None, + number_in_relation=entity.number_in_relation, + dataset_id=dataset_id, + embedding=embedding, ) + db_entities_to_add.append(db_entity) + session.add_all(db_entities_to_add) session.commit() def get_searching_entities( self, dataset_id: int, ) -> tuple[list[LinkerEntity], list[np.ndarray]]: - with self.db() as session: - models = ( - session.query(EntityModel) - .filter(EntityModel.in_search_text is not None) - .filter(EntityModel.dataset_id == dataset_id) - .all() - ) - return ( - [self._map_db_entity_to_linker_entity(model) for model in models], - [model.embedding for model in models], - ) - - def get_chunks_by_ids( - self, - chunk_ids: list[str], - ) -> list[LinkerEntity]: """ - Получение чанков по их ID. - + Получает сущности из указанного датасета, которые имеют текст для поиска + (in_search_text не None), вместе с их эмбеддингами. + Args: - chunk_ids: Список ID чанков - + dataset_id: ID датасета. + Returns: - Список чанков + Кортеж из двух списков: список LinkerEntity и список их эмбеддингов (numpy array). + Порядок эмбеддингов соответствует порядку сущностей. """ - # Преобразуем все ID в строки для единообразия - str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] - + entity_model = self._entity_model_class + linker_entities = [] + embeddings_list = [] + with self.db() as session: - models = ( - session.query(EntityModel) - .filter(EntityModel.uuid.in_(str_chunk_ids)) - .all() + stmt = select(entity_model).where( + entity_model.in_search_text.isnot(None), + entity_model.dataset_id == dataset_id, + entity_model.embedding.isnot(None) ) - return [self._map_db_entity_to_linker_entity(model) for model in models] + db_models = session.execute(stmt).scalars().all() - def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]: - """ - Получить сущности по списку идентификаторов. - - Args: - entity_ids: Список идентификаторов сущностей - - Returns: - Список сущностей, соответствующих указанным идентификаторам - """ - if not entity_ids: - return [] - - # Преобразуем UUID в строки - str_entity_ids = [str(entity_id) for entity_id in entity_ids] - - with self.db() as session: - entity_model = self._entity_model_class() - db_entities = session.execute( - select(entity_model).where(entity_model.uuid.in_(str_entity_ids)) - ).scalars().all() - - return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities] - - def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]: + # Переносим цикл внутрь сессии + for model in db_models: + # Теперь маппинг происходит при активной сессии + linker_entity = self._map_db_entity_to_linker_entity(model) + linker_entities.append(linker_entity) + + # Извлекаем эмбеддинг. + # _map_db_entity_to_linker_entity может поместить его в метаданные. + embedding = linker_entity.metadata.get('_embedding') + if embedding is None and hasattr(model, 'embedding'): # Fallback + embedding = model.embedding # Доступ к model.embedding тоже должен быть внутри сессии + + if embedding is not None: + embeddings_list.append(embedding) + else: + # Обработка случая отсутствия эмбеддинга + print(f"Warning: Entity {model.uuid} has in_search_text but no embedding.") + linker_entities.pop() + + # Возвращаем результаты после закрытия сессии + return linker_entities, embeddings_list + + def count_entities_by_dataset_id(self, dataset_id: int) -> int: """ - Получить соседние чанки для указанных чанков. - + Подсчитывает общее количество сущностей для указанного датасета. + Args: - chunk_ids: Список идентификаторов чанков - max_distance: Максимальное расстояние до соседа - + dataset_id: ID датасета. + Returns: - Список соседних чанков + Общее количество сущностей в датасете. """ - if not chunk_ids: - return [] - - # Преобразуем UUID в строки - str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] - + entity_model = self._entity_model_class + id_column = self._get_id_column() # Получаем колонку ID (uuid или id) + with self.db() as session: - entity_model = self._entity_model_class() - result = [] - - # Сначала получаем указанные чанки, чтобы узнать их индексы и документы - chunks = session.execute( - select(entity_model).where( - and_( - entity_model.uuid.in_(str_chunk_ids), - entity_model.entity_type == "Chunk" # Используем entity_type вместо type - ) - ) - ).scalars().all() - - if not chunks: - return [] - - # Находим документы для чанков через связи - doc_ids = set() - chunk_indices = {} - - for chunk in chunks: - chunk_indices[chunk.uuid] = chunk.chunk_index - - # Находим связь от документа к чанку - links = session.execute( - select(entity_model).where( - and_( - entity_model.target_id == chunk.uuid, - entity_model.name == "document_to_chunk" - ) - ) - ).scalars().all() - - for link in links: - doc_ids.add(link.source_id) - - if not doc_ids or not any(idx is not None for idx in chunk_indices.values()): - return [] - - # Для каждого документа находим все его чанки - for doc_id in doc_ids: - # Находим все связи от документа к чанкам - links = session.execute( - select(entity_model).where( - and_( - entity_model.source_id == doc_id, - entity_model.name == "document_to_chunk" - ) - ) - ).scalars().all() - - doc_chunk_ids = [link.target_id for link in links] - - # Получаем все чанки документа - doc_chunks = session.execute( - select(entity_model).where( - and_( - entity_model.uuid.in_(doc_chunk_ids), - entity_model.entity_type == "Chunk" # Используем entity_type вместо type - ) - ) - ).scalars().all() - - # Для каждого чанка в документе проверяем, является ли он соседом - for doc_chunk in doc_chunks: - if doc_chunk.uuid in str_chunk_ids: - continue - - if doc_chunk.chunk_index is None: - continue - - # Проверяем, является ли чанк соседом какого-либо из исходных чанков - is_neighbor = False - for orig_chunk_id, orig_index in chunk_indices.items(): - if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance: - is_neighbor = True - break - - if is_neighbor: - result.append(self._map_db_entity_to_linker_entity(doc_chunk)) - - return result + stmt = select(func.count(id_column)).where( + entity_model.dataset_id == dataset_id + ) + count = session.execute(stmt).scalar_one() + return count diff --git a/components/embedding_extraction.py b/components/embedding_extraction.py index 125eb7e5a1706d63a05a6ffe92132b5cdb3f1b39..35bb9ac44c8ab8e74a1ec233e15d8e83c642e711 100644 --- a/components/embedding_extraction.py +++ b/components/embedding_extraction.py @@ -5,15 +5,23 @@ import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader -from transformers import (AutoModel, AutoTokenizer, BatchEncoding, - XLMRobertaModel, PreTrainedTokenizer, PreTrainedTokenizerFast) -from transformers.modeling_outputs import \ - BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput +from transformers import ( + AutoModel, + AutoTokenizer, + BatchEncoding, + XLMRobertaModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput, +) from common.decorators import singleton logger = logging.getLogger(__name__) + @singleton class EmbeddingExtractor: """Класс обрабатывает текст вопроса и возвращает embedding""" @@ -26,7 +34,7 @@ class EmbeddingExtractor: do_normalization: bool = True, max_len: int = 510, model: XLMRobertaModel = None, - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None, ): """ Класс, соединяющий в себе модель, токенизатор и параметры векторизации. @@ -46,18 +54,25 @@ class EmbeddingExtractor: device = torch.device(device) self.device = device - + # Инициализация модели if model is not None and tokenizer is not None: self.tokenizer = tokenizer self.model = model elif model_id is not None: - print('EmbeddingExtractor: model loading '+model_id+' to '+str(self.device)) - self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True) - self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id, local_files_only=True).to( - self.device + print( + 'EmbeddingExtractor: model loading ' + + model_id + + ' to ' + + str(self.device) ) - + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, local_files_only=True + ) + self.model: XLMRobertaModel = AutoModel.from_pretrained( + model_id, local_files_only=True + ).to(self.device) + print('EmbeddingExtractor: model loaded') self.model.eval() self.model.share_memory() diff --git a/components/nmd/faiss_vector_search.py b/components/nmd/faiss_vector_search.py index b2dd50b95642a92b906c79cc00660446c72a725f..b0846e52fd01a74f0ac51ebcc9b97aa49fd78040 100644 --- a/components/nmd/faiss_vector_search.py +++ b/components/nmd/faiss_vector_search.py @@ -3,7 +3,6 @@ import logging import faiss import numpy as np -from common.configuration import DataBaseConfiguration from common.constants import DO_NORMALIZATION from components.embedding_extraction import EmbeddingExtractor @@ -12,23 +11,16 @@ logger = logging.getLogger(__name__) class FaissVectorSearch: def __init__( - self, - model: EmbeddingExtractor, + self, + model: EmbeddingExtractor, ids_to_embeddings: dict[str, np.ndarray], - config: DataBaseConfiguration, ): self.model = model - self.config = config - self.path_to_metadata = config.faiss.path_to_metadata - if self.config.ranker.use_ranging: - self.k_neighbors = config.ranker.k_neighbors - else: - self.k_neighbors = config.search.vector_search.k_neighbors self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())} self.__create_index(ids_to_embeddings) def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]): - """Load the metadata file.""" + """Создает индекс для векторного поиска.""" if len(ids_to_embeddings) == 0: self.index = None return @@ -37,12 +29,17 @@ class FaissVectorSearch: self.index = faiss.IndexFlatIP(dim) self.index.add(embeddings) - def search_vectors(self, query: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def search_vectors( + self, + query: str, + max_entities: int = 100, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Поиск векторов в индексе. - + Args: query: Строка, запрос для поиска. + max_entities: Максимальное количество найденных сущностей. Returns: tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов: @@ -54,6 +51,6 @@ class FaissVectorSearch: if self.index is None: return (np.array([]), np.array([]), np.array([])) query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION) - similarities, indexes = self.index.search(query_embeds, self.k_neighbors) + similarities, indexes = self.index.search(query_embeds, max_entities) ids = [self.index_to_id[index] for index in indexes[0]] return query_embeds, similarities[0], np.array(ids) diff --git a/components/services/dataset.py b/components/services/dataset.py index f946da7b4da689d1a499a9e57c61316136891af7..680fc8b966cd1356391e4a8ae350421dd23c7a2e 100644 --- a/components/services/dataset.py +++ b/components/services/dataset.py @@ -6,9 +6,9 @@ import zipfile from datetime import datetime from pathlib import Path -import pandas as pd import torch from fastapi import BackgroundTasks, HTTPException, UploadFile +from components.dbo.models.entity import EntityModel from ntr_fileparser import ParsedDocument, UniversalParser from sqlalchemy.orm import Session @@ -34,9 +34,9 @@ class DatasetService: """ def __init__( - self, + self, entity_service: EntityService, - config: Configuration, + config: Configuration, db: Session, ) -> None: """ @@ -52,7 +52,6 @@ class DatasetService: self.config = config self.parser = UniversalParser() self.entity_service = entity_service - self.regulations_path = Path(config.db_config.files.regulations_path) self.documents_path = Path(config.db_config.files.documents_path) self.tmp_path = Path(os.environ.get("APP_TMP_PATH", '.')) logger.info("DatasetService initialized") @@ -214,7 +213,8 @@ class DatasetService: raise HTTPException( status_code=403, detail='Active dataset cannot be deleted' ) - + + session.query(EntityModel).filter(EntityModel.dataset_id == dataset_id).delete() session.delete(dataset) session.commit() @@ -222,6 +222,7 @@ class DatasetService: """ Метод для выполнения в отдельном процессе. """ + logger.info(f"apply_draft_task started") try: with self.db() as session: dataset = ( @@ -244,6 +245,7 @@ class DatasetService: active_dataset.is_active = False session.commit() + logger.info(f"apply_draft_task finished") except Exception as e: logger.error(f"Error applying draft: {e}") raise @@ -326,7 +328,7 @@ class DatasetService: logger.info(f"Uploading ZIP file {file.filename}") self.raise_if_processing() - file_location = Path(self.tmp_path / 'tmp.json' / 'tmp.zip') + file_location = Path(self.tmp_path / 'tmp' / 'tmp.zip') logger.debug(f"Saving uploaded file to {file_location}") file_location.parent.mkdir(parents=True, exist_ok=True) with open(file_location, 'wb') as f: @@ -338,7 +340,6 @@ class DatasetService: dataset = self.create_dataset_from_directory( is_default=False, directory_with_documents=file_location.parent, - directory_with_ready_dataset=None, ) file_location.unlink() @@ -386,8 +387,10 @@ class DatasetService: TMP_PATH.touch() - documents: list[Document] = [doc_dataset_link.document for doc_dataset_link in dataset.documents] - + documents: list[Document] = [ + doc_dataset_link.document for doc_dataset_link in dataset.documents + ] + for document in documents: path = self.documents_path / f'{document.id}.DOCX' parsed = self.parser.parse_by_path(str(path)) @@ -396,16 +399,12 @@ class DatasetService: logger.warning(f"Failed to parse document {document.id}") continue - # Используем EntityService для обработки документа с callback self.entity_service.process_document( parsed, dataset.id, progress_callback=progress_callback, - words_per_chunk=50, - overlap_words=25, - respect_sentence_boundaries=True, ) - + TMP_PATH.unlink() def raise_if_processing(self) -> None: @@ -422,7 +421,6 @@ class DatasetService: self, is_default: bool, directory_with_documents: Path, - directory_with_ready_dataset: Path | None = None, ) -> Dataset: """ Создать датасет из директории с xml-документами. @@ -446,7 +444,7 @@ class DatasetService: dataset = Dataset( name=name, - is_draft=True if directory_with_ready_dataset is None else False, + is_draft=True, is_active=True if is_default else False, ) session.add(dataset) @@ -465,16 +463,6 @@ class DatasetService: session.flush() - if directory_with_ready_dataset is not None: - shutil.move( - directory_with_ready_dataset, - self.regulations_path / str(dataset.id), - ) - - logger.info( - f"Moved ready dataset to {self.regulations_path / str(dataset.id)}" - ) - self.documents_path.mkdir(parents=True, exist_ok=True) for document in documents: diff --git a/components/services/entity.py b/components/services/entity.py index 02c485d1eb1e57ed4e12fb353aa7ce76032b5cc5..36083e4c99ef74b8f77fc045732cface4946a4d8 100644 --- a/components/services/entity.py +++ b/components/services/entity.py @@ -2,9 +2,10 @@ import logging from typing import Callable, Optional from uuid import UUID -import numpy as np from ntr_fileparser import ParsedDocument -from ntr_text_fragmentation import Destructurer, InjectionBuilder, LinkerEntity +from ntr_text_fragmentation import (EntitiesExtractor, InjectionBuilder, + LinkerEntity) +import numpy as np from common.configuration import Configuration from components.dbo.chunk_repository import ChunkRepository @@ -39,6 +40,16 @@ class EntityService: self.chunk_repository = chunk_repository self.faiss_search = None # Инициализируется при необходимости self.current_dataset_id = None # Текущий dataset_id + + self.neighbors_max_distance = config.db_config.entities.neighbors_max_distance + self.max_entities_per_message = config.db_config.search.max_entities_per_message + self.max_entities_per_dialogue = config.db_config.search.max_entities_per_dialogue + + self.entities_extractor = EntitiesExtractor( + strategy_name=config.db_config.entities.strategy_name, + strategy_params=config.db_config.entities.strategy_params, + process_tables=config.db_config.entities.process_tables, + ) def _ensure_faiss_initialized(self, dataset_id: int) -> None: """ @@ -50,7 +61,9 @@ class EntityService: # Если индекс не инициализирован или датасет изменился if self.faiss_search is None or self.current_dataset_id != dataset_id: logger.info(f'Initializing FAISS for dataset {dataset_id}') - entities, embeddings = self.chunk_repository.get_searching_entities(dataset_id) + entities, embeddings = self.chunk_repository.get_searching_entities( + dataset_id + ) if entities: # Создаем словарь только из не-None эмбеддингов embeddings_dict = { @@ -62,12 +75,15 @@ class EntityService: self.faiss_search = FaissVectorSearch( self.vectorizer, embeddings_dict, - self.config.db_config, ) self.current_dataset_id = dataset_id - logger.info(f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings') + logger.info( + f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings' + ) else: - logger.warning(f'No valid embeddings found for dataset {dataset_id}') + logger.warning( + f'No valid embeddings found for dataset {dataset_id}' + ) self.faiss_search = None self.current_dataset_id = None else: @@ -80,7 +96,6 @@ class EntityService: document: ParsedDocument, dataset_id: int, progress_callback: Optional[Callable] = None, - **destructurer_kwargs, ) -> None: """ Обработка документа: разбиение на чанки и сохранение в базу. @@ -89,49 +104,33 @@ class EntityService: document: Документ для обработки dataset_id: ID датасета progress_callback: Функция для отслеживания прогресса - **destructurer_kwargs: Дополнительные параметры для Destructurer """ logger.info(f"Processing document {document.name} for dataset {dataset_id}") - - # Создаем деструктуризатор с параметрами по умолчанию - destructurer = Destructurer( - document, - strategy_name="fixed_size", - process_tables=True, - **{ - "words_per_chunk": 50, - "overlap_words": 25, - "respect_sentence_boundaries": True, - **destructurer_kwargs, - } - ) - + # Получаем сущности - entities = destructurer.destructure() - + entities = self.entities_extractor.extract(document) + # Фильтруем сущности для поиска - filtering_entities = [entity for entity in entities if entity.in_search_text is not None] + filtering_entities = [ + entity for entity in entities if entity.in_search_text is not None + ] filtering_texts = [entity.in_search_text for entity in filtering_entities] - + # Получаем эмбеддинги с поддержкой callback embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback) embeddings_dict = { str(entity.id): embedding # Преобразуем UUID в строку для ключа for entity, embedding in zip(filtering_entities, embeddings) } - + # Сохраняем в базу self.chunk_repository.add_entities(entities, dataset_id, embeddings_dict) - - # Переинициализируем FAISS индекс, если это текущий датасет - if self.current_dataset_id == dataset_id: - self._ensure_faiss_initialized(dataset_id) - + logger.info(f"Added {len(entities)} entities to dataset {dataset_id}") def build_text( self, - entities: list[LinkerEntity], + entities: list[str], chunk_scores: Optional[list[float]] = None, include_tables: bool = True, max_documents: Optional[int] = None, @@ -140,7 +139,7 @@ class EntityService: Сборка текста из сущностей. Args: - entities: Список сущностей + entities: Список идентификаторов сущностей chunk_scores: Список весов чанков include_tables: Флаг включения таблиц max_documents: Максимальное количество документов @@ -148,18 +147,23 @@ class EntityService: Returns: Собранный текст """ + entities = [UUID(entity) for entity in entities] + entities = self.chunk_repository.get_entities_by_ids(entities) logger.info(f"Building text for {len(entities)} entities") if chunk_scores is not None: - chunk_scores = {entity.id: score for entity, score in zip(entities, chunk_scores)} + chunk_scores = { + entity.id: score for entity, score in zip(entities, chunk_scores) + } builder = InjectionBuilder(self.chunk_repository) return builder.build( - [entity.id for entity in entities], # Передаем UUID напрямую - chunk_scores=chunk_scores, + entities, + scores=chunk_scores, include_tables=include_tables, + neighbors_max_distance=self.neighbors_max_distance, max_documents=max_documents, ) - def search_similar( + def search_similar_old( self, query: str, dataset_id: int, @@ -185,26 +189,64 @@ class EntityService: # Выполняем поиск return self.faiss_search.search_vectors(query) - - def add_neighboring_chunks( + + def search_similar( self, - entities: list[LinkerEntity], - max_distance: int = 1, - ) -> list[LinkerEntity]: + query: str, + dataset_id: int, + previous_entities: list[list[str]] = None, + ) -> tuple[list[list[str]], list[str], list[float]]: """ - Добавление соседних чанков. + Поиск похожих сущностей. Args: - entities: Список сущностей - max_distance: Максимальное расстояние для поиска соседей + query: Текст запроса + dataset_id: ID датасета + previous_entities: Список идентификаторов сущностей, которые уже были найдены Returns: - Расширенный список сущностей + tuple[list[list[str]], list[str], list[float]]: + - Перефильтрованный список идентификаторов сущностей из прошлых запросов + - Список идентификаторов найденных сущностей + - Скоры найденных сущностей """ - # Убедимся, что все ID представлены в UUID формате - for entity in entities: - if not isinstance(entity.id, UUID): - entity.id = UUID(str(entity.id)) + self._ensure_faiss_initialized(dataset_id) + + if self.faiss_search is None: + return previous_entities, [], [] + + if sum(len(entities) for entities in previous_entities) < self.max_entities_per_dialogue - self.max_entities_per_message: + _, scores, ids = self.faiss_search.search_vectors(query, self.max_entities_per_message) + try: + scores = scores.tolist() + ids = ids.tolist() + except: + scores = list(scores) + ids = list(ids) + return previous_entities, ids, scores + + if previous_entities: + _, scores, ids = self.faiss_search.search_vectors(query, self.max_entities_per_dialogue) + scores = scores.tolist() + ids = ids.tolist() + + print(ids) + + previous_entities_ids = [[entity for entity in sublist if entity in ids] for sublist in previous_entities] + previous_entities_flat = [entity for sublist in previous_entities_ids for entity in sublist] + new_entities = [] + new_scores = [] + for id_, score in zip(ids, scores): + if id_ not in previous_entities_flat: + new_entities.append(id_) + new_scores.append(score) + if len(new_entities) >= self.max_entities_per_message: + break - builder = InjectionBuilder(self.chunk_repository) - return builder.add_neighboring_chunks(entities, max_distance) \ No newline at end of file + return previous_entities, new_entities, new_scores + + else: + _, scores, ids = self.faiss_search.search_vectors(query, self.max_entities_per_dialogue) + scores = scores.tolist() + ids = ids.tolist() + return [], ids, scores diff --git a/config_dev.yaml b/config_dev.yaml index 3ab9cb88e4cc2d2d613d890bfb721d1d8a944204..c9c2df5f3dd10456bfb3aedb220707929c86d366 100644 --- a/config_dev.yaml +++ b/config_dev.yaml @@ -1,69 +1,28 @@ common: log_file_path: !ENV ${LOG_FILE_PATH:/data/logs/common.log} log_sql_path: !ENV ${SQLALCHEMY_DATABASE_URL:sqlite:////data/logs.db} + log_level: !ENV ${LOG_LEVEL:INFO} bd: - faiss: - model_embedding_path: !ENV ${EMBEDDING_MODEL_PATH:intfloat/multilingual-e5-large} - path_to_metadata: !ENV ${PATH_TO_METADATA:/data/regulation_datasets} - device: !ENV ${FAISS_DEVICE:cuda} - - elastic: - use_elastic: False - es_host: !ENV ${ELASTIC_HOST:localhost} - es_port: !ENV ${ELASTIC_PORT:9200} - people_path: /data/person_card - - ranging: - use_ranging: false - alpha: 0.35 - beta: -0.15 - k_neighbors: 100 + entities: + strategy_name: !ENV ${ENTITIES_STRATEGY_NAME:fixed_size} + strategy_params: + words_per_chunk: 50 + overlap_words: 25 + respect_sentence_boundaries: true + process_tables: true + neighbors_max_distance: 1 search: use_qe: true - - vector_search: - use_vector_search: true - k_neighbors: 100 - - people_elastic_search: - use_people_search: false - index_name: 'people_search' - k_neighbors: 10 - - chunks_elastic_search: - use_chunks_search: true - index_name: 'nmd_full_text' - k_neighbors: 5 - - groups_elastic_search: - use_groups_search: false - index_name: 'group_search_elastic_nn' - k_neighbors: 1 - - rocks_nn_elastic_search: - use_rocks_nn_search: false - index_name: 'rocks_nn_search_elastic' - k_neighbors: 1 - - segmentation_elastic_search: - use_segmentation_search: false - index_name: 'segmentation_search_elastic' - k_neighbors: 1 - - # Если поиск будет не по чанкам, то добавить название ключа из функции search_answer словаря answer!!! - stop_index_names: ['people_answer', 'groups_answer', 'rocks_nn_answer', 'segmentation_answer'] - - abbreviation_search: - use_abbreviation_search: true - index_name: 'nmd_abbreviation_elastic' - k_neighbors: 10 + use_vector_search: true + vectorizer_path: !ENV ${EMBEDDING_MODEL_PATH:BAAI/bge-m3} + device: !ENV ${DEVICE:cuda} + max_entities_per_message: 75 + max_entities_per_dialogue: 500 files: empty_start: true - regulations_path: /data/regulation_datasets - default_regulations_path: /data/regulation_datasets/default documents_path: /data/documents llm: diff --git a/lib/extractor/ntr_text_fragmentation/__init__.py b/lib/extractor/ntr_text_fragmentation/__init__.py index af29834fc7fb3c7c4b4218fbaeab8b38fc2a2e69..18b93658c10114c45961a981668bb57cd35289b9 100644 --- a/lib/extractor/ntr_text_fragmentation/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/__init__.py @@ -2,18 +2,23 @@ Модуль извлечения и сборки документов. """ -from .core.destructurer import Destructurer -from .core.entity_repository import EntityRepository, InMemoryEntityRepository +from .core.extractor import EntitiesExtractor +from .repositories.entity_repository import EntityRepository from .core.injection_builder import InjectionBuilder -from .models import Chunk, DocumentAsEntity, LinkerEntity +from .repositories import InMemoryEntityRepository +from .models import DocumentAsEntity, LinkerEntity, Link, Entity, register_entity +from .chunking import FIXED_SIZE __all__ = [ - "Destructurer", - "InjectionBuilder", - "EntityRepository", + "EntitiesExtractor", + "InjectionBuilder", + "EntityRepository", "InMemoryEntityRepository", "LinkerEntity", - "Chunk", + "Entity", + "Link", + "register_entity", "DocumentAsEntity", "integrations", -] + "FIXED_SIZE", +] diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py b/lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py index f163392c046ac2890faf8800c7a80c24b0b66db5..8ec87d88f2419c0d10183bde21347731f59fa627 100644 --- a/lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py @@ -1,5 +1,7 @@ -from .table_entity import TableEntity +from .models.table_entity import TableEntity +from .table_processor import TableProcessor __all__ = [ 'TableEntity', + 'TableProcessor', ] diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables/models/__init__.py b/lib/extractor/ntr_text_fragmentation/additors/tables/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9d1034df41fd4fa80e80315180323bec938544 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/additors/tables/models/__init__.py @@ -0,0 +1,10 @@ +from .table_entity import TableEntity +from .subtable_entity import SubTableEntity +from .table_row_entity import TableRowEntity + + +__all__ = [ + 'TableEntity', + 'SubTableEntity', + 'TableRowEntity', +] diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables/models/subtable_entity.py b/lib/extractor/ntr_text_fragmentation/additors/tables/models/subtable_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..20cb066b09d8fac4097fe10e67cbfc944b802078 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/additors/tables/models/subtable_entity.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +from ....models import Entity, register_entity + + +@register_entity +@dataclass +class SubTableEntity(Entity): + """ + Сущность подтаблицы из документа. + + Расширяет основную сущность LinkerEntity, добавляя информацию о таблице. + """ + + header: list[str] | None = None + title: str | None = None + + @classmethod + def _deserialize_to_me(cls, data: Entity) -> 'SubTableEntity': + """ + Десериализует SubTableEntity из объекта Entity. + + Args: + data (Entity): Объект Entity для десериализации. + + Returns: + SubTableEntity: Новый экземпляр SubTableEntity с данными из Entity. + + Raises: + TypeError: Если data не является экземпляром Entity. + """ + if not isinstance(data, Entity): + raise TypeError(f"Ожидался Entity, получен {type(data)}") + + # Пытаемся получить из полей объекта, если это уже SubTableEntity или его потомок + header = getattr(data, 'header', None) + title = getattr(data, 'title', None) + + # Если не нашли в полях, ищем в метаданных + metadata = data.metadata or {} + if header is None: + header = metadata.get('_header') + if title is None: + title = metadata.get('_title') + + # Переписываем блок return, чтобы точно включить groupper + return cls( + id=data.id, + name=data.name, + text=data.text, + in_search_text=data.in_search_text, + metadata={k: v for k, v in metadata.items() if not k.startswith('_')}, # Очищаем метаданные + source_id=data.source_id, + target_id=data.target_id, + number_in_relation=data.number_in_relation, + groupper=data.groupper, # Убеждаемся, что groupper здесь + type=cls.__name__, # Используем имя класса для типа + # Специфичные поля + header=header, + title=title + ) diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables/models/table_entity.py b/lib/extractor/ntr_text_fragmentation/additors/tables/models/table_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..fd58f0d4bea2ea35a14abeb767682a51d96d35f4 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/additors/tables/models/table_entity.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass + +from ....models import Entity, register_entity + + +@register_entity +@dataclass +class TableEntity(Entity): + """ + Сущность таблицы из документа. + + Расширяет основную сущность LinkerEntity, добавляя информацию о таблице. + """ + + title: str | None = None + header: list[str] | None = None + note: str | None = None + + @classmethod + def _deserialize_to_me(cls, data: Entity) -> 'TableEntity': + """ + Десериализует TableEntity из объекта Entity. + + Args: + data (Entity): Объект Entity для десериализации. + + Returns: + TableEntity: Новый экземпляр TableEntity с данными из Entity. + + Raises: + TypeError: Если data не является экземпляром Entity. + """ + if not isinstance(data, Entity): + raise TypeError(f"Ожидался Entity, получен {type(data)}") + + # Пытаемся получить из полей объекта, если это уже TableEntity или его потомок + title = getattr(data, 'title', None) + header = getattr(data, 'header', None) + note = getattr(data, 'note', None) + + # Если не нашли в полях, ищем в метаданных + metadata = data.metadata or {} + if title is None: + title = metadata.get('_title') + if header is None: + header = metadata.get('_header') + if note is None: + note = metadata.get('_note') + + # Переписываем блок return, чтобы точно включить groupper + return cls( + id=data.id, + name=data.name, + text=data.text, + in_search_text=data.in_search_text, + metadata={k: v for k, v in metadata.items() if not k.startswith('_')}, # Очищаем метаданные + source_id=data.source_id, + target_id=data.target_id, + number_in_relation=data.number_in_relation, + groupper=data.groupper, # Убеждаемся, что groupper здесь + type=cls.__name__, # Используем имя класса для типа + # Специфичные поля + title=title, + header=header, + note=note + ) diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables/models/table_row_entity.py b/lib/extractor/ntr_text_fragmentation/additors/tables/models/table_row_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..804d9eab206cb9ba212b4bbc718cc4eb75da2ae1 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/additors/tables/models/table_row_entity.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass, field + +from ....models import Entity, register_entity + + +@register_entity +@dataclass +class TableRowEntity(Entity): + """ + Сущность строки таблицы из документа. + + Расширяет основную сущность LinkerEntity, добавляя информацию о строке таблицы. + """ + + cells: list[str] = field(default_factory=list) + + @classmethod + def _deserialize_to_me(cls, data: Entity) -> "TableRowEntity": + """ + Десериализует TableRowEntity из объекта Entity. + + Args: + data (Entity): Объект Entity для десериализации. + + Returns: + TableRowEntity: Новый экземпляр TableRowEntity с данными из Entity. + + Raises: + TypeError: Если data не является экземпляром Entity. + """ + if not isinstance(data, Entity): + raise TypeError(f"Ожидался Entity, получен {type(data)}") + + # Пытаемся получить из полей объекта, если это уже TableRowEntity или его потомок + cells = getattr(data, 'cells', None) + + # Если не нашли в полях, ищем в метаданных + metadata = data.metadata or {} + if cells is None: + cells = metadata.get('_cells', []) + + # Переписываем блок return, чтобы точно включить groupper + return cls( + id=data.id, + name=data.name, + text=data.text, + in_search_text=data.in_search_text, + metadata={k: v for k, v in metadata.items() if not k.startswith('_')}, # Очищаем метаданные + source_id=data.source_id, + target_id=data.target_id, + number_in_relation=data.number_in_relation, + groupper=data.groupper, # Убеждаемся, что groupper здесь + type=cls.__name__, # Используем имя класса для типа + # Специфичные поля + cells=cells + ) + \ No newline at end of file diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables/table_processor.py b/lib/extractor/ntr_text_fragmentation/additors/tables/table_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..59763e06d1cf18cec534b115591338a7a346dec6 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/additors/tables/table_processor.py @@ -0,0 +1,178 @@ +from ntr_fileparser import ParsedRow, ParsedSubtable, ParsedTable + +from ...models import LinkerEntity +from ...repositories.entity_repository import EntityRepository, GroupedEntities +from .models import SubTableEntity, TableEntity, TableRowEntity + + +class TableProcessor: + def __init__(self): + pass + + def extract( + self, + table: ParsedTable, + doc_entity: LinkerEntity, + ) -> list[LinkerEntity]: + """ + Извлекает таблицу из документа и создает для нее сущность, а также сущности для всех строк таблицы. + """ + entities = [] + table_entity = self._create_table_entity(table, doc_entity) + + entities.append(table_entity) + + for i, subtable in enumerate(table.subtables): + index_in_relation = i + 1 + subtable_entity = self._create_subtable_entity( + subtable, + table_entity, + index_in_relation, + ) + + entities.append(subtable_entity) + + for j, row in enumerate(subtable.rows): + index_in_relation = j + 1 + row_entity = self._create_row_entity( + row, + subtable_entity, + row.to_string(), + index_in_relation, + ) + + entities.append(row_entity) + + return entities + + def _create_table_entity( + self, + table: ParsedTable, + doc_entity: LinkerEntity, + ) -> TableEntity: + entity = TableEntity( + name=table.title or 'NonameTable', + text=table.title or '', + title=table.title, + header=self._create_header(table), + note=table.note, + groupper=f'Table_{doc_entity.id}', + number_in_relation=table.index_in_document, + ) + entity.owner_id = doc_entity.id + + return entity + + def _create_header(self, table: ParsedTable) -> list[str] | None: + if len(table.headers) == 0: + return None + + rows = table.headers + header: list[list[str]] = [[] for _ in range(len(rows[0].cells))] + for row in rows: + for i, cell in enumerate(row.cells): + header[i].append(cell) + result = [" > ".join(column) for column in header] + + return result + + def _create_subtable_entity( + self, + subtable: ParsedSubtable, + table_entity: TableEntity, + number_in_relation: int, + ) -> SubTableEntity: + header = None + if subtable.header: + header = subtable.header.cells + entity = SubTableEntity( + name=subtable.title or 'NonameSubTable', + text=subtable.title or '', + title=subtable.title, + header=header, + groupper=f'SubTable_{table_entity.id}', + number_in_relation=number_in_relation, + ) + entity.owner_id = table_entity.id + return entity + + def _create_row_entity( + self, + row: ParsedRow, + subtable_entity: SubTableEntity, + in_search_text: str, + number_in_relation: int, + ) -> TableRowEntity: + entity = TableRowEntity( + name=f'{row.index}', + text='', + cells=row.cells, + in_search_text=in_search_text, + groupper=f'Row_{subtable_entity.id}', + number_in_relation=number_in_relation, + ) + entity.owner_id = subtable_entity.id + return entity + + def build( + self, + repository: EntityRepository, + group: GroupedEntities[TableEntity], + ) -> str: + """ + Собирает текст таблицы из списка сущностей. + """ + table = group.composer + entities = group.entities + + subtable_grouped: list[GroupedEntities[SubTableEntity]] = ( + repository.group_entities_hierarchically( + entities=entities, + root_type=SubTableEntity, + sort=True, + ) + ) + + result = "" + + if table.title: + result += f"#### {table.title}\n" + else: + result += f"#### Таблица {table.number_in_relation}\n" + + table_header = table.header + + for subtable_group in subtable_grouped: + subtable = subtable_group.composer + subtable_header = subtable.header + rows = [ + row + for row in subtable_group.entities + if isinstance(row, TableRowEntity) + ] + if subtable.title: + result += f"##### {subtable.title}\n" + for row in rows: + result += self._prepare_row( + row, + subtable_header or table_header, + ) + + if table.note: + result += f"**Примечание:** {table.note}\n" + + return result + + def _prepare_row( + self, + row: TableRowEntity, + header: list[str] | None = None, + ) -> str: + row_name = f'Строка {row.number_in_relation}' + if header is None: + cells = "\n".join([f"- - {cell}" for cell in row.cells]) + else: + normalized_header = [h.replace('\n', '') for h in header] + cells = "\n".join([f" - **{normalized_header[i]}**: {row.cells[i]}".replace('\n', '\n -') for i in range(len(header))]) + + return f"- {row_name}\n{cells}\n" diff --git a/lib/extractor/ntr_text_fragmentation/additors/tables_processor.py b/lib/extractor/ntr_text_fragmentation/additors/tables_processor.py index dc40b87322292ce9e5a7d490dbb6f79cc4c5cebc..c4f75da1476322bc7489f62fa9d71d10eff73868 100644 --- a/lib/extractor/ntr_text_fragmentation/additors/tables_processor.py +++ b/lib/extractor/ntr_text_fragmentation/additors/tables_processor.py @@ -2,12 +2,11 @@ Процессор таблиц из документа. """ -from uuid import uuid4 - from ntr_fileparser import ParsedDocument from ..models import LinkerEntity -from .tables import TableEntity +from .tables import TableProcessor, TableEntity +from ..repositories import EntityRepository, GroupedEntities class TablesProcessor: @@ -17,101 +16,42 @@ class TablesProcessor: def __init__(self): """Инициализация процессора таблиц.""" - pass + self.table_processor = TableProcessor() - def process( + def extract( self, document: ParsedDocument, doc_entity: LinkerEntity, ) -> list[LinkerEntity]: - """ - Извлекает таблицы из документа и создает для них сущности. - - Args: - document: Документ для обработки - doc_entity: Сущность документа для связи с таблицами - - Returns: - Список сущностей TableEntity и связей - """ - if not document.tables: - return [] - - table_entities = [] - links = [] - - rows = '\n\n'.join([table.to_string() for table in document.tables]).split( - '\n\n' - ) - - # Обрабатываем каждую таблицу - for idx, row in enumerate(rows): - # Создаем сущность таблицы - table_entity = self._create_table_entity( - table_text=row, - table_index=idx, - doc_name=doc_entity.name, - ) - - # Создаем связь между документом и таблицей - link = self._create_link(doc_entity, table_entity, idx) + """Извлекает таблицы из документа и создает для них сущности.""" + entities = [] + for table in document.tables: + entities.extend(self.table_processor.extract(table, doc_entity)) + return entities - table_entities.append(table_entity) - links.append(link) - - # Возвращаем список таблиц и связей - return table_entities + links - - def _create_table_entity( + def build( self, - table_text: str, - table_index: int, - doc_name: str, - ) -> TableEntity: + repository: EntityRepository, + entities: list[LinkerEntity], + ) -> str: """ - Создает сущность таблицы. - - Args: - table_text: Текст таблицы - table_index: Индекс таблицы в документе - doc_name: Имя документа - - Returns: - Сущность TableEntity + Собирает текст таблиц из списка сущностей. """ - entity_name = f"{doc_name}_table_{table_index}" - return TableEntity( - id=uuid4(), - name=entity_name, - text=table_text, - in_search_text=table_text, - metadata={}, - type=TableEntity.__name__, - table_index=table_index, + groups: list[GroupedEntities[TableEntity]] = ( + repository.group_entities_hierarchically( + entities=entities, + root_type=TableEntity, + sort=True, + ) ) - def _create_link( - self, doc_entity: LinkerEntity, table_entity: TableEntity, index: int - ) -> LinkerEntity: - """ - Создает связь между документом и таблицей. - - Args: - doc_entity: Сущность документа - table_entity: Сущность таблицы - index: Индекс таблицы в документе + groups = sorted( + groups, key=lambda x: x.composer.number_in_relation, + ) - Returns: - Объект связи LinkerEntity - """ - return LinkerEntity( - id=uuid4(), - name="document_to_table", - text="", - metadata={}, - source_id=doc_entity.id, - target_id=table_entity.id, - number_in_relation=index, - type="Link", + result = "\n\n".join( + self.table_processor.build(repository, group) for group in groups ) + + return result diff --git a/lib/extractor/ntr_text_fragmentation/chunking/__init__.py b/lib/extractor/ntr_text_fragmentation/chunking/__init__.py index 0d616841c4d30ce26f8be0214dcc79d7b815ef5c..77f7cb6807b489abd40e5d4e447a5e38e659495d 100644 --- a/lib/extractor/ntr_text_fragmentation/chunking/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/chunking/__init__.py @@ -3,9 +3,21 @@ """ from .chunking_strategy import ChunkingStrategy -from .specific_strategies import FixedSizeChunkingStrategy +from .specific_strategies import ( + FixedSizeChunk, + FixedSizeChunkingStrategy, + FIXED_SIZE, +) +from .text_to_text_base import TextToTextBaseStrategy + +from .chunking_registry import register_chunking_strategy, chunking_registry __all__ = [ "ChunkingStrategy", + "FixedSizeChunk", "FixedSizeChunkingStrategy", + "FIXED_SIZE", + "TextToTextBaseStrategy", + "register_chunking_strategy", + "chunking_registry", ] diff --git a/lib/extractor/ntr_text_fragmentation/chunking/chunking_registry.py b/lib/extractor/ntr_text_fragmentation/chunking/chunking_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..2a377e934141edc2c4fb5d7e397ff26458436b09 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/chunking/chunking_registry.py @@ -0,0 +1,40 @@ +from ..chunking.chunking_strategy import ChunkingStrategy + + +class _ChunkingRegistry: + def __init__(self): + self._chunking_strategies: dict[str, ChunkingStrategy] = {} + + def register(self, name: str, strategy: ChunkingStrategy): + self._chunking_strategies[name] = strategy + + def get(self, name: str) -> ChunkingStrategy: + return self._chunking_strategies[name] + + def get_names(self) -> list[str]: + return list(self._chunking_strategies.keys()) + + def __len__(self) -> int: + return len(self._chunking_strategies) + + def __contains__(self, name: str | ChunkingStrategy) -> bool: + if isinstance(name, ChunkingStrategy): + return name in self._chunking_strategies.values() + return name in self._chunking_strategies + + def __dict__(self) -> dict: + return self._chunking_strategies + + def __getitem__(self, name: str) -> ChunkingStrategy: + return self._chunking_strategies[name] + + +chunking_registry = _ChunkingRegistry() + + +def register_chunking_strategy(name: str | None = None): + def decorator(cls: type[ChunkingStrategy]) -> type[ChunkingStrategy]: + chunking_registry.register(name or cls.__name__, cls) + return cls + + return decorator diff --git a/lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py b/lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py index 65f79be59d467c377c9465a6dccc28c5e9061292..e81a31291fbe275eb6e4c55d2ab0c634a9389bb8 100644 --- a/lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py +++ b/lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py @@ -1,86 +1,99 @@ """ -Базовый класс для всех стратегий чанкинга. +Абстрактный базовый класс для стратегий чанкинга. """ +import logging from abc import ABC, abstractmethod from ntr_fileparser import ParsedDocument -from ..models import Chunk, DocumentAsEntity, LinkerEntity +from ..models import DocumentAsEntity, LinkerEntity +from ..repositories import EntityRepository +from .models import Chunk + +logger = logging.getLogger(__name__) class ChunkingStrategy(ABC): - """ - Базовый абстрактный класс для всех стратегий чанкинга. - """ - + """Абстрактный класс для стратегий чанкинга.""" + @abstractmethod - def chunk(self, document: ParsedDocument, doc_entity: DocumentAsEntity | None = None) -> list[LinkerEntity]: + def chunk( + self, + document: ParsedDocument, + doc_entity: DocumentAsEntity, + ) -> list[LinkerEntity]: """ Разбивает документ на чанки в соответствии со стратегией. - + Args: - document: ParsedDocument для извлечения текста - doc_entity: Опциональная сущность документа для привязки чанков. - Если не указана, будет создана новая. - + document: ParsedDocument для извлечения текста и структуры. + doc_entity: Сущность документа-владельца, к которой будут привязаны чанки. + Returns: - list[LinkerEntity]: Список сущностей (документ, чанки, связи) + Список сущностей (чанки) """ raise NotImplementedError("Стратегия чанкинга должна реализовать метод chunk") - def dechunk(self, chunks: list[LinkerEntity], repository: 'EntityRepository' = None) -> str: + @classmethod + def dechunk( + cls, + repository: EntityRepository, + filtered_entities: list[LinkerEntity], + ) -> str: """ - Собирает документ из чанков и связей. - - Базовая реализация сортирует чанки по chunk_index и объединяет их тексты, - сохраняя структуру параграфов и избегая дублирования текста. - + Собирает текст из отфильтрованных чанков к одному документу. + Args: - chunks: Список отфильтрованных чанков в случайном порядке - repository: Репозиторий сущностей для получения дополнительной информации (может быть None) - + repository: Репозиторий (может понадобиться для получения доп. информации, + хотя в текущей реализации не используется). + filtered_entities: Список отфильтрованных сущностей (чанков), + относящихся к одному документу. + Returns: - Восстановленный текст документа + Собранный текст из чанков. + """ + chunks = [e for e in filtered_entities if isinstance(e, Chunk)] + chunks.sort(key=lambda x: x.number_in_relation) + + groups: list[list[Chunk]] = [] + for chunk in chunks: + if len(groups) == 0: + groups.append([chunk]) + continue + + last_chunk = groups[-1][-1] + if chunk.number_in_relation == last_chunk.number_in_relation + 1: + groups[-1].append(chunk) + else: + groups.append([chunk]) + + result = "" + previous_last_index = 0 + for group in groups: + if previous_last_index is not None: + missing_chunks = group[0].number_in_relation - previous_last_index - 1 + missing_string = f'\n_<...Пропущено {missing_chunks} фрагментов...>_\n' + else: + missing_string = '\n_<...>_\n' + result += missing_string + cls._build_sequenced_chunks(repository, group) + previous_last_index = group[-1].number_in_relation + + return result.strip() + + @classmethod + def _build_sequenced_chunks( + cls, + repository: EntityRepository, + group: list[Chunk], + ) -> str: + """ + Строит текст для последовательных чанков. + Стоит переопределить в конкретной стратегии, если она предполагает сложную логику """ - import re - - # Проверяем, есть ли чанки для сборки - if not chunks: - return "" - - # Отбираем только чанки - valid_chunks = [c for c in chunks if isinstance(c, Chunk)] - - # Сортируем чанки по chunk_index - sorted_chunks = sorted(valid_chunks, key=lambda c: c.chunk_index or 0) - - # Собираем текст документа с учетом структуры параграфов - result_text = "" - - for chunk in sorted_chunks: - # Получаем текст чанка (предпочитаем text, а не in_search_text для избежания дублирования) - chunk_text = chunk.text if hasattr(chunk, 'text') and chunk.text else "" - - # Добавляем текст чанка с сохранением структуры параграфов - if result_text and result_text[-1] != "\n" and chunk_text and chunk_text[0] != "\n": - result_text += " " - result_text += chunk_text - - # Пост-обработка результата - # Заменяем множественные переносы строк на одиночные - result_text = re.sub(r'\n+', '\n', result_text) - - # Заменяем множественные пробелы на одиночные - result_text = re.sub(r' +', ' ', result_text) - - # Убираем пробелы перед переносами строк - result_text = re.sub(r' +\n', '\n', result_text) - - # Убираем пробелы после переносов строк - result_text = re.sub(r'\n +', '\n', result_text) - - # Убираем лишние переносы строк в начале и конце текста - result_text = result_text.strip() - - return result_text \ No newline at end of file + return " ".join([cls._build_chunk(chunk) for chunk in group]) + + @classmethod + def _build_chunk(cls, chunk: Chunk) -> str: + """Строит текст для одного чанка.""" + return chunk.text diff --git a/lib/extractor/ntr_text_fragmentation/chunking/models/__init__.py b/lib/extractor/ntr_text_fragmentation/chunking/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1caf06f8f1e4a1da83dfdf5b37c5f4d43af93723 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/chunking/models/__init__.py @@ -0,0 +1,7 @@ +from .chunk import Chunk +from .custom_chunk import CustomChunk + +__all__ = [ + "Chunk", + "CustomChunk", +] diff --git a/lib/extractor/ntr_text_fragmentation/chunking/models/chunk.py b/lib/extractor/ntr_text_fragmentation/chunking/models/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..ed21797fe4a4f8bc06c27e0bc2260e312944a76a --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/chunking/models/chunk.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +from ...models.linker_entity import Entity, register_entity + + +@register_entity +@dataclass +class Chunk(Entity): + """ + Чанк документа. + """ diff --git a/lib/extractor/ntr_text_fragmentation/chunking/models/custom_chunk.py b/lib/extractor/ntr_text_fragmentation/chunking/models/custom_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..f54881d188aa7db98a612cceff79424cf5fd4858 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/chunking/models/custom_chunk.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass + +from ...models.linker_entity import Entity, register_entity + + +@register_entity +@dataclass +class CustomChunk(Entity): + """ + Чанк документа, полученный в результате применения пользовательской стратегии + чанкинга. + """ diff --git a/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py b/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py index 72bd3bf174484cee47cf1e77675050df527458ea..1f53b5df507929835d79155d1ea71f9d83e4e88a 100644 --- a/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py @@ -3,9 +3,13 @@ """ from .fixed_size import FixedSizeChunk -from .fixed_size_chunking import FixedSizeChunkingStrategy +from .fixed_size_chunking import ( + FixedSizeChunkingStrategy, + FIXED_SIZE, +) __all__ = [ "FixedSizeChunk", "FixedSizeChunkingStrategy", + "FIXED_SIZE", ] diff --git a/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py b/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py index 8b00d808dd0e1e2cc29bc6dcfe6bcecffd8274b3..8fc380e8a7461c49155f3688f1a06a8e51ba0092 100644 --- a/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py +++ b/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py @@ -2,11 +2,11 @@ Класс для представления чанка фиксированного размера. """ -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any -from ....models.chunk import Chunk -from ....models.linker_entity import LinkerEntity, register_entity +from ....models import Entity, LinkerEntity, register_entity +from ...models.chunk import Chunk @register_entity @@ -15,21 +15,14 @@ class FixedSizeChunk(Chunk): """ Представляет чанк фиксированного размера. - Расширяет базовый класс Chunk дополнительными полями, связанными с токенами - и перекрытиями, а также добавляет методы для сборки документа с учетом - границ предложений. + Расширяет базовый класс Chunk дополнительными полями, связанными с токенами, + границами предложений и перекрытиями. """ - token_count: int = 0 - - # Информация о границах предложений и нахлестах - left_sentence_part: str = "" # Часть предложения слева от text - right_sentence_part: str = "" # Часть предложения справа от text - overlap_left: str = "" # Нахлест слева (без учета границ предложений) - overlap_right: str = "" # Нахлест справа (без учета границ предложений) - - # Метаданные для дополнительной информации - metadata: dict[str, Any] = field(default_factory=dict) + left_sentence_part: str | None = None + right_sentence_part: str | None = None + overlap_left: str | None = None + overlap_right: str | None = None def __str__(self) -> str: """ @@ -38,106 +31,64 @@ class FixedSizeChunk(Chunk): Returns: Строка с информацией о чанке. """ + text_preview = ( + f"{self.text[:30]}..." if self.text and len(self.text) > 30 else self.text + ) return ( f"FixedSizeChunk(id={self.id}, chunk_index={self.chunk_index}, " - f"tokens={self.token_count}, " - f"text='{self.text[:30]}{'...' if len(self.text) > 30 else ''}'" - f")" + f"tokens={self.token_count}, text='{text_preview}')" ) - def get_adjacent_chunks_indices(self, max_distance: int = 1) -> list[int]: - """ - Возвращает индексы соседних чанков в пределах указанного расстояния. - - Args: - max_distance: Максимальное расстояние от текущего чанка - - Returns: - Список индексов соседних чанков - """ - indices = [] - for i in range(1, max_distance + 1): - # Добавляем предыдущие чанки - if self.chunk_index - i >= 0: - indices.append(self.chunk_index - i) - # Добавляем следующие чанки - indices.append(self.chunk_index + i) - - return sorted(indices) - @classmethod - def deserialize(cls, entity: LinkerEntity) -> 'FixedSizeChunk': + def _deserialize_to_me(cls, data: Entity) -> "FixedSizeChunk": """ - Десериализует FixedSizeChunk из объекта LinkerEntity. + Десериализует FixedSizeChunk из объекта Entity (LinkerEntity). + + Использует паттерн: сначала ищет поле в атрибутах объекта `data` + (на случай, если он уже частично десериализован или является подклассом), + затем ищет поле в `data.metadata` с префиксом '_'. Args: - entity: Объект LinkerEntity для преобразования в FixedSizeChunk + data: Объект Entity (LinkerEntity) для десериализации. Returns: - Десериализованный объект FixedSizeChunk - """ - metadata = entity.metadata or {} - - # Извлекаем параметры из метаданных - # Сначала проверяем в метаданных под ключом _chunk_index - chunk_index = metadata.get('_chunk_index') - if chunk_index is None: - # Затем пробуем получить как атрибут объекта - chunk_index = getattr(entity, 'chunk_index', None) - if chunk_index is None: - # Если и там нет, пробуем обычный поиск по метаданным - chunk_index = metadata.get('chunk_index') - - # Преобразуем к int, если значение найдено - if chunk_index is not None: - try: - chunk_index = int(chunk_index) - except (ValueError, TypeError): - chunk_index = None - - start_token = metadata.get('start_token', 0) - end_token = metadata.get('end_token', 0) - token_count = metadata.get( - '_token_count', metadata.get('token_count', end_token - start_token + 1) - ) - - # Извлекаем параметры для границ предложений и нахлестов - # Сначала ищем в метаданных с префиксом _ - left_sentence_part = metadata.get('_left_sentence_part') - if left_sentence_part is None: - # Затем пробуем получить как атрибут объекта - left_sentence_part = getattr(entity, 'left_sentence_part', '') - - right_sentence_part = metadata.get('_right_sentence_part') - if right_sentence_part is None: - right_sentence_part = getattr(entity, 'right_sentence_part', '') - - overlap_left = metadata.get('_overlap_left') - if overlap_left is None: - overlap_left = getattr(entity, 'overlap_left', '') - - overlap_right = metadata.get('_overlap_right') - if overlap_right is None: - overlap_right = getattr(entity, 'overlap_right', '') + Новый экземпляр FixedSizeChunk с данными из Entity. + Raises: + TypeError: Если data не является экземпляром LinkerEntity или его подкласса. + """ + if not isinstance(data, LinkerEntity): + raise TypeError( + f"Ожидался LinkerEntity или его подкласс, получен {type(data)}" + ) + + metadata = data.metadata or {} + + # Извлечение специфичных полей с использованием паттерна getattr/metadata.get + def get_field(field_name: str, default: Any = None) -> Any: + value = getattr(data, field_name, None) + if value is None: + value = metadata.get(f"_{field_name}", default) + return value + # Создаем чистые метаданные без служебных полей clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')} # Создаем и возвращаем новый экземпляр FixedSizeChunk return cls( - id=entity.id, - name=entity.name, - text=entity.text, - in_search_text=entity.in_search_text, + id=data.id, + name=data.name, + text=data.text, + in_search_text=data.in_search_text, metadata=clean_metadata, - source_id=entity.source_id, - target_id=entity.target_id, - number_in_relation=entity.number_in_relation, - chunk_index=chunk_index, - token_count=token_count, - left_sentence_part=left_sentence_part, - right_sentence_part=right_sentence_part, - overlap_left=overlap_left, - overlap_right=overlap_right, - type="FixedSizeChunk", + source_id=data.source_id, + target_id=data.target_id, # owner_id + number_in_relation=data.number_in_relation, + groupper=data.groupper, + type=cls.__name__, # Устанавливаем конкретный тип + # Специфичные поля FixedSizeChunk + left_sentence_part=get_field('left_sentence_part', ""), + right_sentence_part=get_field('right_sentence_part', ""), + overlap_left=get_field('overlap_left', ""), + overlap_right=get_field('overlap_right', ""), ) diff --git a/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py b/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py index a3ccb9542b6103f6df605674a6dda52c901897f9..81348be3d3ec5af591c9754d36da8f3e1231bda4 100644 --- a/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py +++ b/lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py @@ -2,46 +2,38 @@ Стратегия чанкинга фиксированного размера. """ +import logging import re -from typing import NamedTuple, TypeVar +from io import StringIO from uuid import uuid4 from ntr_fileparser import ParsedDocument, ParsedTextBlock -from ...chunking.chunking_strategy import ChunkingStrategy from ...models import DocumentAsEntity, LinkerEntity +from ...repositories import EntityRepository +from ..chunking_strategy import ChunkingStrategy +from ..chunking_registry import register_chunking_strategy +from ..models import Chunk from .fixed_size.fixed_size_chunk import FixedSizeChunk -T = TypeVar('T') +logger = logging.getLogger(__name__) - -class _FixedSizeChunkingStrategyParams(NamedTuple): - words_per_chunk: int = 50 - overlap_words: int = 25 - respect_sentence_boundaries: bool = True +FIXED_SIZE = "fixed_size" +@register_chunking_strategy(FIXED_SIZE) class FixedSizeChunkingStrategy(ChunkingStrategy): """ - Стратегия чанкинга, разбивающая текст на чанки фиксированного размера. - - Преимущества: - - Простое и предсказуемое разбиение - - Равные по размеру чанки + Стратегия чанкинга, разбивающая текст на чанки фиксированного размера словами. - Недостатки: - - Может разрезать предложения и абзацы в середине (компенсируется сборкой - как для модели поиска, так и для LLM) - - Не учитывает смысловую структуру текста + Поддерживает перекрытие между чанками и опциональный учет границ предложений + для более качественной сборки текста в `dechunk`. - Особенности реализации: - - В поле `text` чанков хранится текст без нахлеста (для удобства сборки) - - В поле `in_search_text` хранится текст с нахлестом (для улучшения векторизации) + При чанкинге создает экземпляры `FixedSizeChunk`. + При сборке (`dechunk`) использует специфичную логику с `left/right_sentence_part`. """ - name = "fixed_size" - description = ( - "Стратегия чанкинга, разбивающая текст на чанки фиксированного размера." - ) + DEFAULT_GROUPPER: str = "chunk" # Группа для связывания и сортировки чанков def __init__( self, @@ -50,519 +42,276 @@ class FixedSizeChunkingStrategy(ChunkingStrategy): respect_sentence_boundaries: bool = True, ): """ - Инициализация стратегии чанкинга с фиксированным размером. + Инициализация стратегии. Args: - words_per_chunk: Количество слов в чанке - overlap_words: Количество слов перекрытия между чанками - respect_sentence_boundaries: Флаг учета границ предложений + words_per_chunk: Целевое количество слов в чанке (включая перекрытие). + overlap_words: Количество слов перекрытия между чанками. + respect_sentence_boundaries: Учитывать ли границы предложений при + формировании `left/right_sentence_part` для улучшения сборки. """ - - self.params = _FixedSizeChunkingStrategyParams( - words_per_chunk=words_per_chunk, - overlap_words=overlap_words, - respect_sentence_boundaries=respect_sentence_boundaries, - ) + if overlap_words >= words_per_chunk: + raise ValueError("overlap_words должен быть меньше words_per_chunk") + if words_per_chunk <= 0 or overlap_words < 0: + raise ValueError("words_per_chunk должен быть > 0, overlap_words >= 0") + + self.words_per_chunk = words_per_chunk + self.overlap_words = overlap_words + self.respect_sentence_boundaries = respect_sentence_boundaries + self._step = self.words_per_chunk - self.overlap_words + + # Регулярное выражение для поиска конца предложения (точка, ?, ! перед пробелом или концом строки) + self._sentence_end_pattern = re.compile(r'[.!?](?:\s|$)') + # Регулярное выражение для очистки текста при сборке + self._re_multi_newline = re.compile(r'\n{3,}') + self._re_multi_space = re.compile(r' +') + self._re_space_punct = re.compile(r' ([.,!?:;)])') + self._re_space_newline = re.compile(r' +\n') + self._re_newline_space = re.compile(r'\n +') def chunk( - self, - document: ParsedDocument | str, - doc_entity: DocumentAsEntity | None = None, + self, document: ParsedDocument, doc_entity: DocumentAsEntity ) -> list[LinkerEntity]: """ - Разбивает документ на чанки фиксированного размера. + Разбивает документ на чанки FixedSizeChunk. Args: - document: Документ для разбиения (ParsedDocument или текст) - doc_entity: Сущность документа (опционально) + document: Документ для чанкинга. + doc_entity: Сущность документа-владельца. Returns: - Список LinkerEntity - чанки, связи и прочие сущности + Список созданных FixedSizeChunk. """ - doc = self._prepare_document(document) - words = self._extract_words(doc) + words = self._extract_words(document) + total_words = len(words) - # Если документ пустой, возвращаем пустой список - if not words: + if total_words == 0: + logger.debug(f"Документ {doc_entity.name} не содержит слов для чанкинга.") return [] - doc_entity = self._ensure_document_entity(doc, doc_entity) - doc_name = doc_entity.name - - chunks = [] - links = [] - - step = self._calculate_step() - total_words = len(words) - - # Начинаем с первого слова и идем шагами (не полным размером чанка) - for i in range(0, total_words, step): - # Создаем обычный чанк - chunk_text = self._prepare_chunk_text(words, i, step) - in_search_text = self._prepare_chunk_text( - words, i, self.params.words_per_chunk + result_chunks: list[FixedSizeChunk] = [] + chunk_index = 0 + + # Идем по словам с шагом, равным размеру чанка минус перекрытие + for i in range(0, total_words, self._step): + start_idx = i + # Конец основной части чанка (без правого перекрытия) + step_end_idx = min(start_idx + self._step, total_words) + # Конец чанка с правым перекрытием (для in_search_text и подсчета токенов) + chunk_end_idx = min(start_idx + self.words_per_chunk, total_words) + + # Текст чанка без перекрытия (то, что будет соединяться в dechunk) + chunk_text = self._prepare_chunk_text(words, start_idx, step_end_idx) + # Текст для поиска (с правым перекрытием) + in_search_text = self._prepare_chunk_text(words, start_idx, chunk_end_idx) + + # Границы предложений и нахлесты + left_part, right_part, left_overlap, right_overlap = ( + self._calculate_boundaries(words, start_idx, chunk_end_idx, total_words) ) - chunk = self._create_chunk( - chunk_text, - in_search_text, - i, - i + self.params.words_per_chunk, - len(chunks), - words, - total_words, - doc_name, + chunk_instance = self._create_chunk_instance( + doc_entity=doc_entity, + chunk_index=chunk_index, + chunk_text=chunk_text, + in_search_text=in_search_text, + token_count=( + chunk_end_idx - start_idx + ), # Кол-во слов в чанке с правым нахлестом + left_sentence_part=left_part, + right_sentence_part=right_part, + overlap_left=left_overlap, + overlap_right=right_overlap, ) + result_chunks.append(chunk_instance) + chunk_index += 1 - chunks.append(chunk) - links.append(self._create_link(doc_entity, chunk)) - - # Возвращаем все сущности - return [doc_entity] + chunks + links - - def _find_nearest_sentence_boundary( - self, text: str, position: int - ) -> tuple[int, str, str]: - """ - Находит ближайшую границу предложения к указанной позиции. - - Args: - text: Полный текст для поиска границ - position: Позиция, для которой ищем ближайшую границу - - Returns: - tuple из (позиция границы, левая часть текста, правая часть текста) - """ - # Регулярное выражение для поиска конца предложения - sentence_end_pattern = r'[.!?](?:\s|$)' - - # Ищем все совпадения в тексте - matches = list(re.finditer(sentence_end_pattern, text)) - - if not matches: - # Если совпадений нет, возвращаем исходную позицию - return position, text[:position], text[position:] - - # Находим ближайшую границу предложения - nearest_pos = position - min_distance = float('inf') - - for match in matches: - end_pos = match.end() - distance = abs(end_pos - position) - - if distance < min_distance: - min_distance = distance - nearest_pos = end_pos - - # Возвращаем позицию и соответствующие части текста - return nearest_pos, text[:nearest_pos], text[nearest_pos:] - - def _find_sentence_boundary(self, text: str, is_left_boundary: bool) -> str: - """ - Находит часть текста на границе предложения. - - Args: - text: Текст для обработки - is_left_boundary: True для левой границы, False для правой - - Returns: - Часть предложения на границе - """ - # Регулярное выражение для поиска конца предложения - sentence_end_pattern = r'[.!?](?:\s|$)' - matches = list(re.finditer(sentence_end_pattern, text)) - - if not matches: - return text + logger.info( + f"Документ {doc_entity.name} разбит на {len(result_chunks)} FixedSizeChunk." + ) + return result_chunks - if is_left_boundary: - # Для левой границы берем часть после последней границы предложения - last_match = matches[-1] - return text[last_match.end() :].strip() - else: - # Для правой границы берем часть до первой границы предложения - first_match = matches[0] - return text[: first_match.end()].strip() + def _extract_words(self, document: ParsedDocument) -> list[str]: + """Извлекает слова из документа, добавляя '\n' как маркер конца параграфа.""" + words = [] + for paragraph in document.paragraphs: + if isinstance(paragraph, ParsedTextBlock) and paragraph.text: + paragraph_words = paragraph.text.split() + # Добавляем только непустые слова + words.extend(w for w in paragraph_words if w) + # Добавляем маркер конца параграфа, только если были слова + if paragraph_words: + words.append("\n") # Используем '\n' как специальный "токен" + # Удаляем последний '\n', если он есть (не нужен после последнего параграфа) + if words and words[-1] == "\n": + words.pop() + return words - def dechunk( - self, - filtered_chunks: list[LinkerEntity], - repository: 'EntityRepository' = None, # type: ignore + def _prepare_chunk_text( + self, words: list[str], start_idx: int, end_idx: int ) -> str: - """ - Собирает документ из чанков и связей. - - Args: - filtered_chunks: Список отфильтрованных чанков - repository: Репозиторий сущностей для получения дополнительной информации (может быть None) - - Returns: - Восстановленный текст документа - """ - if not filtered_chunks: + """Собирает текст из среза слов, корректно обрабатывая маркеры '\n'.""" + chunk_words = words[start_idx:end_idx] + if not chunk_words: return "" - # Проверяем тип и десериализуем FixedSizeChunk - chunks = [] - for chunk in filtered_chunks: - if chunk.type == "FixedSizeChunk": - chunks.append(FixedSizeChunk.deserialize(chunk)) - else: - chunks.append(chunk) - - # Сортируем чанки по индексу - sorted_chunks = sorted(chunks, key=lambda c: c.chunk_index or 0) - - # Инициализируем результирующий текст - result_text = "" - - # Группируем последовательные чанки - current_group = [] - groups = [] - - for i, chunk in enumerate(sorted_chunks): - current_index = chunk.chunk_index or 0 - - # Если первый чанк или продолжение последовательности - if i == 0 or current_index == (sorted_chunks[i - 1].chunk_index or 0) + 1: - current_group.append(chunk) - else: - # Закрываем текущую группу и начинаем новую - if current_group: - groups.append(current_group) - current_group = [chunk] - - # Добавляем последнюю группу - if current_group: - groups.append(current_group) - - # Обрабатываем каждую группу - for group_index, group in enumerate(groups): - # Добавляем многоточие между непоследовательными группами - if group_index > 0: - result_text += "\n\n...\n\n" - - # Обрабатываем группу соседних чанков - group_text = "" - - # Добавляем левую недостающую часть к первому чанку группы - first_chunk = group[0] - - # Добавляем левую часть предложения к первому чанку группы - if ( - hasattr(first_chunk, 'left_sentence_part') - and first_chunk.left_sentence_part - ): - group_text += first_chunk.left_sentence_part - - # Добавляем текст всех чанков группы - for i, chunk in enumerate(group): - current_text = chunk.text.strip() if hasattr(chunk, 'text') else "" - if not current_text: - continue - - # Проверяем, нужно ли добавить пробел между предыдущим текстом и текущим чанком - if group_text: - # Если текущий чанк начинается с новой строки, не добавляем пробел - if current_text.startswith("\n"): - pass - # Если предыдущий текст заканчивается переносом строки, также не добавляем пробел - elif group_text.endswith("\n"): - pass - # Если предыдущий текст заканчивается знаком препинания без пробела, добавляем пробел - elif group_text.rstrip()[-1] not in [ - "\n", - " ", - ".", - ",", - "!", - "?", - ":", - ";", - "-", - "–", - "—", - ]: - group_text += " " - - # Добавляем текст чанка - group_text += current_text - - # Добавляем правую недостающую часть к последнему чанку группы - last_chunk = group[-1] - - # Добавляем правую часть предложения к последнему чанку группы - if ( - hasattr(last_chunk, 'right_sentence_part') - and last_chunk.right_sentence_part - ): - right_part = last_chunk.right_sentence_part.strip() - if right_part: - # Проверяем нужен ли пробел перед правой частью - if ( - group_text - and group_text[-1] not in ["\n", " "] - and right_part[0] - not in ["\n", " ", ".", ",", "!", "?", ":", ";", "-", "–", "—"] - ): - group_text += " " - group_text += right_part - - # Добавляем текст группы к результату - if ( - result_text - and result_text[-1] not in ["\n", " "] - and group_text - and group_text[0] not in ["\n", " "] - ): - result_text += " " - result_text += group_text - - # Постобработка текста: удаляем лишние пробелы и символы переноса строк - - # Заменяем множественные переносы строк на двойные (для разделения абзацев) - result_text = re.sub(r'\n{3,}', '\n\n', result_text) - - # Заменяем множественные пробелы на одиночные - result_text = re.sub(r' +', ' ', result_text) - - # Убираем пробелы перед знаками препинания - result_text = re.sub(r' ([.,!?:;)])', r'\1', result_text) - - # Убираем пробелы перед переносами строк и после переносов строк - result_text = re.sub(r' +\n', '\n', result_text) - result_text = re.sub(r'\n +', '\n', result_text) - - # Убираем лишние переносы строк и пробелы в начале и конце текста - result_text = result_text.strip() - - return result_text - - def _get_sorted_chunks( - self, chunks: list[LinkerEntity], links: list[LinkerEntity] - ) -> list[LinkerEntity]: - """ - Получает отсортированные чанки на основе связей или поля chunk_index. - - Args: - chunks: Список чанков для сортировки - links: Список связей для определения порядка - - Returns: - Отсортированные чанки - """ - # Сортируем чанки по порядку в связях - if links: - # Получаем словарь для быстрого доступа к чанкам по ID - chunk_dict = {c.id: c for c in chunks} - - # Сортируем по порядку в связях - sorted_chunks = [] - for link in sorted(links, key=lambda l: l.number_in_relation or 0): - if link.target_id in chunk_dict: - sorted_chunks.append(chunk_dict[link.target_id]) - - return sorted_chunks - - # Если нет связей, сортируем по chunk_index - return sorted(chunks, key=lambda c: c.chunk_index or 0) - - def _prepare_document(self, document: ParsedDocument | str) -> ParsedDocument: - """ - Обрабатывает входные данные и возвращает ParsedDocument. - - Args: - document: Документ (ParsedDocument или текст) - - Returns: - Обработанный документ типа ParsedDocument - """ - if isinstance(document, ParsedDocument): - return document - elif isinstance(document, str): - # Простая обработка текстового документа - return ParsedDocument( - paragraphs=[ - ParsedTextBlock(text=paragraph) - for paragraph in document.split('\n') - ] - ) - - def _extract_words(self, doc: ParsedDocument) -> list[str]: - """ - Извлекает все слова из документа. - - Args: - doc: Документ для извлечения слов - - Returns: - Список слов документа - """ - words = [] - for paragraph in doc.paragraphs: - # Добавляем слова из параграфа - paragraph_words = paragraph.text.split() - words.extend(paragraph_words) - # Добавляем маркер конца параграфа как отдельный элемент - words.append("\n") - return words - - def _ensure_document_entity( + with StringIO() as buffer: + first_word = True + for word in chunk_words: + if word == "\n": + buffer.write("\n") + first_word = True # После переноса строки пробел не нужен + else: + if not first_word: + buffer.write(" ") + buffer.write(word) + first_word = False + return buffer.getvalue() + + def _calculate_boundaries( self, - doc: ParsedDocument, - doc_entity: LinkerEntity | None, - ) -> LinkerEntity: - """ - Создает сущность документа, если не предоставлена. + words: list[str], + chunk_start_idx: int, + chunk_end_idx: int, + total_words: int, + ) -> tuple[str, str, str, str]: + """Вычисляет границы предложений и тексты перекрытий.""" + left_sentence_part = "" + right_sentence_part = "" - Args: - doc: Документ - doc_entity: Сущность документа (может быть None) + # Границы для перекрытий + overlap_left_start = max(0, chunk_start_idx - self.overlap_words) + overlap_right_end = min(total_words, chunk_end_idx + self.overlap_words) - Returns: - Сущность документа - """ - if doc_entity is None: - return LinkerEntity( - id=uuid4(), - name=doc.name, - text=doc.name, - metadata={"type": doc.type}, - type="Document", - ) - return doc_entity + # Текст левого перекрытия (для поиска границ и как fallback) + left_overlap_text = self._prepare_chunk_text( + words, overlap_left_start, chunk_start_idx + ) + # Текст правого перекрытия (для поиска границ и как fallback) + right_overlap_text = self._prepare_chunk_text( + words, chunk_end_idx, overlap_right_end + ) - def _calculate_step(self) -> int: - """ - Вычисляет шаг для создания чанков. + if self.respect_sentence_boundaries: + # Ищем границу предложения в левом перекрытии + left_sentence_part = self._find_sentence_boundary(left_overlap_text, True) + # Ищем границу предложения в правом перекрытии + right_sentence_part = self._find_sentence_boundary( + right_overlap_text, False + ) - Returns: - Размер шага между началами чанков - """ - return self.params.words_per_chunk - self.params.overlap_words + return ( + left_sentence_part, + right_sentence_part, + left_overlap_text, + right_overlap_text, + ) - def _prepare_chunk_text( - self, - words: list[str], - start_idx: int, - length: int, - ) -> str: + def _find_sentence_boundary(self, text: str, find_left_part: bool) -> str: """ - Подготавливает текст чанка и текст для поиска. - - Args: - words: Список слов документа - start_idx: Индекс начала чанка - end_idx: Длина текста в словах - - Returns: - Итоговый текст + Находит часть текста на границе предложения. + Если find_left_part=True, ищет часть ПОСЛЕ последнего знака препинания. + Если find_left_part=False, ищет часть ДО первого знака препинания. """ - # Извлекаем текст чанка без нахлеста с сохранением структуры параграфов - end_idx = min(start_idx + length, len(words)) - chunk_words = words[start_idx:end_idx] - chunk_text = "" + if not text: + return "" - for word in chunk_words: - if word == "\n": - # Если это маркер конца параграфа, добавляем перенос строки - chunk_text += "\n" - else: - # Иначе добавляем слово с пробелом - if chunk_text and chunk_text[-1] != "\n": - chunk_text += " " - chunk_text += word + matches = list(self._sentence_end_pattern.finditer(text)) - return chunk_text + if not matches: + # Если нет знаков конца предложения, то для левой части ничего не берем, + # а для правой берем всё (т.к. непонятно, где предложение заканчивается). + return "" if find_left_part else text.strip() + + if find_left_part: + # Ищем часть после последнего знака + last_match_end = matches[-1].end() + return text[last_match_end:].strip() + else: + # Ищем часть до первого знака (включая сам знак) + first_match_end = matches[0].end() + return text[:first_match_end].strip() - def _create_chunk( + def _create_chunk_instance( self, + doc_entity: DocumentAsEntity, + chunk_index: int, chunk_text: str, in_search_text: str, - start_idx: int, - end_idx: int, - chunk_index: int, - words: list[str], - total_words: int, - doc_name: str, + token_count: int, + left_sentence_part: str, + right_sentence_part: str, + overlap_left: str, + overlap_right: str, ) -> FixedSizeChunk: - """ - Создает чанк фиксированного размера. - - Args: - chunk_text: Текст чанка без нахлеста - in_search_text: Текст чанка с нахлестом - start_idx: Индекс первого слова в чанке - end_idx: Индекс последнего слова в чанке - chunk_index: Индекс чанка в документе - words: Список всех слов документа - total_words: Общее количество слов в документе - doc_name: Имя документа - - Returns: - FixedSizeChunk: Созданный чанк - """ - # Определяем нахлесты без учета границ предложений - overlap_left = " ".join( - words[max(0, start_idx - self.params.overlap_words) : start_idx] - ) - overlap_right = " ".join( - words[end_idx : min(total_words, end_idx + self.params.overlap_words)] - ) - - # Определяем границы предложений - left_sentence_part = "" - right_sentence_part = "" - - if self.params.respect_sentence_boundaries: - # Находим ближайшую границу предложения слева - left_text = " ".join( - words[max(0, start_idx - self.params.overlap_words) : start_idx] - ) - left_sentence_part = self._find_sentence_boundary(left_text, True) - - # Находим ближайшую границу предложения справа - right_text = " ".join( - words[end_idx : min(total_words, end_idx + self.params.overlap_words)] - ) - right_sentence_part = self._find_sentence_boundary(right_text, False) - - # Создаем чанк с учетом границ предложений + """Создает экземпляр FixedSizeChunk с необходимыми атрибутами.""" return FixedSizeChunk( id=uuid4(), - name=f"{doc_name}_chunk_{chunk_index}", + name=f"{doc_entity.name}_chunk_{chunk_index}", text=chunk_text, - chunk_index=chunk_index, in_search_text=in_search_text, - token_count=end_idx - start_idx + 1, + metadata={}, # Все нужные поля теперь атрибуты + source_id=None, # Является компонентом, а не связью + target_id=doc_entity.id, # Указывает на владельца (документ) + number_in_relation=chunk_index, # Порядковый номер для сортировки + groupper=self.DEFAULT_GROUPPER, # Группа для сортировки/соседей + # Специфичные поля left_sentence_part=left_sentence_part, right_sentence_part=right_sentence_part, overlap_left=overlap_left, overlap_right=overlap_right, - metadata={}, - type=FixedSizeChunk.__name__, ) - def _create_link( - self, doc_entity: LinkerEntity, chunk: LinkerEntity - ) -> LinkerEntity: + @classmethod + def _build_sequenced_chunks( + cls, + repository: EntityRepository, + group: list[Chunk], + ) -> str: """ - Создает связь между документом и чанком. + Собирает текст для НЕПРЕРЫВНОЙ последовательности FixedSizeChunk. + + Использует `left_sentence_part` первого чанка, `text` всех чанков + и `right_sentence_part` последнего чанка. Переопределяет базовый метод. Args: - doc_entity: Сущность документа - chunk: Сущность чанка + repository: Репозиторий для получения сущностей. + group: Список последовательных FixedSizeChunk. Гарантируется непустым. Returns: - Объект связи + Собранный текст для данной группы. """ - return LinkerEntity( - id=uuid4(), - name="document_to_chunk", - text="", - metadata={}, - source_id=doc_entity.id, - target_id=chunk.id, - type="Link", - ) + # Важно: Проверяем, что все чанки в группе - это FixedSizeChunk + # Это важно, так как мы обращаемся к специфичным атрибутам + if not all(isinstance(c, FixedSizeChunk) for c in group): + logger.warning( + "В _build_sequenced_chunks передан список, содержащий не FixedSizeChunk. Используется базовая сборка." + ) + # Вызываем базовую реализацию, если типы не совпадают + return super()._build_sequenced_chunks(repository, group) + + # Гарантированно работаем с FixedSizeChunk + typed_group: list[FixedSizeChunk] = group # type: ignore + + parts = [] + first_chunk = typed_group[0] + last_chunk = typed_group[-1] + + # Добавляем левую часть предложения (если есть) + if first_chunk.left_sentence_part: + parts.append(first_chunk.left_sentence_part.strip()) + + # Добавляем текст всех чанков группы + for chunk in typed_group: + if chunk.text: + parts.append(chunk.text.strip()) + + # Добавляем правую часть предложения (если есть) + if last_chunk.right_sentence_part: + parts.append(last_chunk.right_sentence_part.strip()) + + # Соединяем все части через пробел, удаляя пустые строки + # Очистка _clean_final_text будет вызвана в конце базового dechunk + group_text = " ".join(filter(None, parts)) + + return group_text diff --git a/lib/extractor/ntr_text_fragmentation/chunking/text_to_text_base.py b/lib/extractor/ntr_text_fragmentation/chunking/text_to_text_base.py new file mode 100644 index 0000000000000000000000000000000000000000..4899c73c4110fef17112feb4d81abd5ad1569875 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/chunking/text_to_text_base.py @@ -0,0 +1,45 @@ +from abc import abstractmethod + +from ntr_fileparser import ParsedDocument + +from ..models import LinkerEntity, DocumentAsEntity +from .models import CustomChunk +from .chunking_strategy import ChunkingStrategy + + +class TextToTextBaseStrategy(ChunkingStrategy): + """ + Базовый класс для всех стратегий чанкинга, которые преобразуют текст в текст. + Наследуясь от этого класса, не забывайте зарегистрировать стратегию через + декоратор @register_chunking_strategy. + """ + + def chunk( + self, document: ParsedDocument, doc_entity: DocumentAsEntity + ) -> list[LinkerEntity]: + text = self._get_text(document) + texts = self._chunk(text, doc_entity) + return [ + CustomChunk( + text=chunk_text, + in_search_text=chunk_text, + doc_entity=doc_entity, + number_in_relation=i, + groupper=self.__class__.__name__, + ) + for i, chunk_text in enumerate(texts) + ] + + def _get_text(self, document: ParsedDocument) -> str: + return "\n".join( + [ + f"{block.text} {block.number_in_relation}" + for block in document.paragraphs + ] + ) + + @abstractmethod + def _chunk(self, text: str, doc_entity: DocumentAsEntity) -> list[LinkerEntity]: + raise NotImplementedError( + "Метод _chunk должен быть реализован в классе-наследнике" + ) diff --git a/lib/extractor/ntr_text_fragmentation/core/__init__.py b/lib/extractor/ntr_text_fragmentation/core/__init__.py index 8520f88d256d37ddeec87769b3269be23316dbb6..c5285d9151e3e02d9b5441a3e8418f9141ffad1f 100644 --- a/lib/extractor/ntr_text_fragmentation/core/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/core/__init__.py @@ -2,8 +2,10 @@ Основные классы для разбиения и сборки документов. """ -from .destructurer import Destructurer -from .entity_repository import EntityRepository, InMemoryEntityRepository +from .extractor import EntitiesExtractor from .injection_builder import InjectionBuilder -__all__ = ["Destructurer", "InjectionBuilder", "EntityRepository", "InMemoryEntityRepository"] +__all__ = [ + "EntitiesExtractor", + "InjectionBuilder", +] diff --git a/lib/extractor/ntr_text_fragmentation/core/extractor.py b/lib/extractor/ntr_text_fragmentation/core/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5db3a09170579395c24ef7b5767da849d904e8 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/core/extractor.py @@ -0,0 +1,207 @@ +""" +Модуль для деструктуризации документа. +""" + +import logging +from typing import Any, NamedTuple +from uuid import uuid4 + +from ntr_fileparser import ParsedDocument, ParsedTextBlock + +from ..additors import TablesProcessor +from ..chunking import ChunkingStrategy, FIXED_SIZE, chunking_registry +from ..models import DocumentAsEntity, LinkerEntity + + +def _check_namedtuple(obj: Any) -> bool: + return hasattr(type(obj), '_fields') and isinstance(obj, tuple) + + +logger = logging.getLogger(__name__) + + +class EntitiesExtractor: + """ + Оркестратор процесса извлечения информации из документа. + + Координирует разбиение документа на чанки и обработку + дополнительных сущностей (например, таблиц) с использованием + зарегистрированных стратегий и процессоров. + """ + + def __init__( + self, + strategy_name: str = FIXED_SIZE, + strategy_params: dict[str, Any] | tuple = {}, + process_tables: bool = True, + ): + """ + Инициализация деструктуризатора. + + Args: + strategy_name: Имя стратегии чанкинга для использования + strategy_params: Параметры для выбранной стратегии чанкинга + process_tables: Флаг обработки таблиц + """ + self.strategy: ChunkingStrategy | None = None + self._strategy_name: str | None = None + self.tables_processor: TablesProcessor | None = None + + self.configure(strategy_name, strategy_params, process_tables) + + def configure( + self, + strategy_name: str | None = None, + strategy_params: dict[str, Any] | tuple = {}, + process_tables: bool | None = None, + ) -> 'EntitiesExtractor': + """ + Переконфигурирование деструктуризатора. + + Args: + strategy_name: Имя стратегии чанкинга + strategy_params: Параметры для выбранной стратегии чанкинга, которыми нужно перезаписать дефолтные + process_tables: Обрабатывать ли таблицы + + Returns: + Destructurer: Возвращает сам себя для удобства использования в цепочке вызовов + """ + if strategy_name is not None: + self.configure_chunking(strategy_name, strategy_params) + if process_tables is not None: + self.configure_tables_extraction(process_tables) + + return self + + def configure_chunking( + self, + strategy_name: str = FIXED_SIZE, + strategy_params: dict[str, Any] | tuple | None = None, + ) -> 'EntitiesExtractor': + """ + Переконфигурирование стратегии чанкинга. + + Args: + strategy_name: Имя стратегии чанкинга + strategy_params: Параметры для выбранной стратегии чанкинга, которыми нужно перезаписать дефолтные + + Returns: + Destructurer: Возвращает сам себя + """ + if strategy_name not in chunking_registry: + raise ValueError( + f"Неизвестная стратегия: {strategy_name}. " + f"Доступные стратегии: {chunking_registry.get_names()}" + f"Для регистрации новой стратегии используйте метод `register_chunking_strategy`" + ) + + strategy_class = chunking_registry[strategy_name] + if _check_namedtuple(strategy_params): + strategy_params = strategy_params._asdict() + elif strategy_params is None: + strategy_params = {} + try: + self.strategy = strategy_class(**strategy_params) + self._strategy_name = strategy_name + except TypeError as e: + raise ValueError( + f"Ошибка при попытке инициализировать стратегию {strategy_class.__name__}: {e}. " + f"Параметры: {strategy_params}" + f"Пожалуйста, проверьте правильность параметров и их соответствие типу стратегии." + ) + + logger.info( + f"Стратегия чанкинга установлена: {strategy_name} с параметрами: {strategy_params}" + ) + + return self + + def configure_tables_extraction( + self, + process_tables: bool = True, + ) -> 'EntitiesExtractor': + """ + Переконфигурирование процессора таблиц. + + Args: + process_tables: Флаг обработки таблиц + + Returns: + Destructurer: Возвращает сам себя для удобства использования в цепочке вызовов + """ + self.tables_processor = TablesProcessor() + logger.info(f"Процессор таблиц установлен: {process_tables}") + return self + + def extract(self, document: ParsedDocument | str) -> list[LinkerEntity]: + """ + Основной метод извлечения информации из документа. + Чанкает и извлекает из документа всё, что можно из него извлечь. + Возвращает список сущностей. + + Args: + document: Документ для извлечения информации. Если передать строку, она будет \ + автоматически преобразована в `ParsedDocument` + + Returns: + list[LinkerEntity]: список сущностей (документ, чанки, таблицы, связи) + + Raises: + RuntimeError: Если стратегия не была сконфигурирована + """ + if isinstance(document, str): + document = ParsedDocument( + name='unknown', + type='PlainText', + paragraphs=[ + ParsedTextBlock(text=paragraph) + for paragraph in document.split('\n') + ], + ) + + doc_entity = self._create_document_entity(document) + entities: list[LinkerEntity] = [doc_entity] + + if self.strategy is not None: + logger.info( + f"Чанкирование документа {document.name} с помощью стратегии {self.strategy.__class__.__name__}..." + ) + entities += self._chunk(document, doc_entity) + + if self.tables_processor is not None: + logger.info(f"Обработка таблиц в документе {document.name}...") + entities += self.tables_processor.extract(document, doc_entity) + + logger.info(f"Извлечение информации из документа {document.name} завершено.") + entities = [entity.serialize() for entity in entities] + + return entities + + def _chunk( + self, + document: ParsedDocument, + doc_entity: DocumentAsEntity, + ) -> list[LinkerEntity]: + if self.strategy is None: + raise RuntimeError("Стратегия чанкинга не выставлена") + + doc_entity.chunking_strategy_ref = self._strategy_name + + return self.strategy.chunk(document, doc_entity) + + def _create_document_entity(self, document: ParsedDocument) -> DocumentAsEntity: + """ + Создает сущность документа. + + Args: + document: Документ для создания сущности + + Returns: + DocumentAsEntity: Сущность документа + """ + return DocumentAsEntity( + id=uuid4(), + name=document.name or "Document", + text="", + metadata={"source_type": document.type}, + ) diff --git a/lib/extractor/ntr_text_fragmentation/core/injection_builder.py b/lib/extractor/ntr_text_fragmentation/core/injection_builder.py index f18f94df1126ba96b2e86050b5fd2e8c6fc85b41..8f5014feee9278b0db45782369ebda73ba4c1c74 100644 --- a/lib/extractor/ntr_text_fragmentation/core/injection_builder.py +++ b/lib/extractor/ntr_text_fragmentation/core/injection_builder.py @@ -1,26 +1,30 @@ """ -Класс для сборки документа из чанков. +Класс для сборки документа из деструктурированных сущностей (чанков, таблиц). """ -from collections import defaultdict -from typing import Optional, Type +import logging from uuid import UUID -from ..chunking.chunking_strategy import ChunkingStrategy -from ..models.chunk import Chunk -from ..models.linker_entity import LinkerEntity -from .entity_repository import EntityRepository, InMemoryEntityRepository +from ..additors import TablesProcessor +from ..chunking import chunking_registry +from ..models import DocumentAsEntity, LinkerEntity +from ..repositories import EntityRepository, GroupedEntities, InMemoryEntityRepository + +# Настраиваем базовый логгер +logger = logging.getLogger(__name__) class InjectionBuilder: """ - Класс для сборки документов из чанков и связей. + Класс для сборки документов из отфильтрованного набора сущностей. Отвечает за: - - Сборку текста из чанков с учетом порядка - - Ранжирование документов на основе весов чанков - - Добавление соседних чанков для улучшения сборки - - Сборку данных из таблиц и других сущностей + - Получение десериализованных сущностей по их ID. + - Группировку сущностей по документам, к которым они относятся. + - Вызов соответствующего метода сборки (`dechunk` или `build`) + у стратегии/процессора, передавая им документ, репозиторий и + список *всех* отфильтрованных сущностей, относящихся к этому документу. + - Агрегацию результатов сборки для нескольких документов. """ def __init__( @@ -32,382 +36,143 @@ class InjectionBuilder: Инициализация сборщика инъекций. Args: - repository: Репозиторий сущностей (если None, используется InMemoryEntityRepository) - entities: Список всех сущностей (опционально, для обратной совместимости) - """ - # Для обратной совместимости - if repository is None and entities is not None: + repository: Репозиторий для доступа к сущностям. + entities: Список сущностей для инициализации дефолтного репозитория, если не указан repository. + Использование одновременно repository и entities не допускается. + """ + if repository is None and entities is None: + raise ValueError("Необходимо указать либо repository, либо entities.") + if repository is not None and entities is not None: + raise ValueError( + "Использование одновременно repository и entities не допускается." + ) + if repository is None: repository = InMemoryEntityRepository(entities) - - self.repository = repository or InMemoryEntityRepository() - self.strategy_map: dict[str, Type[ChunkingStrategy]] = {} - - def register_strategy( - self, - doc_type: str, - strategy: Type[ChunkingStrategy], - ) -> None: - """ - Регистрирует стратегию для определенного типа документа. - - Args: - doc_type: Тип документа - strategy: Стратегия чанкинга - """ - self.strategy_map[doc_type] = strategy + self.repository = repository + self.tables_processor = TablesProcessor() def build( self, - filtered_entities: list[LinkerEntity] | list[UUID], - chunk_scores: dict[str, float] | None = None, + entities: list[UUID] | list[LinkerEntity], + scores: list[float] | None = None, include_tables: bool = True, - max_documents: Optional[int] = None, - ) -> str: - """ - Собирает текст из всех документов, связанных с предоставленными чанками. - - Args: - filtered_entities: Список чанков или их идентификаторов - chunk_scores: Словарь весов чанков {chunk_id: score} - include_tables: Флаг для включения таблиц в результат - max_documents: Максимальное количество документов (None = все) - - Returns: - Собранный текст со всеми документами - """ - # Преобразуем входные данные в список идентификаторов - entity_ids = [ - entity.id if isinstance(entity, LinkerEntity) else entity - for entity in filtered_entities - ] - - if not entity_ids: - return "" - - # Получаем сущности по их идентификаторам - entities = self.repository.get_entities_by_ids(entity_ids) - - # Десериализуем сущности в их специализированные типы - deserialized_entities = [] - for entity in entities: - # Используем статический метод десериализации - deserialized_entity = LinkerEntity.deserialize(entity) - deserialized_entities.append(deserialized_entity) - - # Фильтруем сущности на чанки и таблицы - chunks = [e for e in deserialized_entities if "Chunk" in e.type] - tables = [e for e in deserialized_entities if "Table" in e.type] - - # Группируем таблицы по документам - table_ids = {table.id for table in tables} - doc_tables = self._group_tables_by_document(table_ids) - - if not chunks and not tables: - return "" - - # Получаем идентификаторы чанков - chunk_ids = [chunk.id for chunk in chunks] - - # Получаем связи для чанков (чанки являются целями связей) - links = self.repository.get_related_entities( - chunk_ids, - relation_name="document_to_chunk", - as_target=True, - ) - - # Группируем чанки по документам - doc_chunks = self._group_chunks_by_document(chunks, links) - - # Получаем все документы для чанков и таблиц - doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys()) - docs = self.repository.get_entities_by_ids(doc_ids) - - # Десериализуем документы - deserialized_docs = [] - for doc in docs: - deserialized_doc = LinkerEntity.deserialize(doc) - deserialized_docs.append(deserialized_doc) - - # Вычисляем веса документов на основе весов чанков - doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores) - - # Сортируем документы по весам (по убыванию) - sorted_docs = sorted( - deserialized_docs, - key=lambda d: doc_scores.get(str(d.id), 0.0), - reverse=True - ) - - # Ограничиваем количество документов, если указано - if max_documents: - sorted_docs = sorted_docs[:max_documents] - - # Собираем текст для каждого документа - result_parts = [] - for doc in sorted_docs: - doc_text = self._build_document_text( - doc, - doc_chunks.get(doc.id, []), - doc_tables.get(doc.id, []), - include_tables - ) - if doc_text: - result_parts.append(doc_text) - - # Объединяем результаты - return "\n\n".join(result_parts) - - def _build_document_text( - self, - doc: LinkerEntity, - chunks: list[LinkerEntity], - tables: list[LinkerEntity], - include_tables: bool + neighbors_max_distance: int = 1, + max_documents: int | None = None, + document_prefix: str = "[Источник] - ", ) -> str: """ - Собирает текст документа из чанков и таблиц. - - Args: - doc: Сущность документа - chunks: Список чанков документа - tables: Список таблиц документа - include_tables: Флаг для включения таблиц - - Returns: - Собранный текст документа - """ - # Получаем стратегию чанкинга - strategy_name = doc.metadata.get("chunking_strategy", "fixed_size") - strategy = self._get_strategy_instance(strategy_name) - - # Собираем текст из чанков - chunks_text = strategy.dechunk(chunks, self.repository) if chunks else "" - - # Собираем текст из таблиц, если нужно - tables_text = "" - if include_tables and tables: - # Сортируем таблицы по индексу, если он есть - sorted_tables = sorted( - tables, - key=lambda t: t.metadata.get("table_index", 0) if t.metadata else 0 - ) - - # Собираем текст таблиц - tables_text = "\n\n".join(table.text for table in sorted_tables if hasattr(table, 'text')) - - # Формируем результат - result = f"[Источник] - {doc.name}\n" - if chunks_text: - result += chunks_text - if tables_text: - if chunks_text: - result += "\n\n" - result += tables_text - - return result - - def _group_chunks_by_document( - self, - chunks: list[LinkerEntity], - links: list[LinkerEntity] - ) -> dict[UUID, list[LinkerEntity]]: - """ - Группирует чанки по документам. - - Args: - chunks: Список чанков - links: Список связей между документами и чанками - - Returns: - Словарь {doc_id: [chunks]} - """ - result = defaultdict(list) - - # Создаем словарь для быстрого доступа к чанкам по ID - chunk_dict = {chunk.id: chunk for chunk in chunks} - - # Группируем чанки по документам на основе связей - for link in links: - if link.target_id in chunk_dict and link.source_id: - result[link.source_id].append(chunk_dict[link.target_id]) - - return result - - def _group_tables_by_document( - self, - table_ids: set[UUID] - ) -> dict[UUID, list[LinkerEntity]]: - """ - Группирует таблицы по документам. - - Args: - table_ids: Множество идентификаторов таблиц - - Returns: - Словарь {doc_id: [tables]} - """ - result = defaultdict(list) - - table_ids = [str(table_id) for table_id in table_ids] - - # Получаем связи для таблиц (таблицы являются целями связей) - if not table_ids: - return result - - links = self.repository.get_related_entities( - table_ids, - relation_name="document_to_table", - as_target=True, - ) - - # Получаем сами таблицы - tables = self.repository.get_entities_by_ids(table_ids) - - # Десериализуем таблицы - deserialized_tables = [] - for table in tables: - deserialized_table = LinkerEntity.deserialize(table) - deserialized_tables.append(deserialized_table) - - # Создаем словарь для быстрого доступа к таблицам по ID - table_dict = {str(table.id): table for table in deserialized_tables} - - # Группируем таблицы по документам на основе связей - for link in links: - if link.target_id in table_dict and link.source_id: - result[link.source_id].append(table_dict[link.target_id]) - - return result - - def _calculate_document_scores( - self, - doc_chunks: dict[UUID, list[LinkerEntity]], - chunk_scores: Optional[dict[str, float]] - ) -> dict[str, float]: - """ - Вычисляет веса документов на основе весов чанков. - - Args: - doc_chunks: Словарь {doc_id: [chunks]} - chunk_scores: Словарь весов чанков {chunk_id: score} - - Returns: - Словарь весов документов {doc_id: score} - """ - if not chunk_scores: - return {str(doc_id): 1.0 for doc_id in doc_chunks.keys()} - - result = {} - for doc_id, chunks in doc_chunks.items(): - # Берем максимальный вес среди чанков документа - chunk_weights = [chunk_scores.get(str(c.id), 0.0) for c in chunks] - result[str(doc_id)] = max(chunk_weights) if chunk_weights else 0.0 - - return result - - def add_neighboring_chunks( - self, entities: list[LinkerEntity] | list[UUID], max_distance: int = 1 - ) -> list[LinkerEntity]: - """ - Добавляет соседние чанки к отфильтрованному списку чанков. + Собирает текст документов на основе *отфильтрованного* списка ID сущностей + (чанков, строк таблиц и т.д.). Args: - entities: Список сущностей или их идентификаторов - max_distance: Максимальное расстояние для поиска соседей + entities: Список ID сущностей (UUID), которые были отобраны + (например, в результате поиска) и должны войти в контекст. + scores: Список оценок для каждой сущности из логики больше - лучше. + include_tables: Включать ли таблицы из соответствующих документов. + max_documents: Максимальное количество документов для включения в результат + (сортировка документов пока не реализована). Returns: - Расширенный список сущностей - """ - # Преобразуем входные данные в список идентификаторов - entity_ids = [ - entity.id if isinstance(entity, LinkerEntity) else entity - for entity in entities - ] + Собранный текст из указанных сущностей (и, возможно, таблиц) + сгруппированный по документам. - if not entity_ids: - return [] + Raises: + ValueError: Если entity_ids пуст или содержит невалидные UUID. + """ + if any(isinstance(eid, UUID) for eid in entities): + entities = self.repository.get_entities_by_ids(entities) - # Получаем исходные сущности - original_entities = self.repository.get_entities_by_ids(entity_ids) + if not entities: + logger.warning("Не удалось получить ни одной сущности по переданным ID.") + return "" - # Фильтруем только чанки - chunk_entities = [e for e in original_entities if isinstance(e, Chunk)] + entities = [e.deserialize() for e in entities] - if not chunk_entities: - return original_entities + logger.info(f"Получено {len(entities)} сущностей для сборки.") - # Получаем идентификаторы чанков - chunk_ids = [chunk.id for chunk in chunk_entities] + if neighbors_max_distance > 0: + neighbors = self.repository.get_neighboring_entities( + entities, neighbors_max_distance + ) + neighbors = [e.deserialize() for e in neighbors] + entities.extend(neighbors) - # Получаем соседние чанки - neighboring_chunks = self.repository.get_neighboring_chunks( - chunk_ids, max_distance - ) + logger.info(f"Получено {len(entities)} сущностей для сборки с соседями.") - # Объединяем исходные сущности с соседними чанками - result = list(original_entities) - for chunk in neighboring_chunks: - if chunk not in result: - result.append(chunk) + if scores is None: + logger.info( + "Оценки не предоставлены, используем порядковые номера в обратном порядке." + ) + scores = [float(i) for i in range(len(entities), 0, -1)] - # Получаем документы и связи для всех чанков - all_chunk_ids = [chunk.id for chunk in result if isinstance(chunk, Chunk)] + id_to_score = {entity.id: score for entity, score in zip(entities, scores)} - docs = self.repository.get_document_for_chunks(all_chunk_ids) - links = self.repository.get_related_entities( - all_chunk_ids, relation_name="document_to_chunk", as_target=True + groups: list[GroupedEntities[DocumentAsEntity]] = ( + self.repository.group_entities_hierarchically( + entities=entities, + root_type=DocumentAsEntity, + ) ) - # Добавляем документы и связи в результат - for doc in docs: - if doc not in result: - result.append(doc) - - for link in links: - if link not in result: - result.append(link) - - return result - - def _get_strategy_instance(self, strategy_name: str) -> ChunkingStrategy: - """ - Создает экземпляр стратегии чанкинга по имени. - - Args: - strategy_name: Имя стратегии + logger.info(f"Сгруппировано {len(groups)} документов.") - Returns: - Экземпляр соответствующей стратегии - """ - # Используем словарь для маппинга имен стратегий на их классы - strategies = { - "fixed_size": "..chunking.specific_strategies.fixed_size_chunking.FixedSizeChunkingStrategy", + document_scores = { + group.composer.id: max( + id_to_score[eid.id] for eid in group.entities if eid.id in id_to_score + ) + for group in groups + if any(eid.id in id_to_score for eid in group.entities) } - # Если стратегия зарегистрирована в self.strategy_map, используем её - if strategy_name in self.strategy_map: - return self.strategy_map[strategy_name]() - - # Если стратегия известна, импортируем и инициализируем её - if strategy_name in strategies: - import importlib - - module_path, class_name = strategies[strategy_name].rsplit(".", 1) - try: - # Конвертируем относительный путь в абсолютный - abs_module_path = f"ntr_text_fragmentation{module_path[2:]}" - module = importlib.import_module(abs_module_path) - strategy_class = getattr(module, class_name) - return strategy_class() - except (ImportError, AttributeError) as e: - # Если импорт не удался, используем стратегию по умолчанию - from ..chunking.specific_strategies.fixed_size_chunking import \ - FixedSizeChunkingStrategy - - return FixedSizeChunkingStrategy() + groups = sorted( + groups, key=lambda x: document_scores[x.composer.id], reverse=True + ) + groups = list(groups)[:max_documents] - # По умолчанию используем стратегию с фиксированным размером - from ..chunking.specific_strategies.fixed_size_chunking import \ - FixedSizeChunkingStrategy + builded_documents = [ + self._build_document(group, include_tables, document_prefix).replace( + "\n", "\n\n" + ) + for group in groups + ] + return "\n\n".join(builded_documents) - return FixedSizeChunkingStrategy() + def _build_document( + self, + group: GroupedEntities, + include_tables: bool = True, + document_prefix: str = "[Источник] - ", + ) -> str: + document = group.composer + entities = group.entities + + name = document.name + + strategy = document.chunking_strategy_ref + builded_chunks = None + builded_tables = None + if strategy is None: + logger.warning(f"Стратегия чанкинга не указана для документа {name}") + else: + strategy_class = chunking_registry.get(strategy) + builded_chunks = strategy_class.dechunk(self.repository, entities) + + if include_tables: + builded_tables = self.tables_processor.build(self.repository, entities) + + result_text = f"## {document_prefix}{name}\n\n" + if builded_chunks: + result_text += f'### Текст\n{builded_chunks}\n\n' + if builded_tables: + result_text += f'### Таблицы\n{builded_tables}\n\n' + + return result_text + + def _deserialize_all(self, groups: list[GroupedEntities]) -> list[GroupedEntities]: + return [ + GroupedEntities( + composer=group.composer.deserialize(), + entities=[entity.deserialize() for entity in group.entities], + ) + for group in groups + ] diff --git a/lib/extractor/ntr_text_fragmentation/integrations/__init__.py b/lib/extractor/ntr_text_fragmentation/integrations/__init__.py index a615b54c85be9b31d1b45546c4e3314314e49bd7..a37025f5e5f4e580a94a3047e59d7217b9a10b81 100644 --- a/lib/extractor/ntr_text_fragmentation/integrations/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/integrations/__init__.py @@ -2,8 +2,9 @@ Модуль интеграций с внешними хранилищами данных и ORM системами. """ -from .sqlalchemy_repository import SQLAlchemyEntityRepository +from ..repositories.in_memory_repository import InMemoryEntityRepository +# SQLAlchemy не импортируется, чтобы не тащить лишние зависимости в основной код __all__ = [ - "SQLAlchemyEntityRepository", + "InMemoryEntityRepository", ] diff --git a/lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/__init__.py b/lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b9f7b51ec35295d0565b4b5cc8f8ce25849f0f --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/__init__.py @@ -0,0 +1,6 @@ +from .sqlalchemy_repository import SQLAlchemyEntityRepository + +__all__ = [ + "SQLAlchemyEntityRepository", +] + diff --git a/lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/sqlalchemy_repository.py b/lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/sqlalchemy_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f7d128bf8700ac7657d02092a3e725ab5cb7f3 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/sqlalchemy_repository.py @@ -0,0 +1,455 @@ +""" +Реализация EntityRepository для работы с SQLAlchemy. +""" +# Добавляем импорт logging и создаем логгер +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Type +from uuid import UUID + +logger = logging.getLogger(__name__) + +from sqlalchemy import Column, and_, or_, select +from sqlalchemy.orm import Session, sessionmaker + +from ...models import LinkerEntity +from ...repositories.entity_repository import EntityRepository, GroupedEntities + +Base = Any + + +class SQLAlchemyEntityRepository(EntityRepository, ABC): + """ + Абстрактная реализация EntityRepository для работы с базой данных через SQLAlchemy. + + Требует определения методов `_entity_model_class` и + `_map_db_entity_to_linker_entity` в дочерних классах для работы с конкретной + моделью SQLAlchemy и маппинга на LinkerEntity. + """ + + def __init__(self, db_session_factory: sessionmaker[Session]): + """ + Инициализирует репозиторий с фабрикой сессий SQLAlchemy. + + Args: + db_session_factory: Фабрика сессий SQLAlchemy (sessionmaker). + """ + self.db = db_session_factory + + @property + @abstractmethod + def _entity_model_class(self) -> Type[Base]: + """Возвращает класс модели SQLAlchemy, используемый этим репозиторием.""" + pass + + @abstractmethod + def _map_db_entity_to_linker_entity(self, db_entity: Base) -> LinkerEntity: + """Преобразует объект модели SQLAlchemy в объект LinkerEntity.""" + pass + + def _get_id_column(self) -> Column: + """Возвращает колонку ID (uuid или id) из модели.""" + entity_model = self._entity_model_class + # SQLAlchemy 2.0 style attribute access if Base is DeclarativeBase + id_column = getattr(entity_model, 'uuid', getattr(entity_model, 'id', None)) + if id_column is None: + raise AttributeError(f"Модель {entity_model.__name__} не имеет атрибута/колонки 'id' или 'uuid'") + # Ensure it's a Column object if using older style mapping + # If using 2.0 MappedAsDataclass, this might need adjustment + # For now, assuming it returns something comparable + return id_column + + def _normalize_entities( + self, entities: Iterable[UUID] | Iterable[LinkerEntity] + ) -> list[UUID]: + """Преобразует входные данные в список UUID.""" + result = [] + if entities is None: + return result + for entity in entities: + if isinstance(entity, UUID): + result.append(entity) + elif isinstance(entity, LinkerEntity): + result.append(entity.id) + return result + + def get_entities_by_ids( + self, entity_ids: Iterable[UUID] + ) -> List[LinkerEntity]: + """ + Получить сущности по списку идентификаторов UUID. + + Args: + entity_ids: Итерируемый объект с UUID сущностей. + + Returns: + Список найденных сущностей LinkerEntity. + """ + ids_list = list(entity_ids) + if not ids_list: + return [] + + string_ids = [str(eid) for eid in ids_list] + entity_model = self._entity_model_class + id_column = self._get_id_column() + + with self.db() as session: + db_entities = session.execute( + select(entity_model).where(id_column.in_(string_ids)) + ).scalars().all() + + return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities] + + def group_entities_hierarchically( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + root_type: Type[LinkerEntity], + max_levels: int = 10, + sort: bool = True, + ) -> list[GroupedEntities[LinkerEntity]]: + """ + Группирует сущности по корневым элементам иерархии. + + Ищет родительские связи (где сущность является target_id), поднимаясь + вверх по иерархии до `max_levels` или до нахождения `root_type`. + + Args: + entities: Список идентификаторов UUID или сущностей LinkerEntity. + root_type: Класс корневого типа (например, DocumentAsEntity). + max_levels: Максимальная глубина поиска вверх по иерархии. + sort: Флаг для сортировки сущностей в группах. + + Returns: + Список групп сущностей `GroupedEntities`. + """ + entity_ids_list = self._normalize_entities(entities) + if not entity_ids_list: + return [] + + entity_model = self._entity_model_class + id_column = self._get_id_column() + root_type_str = root_type.__name__ + logger.info(f"[group_hierarchically] Искомый тип корня: '{root_type_str}'") + + entity_type_column = getattr(entity_model, 'entity_type', getattr(entity_model, 'type', None)) + source_id_column = getattr(entity_model, 'source_id', None) + target_id_column = getattr(entity_model, 'target_id', None) + + if not all([entity_type_column, source_id_column, target_id_column]): + raise AttributeError(f"Модель {entity_model.__name__} не имеет необходимых колонок: 'entity_type'/'type', 'source_id', 'target_id'") + + entity_to_root_cache: Dict[str, str | None] = {} + fetched_entities: Dict[str, Base] = {} + + with self.db() as session: + + def _fetch_entity(entity_id_str: str) -> Base | None: + """Загружает сущность из БД, если её еще нет в fetched_entities.""" + if entity_id_str not in fetched_entities: + entity = session.get(entity_model, entity_id_str) + if entity is None: + stmt = select(entity_model).where(id_column == entity_id_str) + entity = session.execute(stmt).scalar_one_or_none() + fetched_entities[entity_id_str] = entity # Store entity or None + return fetched_entities[entity_id_str] + + def _find_root(entity_id_str: str, level: int) -> str | None: + """Рекурсивный поиск корневой сущности.""" + if level > max_levels or not entity_id_str: + return None + if entity_id_str in entity_to_root_cache: + return entity_to_root_cache[entity_id_str] + + db_entity = _fetch_entity(entity_id_str) + if not db_entity: + logger.warning(f"[_find_root] Не удалось найти сущность с ID {entity_id_str}") + entity_to_root_cache[entity_id_str] = None + return None + + current_entity_type = getattr(db_entity, entity_type_column.name) + + if current_entity_type == root_type_str: + # logger.debug(f"[_find_root] Сущность {entity_id_str} сама является корнем типа '{root_type_str}'") + entity_to_root_cache[entity_id_str] = entity_id_str + return entity_id_str + + parent_id_str = getattr(db_entity, target_id_column.name, None) + + root_id = None + if parent_id_str: + # logger.debug(f"[_find_root] Сущность {entity_id_str} указывает на родителя {parent_id_str} через target_id.") + root_id = _find_root(str(parent_id_str), level + 1) + + entity_to_root_cache[entity_id_str] = root_id + return root_id + + roots_map: Dict[str, str | None] = {} + for start_entity_id in entity_ids_list: + start_entity_id_str = str(start_entity_id) + if start_entity_id_str not in roots_map: + found_root = _find_root(start_entity_id_str, 0) + roots_map[start_entity_id_str] = found_root + # if found_root: + # logger.debug(f"[group_hierarchically] Найден корень {found_root} для сущности {start_entity_id_str}") + + + groups: Dict[str, List[Base]] = defaultdict(list) + initial_db_entities = session.execute( + select(entity_model).where(id_column.in_([str(eid) for eid in entity_ids_list])) + ).scalars().all() + + found_roots_count = 0 + grouped_entities_count = 0 + for db_entity in initial_db_entities: + entity_id_str = str(getattr(db_entity, id_column.name)) + root_id = roots_map.get(entity_id_str) + if root_id: + groups[root_id].append(db_entity) + grouped_entities_count += 1 + if len(groups[root_id]) == 1: + found_roots_count += 1 + + logger.info(f"[group_hierarchically] Сгруппировано {grouped_entities_count} сущностей в {len(groups)} групп (найдено {found_roots_count} уникальных корней).") + + result: list[GroupedEntities[LinkerEntity]] = [] + for root_id_str, db_entities_list in groups.items(): + root_db_entity = fetched_entities.get(root_id_str) + if root_db_entity: + composer = self._map_db_entity_to_linker_entity(root_db_entity) + grouped_linker_entities = [ + self._map_db_entity_to_linker_entity(db_e) for db_e in db_entities_list + ] + + if sort: + grouped_linker_entities.sort( + key=lambda entity: ( + str(getattr(entity, 'groupper', getattr(entity, 'entity_type', getattr(entity, 'type', '')))), + int(getattr(entity, 'number_in_relation', getattr(entity, 'chunk_index', float('inf')))) + ) + ) + # logger.debug(f"[group_hierarchically] Отсортирована группа для корня {root_id_str}") + + result.append(GroupedEntities(composer=composer, entities=grouped_linker_entities)) + + logger.info(f"[group_hierarchically] Сформировано {len(result)} объектов GroupedEntities.") + return result + + + def get_neighboring_entities( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + max_distance: int = 1, + ) -> list[LinkerEntity]: + """ + Получить соседние сущности для указанных сущностей. + + Соседство определяется на основе общего "родителя" (target_id сущности, + если source_id is None) и близости по `number_in_relation` или + `chunk_index` в рамках одной группы (`entity_type` или `type`). + + Args: + entities: Список идентификаторов UUID или сущностей LinkerEntity. + max_distance: Максимальное расстояние (по порядку) между соседями. + + Returns: + Список соседних сущностей LinkerEntity. + """ + entity_ids_list = self._normalize_entities(entities) + if not entity_ids_list or max_distance < 1: + return [] + + string_entity_ids = {str(eid) for eid in entity_ids_list} + entity_model = self._entity_model_class + id_column = self._get_id_column() + source_id_column = getattr(entity_model, 'source_id', None) + target_id_column = getattr(entity_model, 'target_id', None) + order_column = getattr(entity_model, 'number_in_relation', None) + group_column = getattr(entity_model, 'entity_type', getattr(entity_model, 'type', None)) + + if not all([source_id_column, target_id_column, order_column, group_column]): + raise AttributeError(f"Модель {entity_model.__name__} не имеет необходимых колонок: 'source_id', 'target_id', 'chunk_index'/'number_in_relation', 'entity_type'/'type'") + + neighbor_entities_map: Dict[str, Base] = {} + + with self.db() as session: + initial_entities_query = select(entity_model).where(id_column.in_(list(string_entity_ids))) + initial_db_entities = session.execute(initial_entities_query).scalars().all() + + valid_initial_entities_info: Dict[tuple[str, str], list[Dict[str, Any]]] = defaultdict(list) + parent_group_keys: set[tuple[str, str]] = set() + + for db_entity in initial_db_entities: + entity_id_str = str(getattr(db_entity, id_column.name)) + source_id = getattr(db_entity, source_id_column.name, None) + target_id = getattr(db_entity, target_id_column.name, None) + group_value = getattr(db_entity, group_column.name, None) + order_value = getattr(db_entity, order_column.name, None) + + if source_id is None and target_id is not None and group_value is not None and order_value is not None: + parent_id_str = str(target_id) + group_key = (parent_id_str, str(group_value)) + parent_group_keys.add(group_key) + valid_initial_entities_info[group_key].append({ + "id": entity_id_str, + "order": order_value + }) + + if not parent_group_keys: + return [] + + sibling_conditions = [] + for parent_id, group_val in parent_group_keys: + sibling_conditions.append( + and_( + target_id_column == parent_id, + group_column == group_val, + source_id_column.is_(None), + order_column.isnot(None) + ) + ) + + potential_siblings_query = select(entity_model).where(or_(*sibling_conditions)) + potential_siblings = session.execute(potential_siblings_query).scalars().all() + + for sibling_entity in potential_siblings: + sibling_id_str = str(getattr(sibling_entity, id_column.name)) + + if sibling_id_str in string_entity_ids: + continue + + sibling_target_id = getattr(sibling_entity, target_id_column.name) + sibling_group = getattr(sibling_entity, group_column.name) + sibling_order = getattr(sibling_entity, order_column.name) + + if sibling_target_id is None or sibling_group is None or sibling_order is None: + logger.warning(f"Потенциальный сиблинг {sibling_id_str} не имеет target_id, группы или порядка, хотя был выбран запросом.") + continue + + sibling_parent_id_str = str(sibling_target_id) + sibling_group_str = str(sibling_group) + group_key = (sibling_parent_id_str, sibling_group_str) + + if group_key in valid_initial_entities_info: + for initial_info in valid_initial_entities_info[group_key]: + initial_order = initial_info["order"] + distance = abs(sibling_order - initial_order) + + if 0 < distance <= max_distance and sibling_id_str not in neighbor_entities_map: + neighbor_entities_map[sibling_id_str] = sibling_entity + break + + + return [self._map_db_entity_to_linker_entity(ne) for ne in neighbor_entities_map.values()] + + + def get_related_entities( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + relation_type: Type[LinkerEntity] | None = None, + as_source: bool = False, + as_target: bool = False, + as_owner: bool = False, # Добавлено + ) -> List[LinkerEntity]: + """ + Получить сущности, связанные с указанными, а также сами связи. + + Args: + entities: Список идентификаторов UUID или сущностей LinkerEntity. + relation_type: Опциональный класс связи для фильтрации (например, CompositionLink). + as_source: Искать связи, где entities - источники (`source_id`). + as_target: Искать связи, где entities - цели (`target_id`). + as_owner: Искать связи, где entities - владельцы (`source_id`, предполагая связь владения). + + Returns: + Список связанных сущностей LinkerEntity и самих связей. + """ + entity_ids_list = self._normalize_entities(entities) + if not entity_ids_list: + return [] + + if not as_source and not as_target and not as_owner: + as_source = True + as_target = True + + string_ids = [str(eid) for eid in entity_ids_list] + entity_model = self._entity_model_class + id_column = self._get_id_column() + source_id_column = getattr(entity_model, 'source_id', None) + target_id_column = getattr(entity_model, 'target_id', None) + entity_type_column = getattr(entity_model, 'entity_type', getattr(entity_model, 'type', None)) + + if not all([source_id_column, target_id_column, entity_type_column]): + raise AttributeError(f"Модель {entity_model.__name__} не имеет необходимых колонок: 'source_id', 'target_id', 'entity_type'/'type'") + + related_db_objects_map: Dict[str, Base] = {} + + relation_type_str = None + if relation_type: + relation_type_str = relation_type.__name__ + + with self.db() as session: + def _add_related(db_objects: Iterable[Base]): + """Helper function to add objects and fetch related source/target entities.""" + ids_to_fetch = set() + for db_obj in db_objects: + obj_id = str(getattr(db_obj, id_column.name)) + if obj_id not in related_db_objects_map: + related_db_objects_map[obj_id] = db_obj + source_id = getattr(db_obj, source_id_column.name, None) + target_id = getattr(db_obj, target_id_column.name, None) + if source_id: + ids_to_fetch.add(str(source_id)) + if target_id: + ids_to_fetch.add(str(target_id)) + + ids_to_fetch.difference_update(related_db_objects_map.keys()) + if ids_to_fetch: + fetched = session.execute( + select(entity_model).where(id_column.in_(list(ids_to_fetch))) + ).scalars().all() + for fetched_obj in fetched: + fetched_id = str(getattr(fetched_obj, id_column.name)) + if fetched_id not in related_db_objects_map: + related_db_objects_map[fetched_id] = fetched_obj + + if as_source or as_owner: + conditions = [source_id_column.in_(string_ids)] + if relation_type_str: + conditions.append(entity_type_column == relation_type_str) + source_links_query = select(entity_model).where(and_(*conditions)) + source_links = session.execute(source_links_query).scalars().all() + _add_related(source_links) + + if as_target: + conditions = [target_id_column.in_(string_ids)] + if relation_type_str: + conditions.append(entity_type_column == relation_type_str) + target_links_query = select(entity_model).where(and_(*conditions)) + target_links = session.execute(target_links_query).scalars().all() + _add_related(target_links) + + final_map: Dict[UUID, LinkerEntity] = {} + + for db_obj in related_db_objects_map.values(): + linker_entity = self._map_db_entity_to_linker_entity(db_obj) + if relation_type: + is_link = linker_entity.is_link() + is_relevant_link = False + if is_link: + link_source_uuid = linker_entity.source_id + link_target_uuid = linker_entity.target_id + original_uuids = {UUID(s_id) for s_id in string_ids} + + if (as_source or as_owner) and link_source_uuid in original_uuids: + is_relevant_link = True + elif as_target and link_target_uuid in original_uuids: + is_relevant_link = True + + if is_relevant_link and not isinstance(linker_entity, relation_type): + continue + + if linker_entity.id not in final_map: + final_map[linker_entity.id] = linker_entity + + return list(final_map.values()) diff --git a/lib/extractor/ntr_text_fragmentation/models/__init__.py b/lib/extractor/ntr_text_fragmentation/models/__init__.py index 738b25982483c288bf71d1ab863557a9e37ee654..4314f0a324fec195d8bb28316b1d90ff14f084c7 100644 --- a/lib/extractor/ntr_text_fragmentation/models/__init__.py +++ b/lib/extractor/ntr_text_fragmentation/models/__init__.py @@ -2,12 +2,13 @@ Модуль моделей данных. """ -from .chunk import Chunk from .document import DocumentAsEntity -from .linker_entity import LinkerEntity +from .linker_entity import LinkerEntity, Entity, Link, register_entity __all__ = [ - "LinkerEntity", - "DocumentAsEntity", - "Chunk", -] \ No newline at end of file + "LinkerEntity", + "DocumentAsEntity", + "Entity", + "Link", + "register_entity", +] diff --git a/lib/extractor/ntr_text_fragmentation/models/document.py b/lib/extractor/ntr_text_fragmentation/models/document.py index 9d661a195b938d8ef99a959dec0b86b5afb259f1..d500e251561bf0b6c26f0dbf20beef7551f0c762 100644 --- a/lib/extractor/ntr_text_fragmentation/models/document.py +++ b/lib/extractor/ntr_text_fragmentation/models/document.py @@ -2,39 +2,51 @@ Класс для представления документа как сущности. """ -from dataclasses import dataclass +from dataclasses import dataclass, field -from .linker_entity import LinkerEntity, register_entity +from .linker_entity import Entity, register_entity @register_entity @dataclass -class DocumentAsEntity(LinkerEntity): +class DocumentAsEntity(Entity): """ Класс для представления документа как сущности в системе извлечения и сборки. + Содержит ссылки на классы стратегии чанкинга и обработчика таблиц, + использовавшихся при деструктуризации. """ doc_type: str = "unknown" - + + chunking_strategy_ref: str | None = None + + type: str = field(default="DocumentAsEntity") + @classmethod - def deserialize(cls, data: LinkerEntity) -> 'DocumentAsEntity': + def _deserialize_to_me(cls, data: Entity) -> 'DocumentAsEntity': """ Десериализует DocumentAsEntity из объекта LinkerEntity. - + Args: data: Объект LinkerEntity для преобразования в DocumentAsEntity - + Returns: Десериализованный объект DocumentAsEntity """ + if not isinstance(data, Entity): + raise TypeError(f"Ожидался LinkerEntity, получен {type(data)}") + metadata = data.metadata or {} - - # Получаем тип документа из метаданных или используем значение по умолчанию - doc_type = metadata.get('_doc_type', 'unknown') - + + # Получаем поля из атрибутов или метаданных + doc_type = getattr(data, 'doc_type', metadata.get('_doc_type', 'unknown')) + strategy_ref = getattr( + data, 'chunking_strategy_ref', metadata.get('_chunking_strategy_ref', None) + ) + # Создаем чистые метаданные без служебных полей clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')} - + return cls( id=data.id, name=data.name, @@ -44,6 +56,8 @@ class DocumentAsEntity(LinkerEntity): source_id=data.source_id, target_id=data.target_id, number_in_relation=data.number_in_relation, - type="DocumentAsEntity", - doc_type=doc_type + groupper=data.groupper, + type=cls.__name__, + doc_type=doc_type, + chunking_strategy_ref=strategy_ref, ) diff --git a/lib/extractor/ntr_text_fragmentation/models/linker_entity.py b/lib/extractor/ntr_text_fragmentation/models/linker_entity.py index 05619db99e101ffbbd9b4ccc48dd46b6918ffa63..8adcd14e66ce4c92f229fb0c10638b3946ae493b 100644 --- a/lib/extractor/ntr_text_fragmentation/models/linker_entity.py +++ b/lib/extractor/ntr_text_fragmentation/models/linker_entity.py @@ -2,56 +2,75 @@ Базовый абстрактный класс для всех сущностей с поддержкой триплетного подхода. """ +import logging import uuid -from abc import abstractmethod from dataclasses import dataclass, field, fields from uuid import UUID +logger = logging.getLogger(__name__) + @dataclass class LinkerEntity: """ Общий класс для всех сущностей в системе извлечения и сборки. Поддерживает триплетный подход, где каждая сущность может опционально связывать две другие сущности. - + Attributes: id (UUID): Уникальный идентификатор сущности. name (str): Название сущности. text (str): Текстое представление сущности. in_search_text (str | None): Текст для поиска. Если задан, используется в __str__, иначе используется обычное представление. metadata (dict): Метаданные сущности. - source_id (UUID | None): Опциональный идентификатор исходной сущности. + source_id (UUID | None): Опциональный идентификатор исходной сущности. Если указан, эта сущность является связью. - target_id (UUID | None): Опциональный идентификатор целевой сущности. + target_id (UUID | None): Опциональный идентификатор целевой сущности. Если указан, эта сущность является связью. - number_in_relation (int | None): Используется в случае связей один-ко-многим, + number_in_relation (int | None): Используется в случае связей один-ко-многим, указывает номер целевой сущности в списке. type (str): Тип сущности. """ - id: UUID - name: str - text: str - metadata: dict # JSON с метаданными + id: UUID = field(default_factory=uuid.uuid4) + name: str = field(default="") + text: str = field(default="") + metadata: dict = field(default_factory=dict) in_search_text: str | None = None source_id: UUID | None = None target_id: UUID | None = None number_in_relation: int | None = None - type: str = field(default_factory=lambda: "Entity") + groupper: str | None = None + type: str | None = None + + @property + def owner_id(self) -> UUID | None: + """ + Возвращает идентификатор владельца сущности. + """ + if self.is_link(): + return None + return self.target_id + + @owner_id.setter + def owner_id(self, value: UUID | None): + """ + Устанавливает идентификатор владельца сущности. + """ + if self.is_link(): + raise ValueError("Связь не может иметь владельца") + self.target_id = value def __post_init__(self): if self.id is None: self.id = uuid.uuid4() - - # Проверяем корректность полей связи - if (self.source_id is not None and self.target_id is None) or \ - (self.source_id is None and self.target_id is not None): - raise ValueError("source_id и target_id должны быть либо оба указаны, либо оба None") + + if self.type is None: + self.type = self.__class__.__name__ def is_link(self) -> bool: """ Проверяет, является ли сущность связью (имеет и source_id, и target_id). - + Returns: bool: True, если сущность является связью, иначе False """ @@ -85,7 +104,7 @@ class LinkerEntity: and self.text == other.text and self.type == other.type ) - + # Если мы имеем дело со связями, также проверяем поля связи if self.is_link() or other.is_link(): return ( @@ -93,62 +112,51 @@ class LinkerEntity: and self.source_id == other.source_id and self.target_id == other.target_id ) - + return basic_equality def serialize(self) -> 'LinkerEntity': """ - Сериализует сущность в простейшую форму сущности, передавая все дополнительные поля в метаданные. + Сериализует сущность в базовый класс `LinkerEntity`, сохраняя все дополнительные поля в метаданные. """ - # Получаем список полей базового класса - known_fields = {field.name for field in fields(LinkerEntity)} + base_fields = {f.name for f in fields(LinkerEntity)} + current_fields = {f.name for f in fields(self.__class__)} + extra_field_names = current_fields - base_fields - # Получаем все атрибуты текущего объекта - dict_entity = {} - for attr_name in dir(self): - # Пропускаем служебные атрибуты, методы и уже известные поля - if ( - attr_name.startswith('_') - or attr_name in known_fields - or callable(getattr(self, attr_name)) - ): - continue - - # Добавляем дополнительные поля в словарь - dict_entity[attr_name] = getattr(self, attr_name) + # Собираем только дополнительные поля, определенные в подклассе + extra_fields_dict = {name: getattr(self, name) for name in extra_field_names} # Преобразуем имена дополнительных полей, добавляя префикс "_" - dict_entity = {f'_{name}': value for name, value in dict_entity.items()} + prefixed_extra_fields = { + f'_{name}': value for name, value in extra_fields_dict.items() + } - # Объединяем с существующими метаданными - dict_entity = {**dict_entity, **self.metadata} + # Объединяем с существующими метаданными (если они были установлены вручную) + final_metadata = {**prefixed_extra_fields, **self.metadata} result_type = self.type if result_type == "Entity": result_type = self.__class__.__name__ - # Создаем базовый объект LinkerEntity с новыми метаданными + # Создаем базовый объект LinkerEntity return LinkerEntity( id=self.id, name=self.name, text=self.text, in_search_text=self.in_search_text, - metadata=dict_entity, + metadata=final_metadata, # Используем собранные метаданные source_id=self.source_id, target_id=self.target_id, number_in_relation=self.number_in_relation, + groupper=self.groupper, type=result_type, ) - @classmethod - @abstractmethod - def deserialize(cls, data: 'LinkerEntity') -> 'Self': + def deserialize(self) -> 'LinkerEntity': """ - Десериализует сущность из простейшей формы сущности, учитывая все дополнительные поля в метаданных. + Десериализует сущность в нужный тип на основе поля type. """ - raise NotImplementedError( - f"Метод deserialize для класса {cls.__class__.__name__} не реализован" - ) + return self._deserialize(self) # Реестр для хранения всех наследников LinkerEntity _entity_classes = {} @@ -157,61 +165,85 @@ class LinkerEntity: def register_entity_class(cls, entity_class): """ Регистрирует класс-наследник в реестре. - + Args: entity_class: Класс для регистрации """ entity_type = entity_class.__name__ cls._entity_classes[entity_type] = entity_class - # Также регистрируем по типу, если он отличается от имени класса if hasattr(entity_class, 'type') and isinstance(entity_class.type, str): cls._entity_classes[entity_class.type] = entity_class - + @classmethod - def deserialize(cls, data: 'LinkerEntity') -> 'LinkerEntity': + def _deserialize(cls, data: 'LinkerEntity') -> 'LinkerEntity': """ Десериализует сущность в нужный тип на основе поля type. - + Args: data: Сериализованная сущность типа LinkerEntity - + Returns: Десериализованная сущность правильного типа """ # Получаем тип сущности entity_type = data.type - + # Проверяем реестр классов if entity_type in cls._entity_classes: try: - return cls._entity_classes[entity_type].deserialize(data) - except (AttributeError, NotImplementedError) as e: - # Если метод не реализован, возвращаем исходную сущность + return cls._entity_classes[entity_type]._deserialize_to_me(data) + except Exception as e: + logger.error(f"Ошибка при вызове _deserialize_to_me для {entity_type}: {e}", exc_info=True) return data - - # Если тип не найден в реестре, просто возвращаем исходную сущность - # Больше не используем опасное сканирование sys.modules + return data + @classmethod + def _deserialize_to_me(cls, data: 'LinkerEntity') -> 'LinkerEntity': + """ + Десериализует сущность в нужный тип на основе поля type. + """ + return cls( + id=data.id, + name=data.name, + text=data.text, + in_search_text=data.in_search_text, + metadata=data.metadata, + source_id=data.source_id, + target_id=data.target_id, + number_in_relation=data.number_in_relation, + type=data.type, + groupper=data.groupper, + ) + + +# Алиасы для удобства +Link = LinkerEntity +Entity = LinkerEntity + # Декоратор для регистрации производных классов def register_entity(cls): """ Декоратор для регистрации классов-наследников LinkerEntity. - + Пример использования: - + @register_entity class MyEntity(LinkerEntity): type = "my_entity" - + Args: cls: Класс, который нужно зарегистрировать - + Returns: Исходный класс (без изменений) """ # Регистрируем класс в реестр, используя его имя или указанный тип - entity_type = getattr(cls, 'type', cls.__name__) + entity_type = cls.__name__ LinkerEntity._entity_classes[entity_type] = cls + + if hasattr(cls, 'type') and isinstance(cls.type, str): + LinkerEntity._entity_classes[cls.type] = cls + return cls diff --git a/lib/extractor/ntr_text_fragmentation/repositories/__init__.py b/lib/extractor/ntr_text_fragmentation/repositories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5e48542d731661a79b419f3cd7fff1aeb891ac --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/repositories/__init__.py @@ -0,0 +1,8 @@ +from .entity_repository import EntityRepository, GroupedEntities +from .in_memory_repository import InMemoryEntityRepository + +__all__ = [ + "EntityRepository", + "GroupedEntities", + "InMemoryEntityRepository", +] diff --git a/lib/extractor/ntr_text_fragmentation/repositories/entity_repository.py b/lib/extractor/ntr_text_fragmentation/repositories/entity_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..dc66c5b71d8c30bca333feb18245a3b1c20956de --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/repositories/entity_repository.py @@ -0,0 +1,106 @@ +""" +Интерфейс репозитория сущностей. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, Iterable, Type, TypeVar +from uuid import UUID + +from ..models import LinkerEntity + +T = TypeVar('T', bound=LinkerEntity) + + +@dataclass +class GroupedEntities(Generic[T]): + composer: T + entities: list[LinkerEntity] + + +class EntityRepository(ABC): + """ + Абстрактный интерфейс для доступа к хранилищу сущностей. + Позволяет InjectionBuilder получать нужные сущности независимо от их хранилища. + + Этот интерфейс определяет только методы для получения сущностей. + Логика сохранения и изменения сущностей остается за пределами этого интерфейса + и должна быть реализована в конкретных классах, расширяющих данный интерфейс. + """ + + @abstractmethod + def get_entities_by_ids( + self, + entity_ids: Iterable[UUID], + ) -> list[LinkerEntity]: + """ + Получить сущности по списку идентификаторов. + Может возвращать экземпляры подклассов LinkerEntity. + """ + pass + + @abstractmethod + def group_entities_hierarchically( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + root_type: Type[LinkerEntity], + max_levels: int = 10, + sort: bool = True, + ) -> list[GroupedEntities[LinkerEntity]]: + """ + Группирует сущности по корневым элементам иерархии, поддерживая + многоуровневые связи (например, строка → подтаблица → таблица → документ). + + Args: + entities: Список идентификаторов или сущностей для группировки + root_type: Корневой тип сущностей для группировки (например, DocumentAsEntity) + max_levels: Максимальная глубина поиска корневого элемента + sort: Флаг для сортировки сущностей в группах по их позициям + + Returns: + Список групп сущностей, объединенных по корневому объекту + """ + pass + + @abstractmethod + def get_neighboring_entities( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + max_distance: int = 1, + ) -> list[LinkerEntity]: + """ + Получить соседние сущности для указанных сущностей. + Порядок определяется через CompositionLink и number_in_relation. + + Args: + entities: Список идентификаторов сущностей (UUID) или самих сущностей (LinkerEntity). + max_distance: Максимальное расстояние между сущностями (по умолчанию 1). + + Returns: + Список соседних сущностей. + """ + pass + + @abstractmethod + def get_related_entities( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + relation_type: Type[LinkerEntity] | None = None, + as_source: bool = False, + as_target: bool = False, + as_owner: bool = False, + ) -> list[LinkerEntity]: + """ + Получить сущности, связанные с указанными. Возвращает как сущности, так и связи к ним ведущие. + + Args: + entities: Список идентификаторов сущностей (UUID) или самих сущностей (LinkerEntity). + relation_type: Опциональный тип связи для фильтрации (например, CompositionLink) + as_source: Искать связи, где entities - источники + as_target: Искать связи, где entities - цели + as_owner: Искать связи, где entities - владельцы (связи-композиции) + + Returns: + Список связанных сущностей и самих связей + """ + pass diff --git a/lib/extractor/ntr_text_fragmentation/repositories/in_memory_repository.py b/lib/extractor/ntr_text_fragmentation/repositories/in_memory_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..c856237eac5c7ea9de10c881e1f20405c4f14ed4 --- /dev/null +++ b/lib/extractor/ntr_text_fragmentation/repositories/in_memory_repository.py @@ -0,0 +1,337 @@ +import logging +from collections import defaultdict +from typing import Iterable, Type +from uuid import UUID + +from ..models import LinkerEntity +from .entity_repository import EntityRepository, GroupedEntities + +logger = logging.getLogger(__name__) + + +class InMemoryEntityRepository(EntityRepository): + """ + Реализация EntityRepository, хранящая все сущности в памяти. + Обеспечивает обратную совместимость и используется для тестирования. + """ + + def __init__(self, entities: list[LinkerEntity] | None = None): + """ + Инициализация репозитория с начальным списком сущностей. + + Args: + entities: Начальный список сущностей + """ + self.entities = entities or [] + self.entities_by_id: dict[UUID, LinkerEntity] = {} + self.relations_by_source: dict[UUID, list[LinkerEntity]] = defaultdict(list) + self.relations_by_target: dict[UUID, list[LinkerEntity]] = defaultdict(list) + self.compositions: dict[UUID, list[LinkerEntity]] = defaultdict(list) + + self._build_indices() + + def _build_indices(self) -> None: + """ + Строит индексы для быстрого доступа. + Использует LinkerEntity.deserialize для возможной типизации связей. + """ + self.entities_by_id.clear() + self.relations_by_source.clear() + self.relations_by_target.clear() + self.compositions.clear() + + for entity in self.entities: + try: + deserialized_entity = LinkerEntity._deserialize(entity) + except Exception as e: + logger.warning(f"Error deserializing entity: {e}") + deserialized_entity = entity + + self.entities_by_id[deserialized_entity.id] = deserialized_entity + + if deserialized_entity.is_link(): + self.relations_by_source[deserialized_entity.source_id].append( + deserialized_entity + ) + self.relations_by_target[deserialized_entity.target_id].append( + deserialized_entity + ) + + if deserialized_entity.owner_id is not None: + self.compositions[deserialized_entity.owner_id].append( + deserialized_entity + ) + + logger.info(f"Построены индексы для {len(self.entities)} сущностей.") + logger.info(f"Всего сущностей: {len(self.entities_by_id)}") + logger.info(f"Всего связей: {len(self.relations_by_source)}") + logger.info(f"Всего композиций: {len(self.compositions)}") + + def _normalize_entities( + self, entities: Iterable[UUID] | Iterable[LinkerEntity] + ) -> list[UUID]: + """ + Преобразует входные данные в список UUID. + + Args: + entities: Итерируемый объект с UUID или LinkerEntity + + Returns: + list[UUID]: Список идентификаторов + """ + result = [] + for entity in entities: + if isinstance(entity, UUID): + result.append(entity) + elif isinstance(entity, LinkerEntity): + result.append(entity.id) + return result + + def get_entities_by_ids( + self, entities: Iterable[UUID] | Iterable[LinkerEntity] + ) -> list[LinkerEntity]: + """ + Получить сущности по списку идентификаторов или сущностей. + + Args: + entities: Список идентификаторов или сущностей + + Returns: + list[LinkerEntity]: Список найденных сущностей + """ + entity_ids = self._normalize_entities(entities) + return [ + self.entities_by_id[eid] for eid in entity_ids if eid in self.entities_by_id + ] + + def group_entities_hierarchically( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + root_type: Type[LinkerEntity], + max_levels: int = 10, + sort: bool = True, + ) -> list[GroupedEntities]: + """ + Группирует сущности по корневым элементам иерархии, поддерживая + многоуровневые связи (например, строка → подтаблица → таблица → документ). + + Args: + entities: Список идентификаторов или сущностей для группировки + root_type: Корневой тип сущностей для группировки (например, DocumentAsEntity) + max_levels: Максимальная глубина поиска корневого элемента + sort: Флаг для сортировки сущностей в группах по их позициям + + Returns: + Список групп сущностей, объединенных по корневому объекту + """ + entity_ids = self._normalize_entities(entities) + + # Словарь для хранения найденных корневых элементов для каждой сущности + entity_to_root: dict[UUID, UUID] = {} + + # Если включена сортировка, соберем информацию о позициях сущностей + entity_positions: dict[UUID, tuple[str, int]] = {} + if sort: + for entity_id in entity_ids: + entity = self.entities_by_id.get(entity_id) + if entity: + entity_positions[entity_id] = ( + entity.groupper, + entity.number_in_relation, + ) + + # Функция для нахождения корневого элемента для сущности + def find_root( + entity_id: UUID, visited: set | None = None, level: int = 0 + ) -> UUID | None: + # Проверка на максимальную глубину поиска + if level >= max_levels: + return None + + # Инициализация множества посещенных узлов для отслеживания пути + if visited is None: + visited = set() + + # Проверка, не обрабатывали ли мы уже эту сущность + if entity_id in visited: + return None + + # Добавляем текущую сущность в посещенные + visited.add(entity_id) + + # Проверяем, есть ли уже найденный корень для этой сущности + if entity_id in entity_to_root: + return entity_to_root[entity_id] + + # Проверяем, является ли сама сущность корневым типом + entity = self.entities_by_id.get(entity_id) + + if entity and isinstance(entity, root_type): + return entity_id + + # Получаем родительскую сущность через owner_id + if entity and entity.owner_id: + parent_root = find_root(entity.owner_id, visited, level + 1) + if parent_root: + return parent_root + + return None + + # Находим корневой элемент для каждой сущности + for entity_id in entity_ids: + root_id = find_root(entity_id) + if root_id: + entity_to_root[entity_id] = root_id + + logger.info(f"Найдены корневые элементы для {len(entity_to_root)} сущностей из общего количества {len(entity_ids)}.") + + # Группируем сущности по корневым элементам + root_to_entities: dict[UUID, list[LinkerEntity]] = defaultdict(list) + + for entity_id in entity_ids: + if entity_id in entity_to_root: + root_id = entity_to_root[entity_id] + entity = self.entities_by_id.get(entity_id) + if entity: + root_to_entities[root_id].append(entity) + + # Формируем результат + result = [] + for root_id, entities_list in root_to_entities.items(): + root_entity = self.entities_by_id.get(root_id) + if root_entity: + # Сортируем сущности при формировании групп, если нужно + if sort: + entities_list.sort( + key=lambda entity: entity_positions.get( + entity.id, + ("", float('inf')), # Сущности без позиции в конец + ) + ) + + result.append( + GroupedEntities(composer=root_entity, entities=entities_list) + ) + + return result + + def get_neighboring_entities( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + max_distance: int = 1, + ) -> list[LinkerEntity]: + """ + Получает соседние сущности в пределах указанного расстояния в рамках одной композиционной группы. + + Args: + entities: Список идентификаторов или сущностей + max_distance: Максимальное расстояние между сущностями + + Returns: + list[LinkerEntity]: Список соседних сущностей + """ + entity_ids = self._normalize_entities(entities) + + if not entity_ids: + return [] + + entities = self.get_entities_by_ids(entity_ids) + entities = [entity for entity in entities if entity.owner_id is not None] + + # Ищем соседей + neighbors = { + entity.owner_id: [ + sibling + for sibling in self.compositions.get(entity.owner_id, []) + if ( + sibling.groupper == entity.groupper + and abs(sibling.number_in_relation - entity.number_in_relation) + <= max_distance + ) + ] + for entity in entities + } + + neighbors = { + owner_id: sorted( + neighbors, key=lambda x: (x.groupper, x.number_in_relation) + ) + for owner_id, neighbors in neighbors.items() + } + + neighbor_ids = { + owner_id: [neighbor.id for neighbor in neighbors] + for owner_id, neighbors in neighbors.items() + } + + # Собираем все ID соседей и исключаем исходные сущности + all_neighbor_ids = set(sum(neighbor_ids.values(), [])) - set(entity_ids) + + return self.get_entities_by_ids(all_neighbor_ids) + + def get_related_entities( + self, + entities: Iterable[UUID] | Iterable[LinkerEntity], + relation_type: Type[LinkerEntity] | None = None, + as_source: bool = False, + as_target: bool = False, + as_owner: bool = False, + ) -> list[LinkerEntity]: + """ + Получает связанные сущности и их связи. + + Args: + entities: Список идентификаторов или сущностей + relation_type: Тип связи для фильтрации + as_source: Искать связи, где entities являются источниками + as_target: Искать связи, где entities являются целями + as_owner: Искать связи, где entities являются владельцами (связи-композиции) + + Returns: + list[LinkerEntity]: Список связанных сущностей и их связей + """ + entity_ids = self._normalize_entities(entities) + result = set() + + # Если не указано направление, ищем в обе стороны + if not as_source and not as_target and not as_owner: + as_source = True + as_target = True + as_owner = True + + # Поиск связей, где entities являются источниками + if as_source: + for entity_id in entity_ids: + for relation in self.relations_by_source.get(entity_id, []): + if relation_type is None or isinstance(relation, relation_type): + result.add(relation) + if relation.target_id in self.entities_by_id: + result.add(self.entities_by_id[relation.target_id]) + + # Поиск связей, где entities являются целями + if as_target: + for entity_id in entity_ids: + for relation in self.relations_by_target.get(entity_id, []): + if relation_type is None or isinstance(relation, relation_type): + result.add(relation) + if relation.source_id in self.entities_by_id: + result.add(self.entities_by_id[relation.source_id]) + + # Поиск связей, где entities являются владельцами + if as_owner: + for entity_id in entity_ids: + for child in self.compositions.get(entity_id, []): + if relation_type is None or isinstance(child, relation_type): + result.add(child) + + return list(result) + + def add_entities(self, entities: list[LinkerEntity]) -> None: + """Добавляет сущности в репозиторий и перестраивает индексы.""" + self.entities.extend(entities) + self._build_indices() + + def set_entities(self, entities: list[LinkerEntity]) -> None: + """Устанавливает сущности в репозиторий и перестраивает индексы.""" + self.entities = entities + self._build_indices() diff --git a/lib/extractor/scripts/test_chunking.py b/lib/extractor/scripts/test_chunking.py new file mode 100644 index 0000000000000000000000000000000000000000..e870ef1985249f50cf5ff4c7094c0dad1cd077c9 --- /dev/null +++ b/lib/extractor/scripts/test_chunking.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python +""" +Скрипт для визуального тестирования процесса чанкинга и сборки документа. + +Этот скрипт: +1. Считывает test_input/test.docx с помощью UniversalParser +2. Чанкит документ через Destructurer с fixed_size-стратегией +3. Сохраняет результат чанкинга в test_output/test.csv +4. Выбирает 20-30 случайных чанков из CSV +5. Создает InjectionBuilder с InMemoryEntityRepository +6. Собирает текст из выбранных чанков +7. Сохраняет результат в test_output/test_builded.txt +""" + +import json +import logging +import os +import random +from pathlib import Path +from typing import List +from uuid import UUID + +import pandas as pd +from ntr_fileparser import UniversalParser +from ntr_text_fragmentation import (DocumentAsEntity, EntitiesExtractor, + InjectionBuilder, InMemoryEntityRepository, + LinkerEntity) + + +def setup_logging() -> None: + """Настройка логгирования.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - [%(pathname)s:%(lineno)d] - %(message)s", + ) + + +def ensure_directories() -> None: + """Проверка наличия необходимых директорий.""" + for directory in ["test_input", "test_output"]: + Path(directory).mkdir(parents=True, exist_ok=True) + + +def save_entities_to_csv(entities: List[LinkerEntity], csv_path: str) -> None: + """ + Сохраняет сущности в CSV файл. + + Args: + entities: Список сущностей + csv_path: Путь для сохранения CSV файла + """ + data = [] + for entity in entities: + # Базовые поля для всех типов сущностей + entity_dict = { + "id": str(entity.id), + "type": entity.type, + "name": entity.name, + "text": entity.text, + "metadata": json.dumps(entity.metadata or {}, ensure_ascii=False), + "in_search_text": entity.in_search_text, + "source_id": str(entity.source_id) if entity.source_id else None, + "target_id": str(entity.target_id) if entity.target_id else None, + "number_in_relation": entity.number_in_relation, + "groupper": entity.groupper, + "type": entity.type, + } + + # Дополнительные поля специфичные для подклассов (если они есть в __dict__) + # Это не самый надежный способ, но для скрипта визуализации может подойти + # Сериализация LinkerEntity теперь должна сама класть доп поля в metadata + # for key, value in entity.__dict__.items(): + # if key not in entity_dict and not key.startswith('_'): + # entity_dict[key] = value + + data.append(entity_dict) + + df = pd.DataFrame(data) + # Указываем кодировку UTF-8 при записи CSV + df.to_csv(csv_path, index=False, encoding='utf-8') + logging.info(f"Сохранено {len(entities)} сущностей в {csv_path}") + + +def load_entities_from_csv(csv_path: str) -> List[LinkerEntity]: + """ + Загружает сущности из CSV файла. + + Args: + csv_path: Путь к CSV файлу + + Returns: + Список сущностей + """ + df = pd.read_csv(csv_path) + entities = [] + + for _, row in df.iterrows(): + # Обработка метаданных + metadata_str = row.get("metadata", "{}") + try: + # Используем json.loads для парсинга JSON строки + metadata = ( + json.loads(metadata_str) + if pd.notna(metadata_str) and metadata_str + else {} + ) + except json.JSONDecodeError: # Ловим ошибку JSON + logging.warning( + f"Не удалось распарсить метаданные JSON: {metadata_str}. Используется пустой словарь." + ) + metadata = {} + + # Общие поля для всех типов сущностей + # Преобразуем ID обратно в UUID + entity_id = row['id'] + if isinstance(entity_id, str): + try: + entity_id = UUID(entity_id) + except ValueError: + logging.warning( + f"Неверный формат UUID для id: {entity_id}. Пропускаем сущность." + ) + continue + + common_args = { + "id": entity_id, + "name": row["name"] if pd.notna(row.get("name")) else "", + "text": row["text"] if pd.notna(row.get("text")) else "", + "metadata": metadata, + "in_search_text": ( + row["in_search_text"] if pd.notna(row.get('in_search_text')) else None + ), + "type": ( + row["type"] if pd.notna(row.get('type')) else LinkerEntity.__name__ + ), # Используем базовый тип, если не указан + "groupper": row["groupper"] if pd.notna(row.get("groupper")) else None, + } + + # Добавляем поля связи, если они есть, преобразуя в UUID + source_id_str = row.get("source_id") + target_id_str = row.get("target_id") + + if pd.notna(source_id_str): + try: + common_args["source_id"] = UUID(source_id_str) + except ValueError: + logging.warning( + f"Неверный формат UUID для source_id: {source_id_str}. Пропускаем поле." + ) + if pd.notna(target_id_str): + try: + common_args["target_id"] = UUID(target_id_str) + except ValueError: + logging.warning( + f"Неверный формат UUID для target_id: {target_id_str}. Пропускаем поле." + ) + + if pd.notna(row.get("number_in_relation")): + try: + common_args["number_in_relation"] = int(row["number_in_relation"]) + except ValueError: + logging.warning( + f"Неверный формат для number_in_relation: {row['number_in_relation']}. Пропускаем поле." + ) + + # Пытаемся десериализовать в конкретный тип, если он известен + entity_class = LinkerEntity._entity_classes.get( + common_args["type"], LinkerEntity + ) + try: + # Создаем экземпляр, передавая только те аргументы, которые ожидает класс + # (используя LinkerEntity._deserialize_to_me как пример, но нужно убедиться, + # что он принимает все нужные поля или имеет **kwargs) + # Пока создаем базовый LinkerEntity, т.к. подклассы могут требовать специфичные поля + # которых нет в CSV или в common_args + entity = LinkerEntity(**common_args) + # Если нужно строгое восстановление типов, потребуется более сложная логика + # с проверкой полей каждого подкласса + except TypeError as e: + logging.warning( + f"Ошибка создания экземпляра {entity_class.__name__} для ID {common_args['id']}: {e}. Создан базовый LinkerEntity." + ) + entity = LinkerEntity(**common_args) # Откат к базовому классу + + entities.append(entity) + + logging.info(f"Загружено {len(entities)} сущностей из {csv_path}") + return entities + + +def main() -> None: + """Основная функция скрипта.""" + setup_logging() + ensure_directories() + + # Пути к файлам + input_doc_path = "test_input/test2.docx" + output_csv_path = "test_output/test2.csv" + output_text_path = "test_output/test2.md" + + # Проверка наличия входного файла + if not os.path.exists(input_doc_path): + logging.error(f"Файл {input_doc_path} не найден!") + return + + logging.info(f"Парсинг документа {input_doc_path}") + + try: + # Шаг 1: Парсинг документа дважды, как если бы это были два разных документа + parser = UniversalParser() + document1 = parser.parse_by_path(input_doc_path) + document2 = parser.parse_by_path(input_doc_path) + + # Меняем название второго документа, чтобы отличить его + document2.name = document2.name + "_copy" if document2.name else "copy_doc" + + # Шаг 2: Чанкинг и извлечение таблиц с использованием EntitiesExtractor + all_entities = [] + + # Обработка первого документа + logging.info("Начало процесса деструктуризации первого документа") + # Инициализируем экстрактор без документа (используем дефолтные настройки или настроим позже) + extractor1 = EntitiesExtractor() + # Настройка чанкинга + extractor1.configure_chunking( + strategy_name="fixed_size", + strategy_params={ + "words_per_chunk": 50, + "overlap_words": 25, + "respect_sentence_boundaries": True, # Добавлено по запросу + }, + ) + # Настройка извлечения таблиц + extractor1.configure_tables_extraction(process_tables=True) + # Выполнение деструктуризации + entities1 = extractor1.extract(document1) + + # Находим ID документа 1 + doc1_entity = next((e for e in entities1 if e.type == DocumentAsEntity.__name__), None) + if not doc1_entity: + logging.error("Не удалось найти DocumentAsEntity для первого документа!") + return + doc1_id = doc1_entity.id + logging.info(f"ID первого документа: {doc1_id}") + + logging.info(f"Получено {len(entities1)} сущностей из первого документа") + all_entities.extend(entities1) + + # Обработка второго документа + logging.info("Начало процесса деструктуризации второго документа") + # Инициализируем экстрактор без документа + extractor2 = EntitiesExtractor() + # Настройка чанкинга (те же параметры) + extractor2.configure_chunking( + strategy_name="fixed_size", + strategy_params={ + "words_per_chunk": 50, + "overlap_words": 25, + "respect_sentence_boundaries": True, + }, + ) + # Настройка извлечения таблиц + extractor2.configure_tables_extraction(process_tables=True) + # Выполнение деструктуризации + entities2 = extractor2.extract(document2) + + # Находим ID документа 2 + doc2_entity = next((e for e in entities2 if e.type == DocumentAsEntity.__name__), None) + if not doc2_entity: + logging.error("Не удалось найти DocumentAsEntity для второго документа!") + return + doc2_id = doc2_entity.id + logging.info(f"ID второго документа: {doc2_id}") + + logging.info(f"Получено {len(entities2)} сущностей из второго документа") + all_entities.extend(entities2) + + logging.info( + f"Всего получено {len(all_entities)} сущностей из обоих документов" + ) + + # Шаг 3: Сохранение результатов чанкинга в CSV + save_entities_to_csv(all_entities, output_csv_path) + + # Шаг 4: Загрузка сущностей из CSV и выбор случайных чанков + loaded_entities = load_entities_from_csv(output_csv_path) + + # Шаг 5: Создание InjectionBuilder с InMemoryEntityRepository + # Сначала создаем репозиторий со ВСЕМИ загруженными сущностями + repository = InMemoryEntityRepository(loaded_entities) + builder = InjectionBuilder(repository=repository) + + # Фильтрация только чанков (сущностей с in_search_text) + # Убедимся, что работаем с десериализованными сущностями из репозитория + # (Репозиторий уже десериализует при инициализации, если нужно) + all_entities_from_repo = repository.get_entities_by_ids( + [e.id for e in loaded_entities] + ) + # Выбираем все сущности с in_search_text + selectable_entities = [ + e for e in all_entities_from_repo if e.in_search_text is not None + ] + + # Выбор случайных сущностей (от 20 до 30, но не более доступных) + num_entities_to_select = min(random.randint(100, 500), len(selectable_entities)) + if num_entities_to_select > 0: + selected_entities = random.sample( + selectable_entities, num_entities_to_select + ) + selected_ids = [entity.id for entity in selected_entities] + logging.info( + f"Выбрано {len(selected_ids)} случайных ID сущностей (с in_search_text) для сборки" + ) + + # Дополнительная статистика по документам + # Используем репозиторий для получения информации о владельцах + selected_entities_details = repository.get_entities_by_ids(selected_ids) + # Считаем на основе owner_id + doc1_entities_count = sum(1 for e in selected_entities_details if e.owner_id == doc1_id) + doc2_entities_count = sum(1 for e in selected_entities_details if e.owner_id == doc2_id) + other_owner_count = len(selected_entities_details) - (doc1_entities_count + doc2_entities_count) + + logging.info( + f"Из них {doc1_entities_count} принадлежат первому документу (ID: {doc1_id}), " + f"{doc2_entities_count} второму (ID: {doc2_id}) (на основе owner_id). " + f"{other_owner_count} имеют другого владельца (вероятно, таблицы/строки)." + ) + + else: + logging.warning("Не найдено сущностей с in_search_text для выбора.") + selected_ids = [] + selected_entities = [] # Добавлено для ясности + + # Шаг 6: Сборка текста из выбранных ID + logging.info("Начало сборки текста из выбранных ID") + # Передаем ID, а не сущности, т.к. builder сам их получит из репозитория + assembled_text = builder.build( + selected_ids, include_tables=True + ) # Включаем таблицы + + # Шаг 7: Сохранение результата в файл + with open(output_text_path, "w", encoding="utf-8") as f: + f.write(assembled_text.replace('\n', '\n\n')) + + logging.info(f"Результат сборки сохранен в {output_text_path}") + + except Exception as e: + logging.error(f"Произошла ошибка: {e}", exc_info=True) + + +if __name__ == "__main__": + main() diff --git a/lib/extractor/tests/chunking/test_chunking_registry.py b/lib/extractor/tests/chunking/test_chunking_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6c1490b4a0b72d74df4a706c4bf125e826c7e1 --- /dev/null +++ b/lib/extractor/tests/chunking/test_chunking_registry.py @@ -0,0 +1,122 @@ +""" +Unit-тесты для реестра стратегий чанкинга _ChunkingRegistry. +""" + +import pytest +from ntr_text_fragmentation.chunking import (ChunkingStrategy, + _ChunkingRegistry, + chunking_registry, + register_chunking_strategy) + + +# Фикстуры +class MockStrategy(ChunkingStrategy): + """Мок-стратегия для тестов.""" + + def chunk(self, document, doc_entity): + pass + + @classmethod + def dechunk(cls, repository, filtered_entities): + pass + + +@pytest.fixture +def clean_registry() -> _ChunkingRegistry: + """Фикстура для получения чистого экземпляра реестра.""" + # Создаем новый экземпляр, чтобы не влиять на глобальный chunking_registry + return _ChunkingRegistry() + + +@pytest.fixture +def populated_registry(clean_registry: _ChunkingRegistry) -> _ChunkingRegistry: + """Фикстура для реестра с зарегистрированными стратегиями.""" + clean_registry.register("mock1", MockStrategy) + clean_registry.register("mock2", MockStrategy) + return clean_registry + + +# Тесты +def test_register(clean_registry: _ChunkingRegistry): + """Тест регистрации стратегии.""" + assert len(clean_registry) == 0 + clean_registry.register("test_strategy", MockStrategy) + assert len(clean_registry) == 1 + assert "test_strategy" in clean_registry + assert clean_registry.get("test_strategy") is MockStrategy + +def test_get(populated_registry: _ChunkingRegistry): + """Тест получения стратегии по имени.""" + strategy = populated_registry.get("mock1") + assert strategy is MockStrategy + + # Тест получения несуществующей стратегии + with pytest.raises(KeyError): + populated_registry.get("nonexistent") + +def test_getitem(populated_registry: _ChunkingRegistry): + """Тест получения стратегии через __getitem__.""" + strategy = populated_registry["mock1"] + assert strategy is MockStrategy + + # Тест получения несуществующей стратегии + with pytest.raises(KeyError): + _ = populated_registry["nonexistent"] + +def test_get_names(populated_registry: _ChunkingRegistry): + """Тест получения списка имен зарегистрированных стратегий.""" + names = populated_registry.get_names() + assert isinstance(names, list) + assert len(names) == 2 + assert "mock1" in names + assert "mock2" in names + +def test_len(populated_registry: _ChunkingRegistry): + """Тест получения количества зарегистрированных стратегий.""" + assert len(populated_registry) == 2 + +def test_contains(populated_registry: _ChunkingRegistry): + """Тест проверки наличия стратегии.""" + assert "mock1" in populated_registry + assert "nonexistent" not in populated_registry + # Проверка по самому классу стратегии (экземпляры не хранятся) + assert MockStrategy in populated_registry + class AnotherStrategy(ChunkingStrategy): # type: ignore + def chunk(self, document, doc_entity): pass + @classmethod + def dechunk(cls, repository, filtered_entities): pass + assert AnotherStrategy not in populated_registry + +def test_decorator_register(): + """Тест декоратора register_chunking_strategy.""" + # Сохраняем текущее состояние глобального реестра + original_registry_state = chunking_registry._chunking_strategies.copy() + original_len = len(chunking_registry) + + @register_chunking_strategy("decorated_strategy") + class DecoratedStrategy(ChunkingStrategy): + def chunk(self, document, doc_entity): + pass + @classmethod + def dechunk(cls, repository, filtered_entities): + pass + + assert len(chunking_registry) == original_len + 1 + assert "decorated_strategy" in chunking_registry + assert chunking_registry.get("decorated_strategy") is DecoratedStrategy + + # Тест регистрации с именем по умолчанию (имя класса) + @register_chunking_strategy() + class DefaultNameStrategy(ChunkingStrategy): + def chunk(self, document, doc_entity): + pass + @classmethod + def dechunk(cls, repository, filtered_entities): + pass + + assert len(chunking_registry) == original_len + 2 + assert "DefaultNameStrategy" in chunking_registry + assert chunking_registry.get("DefaultNameStrategy") is DefaultNameStrategy + + # Восстанавливаем исходное состояние глобального реестра + chunking_registry._chunking_strategies = original_registry_state \ No newline at end of file diff --git a/lib/extractor/tests/chunking/test_fixed_size_chunking.py b/lib/extractor/tests/chunking/test_fixed_size_chunking.py index 36b939e72d7d539f7b21ef6fc7f14eb02a6fe3db..9eed9f330e5e5cc77285d44bbaceac8ef0aec3ea 100644 --- a/lib/extractor/tests/chunking/test_fixed_size_chunking.py +++ b/lib/extractor/tests/chunking/test_fixed_size_chunking.py @@ -1,334 +1,355 @@ -from uuid import UUID +""" +Unit-тесты для стратегии чанкинга FixedSizeChunkingStrategy. +""" + +import re +from uuid import uuid4 import pytest from ntr_fileparser import ParsedDocument, ParsedTextBlock - -from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import \ - FixedSizeChunkingStrategy -from ntr_text_fragmentation.models import DocumentAsEntity - - -class TestFixedSizeChunkingStrategy: - """Набор тестов для проверки стратегии чанкинга фиксированного размера.""" - - @pytest.fixture - def sample_document(self): - """Фикстура для создания тестового документа.""" - paragraphs = [ - ParsedTextBlock( - text="Это первый параграф тестового документа. Он содержит два предложения." - ), - ParsedTextBlock(text="Это второй параграф с одним предложением."), - ParsedTextBlock( - text="Третий параграф. Содержит еще два предложения. И оно короткое." - ), +from ntr_text_fragmentation.chunking.specific_strategies.fixed_size.fixed_size_chunk import \ + FixedSizeChunk +from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import ( + FIXED_SIZE, FixedSizeChunkingStrategy) +from lib.extractor.ntr_text_fragmentation.repositories.in_memory_repository import \ + InMemoryEntityRepository +from ntr_text_fragmentation.models import DocumentAsEntity, LinkerEntity + + +# --- Фикстуры --- +@pytest.fixture +def sample_text() -> str: + """Пример текста для тестов.""" + return ( + "Это первое предложение. Второе предложение немного длиннее. " + "Третье! Четвертое? И пятое.\n" + "Новый параграф начинается здесь. Он содержит еще одно предложение. " + "И заканчивается тут.\n" + "Последний параграф." + ) + + +@pytest.fixture +def parsed_document(sample_text: str) -> ParsedDocument: + """Фикстура для ParsedDocument.""" + paragraphs = [ + ParsedTextBlock(text=p) for p in sample_text.split('\n') if p + ] + return ParsedDocument(name="test_doc.txt", type="txt", paragraphs=paragraphs) + + +@pytest.fixture +def doc_entity() -> DocumentAsEntity: + """Фикстура для DocumentAsEntity.""" + return DocumentAsEntity(id=uuid4(), name="test_doc") + + +@pytest.fixture(scope="module") +def default_strategy() -> FixedSizeChunkingStrategy: + """Стратегия с настройками по умолчанию.""" + return FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=3) + + +@pytest.fixture(scope="module") +def no_sentence_boundary_strategy() -> FixedSizeChunkingStrategy: + """Стратегия без учета границ предложений.""" + return FixedSizeChunkingStrategy( + words_per_chunk=10, overlap_words=3, respect_sentence_boundaries=False + ) + + +@pytest.fixture +def extracted_words(parsed_document: ParsedDocument, default_strategy: FixedSizeChunkingStrategy) -> list[str]: + """Извлеченные слова из sample_text.""" + # Используем приватный метод для консистентности + return default_strategy._extract_words(parsed_document) + + +@pytest.fixture +def chunked_entities( + default_strategy: FixedSizeChunkingStrategy, + parsed_document: ParsedDocument, + doc_entity: DocumentAsEntity +) -> list[LinkerEntity]: + """Результат чанкинга документа стратегией по умолчанию.""" + return default_strategy.chunk(parsed_document, doc_entity) + + +# --- Тесты инициализации и вспомогательных методов --- +class TestFixedSizeChunkingStrategyInitAndHelpers: + """Тесты инициализации и приватных методов FixedSizeChunkingStrategy.""" + + def test_init_validation(self): + """Тест валидации параметров при инициализации.""" + with pytest.raises(ValueError, match="overlap_words должен быть меньше words_per_chunk"): + FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=10) + with pytest.raises(ValueError, match="overlap_words должен быть меньше words_per_chunk"): + FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=15) + with pytest.raises(ValueError, match="words_per_chunk должен быть > 0"): + FixedSizeChunkingStrategy(words_per_chunk=0, overlap_words=0) + with pytest.raises(ValueError, match="overlap_words >= 0"): + FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=-1) + + def test_extract_words(self, parsed_document: ParsedDocument, default_strategy: FixedSizeChunkingStrategy): + """Тест метода _extract_words.""" + words = default_strategy._extract_words(parsed_document) + assert words == [ + "Это", "первое", "предложение.", "Второе", "предложение", "немного", "длиннее.", + "Третье!", "Четвертое?", "И", "пятое.", "\n", + "Новый", "параграф", "начинается", "здесь.", "Он", "содержит", "еще", "одно", "предложение.", + "И", "заканчивается", "тут.", "\n", + "Последний", "параграф." ] - - return ParsedDocument( - name="test_document.txt", type="text", paragraphs=paragraphs + # Проверяем на пустом документе + empty_doc = ParsedDocument() + assert default_strategy._extract_words(empty_doc) == [] + + def test_prepare_chunk_text(self, extracted_words: list[str], default_strategy: FixedSizeChunkingStrategy): + """Тест метода _prepare_chunk_text.""" + # Первые 5 слов + text = default_strategy._prepare_chunk_text(extracted_words, 0, 5) + assert text == "Это первое предложение. Второе предложение" + # Слова с переносом строки + text = default_strategy._prepare_chunk_text(extracted_words, 10, 15) + assert text == "пятое.\nНовый параграф начинается" + # Пустой срез + text = default_strategy._prepare_chunk_text(extracted_words, 5, 5) + assert text == "" + + def test_find_sentence_boundary(self, default_strategy: FixedSizeChunkingStrategy): + """Тест метода _find_sentence_boundary.""" + # Ищем левую часть (после знака препинания) + text1 = "Some text. This part should be found." + assert default_strategy._find_sentence_boundary(text1, True) == "This part should be found." + text2 = "No punctuation here" + assert default_strategy._find_sentence_boundary(text2, True) == "" + text3 = "Ends with dot." + assert default_strategy._find_sentence_boundary(text3, True) == "" + text4 = "Multiple sentences. Second one? Third one!" + assert default_strategy._find_sentence_boundary(text4, True) == "" + + # Ищем правую часть (до знака препинания) + text5 = "Find this part. Rest of text." + assert default_strategy._find_sentence_boundary(text5, False) == "Find this part." + text6 = "No punctuation here" + assert default_strategy._find_sentence_boundary(text6, False) == "No punctuation here" + text7 = "Ends with dot." + assert default_strategy._find_sentence_boundary(text7, False) == "Ends with dot." + text8 = "Multiple sentences. Second one? Third one!" + assert default_strategy._find_sentence_boundary(text8, False) == "Multiple sentences." + text9 = "" + assert default_strategy._find_sentence_boundary(text9, False) == "" + + def test_calculate_boundaries(self, extracted_words: list[str], default_strategy: FixedSizeChunkingStrategy): + """Тест метода _calculate_boundaries (с respect_sentence_boundaries=True).""" + # Пример для первого чанка (индексы 0-10, overlap=3) + # overlap_left_start=0, chunk_start=0, chunk_end=10, overlap_right_end=13 + left_part, right_part, left_overlap, right_overlap = default_strategy._calculate_boundaries( + extracted_words, 0, 10, len(extracted_words) ) - - @pytest.fixture - def doc_entity(self): - """Фикстура для создания сущности документа.""" - return DocumentAsEntity( - id=UUID('12345678-1234-5678-1234-567812345678'), - name="Тестовый документ", - text="", - metadata={"type": "text"}, - type="Document", + assert left_overlap == "" # Нет левого оверлапа + assert left_part == "" # Т.к. левый оверлап пустой + # Правый оверлап: слова с 10 по 13 (исключая 13) -> "пятое. \n Новый" + assert right_overlap == "пятое.\nНовый" + # Правая часть предложения: ищем до первого знака в right_overlap -> "пятое." + assert right_part == "пятое." + + # Пример для второго чанка (индексы 7-17, step=7) + # overlap_left_start=4, chunk_start=7, chunk_end=17, overlap_right_end=20 + left_part, right_part, left_overlap, right_overlap = default_strategy._calculate_boundaries( + extracted_words, 7, 17, len(extracted_words) + ) + # Левый оверлап: слова 4-7 -> "предложение немного длиннее." + assert left_overlap == "предложение немного длиннее." + # Левая часть предложения: ищем после последнего знака в left_overlap -> "" + assert left_part == "" + # Правый оверлап: слова 17-20 -> "содержит еще одно" + assert right_overlap == "содержит еще одно" + # Правая часть предложения: ищем до первого знака -> нет знаков, берем всё -> "содержит еще одно" + assert right_part == "содержит еще одно" + + def test_calculate_boundaries_no_respect(self, extracted_words: list[str], no_sentence_boundary_strategy: FixedSizeChunkingStrategy): + """Тест _calculate_boundaries с respect_sentence_boundaries=False.""" + left_part, right_part, left_overlap, right_overlap = no_sentence_boundary_strategy._calculate_boundaries( + extracted_words, 7, 17, len(extracted_words) ) + assert left_overlap == "предложение немного длиннее." + assert right_overlap == "содержит еще одно" + # left/right_sentence_part должны быть пустыми, т.к. respect_sentence_boundaries=False + assert left_part == "" + assert right_part == "" + + def test_clean_final_text(self, default_strategy: FixedSizeChunkingStrategy): + """Тест метода _clean_final_text.""" + text = " Too many spaces. \n\n\nMultiple newlines. space before punct . ( parentheses ) \n space after newline " + cleaned = FixedSizeChunkingStrategy._clean_final_text(text) + expected = "Too many spaces.\n\nMultiple newlines. space before punct.(parentheses)\nspace after newline" + assert cleaned == expected + + +# --- Тесты метода chunk --- +class TestFixedSizeChunkingStrategyChunk: + """Тесты основного метода chunk.""" + + def test_chunk_defaults(self, chunked_entities: list[LinkerEntity], doc_entity: DocumentAsEntity): + """Тест чанкинга с настройками по умолчанию.""" + # Ожидаемое кол-во слов = 29. chunk=10, overlap=3, step=7. + # Чанки: 0-10, 7-17, 14-24, 21-29(обрезано), 28-29(остаток) + assert len(chunked_entities) == 5 + + # Проверка первого чанка + chunk0 = chunked_entities[0] + assert isinstance(chunk0, FixedSizeChunk) + assert chunk0.owner_id == doc_entity.id + assert chunk0.number_in_relation == 0 + assert chunk0.groupper == "chunk" + assert chunk0.text == "Это первое предложение. Второе предложение немного длиннее." # step=7 слов + assert chunk0.in_search_text == "Это первое предложение. Второе предложение немного длиннее. Третье! Четвертое? И" # chunk=10 слов + assert chunk0.token_count == 10 + assert chunk0.left_sentence_part == "" + assert chunk0.right_sentence_part == "пятое." # Из правого overlap "пятое.\nНовый" + assert chunk0.overlap_left == "" + assert chunk0.overlap_right == "пятое.\nНовый" + + # Проверка второго чанка + chunk1 = chunked_entities[1] + assert isinstance(chunk1, FixedSizeChunk) + assert chunk1.number_in_relation == 1 + assert chunk1.text == "Третье! Четвертое? И пятое.\nНовый параграф начинается" # Индексы 7-14 (step=7) + assert chunk1.in_search_text == "Третье! Четвертое? И пятое.\nНовый параграф начинается здесь. Он" # Индексы 7-17 (chunk=10) + assert chunk1.token_count == 10 + assert chunk1.left_sentence_part == "" # Из левого overlap "предложение немного длиннее." + assert chunk1.right_sentence_part == "содержит еще одно" # Из правого overlap "содержит еще одно предложение." + assert chunk1.overlap_left == "предложение немного длиннее." + assert chunk1.overlap_right == "содержит еще одно предложение." + + # Проверка последнего чанка (остаток) + chunk4 = chunked_entities[4] + assert isinstance(chunk4, FixedSizeChunk) + assert chunk4.number_in_relation == 4 + assert chunk4.text == "параграф." # Индекс 28 (step=1) + assert chunk4.in_search_text == "параграф." # Индекс 28 (chunk=1, остаток) + assert chunk4.token_count == 1 + assert chunk4.left_sentence_part == "" # Из левого overlap "\nПоследний" + assert chunk4.right_sentence_part == "" # Правый overlap пустой + assert chunk4.overlap_left == "\nПоследний" + assert chunk4.overlap_right == "" + + def test_chunk_no_sentence_boundary( + self, + no_sentence_boundary_strategy: FixedSizeChunkingStrategy, + parsed_document: ParsedDocument, + doc_entity: DocumentAsEntity + ): + """Тест чанкинга без учета границ предложений.""" + chunks = no_sentence_boundary_strategy.chunk(parsed_document, doc_entity) + assert len(chunks) == 5 + chunk0 = chunks[0] + assert isinstance(chunk0, FixedSizeChunk) + # left/right_sentence_part должны быть пустыми + assert chunk0.left_sentence_part == "" + assert chunk0.right_sentence_part == "" + assert chunk0.overlap_right == "пятое.\nНовый" # overlap сам по себе остается + + chunk1 = chunks[1] + assert isinstance(chunk1, FixedSizeChunk) + assert chunk1.left_sentence_part == "" + assert chunk1.right_sentence_part == "" + assert chunk1.overlap_left == "предложение немного длиннее." + assert chunk1.overlap_right == "содержит еще одно предложение." + + def test_chunk_empty_document(self, default_strategy: FixedSizeChunkingStrategy, doc_entity: DocumentAsEntity): + """Тест чанкинга пустого документа.""" + empty_doc = ParsedDocument() + chunks = default_strategy.chunk(empty_doc, doc_entity) + assert chunks == [] + + def test_chunk_short_document(self, default_strategy: FixedSizeChunkingStrategy, doc_entity: DocumentAsEntity): + """Тест чанкинга очень короткого документа.""" + short_doc = ParsedDocument(paragraphs=[ParsedTextBlock(text="One two three.")]) + chunks = default_strategy.chunk(short_doc, doc_entity) + assert len(chunks) == 1 + chunk0 = chunks[0] + assert isinstance(chunk0, FixedSizeChunk) + assert chunk0.text == "One two three." + assert chunk0.in_search_text == "One two three." + assert chunk0.token_count == 3 + assert chunk0.left_sentence_part == "" + assert chunk0.right_sentence_part == "" # Нет правого оверлапа + + +# --- Тесты метода dechunk --- +class TestFixedSizeChunkingStrategyDechunk: + """Тесты classmethod dechunk.""" @pytest.fixture - def large_document(self): - """Фикстура для создания большого тестового документа.""" - paragraphs = [ - ParsedTextBlock( - text="Это первый параграф большого документа. Он содержит несколько предложений разной длины." - ), - ParsedTextBlock( - text="Второй параграф начинается с короткого предложения. А затем идет длинное предложение, которое содержит много слов и должно быть разбито на несколько чанков, потому что оно не помещается в один чанк стандартного размера." - ), - ParsedTextBlock( - text="Третий параграф содержит несколько предложений. Каждое предложение имеет свою структуру. И все они должны корректно обрабатываться." - ), - ParsedTextBlock( - text="Четвертый параграф начинается с длинного предложения, которое также должно быть разбито на несколько чанков, так как оно содержит много слов и не помещается в один чанк стандартного размера. А затем идет короткое предложение." - ), - ParsedTextBlock( - text="Пятый параграф. Содержит разные предложения. С разной пунктуацией. И разной структурой." - ), - ParsedTextBlock( - text="Шестой параграф начинается с короткого предложения. Затем идет длинное предложение, которое должно быть разбито на несколько чанков, потому что оно содержит много слов и не помещается в один чанк стандартного размера. И заканчивается коротким предложением." - ), - ParsedTextBlock( - text="Седьмой параграф содержит несколько предложений разной длины. Каждое предложение имеет свою структуру. И все они должны корректно обрабатываться." - ), - ParsedTextBlock( - text="Восьмой параграф начинается с длинного предложения, которое также должно быть разбито на несколько чанков, так как оно содержит много слов и не помещается в один чанк стандартного размера. А затем идет короткое предложение." - ), - ParsedTextBlock( - text="Девятый параграф. Содержит разные предложения. С разной пунктуацией. И разной структурой." - ), - ParsedTextBlock( - text="Десятый параграф начинается с короткого предложения. Затем идет длинное предложение, которое должно быть разбито на несколько чанков, потому что оно содержит много слов и не помещается в один чанк стандартного размера. И заканчивается коротким предложением." - ), + def mock_repository(self, chunked_entities: list[LinkerEntity]) -> InMemoryEntityRepository: + """Мок-репозиторий с чанками.""" + # В dechunk репозиторий пока не используется, но передадим его + return InMemoryEntityRepository(chunked_entities) + + def test_dechunk_full_sequence(self, mock_repository: InMemoryEntityRepository, chunked_entities: list[LinkerEntity]): + """Тест сборки полной последовательности чанков.""" + # Передаем все чанки + assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, chunked_entities) + + # Ожидаем, что текст будет собран с использованием left/right_sentence_part + # chunk0.left + chunk0.text + chunk1.text + ... + chunkN.text + chunkN.right + # chunk0.left_sentence_part = "" + # chunk4.right_sentence_part = "" + expected_parts = [ + chunked_entities[0].text, # "Это первое предложение. Второе предложение немного длиннее." + chunked_entities[1].text, # "Третье! Четвертое? И пятое.\nНовый параграф начинается" + chunked_entities[2].text, # "здесь. Он содержит еще одно предложение." + chunked_entities[3].text, # "И заканчивается тут.\nПоследний" + chunked_entities[4].text, # "параграф." ] - - return ParsedDocument( - name="large_test_document.txt", type="text", paragraphs=paragraphs - ) - - def test_basic_chunking_and_dechunking(self, sample_document, doc_entity): - """Тест базового сценария нарезки и сборки документа.""" - strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2) - - # Разбиваем документ - entities = strategy.chunk(sample_document, doc_entity) - - # Выделяем только чанки - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Собираем документ обратно - result_text = strategy.dechunk(chunks) - - # Проверяем, что текст не пустой - assert result_text - - # Проверяем, что все слова из оригинального документа присутствуют в результате - original_text = " ".join([p.text for p in sample_document.paragraphs]) - original_words = set(original_text.split()) - result_words = set(result_text.split()) - - # Все оригинальные слова должны быть в результате - assert original_words.issubset(result_words) - - # Проверяем, что длина результата примерно равна длине исходного текста - assert abs(len(result_text.split()) - len(original_text.split())) < 5 - - def test_chunking_with_different_sentence_lengths(self, doc_entity): - """Тест нарезки документа с предложениями разной длины.""" - # Создаем документ с предложениями разной длины - text = ( - "Короткое предложение. " - "Это предложение средней длины с несколькими словами. " - "А это очень длинное предложение, которое содержит много слов и должно быть разбито на несколько чанков, " - "потому что оно не помещается в один чанк стандартного размера. " - "И снова короткое." - ) - doc = ParsedDocument( - name="test_document.txt", - type="text", - paragraphs=[ParsedTextBlock(text=text)], - ) - - strategy = FixedSizeChunkingStrategy(words_per_chunk=15, overlap_words=5) - - # Разбиваем документ - entities = strategy.chunk(doc, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Проверяем, что длинное предложение было разбито на несколько чанков - assert len(chunks) > 1 - - # Собираем документ обратно - result_text = strategy.dechunk(chunks) - - # Проверяем корректность сборки - original_words = set(text.split()) - result_words = set(result_text.split()) - assert original_words.issubset(result_words) - - # Проверяем, что все предложения сохранились - original_sentences = set(s.strip() for s in text.split('.')) - result_sentences = set(s.strip() for s in result_text.split('.')) - assert original_sentences.issubset(result_sentences) - - def test_empty_document(self, doc_entity): - """Тест обработки пустого документа.""" - doc = ParsedDocument(name="empty.txt", type="text", paragraphs=[]) - - strategy = FixedSizeChunkingStrategy() - - # Разбиваем документ - entities = strategy.chunk(doc, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Проверяем, что чанков нет - assert len(chunks) == 0 - - # Проверяем, что сборка пустого документа возвращает пустую строку - result_text = strategy.dechunk(chunks) - assert result_text == "" - - def test_special_characters_and_punctuation(self, doc_entity): - """Тест обработки текста со специальными символами и пунктуацией.""" - text = ( - "Текст с разными символами: !@#$%^&*(). " - "Скобки (внутри) и [квадратные]. " - "Кавычки «елочки» и \"прямые\". " - "Тире — и дефис-. " - "Многоточие... и запятые, в разных местах." - ) - doc = ParsedDocument( - name="test_document.txt", - type="text", - paragraphs=[ParsedTextBlock(text=text)], - ) - - strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2) - - # Разбиваем документ - entities = strategy.chunk(doc, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Собираем документ обратно - result_text = strategy.dechunk(chunks) - - # Проверяем, что все специальные символы сохранились - special_chars = set('!@#$%^&*()[]«»"—...') - result_chars = set(result_text) - assert special_chars.issubset(result_chars) - - # Проверяем, что текст совпадает с оригиналом - assert result_text == text - - def test_large_document_chunking(self, large_document, doc_entity): - """Тест нарезки и сборки большого документа с множеством параграфов.""" - strategy = FixedSizeChunkingStrategy(words_per_chunk=20, overlap_words=5) - - # Разбиваем документ - entities = strategy.chunk(large_document, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Проверяем, что документ был разбит на несколько чанков - assert len(chunks) > 1 - - # Собираем документ обратно - result_text = strategy.dechunk(chunks) - - # Получаем оригинальный текст - original_paragraphs = [p.text for p in large_document.paragraphs] - - # Проверяем, что все параграфы сохранились - result_paragraphs = result_text.split('\n') - assert len(result_paragraphs) == len(original_paragraphs) - - # Проверяем, что каждый параграф совпадает с оригиналом - for orig, res in zip(original_paragraphs, result_paragraphs): - assert orig.strip() == res.strip() - - def test_exact_text_comparison(self, sample_document, doc_entity): - """Тест точного сравнения текстов после нарезки и сборки.""" - strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2) - - # Разбиваем документ - entities = strategy.chunk(sample_document, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Собираем документ обратно - result_text = strategy.dechunk(chunks) - - # Получаем оригинальный текст по параграфам - original_paragraphs = [p.text for p in sample_document.paragraphs] - - # Проверяем, что все параграфы сохранились - result_paragraphs = result_text.split('\n') - assert len(result_paragraphs) == len(original_paragraphs) - - # Проверяем, что каждый параграф совпадает с оригиналом - for orig, res in zip(original_paragraphs, result_paragraphs): - assert orig.strip() == res.strip() - - def test_non_sequential_chunks(self, large_document, doc_entity): - """Тест обработки непоследовательных чанков с вставкой многоточий.""" - strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2) - - # Разбиваем документ - entities = strategy.chunk(large_document, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Проверяем, что получили достаточное количество чанков - assert len(chunks) >= 5, "Для теста нужно не менее 5 чанков" - - # Отсортируем чанки по индексу - sorted_chunks = sorted(chunks, key=lambda c: c.chunk_index or 0) - - # Выберем несколько несмежных чанков (например, 0, 1, 3, 4, 7) - selected_indices = [0, 1, 3, 4, 7] - selected_chunks = [sorted_chunks[i] for i in selected_indices if i < len(sorted_chunks)] - - # Перемешаем чанки, чтобы убедиться, что сортировка работает - import random - random.shuffle(selected_chunks) - - # Собираем документ из несмежных чанков - result_text = strategy.dechunk(selected_chunks) - - # Проверяем наличие многоточий между непоследовательными чанками - assert "\n\n...\n\n" in result_text, "В тексте должно быть многоточие между непоследовательными чанками" - - # Подсчитываем количество многоточий, должно быть 2 группы разрыва (между 1-3 и 4-7) - ellipsis_count = result_text.count("\n\n...\n\n") - assert ellipsis_count == 2, f"Ожидалось 2 многоточия, получено {ellipsis_count}" - - # Проверяем, что чанки с индексами 0 и 1 идут без многоточия между ними - # Для этого находим текст первого чанка и проверяем, что после него нет многоточия - first_chunk_text = sorted_chunks[0].text - second_chunk_text = sorted_chunks[1].text - - # Проверяем, что текст первого чанка не заканчивается многоточием - first_chunk_position = result_text.find(first_chunk_text) - second_chunk_position = result_text.find(second_chunk_text, first_chunk_position) - - # Текст между первым и вторым чанком не должен содержать многоточие - text_between = result_text[first_chunk_position + len(first_chunk_text):second_chunk_position] - assert "\n\n...\n\n" not in text_between, "Не должно быть многоточия между последовательными чанками" - - def test_overlap_addition_in_dechunk(self, large_document, doc_entity): - """Тест добавления нахлеста при сборке чанков.""" - strategy = FixedSizeChunkingStrategy(words_per_chunk=15, overlap_words=5) - - # Разбиваем документ - entities = strategy.chunk(large_document, doc_entity) - chunks = [e for e in entities if e.type == "FixedSizeChunk"] - - # Отбираем несколько чанков с непустыми overlap_left и overlap_right - overlapping_chunks = [] - for chunk in chunks: - if hasattr(chunk, 'overlap_left') and hasattr(chunk, 'overlap_right'): - if chunk.overlap_left and chunk.overlap_right: - overlapping_chunks.append(chunk) - if len(overlapping_chunks) >= 3: - break - - # Проверяем, что нашли подходящие чанки - assert len(overlapping_chunks) > 0, "Не найдены чанки с нахлестом" - - # Собираем чанки - result_text = strategy.dechunk(overlapping_chunks) - - # Проверяем, что нахлесты включены в результат - for chunk in overlapping_chunks: - if hasattr(chunk, 'overlap_left') and chunk.overlap_left: - # Хотя бы часть нахлеста должна присутствовать в тексте - # Берем первые три слова нахлеста для проверки - overlap_words = chunk.overlap_left.split()[:3] - if overlap_words: - overlap_sample = " ".join(overlap_words) - assert overlap_sample in result_text, f"Левый нахлест не найден в результате: {overlap_sample}" - - if hasattr(chunk, 'overlap_right') and chunk.overlap_right: - # Аналогично проверяем правый нахлест - overlap_words = chunk.overlap_right.split()[:3] - if overlap_words: - overlap_sample = " ".join(overlap_words) - assert overlap_sample in result_text, f"Правый нахлест не найден в результате: {overlap_sample}" - - # Проверяем обработку предложений - for chunk in overlapping_chunks: - if hasattr(chunk, 'left_sentence_part') and chunk.left_sentence_part: - assert chunk.left_sentence_part in result_text, "Левая часть предложения не найдена в результате" - - if hasattr(chunk, 'right_sentence_part') and chunk.right_sentence_part: - assert chunk.right_sentence_part in result_text, "Правая часть предложения не найдена в результате" \ No newline at end of file + expected_raw = " ".join(expected_parts) + # Применяем очистку, как в методе _build_sequenced_chunks + expected_cleaned = FixedSizeChunkingStrategy._clean_final_text(expected_raw) + + # Сравниваем с очищенным результатом + assert assembled_text == expected_cleaned + # Проверим, что переносы строк сохранились где надо + assert "пятое.\nНовый" in assembled_text + assert "тут.\nПоследний" in assembled_text + + def test_dechunk_with_gap(self, mock_repository: InMemoryEntityRepository, chunked_entities: list[LinkerEntity]): + """Тест сборки с пропуском чанка.""" + # Удаляем chunk 1 (индекс 1) + filtered_chunks = [chunked_entities[0]] + chunked_entities[2:] + assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, filtered_chunks) + + # Группа 1: chunk 0 + group1_parts = [ + chunked_entities[0].left_sentence_part, # "" + chunked_entities[0].text, # "Это первое предложение. Второе предложение немного длиннее." + chunked_entities[0].right_sentence_part # "пятое." + ] + group1_text = FixedSizeChunkingStrategy._clean_final_text(" ".join(filter(None, group1_parts))) + + # Группа 2: chunks 2, 3, 4 + group2_parts = [ + chunked_entities[2].left_sentence_part, # "здесь. Он" + chunked_entities[2].text, # "здесь. Он содержит еще одно предложение." + chunked_entities[3].text, # "И заканчивается тут.\nПоследний" + chunked_entities[4].text, # "параграф." + chunked_entities[4].right_sentence_part # "" + ] + group2_text = FixedSizeChunkingStrategy._clean_final_text(" ".join(filter(None, group2_parts))) + + expected_text = f"{group1_text}\n...\n{group2_text}" + assert assembled_text == expected_text + + def test_dechunk_not_fixed_size_chunk(self, mock_repository: InMemoryEntityRepository, doc_entity: DocumentAsEntity): + """Тест сборки, если передан не FixedSizeChunk.""" + # Создаем обычный LinkerEntity вместо FixedSizeChunk + non_fsc = LinkerEntity(id=uuid4(), name="not a chunk", text="some text", target_id=doc_entity.id, number_in_relation=0) + assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, [non_fsc]) + # Ожидаем просто текст из .text + assert assembled_text == "some text" + + def test_dechunk_empty_list(self, mock_repository: InMemoryEntityRepository): + """Тест сборки пустого списка чанков.""" + assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, []) + assert assembled_text == "" \ No newline at end of file diff --git a/lib/extractor/tests/conftest.py b/lib/extractor/tests/conftest.py index 52aa6b42907ce102ac3b186f12c9df225fd59b0c..b5f5112adcf258827c7660b41562de6877474b4e 100644 --- a/lib/extractor/tests/conftest.py +++ b/lib/extractor/tests/conftest.py @@ -9,47 +9,3 @@ import pytest from ntr_text_fragmentation.models.linker_entity import LinkerEntity from tests.custom_entity import CustomEntity # Импортируем наш кастомный класс - -@pytest.fixture -def sample_entity(): - """ - Фикстура, возвращающая экземпляр LinkerEntity с предустановленными значениями. - """ - return LinkerEntity( - id=UUID('12345678-1234-5678-1234-567812345678'), - name="Тестовая сущность", - text="Текст тестовой сущности", - metadata={"test_key": "test_value"} - ) - - -@pytest.fixture -def sample_custom_entity(): - """ - Фикстура, возвращающая экземпляр CustomEntity с предустановленными значениями. - """ - return CustomEntity( - id=UUID('87654321-8765-4321-8765-432187654321'), - name="Тестовый кастомный объект", - text="Текст кастомного объекта", - metadata={"original_key": "original_value"}, - in_search_text="Текст для поиска кастомного объекта", - custom_field1="custom_value", - custom_field2=42 - ) - - -@pytest.fixture -def sample_link(): - """ - Фикстура, возвращающая экземпляр LinkerEntity с предустановленными значениями связи. - """ - return LinkerEntity( - id=UUID('98765432-9876-5432-9876-543298765432'), - name="Тестовая связь", - text="Текст тестовой связи", - metadata={"test_key": "test_value"}, - source_id=UUID('12345678-1234-5678-1234-567812345678'), - target_id=UUID('87654321-8765-4321-8765-432187654321'), - type="Link" - ) \ No newline at end of file diff --git a/lib/extractor/tests/core/test_extractor.py b/lib/extractor/tests/core/test_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a7c6b5aa34e4c59d2b4541d077ee249ac68430 --- /dev/null +++ b/lib/extractor/tests/core/test_extractor.py @@ -0,0 +1,265 @@ +""" +Unit-тесты для EntitiesExtractor. +""" + +from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 + +import pytest +from ntr_fileparser import ParsedDocument, ParsedTextBlock +# Импортируем конкретную стратегию и процессор для мокирования +from ntr_text_fragmentation.additors.tables_processor import TablesProcessor +from ntr_text_fragmentation.chunking import chunking_registry +from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import ( + FIXED_SIZE, FixedSizeChunkingStrategy) +from ntr_text_fragmentation.core.extractor import EntitiesExtractor +from ntr_text_fragmentation.models import DocumentAsEntity, LinkerEntity + + +# --- Фикстуры --- +@pytest.fixture +def mock_document() -> ParsedDocument: + """Мок ParsedDocument.""" + return ParsedDocument( + name="mock_doc.pdf", + type="pdf", + paragraphs=[ParsedTextBlock(text="Paragraph 1."), ParsedTextBlock(text="Paragraph 2.")], + # Можно добавить таблицы и т.д., если нужно тестировать их обработку + ) + +@pytest.fixture +def mock_chunk() -> LinkerEntity: + """Мок сущности чанка.""" + return LinkerEntity(id=uuid4(), name="mock_chunk", text="chunk text", type="Chunk") + +@pytest.fixture +def mock_table_entity() -> LinkerEntity: + """Мок сущности таблицы.""" + return LinkerEntity(id=uuid4(), name="mock_table", text="table text", type="TableEntity") + +@pytest.fixture +def mock_strategy_instance(mock_chunk: LinkerEntity) -> MagicMock: + """Мок экземпляра стратегии чанкинга.""" + instance = MagicMock(spec=FixedSizeChunkingStrategy) + # Мокируем метод chunk, чтобы он возвращал предопределенный чанк + instance.chunk.return_value = [mock_chunk] + return instance + +@pytest.fixture +def mock_tables_processor_instance(mock_table_entity: LinkerEntity) -> MagicMock: + """Мок экземпляра процессора таблиц.""" + instance = MagicMock(spec=TablesProcessor) + # Мокируем метод extract, чтобы он возвращал предопределенную сущность таблицы + instance.extract.return_value = [mock_table_entity] + return instance + +@pytest.fixture(autouse=True) +def mock_registry_and_processors( + mock_strategy_instance: MagicMock, + mock_tables_processor_instance: MagicMock +): + """Мокирует реестр стратегий и конструкторы процессоров.""" + # Мокируем реестр, чтобы он возвращал наш мок-класс стратегии + mock_strategy_class = MagicMock(return_value=mock_strategy_instance) + with patch.dict(chunking_registry._chunking_strategies, {FIXED_SIZE: mock_strategy_class}, clear=True): + # Мокируем конструктор TablesProcessor, чтобы он возвращал наш мок-экземпляр + with patch('ntr_text_fragmentation.core.extractor.TablesProcessor', return_value=mock_tables_processor_instance): + yield + + +# --- Тесты --- # +class TestEntitiesExtractor: + """Тесты для EntitiesExtractor.""" + + def test_init_defaults(self, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock): + """Тест инициализации с настройками по умолчанию.""" + extractor = EntitiesExtractor() + # По умолчанию используется FIXED_SIZE стратегия и process_tables=True + assert extractor.strategy is mock_strategy_instance + assert extractor._strategy_name == FIXED_SIZE + assert extractor.tables_processor is mock_tables_processor_instance + + def test_init_custom_strategy(self, mock_strategy_instance: MagicMock): + """Тест инициализации с указанием стратегии и параметров.""" + strategy_params = {'words_per_chunk': 100, 'overlap_words': 10} + # Ожидаем, что конструктор мок-стратегии будет вызван с этими параметрами + extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, strategy_params=strategy_params, process_tables=False) + + # Проверяем, что конструктор мок-стратегии был вызван с правильными параметрами + mock_strategy_class = chunking_registry[FIXED_SIZE] + mock_strategy_class.assert_called_once_with(**strategy_params) + assert extractor.strategy is mock_strategy_instance + assert extractor._strategy_name == FIXED_SIZE + assert extractor.tables_processor is None # process_tables=False + + def test_init_invalid_strategy_name(self): + """Тест инициализации с невалидным именем стратегии.""" + with pytest.raises(ValueError, match="Неизвестная стратегия: invalid_strategy"): + EntitiesExtractor(strategy_name="invalid_strategy") + + def test_configure_chunking(self, mock_strategy_instance: MagicMock): + """Тест переконфигурации стратегии чанкинга.""" + extractor = EntitiesExtractor(process_tables=False) # Изначально без стратегии + assert extractor.strategy is None + + params = {'words_per_chunk': 20} + extractor.configure_chunking(strategy_name=FIXED_SIZE, strategy_params=params) + + mock_strategy_class = chunking_registry[FIXED_SIZE] + mock_strategy_class.assert_called_once_with(**params) + assert extractor.strategy is mock_strategy_instance + assert extractor._strategy_name == FIXED_SIZE + + def test_configure_chunking_invalid_params(self): + """Тест ошибки при неверных параметрах для стратегии.""" + # Настроим мок-класс стратегии, чтобы он вызывал TypeError при инициализации + mock_strategy_class_error = MagicMock(side_effect=TypeError("Invalid param")) + with patch.dict(chunking_registry._chunking_strategies, {FIXED_SIZE: mock_strategy_class_error}): + extractor = EntitiesExtractor(process_tables=False) + with pytest.raises(ValueError, match="Ошибка при попытке инициализировать стратегию"): + extractor.configure_chunking(strategy_name=FIXED_SIZE, strategy_params={"invalid": 1}) + + def test_configure_tables_extraction(self, mock_tables_processor_instance: MagicMock): + """Тест переконфигурации извлечения таблиц.""" + extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=False) # Изначально без таблиц + assert extractor.tables_processor is None + + extractor.configure_tables_extraction(process_tables=True) + assert extractor.tables_processor is mock_tables_processor_instance + + extractor.configure_tables_extraction(process_tables=False) + # Экземпляр процессора создается, но не используется, если process_tables=False в destructure + # Однако configure_tables_extraction устанавливает его. Проверим это. + # Ожидаем, что конструктор TablesProcessor будет вызван при configure_tables_extraction(True) + # и при configure_tables_extraction(False) он не обнулится? + # Судя по коду configure_tables_extraction, он всегда создает новый TablesProcessor. + # Давайте уточним тест, что процессор создается. + with patch('ntr_text_fragmentation.core.extractor.TablesProcessor') as mock_constructor: + extractor.configure_tables_extraction(process_tables=True) + mock_constructor.assert_called_once() + assert extractor.tables_processor is not None + + mock_constructor.reset_mock() + extractor.configure_tables_extraction(process_tables=False) + # Повторный вызов конструктора! Это может быть неэффективно, но тест должен отражать код. + mock_constructor.assert_called_once() + assert extractor.tables_processor is not None # Экземпляр остается + + def test_configure_chaining(self, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock): + """Тест цепочки вызовов configure.""" + extractor = EntitiesExtractor(strategy_name=None, process_tables=None) # Полностью пустой + assert extractor.strategy is None + assert extractor.tables_processor is None + + returned_extractor = extractor.configure(strategy_name=FIXED_SIZE, process_tables=True) + + assert returned_extractor is extractor # Должен возвращать себя + assert extractor.strategy is mock_strategy_instance + assert extractor.tables_processor is mock_tables_processor_instance + + # Переконфигурируем только стратегию + new_params = {"words_per_chunk": 50} + extractor.configure(strategy_name=FIXED_SIZE, strategy_params=new_params) + mock_strategy_class = chunking_registry[FIXED_SIZE] + # Конструктор стратегии вызывается повторно + mock_strategy_class.assert_called_with(**new_params) + assert extractor.tables_processor is mock_tables_processor_instance # Процессор таблиц не изменился + + # Переконфигурируем только таблицы + extractor.configure(process_tables=False) + # tables_processor создается, но использоваться не будет + assert extractor.tables_processor is not None + assert extractor.strategy is mock_strategy_instance # Стратегия не изменилась + + def test_destructure_calls_chunk_and_extract(self, mock_document: ParsedDocument, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock, mock_chunk: LinkerEntity, mock_table_entity: LinkerEntity): + """Тест, что destructure вызывает chunk и extract.""" + extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=True) + entities = extractor.destructure(mock_document) + + # Проверяем вызов chunk стратегии + mock_strategy_instance.chunk.assert_called_once() + # Проверяем аргументы вызова chunk + call_args, _ = mock_strategy_instance.chunk.call_args + assert call_args[0] is mock_document + assert isinstance(call_args[1], DocumentAsEntity) + assert call_args[1].name == "mock_doc.pdf" # Имя документа передано + # Проверяем, что стратегия записалась в DocumentAsEntity + assert call_args[1].chunking_strategy_ref == FIXED_SIZE + + # Проверяем вызов extract процессора таблиц + mock_tables_processor_instance.extract.assert_called_once() + # Проверяем аргументы вызова extract + call_args_table, _ = mock_tables_processor_instance.extract.call_args + assert call_args_table[0] is mock_document + assert isinstance(call_args_table[1], DocumentAsEntity) + assert call_args_table[1].name == "mock_doc.pdf" + + # Проверяем результат: должен содержать DocumentAsEntity, результат chunk, результат extract + assert len(entities) == 3 + entity_types = {type(e) for e in entities} + # Все сущности сериализованы в LinkerEntity + assert entity_types == {LinkerEntity} + + entity_ids = {e.id for e in entities} + # Проверяем наличие ID моков (после сериализации ID сохраняются) + assert mock_chunk.id in entity_ids + assert mock_table_entity.id in entity_ids + # Проверяем наличие ID документа (он создается внутри) + doc_entity_id = next(e.id for e in entities if e.type == "DocumentAsEntity") + assert isinstance(doc_entity_id, UUID) + + def test_destructure_only_chunking(self, mock_document: ParsedDocument, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock, mock_chunk: LinkerEntity): + """Тест destructure только с чанкингом.""" + extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=False) + entities = extractor.destructure(mock_document) + + mock_strategy_instance.chunk.assert_called_once() + mock_tables_processor_instance.extract.assert_not_called() # extract не должен вызываться + + assert len(entities) == 2 # DocumentAsEntity + chunk + entity_types = {e.type for e in entities} + assert "DocumentAsEntity" in entity_types + assert "Chunk" in entity_types + + def test_destructure_no_strategy_no_tables(self, mock_document: ParsedDocument, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock): + """Тест destructure без стратегии и без таблиц.""" + # Убираем стратегию из реестра на время теста + with patch.dict(chunking_registry._chunking_strategies, {}, clear=True): + extractor = EntitiesExtractor(strategy_name=None, process_tables=False) + entities = extractor.destructure(mock_document) + + mock_strategy_instance.chunk.assert_not_called() + mock_tables_processor_instance.extract.assert_not_called() + + assert len(entities) == 1 # Только DocumentAsEntity + assert entities[0].type == "DocumentAsEntity" + assert entities[0].name == "mock_doc.pdf" + + def test_destructure_with_string_input(self, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock): + """Тест destructure с входной строкой.""" + input_string = "Это тестовая строка.\nВторая строка." + extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=False) + entities = extractor.destructure(input_string) + + # Проверяем, что chunk был вызван с созданным ParsedDocument + mock_strategy_instance.chunk.assert_called_once() + call_args, _ = mock_strategy_instance.chunk.call_args + assert isinstance(call_args[0], ParsedDocument) + assert call_args[0].name == "unknown" + assert call_args[0].type == "PlainText" + assert len(call_args[0].paragraphs) == 2 + assert call_args[0].paragraphs[0].text == "Это тестовая строка." + assert isinstance(call_args[1], DocumentAsEntity) + + mock_tables_processor_instance.extract.assert_not_called() + + assert len(entities) == 2 # Document + chunk + + def test_destructure_runtime_error_no_strategy(self, mock_document: ParsedDocument): + """Тест RuntimeError, если стратегия не сконфигурирована, но вызывается _chunk.""" + # Этот тест немного искусственный, т.к. destructure не вызовет _chunk, если strategy is None + # Но проверим сам метод _chunk на всякий случай + extractor = EntitiesExtractor(strategy_name=None, process_tables=False) + doc_entity = extractor._create_document_entity(mock_document) + with pytest.raises(RuntimeError, match="Стратегия чанкинга не выставлена"): + extractor._chunk(mock_document, doc_entity) \ No newline at end of file diff --git a/lib/extractor/tests/core/test_in_memory_repository.py b/lib/extractor/tests/core/test_in_memory_repository.py new file mode 100644 index 0000000000000000000000000000000000000000..a252af08e565da8a33cd4bff5c8a8f5f58aec19a --- /dev/null +++ b/lib/extractor/tests/core/test_in_memory_repository.py @@ -0,0 +1,433 @@ +""" +Unit-тесты для InMemoryEntityRepository. +""" + +from uuid import UUID, uuid4 + +import pytest +# Импортируем все необходимые типы сущностей +from ntr_text_fragmentation.additors.tables.models import (SubTableEntity, + TableEntity, + TableRowEntity) +from ntr_text_fragmentation.chunking.models import Chunk +from ntr_text_fragmentation.chunking.specific_strategies.fixed_size.fixed_size_chunk import \ + FixedSizeChunk +from ntr_text_fragmentation.core.entity_repository import GroupedEntities +from lib.extractor.ntr_text_fragmentation.repositories.in_memory_repository import \ + InMemoryEntityRepository +from ntr_text_fragmentation.models import DocumentAsEntity, LinkerEntity + +# --- Фикстуры --- + +@pytest.fixture +def doc1_id() -> UUID: + return uuid4() + +@pytest.fixture +def doc2_id() -> UUID: + return uuid4() + +@pytest.fixture +def table1_id() -> UUID: + return uuid4() + +@pytest.fixture +def subtable1_id() -> UUID: + return uuid4() + +@pytest.fixture +def doc1(doc1_id: UUID) -> DocumentAsEntity: + return DocumentAsEntity(id=doc1_id, name="doc1") + +@pytest.fixture +def doc2(doc2_id: UUID) -> DocumentAsEntity: + return DocumentAsEntity(id=doc2_id, name="doc2") + +@pytest.fixture +def chunk1_doc1(doc1_id: UUID) -> Chunk: + return Chunk(id=uuid4(), name="chunk1_doc1", text="text1", target_id=doc1_id, number_in_relation=0, groupper="chunk") + +@pytest.fixture +def chunk2_doc1(doc1_id: UUID) -> Chunk: + return Chunk(id=uuid4(), name="chunk2_doc1", text="text2", target_id=doc1_id, number_in_relation=1, groupper="chunk") + +@pytest.fixture +def chunk3_doc1(doc1_id: UUID) -> Chunk: + # Пропускаем номер 2 для теста соседей и группировки + return Chunk(id=uuid4(), name="chunk3_doc1", text="text3", target_id=doc1_id, number_in_relation=3, groupper="chunk") + +@pytest.fixture +def chunk1_doc2(doc2_id: UUID) -> Chunk: + return Chunk(id=uuid4(), name="chunk1_doc2", text="text_doc2", target_id=doc2_id, number_in_relation=0, groupper="chunk") + +@pytest.fixture +def table1_doc1(doc1_id: UUID, table1_id: UUID) -> TableEntity: + return TableEntity(id=table1_id, name="table1", target_id=doc1_id, number_in_relation=0, groupper="table") + +@pytest.fixture +def row1_table1(table1_id: UUID) -> TableRowEntity: + return TableRowEntity(id=uuid4(), name="row1_table1", cells=["r1c1", "r1c2"], target_id=table1_id, number_in_relation=0, groupper="row") + +@pytest.fixture +def row2_table1(table1_id: UUID) -> TableRowEntity: + return TableRowEntity(id=uuid4(), name="row2_table1", cells=["r2c1", "r2c2"], target_id=table1_id, number_in_relation=1, groupper="row") + +@pytest.fixture +def subtable1_table1(table1_id: UUID, subtable1_id: UUID) -> SubTableEntity: + return SubTableEntity(id=subtable1_id, name="subtable1", target_id=table1_id, number_in_relation=2, groupper="subtable") + +@pytest.fixture +def row1_subtable1(subtable1_id: UUID) -> TableRowEntity: + return TableRowEntity(id=uuid4(), name="row1_subtable1", cells=["sr1c1"], target_id=subtable1_id, number_in_relation=0, groupper="subrow") + +@pytest.fixture +def link1(chunk1_doc1: Chunk, chunk2_doc1: Chunk) -> LinkerEntity: + # Пример кастомной связи + return LinkerEntity(id=uuid4(), name="link1", source_id=chunk1_doc1.id, target_id=chunk2_doc1.id, type="CustomLink") + +@pytest.fixture +def all_entities( + doc1: DocumentAsEntity, + doc2: DocumentAsEntity, + chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk, + chunk1_doc2: Chunk, + table1_doc1: TableEntity, + row1_table1: TableRowEntity, row2_table1: TableRowEntity, + subtable1_table1: SubTableEntity, + row1_subtable1: TableRowEntity, + link1: LinkerEntity +) -> list[LinkerEntity]: + return [ + doc1, doc2, chunk1_doc1, chunk2_doc1, chunk3_doc1, chunk1_doc2, + table1_doc1, row1_table1, row2_table1, subtable1_table1, row1_subtable1, link1 + ] + +@pytest.fixture +def empty_repository() -> InMemoryEntityRepository: + return InMemoryEntityRepository() + +@pytest.fixture +def populated_repository(all_entities: list[LinkerEntity]) -> InMemoryEntityRepository: + return InMemoryEntityRepository(all_entities) + + +# --- Тесты --- # +class TestInMemoryEntityRepository: + """Тесты для InMemoryEntityRepository.""" + + def test_init_empty(self, empty_repository: InMemoryEntityRepository): + """Тест инициализации пустого репозитория.""" + assert empty_repository.entities == [] + assert empty_repository.entities_by_id == {} + assert empty_repository.relations_by_source == {} + assert empty_repository.relations_by_target == {} + assert empty_repository.compositions == {} + + def test_init_with_entities(self, populated_repository: InMemoryEntityRepository, all_entities: list[LinkerEntity], doc1_id: UUID, chunk1_doc1: Chunk, link1: LinkerEntity): + """Тест инициализации с сущностями и построения индексов.""" + assert len(populated_repository.entities) == len(all_entities) + assert len(populated_repository.entities_by_id) == len(all_entities) + assert doc1_id in populated_repository.entities_by_id + assert chunk1_doc1.id in populated_repository.entities_by_id + + # Проверка индекса compositions (owner_id) + assert doc1_id in populated_repository.compositions + assert len(populated_repository.compositions[doc1_id]) == 4 # chunk1, chunk2, chunk3, table1 + doc1_children_ids = {e.id for e in populated_repository.compositions[doc1_id]} + assert chunk1_doc1.id in doc1_children_ids + + # Проверка индекса relations + assert link1.source_id in populated_repository.relations_by_source + assert link1 in populated_repository.relations_by_source[link1.source_id] + assert link1.target_id in populated_repository.relations_by_target + assert link1 in populated_repository.relations_by_target[link1.target_id] + + def test_add_entities(self, empty_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk): + """Тест добавления сущностей.""" + empty_repository.add_entities([chunk1_doc1]) + assert len(empty_repository.entities) == 1 + assert chunk1_doc1.id in empty_repository.entities_by_id + assert chunk1_doc1.owner_id in empty_repository.compositions + + empty_repository.add_entities([chunk2_doc1]) + assert len(empty_repository.entities) == 2 + assert chunk2_doc1.id in empty_repository.entities_by_id + assert len(empty_repository.compositions[chunk1_doc1.owner_id]) == 2 + + def test_set_entities(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk): + """Тест установки (замены) сущностей.""" + initial_count = len(populated_repository.entities) + populated_repository.set_entities([chunk1_doc1, chunk2_doc1]) + assert len(populated_repository.entities) == 2 + assert len(populated_repository.entities_by_id) == 2 + assert chunk1_doc1.id in populated_repository.entities_by_id + assert chunk2_doc1.id in populated_repository.entities_by_id + assert len(populated_repository.compositions) == 1 # Только один owner_id + assert len(populated_repository.compositions[chunk1_doc1.owner_id]) == 2 + assert len(populated_repository.relations_by_source) == 0 # Старые связи удалены + + def test_get_entities_by_ids(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk1_doc2: Chunk): + """Тест получения сущностей по ID.""" + ids_to_get = [chunk1_doc1.id, chunk1_doc2.id, uuid4()] # Последний ID не существует + result = populated_repository.get_entities_by_ids(ids_to_get) + assert len(result) == 2 + result_ids = {e.id for e in result} + assert chunk1_doc1.id in result_ids + assert chunk1_doc2.id in result_ids + + # Тест с передачей самих сущностей + result_from_entities = populated_repository.get_entities_by_ids([chunk1_doc1, chunk1_doc2]) + assert len(result_from_entities) == 2 + assert chunk1_doc1 in result_from_entities + assert chunk1_doc2 in result_from_entities + + # Тест с пустым списком ID + assert populated_repository.get_entities_by_ids([]) == [] + + # --- Тесты group_entities_hierarchically --- + def test_group_by_doc_simple(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk, table1_doc1: TableEntity): + """Тест простой группировки по документу.""" + entities_to_group = [chunk1_doc1.id, chunk2_doc1.id, chunk3_doc1.id, table1_doc1.id] + groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity) + + assert len(groups) == 1 + group = groups[0] + assert isinstance(group, GroupedEntities) + assert group.composer == doc1 + assert len(group.entities) == 4 + grouped_ids = {e.id for e in group.entities} + assert chunk1_doc1.id in grouped_ids + assert chunk2_doc1.id in grouped_ids + assert chunk3_doc1.id in grouped_ids + assert table1_doc1.id in grouped_ids + # Проверка сортировки по number_in_relation + assert [e.number_in_relation for e in group.entities] == [0, 0, 1, 3] + + def test_group_by_doc_multi_level(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, row1_table1: TableRowEntity, row1_subtable1: TableRowEntity): + """Тест иерархической группировки (строка -> подтаблица -> таблица -> документ).""" + entities_to_group = [row1_table1.id, row1_subtable1.id] + groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity) + + assert len(groups) == 1 + group = groups[0] + assert group.composer == doc1 + assert len(group.entities) == 2 + grouped_ids = {e.id for e in group.entities} + assert row1_table1.id in grouped_ids + assert row1_subtable1.id in grouped_ids + # Проверка сортировки (строка подтаблицы идет после строки таблицы) + # row1_table1: owner=table1 (num=0), row1_subtable1: owner=subtable1 (num=0) -> owner=table1 (num=2) + # Порядок сортировки по умолчанию не определен для разных уровней иерархии, только внутри одного owner + # Позиция определяется как (groupper, number_in_relation) + assert group.entities[0].id == row1_table1.id + assert group.entities[1].id == row1_subtable1.id + + def test_group_by_doc_multiple_docs(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, doc2: DocumentAsEntity, chunk1_doc1: Chunk, chunk1_doc2: Chunk): + """Тест группировки сущностей из разных документов.""" + entities_to_group = [chunk1_doc1.id, chunk1_doc2.id] + groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity) + + assert len(groups) == 2 + composers = {g.composer for g in groups} + assert doc1 in composers + assert doc2 in composers + + for group in groups: + if group.composer == doc1: + assert len(group.entities) == 1 + assert group.entities[0] == chunk1_doc1 + elif group.composer == doc2: + assert len(group.entities) == 1 + assert group.entities[0] == chunk1_doc2 + + def test_group_no_sort(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, chunk3_doc1: Chunk): + """Тест группировки без сортировки.""" + # Передаем в обратном порядке + entities_to_group = [chunk3_doc1.id, chunk1_doc1.id] + groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity, sort=False) + assert len(groups) == 1 + group = groups[0] + # Порядок должен сохраниться как в entities_to_group + assert group.entities[0].id == chunk3_doc1.id + assert group.entities[1].id == chunk1_doc1.id + + def test_group_empty(self, populated_repository: InMemoryEntityRepository): + """Тест группировки пустого списка.""" + groups = populated_repository.group_entities_hierarchically([], DocumentAsEntity) + assert groups == [] + + def test_group_max_levels(self, populated_repository: InMemoryEntityRepository, row1_subtable1: TableRowEntity): + """Тест ограничения глубины поиска родителя.""" + # Уровень 1: SubTable, Уровень 2: Table, Уровень 3: Document + # Ищем корень DocumentAsEntity (нужно 3 уровня) + groups_ok = populated_repository.group_entities_hierarchically([row1_subtable1.id], DocumentAsEntity, max_levels=3) + assert len(groups_ok) == 1 + assert groups_ok[0].composer.type == "DocumentAsEntity" + + # Ищем с max_levels=2 (должен не найти Document) + groups_fail = populated_repository.group_entities_hierarchically([row1_subtable1.id], DocumentAsEntity, max_levels=2) + assert len(groups_fail) == 0 + + # Ищем корень TableEntity (нужно 2 уровня) + groups_table_ok = populated_repository.group_entities_hierarchically([row1_subtable1.id], TableEntity, max_levels=2) + assert len(groups_table_ok) == 1 + assert groups_table_ok[0].composer.type == "TableEntity" + + groups_table_fail = populated_repository.group_entities_hierarchically([row1_subtable1.id], TableEntity, max_levels=1) + assert len(groups_table_fail) == 0 + + # --- Тесты get_neighboring_entities --- + def test_get_neighbors_distance_1(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk): + """Тест получения соседей с distance=1.""" + # Соседи для chunk2 (индекс 1) + neighbors = populated_repository.get_neighboring_entities([chunk2_doc1.id], max_distance=1) + neighbor_ids = {e.id for e in neighbors} + assert len(neighbors) == 2 + assert chunk1_doc1.id in neighbor_ids # Сосед слева (индекс 0) + assert chunk3_doc1.id not in neighbor_ids # Сосед справа (индекс 3) - далеко + + # Соседи для chunk1 (индекс 0) + neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id], max_distance=1) + neighbor_ids = {e.id for e in neighbors} + assert len(neighbors) == 1 + assert chunk2_doc1.id in neighbor_ids # Сосед справа (индекс 1) + + # Соседи для chunk3 (индекс 3) + neighbors = populated_repository.get_neighboring_entities([chunk3_doc1.id], max_distance=1) + neighbor_ids = {e.id for e in neighbors} + # Сосед слева chunk2 (индекс 1) слишком далеко (diff = 2) + assert len(neighbors) == 0 + + def test_get_neighbors_distance_2(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk): + """Тест получения соседей с distance=2.""" + neighbors = populated_repository.get_neighboring_entities([chunk2_doc1.id], max_distance=2) + neighbor_ids = {e.id for e in neighbors} + assert len(neighbors) == 2 + assert chunk1_doc1.id in neighbor_ids # Сосед слева (индекс 0, diff=1) + assert chunk3_doc1.id in neighbor_ids # Сосед справа (индекс 3, diff=2) + + def test_get_neighbors_multiple_entities(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk): + """Тест получения соседей для нескольких сущностей.""" + neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id, chunk3_doc1.id], max_distance=1) + neighbor_ids = {e.id for e in neighbors} + # Сосед chunk1 -> chunk2 + # Соседей у chunk3 нет (с distance=1) + assert len(neighbors) == 1 + assert chunk2_doc1.id in neighbor_ids + + def test_get_neighbors_different_owners(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk1_doc2: Chunk): + """Тест: соседи ищутся только в рамках одного owner.""" + neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id], max_distance=5) # Большая дистанция + # Должен найти только chunk2_doc1 и chunk3_doc1, но не chunk1_doc2 + neighbor_ids = {e.id for e in neighbors} + assert len(neighbors) == 2 + assert chunk1_doc2.id not in neighbor_ids + + def test_get_neighbors_different_groupers(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, table1_doc1: TableEntity): + """Тест: соседи ищутся только в рамках одного groupper.""" + # У chunk1_doc1 groupper='chunk', у table1_doc1 groupper='table' + # Оба принадлежат doc1, number_in_relation = 0 у обоих + neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id], max_distance=1) + neighbor_ids = {e.id for e in neighbors} + assert table1_doc1.id not in neighbor_ids # Не должен найти таблицу + + def test_get_neighbors_no_owner(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity): + """Тест: сущности без owner_id не должны иметь соседей.""" + neighbors = populated_repository.get_neighboring_entities([doc1.id], max_distance=1) + assert len(neighbors) == 0 + + def test_get_neighbors_empty(self, populated_repository: InMemoryEntityRepository): + """Тест получения соседей для пустого списка.""" + neighbors = populated_repository.get_neighboring_entities([], max_distance=1) + assert neighbors == [] + + # --- Тесты get_related_entities --- + def test_get_related_as_source(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, link1: LinkerEntity): + """Тест поиска связей, где сущность - источник.""" + related = populated_repository.get_related_entities([chunk1_doc1.id], as_source=True) + related_ids = {e.id for e in related} + assert len(related) == 2 # Сама связь + цель связи + assert link1.id in related_ids + assert link1.target_id in related_ids # chunk2_doc1.id + + def test_get_related_as_target(self, populated_repository: InMemoryEntityRepository, chunk2_doc1: Chunk, link1: LinkerEntity): + """Тест поиска связей, где сущность - цель.""" + related = populated_repository.get_related_entities([chunk2_doc1.id], as_target=True) + related_ids = {e.id for e in related} + assert len(related) == 2 # Сама связь + источник связи + assert link1.id in related_ids + assert link1.source_id in related_ids # chunk1_doc1.id + + def test_get_related_as_owner(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, table1_doc1: TableEntity): + """Тест поиска дочерних сущностей (по owner_id).""" + related = populated_repository.get_related_entities([doc1.id], as_owner=True) + related_ids = {e.id for e in related} + # Ожидаем chunk1, chunk2, chunk3, table1 + assert len(related) == 4 + assert chunk1_doc1.id in related_ids + assert table1_doc1.id in related_ids + + def test_get_related_all_directions(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, chunk2_doc1: Chunk, link1: LinkerEntity): + """Тест поиска связей во всех направлениях (по умолчанию).""" + related_c1 = populated_repository.get_related_entities([chunk1_doc1.id]) # source для link1, child для doc1 + related_c1_ids = {e.id for e in related_c1} + assert len(related_c1) == 2 # link1, chunk2_doc1 (target) + assert link1.id in related_c1_ids + assert chunk2_doc1.id in related_c1_ids + + related_c2 = populated_repository.get_related_entities([chunk2_doc1.id]) # target для link1, child для doc1 + related_c2_ids = {e.id for e in related_c2} + assert len(related_c2) == 2 # link1, chunk1_doc1 (source) + assert link1.id in related_c2_ids + assert chunk1_doc1.id in related_c2_ids + + related_doc = populated_repository.get_related_entities([doc1.id]) # owner для chunk1/2/3, table1 + related_doc_ids = {e.id for e in related_doc} + assert len(related_doc) == 4 # chunk1, chunk2, chunk3, table1 + assert chunk1_doc1.id in related_doc_ids + + def test_get_related_filter_by_type(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, table1_doc1: TableEntity): + """Тест фильтрации связей по типу.""" + # Ищем только чанки, принадлежащие doc1 + related_chunks = populated_repository.get_related_entities([doc1.id], as_owner=True, relation_type=Chunk) + related_ids = {e.id for e in related_chunks} + assert len(related_chunks) == 3 + assert chunk1_doc1.id in related_ids + assert table1_doc1.id not in related_ids + + # Ищем только таблицы, принадлежащие doc1 + related_tables = populated_repository.get_related_entities([doc1.id], as_owner=True, relation_type=TableEntity) + assert len(related_tables) == 1 + assert related_tables[0].id == table1_doc1.id + + # Ищем только связи типа CustomLink, где chunk1 - источник + related_custom_link = populated_repository.get_related_entities([chunk1_doc1.id], as_source=True, relation_type=LinkerEntity) # Используем базовый тип, т.к. CustomLink не регистрировали + related_custom_link_ids = {e.id for e in related_custom_link} + assert len(related_custom_link) == 2 + assert link1.id in related_custom_link_ids + + def test_get_related_multiple_entities_input(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, link1: LinkerEntity): + """Тест поиска связей для нескольких сущностей одновременно.""" + related = populated_repository.get_related_entities([chunk1_doc1.id, chunk2_doc1.id], as_source=True) + related_ids = {e.id for e in related} + # chunk1 -> link1 -> chunk2 + # chunk2 -> нет связей как source + assert len(related) == 2 # link1, chunk2 + assert link1.id in related_ids + assert link1.target_id in related_ids + + def test_get_related_no_relations(self, populated_repository: InMemoryEntityRepository, doc2: DocumentAsEntity): + """Тест поиска связей для сущности без связей.""" + related = populated_repository.get_related_entities([doc2.id]) # У doc2 есть только дочерний chunk1_doc2 + related_ids = {e.id for e in related} + assert len(related) == 1 # Находит только дочерний chunk1_doc2 + assert chunk1_doc2.id in related_ids + + # Ищем только source/target связи для doc2 + related_links = populated_repository.get_related_entities([doc2.id], as_source=True, as_target=True) + assert len(related_links) == 0 + + def test_get_related_empty_input(self, populated_repository: InMemoryEntityRepository): + """Тест поиска связей для пустого списка.""" + related = populated_repository.get_related_entities([]) + assert related == [] \ No newline at end of file diff --git a/lib/extractor/tests/core/test_injection_builder.py b/lib/extractor/tests/core/test_injection_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bb44bccd776c0cea20be2fcaea73d9f4ee1f31 --- /dev/null +++ b/lib/extractor/tests/core/test_injection_builder.py @@ -0,0 +1,412 @@ +""" +Unit-тесты для InjectionBuilder. +""" + +from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 + +import pytest +from ntr_text_fragmentation.additors import TablesProcessor +from ntr_text_fragmentation.chunking import (ChunkingStrategy, + chunking_registry, + register_chunking_strategy) +from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import ( + FIXED_SIZE, FixedSizeChunkingStrategy) +from ntr_text_fragmentation.core.entity_repository import EntityRepository +from lib.extractor.ntr_text_fragmentation.repositories.in_memory_repository import \ + InMemoryEntityRepository +from ntr_text_fragmentation.core.injection_builder import InjectionBuilder +from ntr_text_fragmentation.models import (Chunk, DocumentAsEntity, + LinkerEntity, TableEntity, + TableRowEntity) + +# --- Фикстуры --- # + +# Используем реальные ID для связей в фикстурах +DOC1_ID = uuid4() +DOC2_ID = uuid4() +TABLE1_ID = uuid4() +CHUNK1_ID = uuid4() +CHUNK2_ID = uuid4() +CHUNK3_ID = uuid4() +CHUNK4_ID = uuid4() +ROW1_ID = uuid4() +ROW2_ID = uuid4() + + +@pytest.fixture +def doc1() -> DocumentAsEntity: + return DocumentAsEntity(id=DOC1_ID, name="Document 1", chunking_strategy_ref=FIXED_SIZE) + +@pytest.fixture +def doc2() -> DocumentAsEntity: + return DocumentAsEntity(id=DOC2_ID, name="Document 2", chunking_strategy_ref=FIXED_SIZE) + +# Используем FixedSizeChunk для тестирования dechunk +@pytest.fixture +def chunk1_doc1(doc1: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk: + return FixedSizeChunkingStrategy.FixedSizeChunk( + id=CHUNK1_ID, name="c1d1", text="Chunk 1 text.", owner_id=doc1.id, + number_in_relation=0, groupper="chunk", + right_sentence_part="More from chunk 1." + ) + +@pytest.fixture +def chunk2_doc1(doc1: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk: + return FixedSizeChunkingStrategy.FixedSizeChunk( + id=CHUNK2_ID, name="c2d1", text="Chunk 2 is here.", owner_id=doc1.id, + number_in_relation=1, groupper="chunk", + left_sentence_part="Continuation from chunk 1.", right_sentence_part="End of chunk 2." + ) + +@pytest.fixture +def chunk3_doc1(doc1: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk: + # Пропуск для теста разрыва + return FixedSizeChunkingStrategy.FixedSizeChunk( + id=CHUNK3_ID, name="c3d1", text="Chunk 3 after gap.", owner_id=doc1.id, + number_in_relation=3, groupper="chunk", + left_sentence_part="Start after gap." + ) + +@pytest.fixture +def chunk4_doc2(doc2: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk: + return FixedSizeChunkingStrategy.FixedSizeChunk( + id=CHUNK4_ID, name="c4d2", text="Chunk 4 from doc 2.", owner_id=doc2.id, + number_in_relation=0, groupper="chunk" + ) + +@pytest.fixture +def table1_doc1(doc1: DocumentAsEntity) -> TableEntity: + return TableEntity(id=TABLE1_ID, name="t1d1", text="Table 1 representation", owner_id=doc1.id, number_in_relation=0, groupper="table") + +@pytest.fixture +def row1_table1(table1_doc1: TableEntity) -> TableRowEntity: + return TableRowEntity(id=ROW1_ID, name="r1t1", cells=["a", "b"], owner_id=table1_doc1.id, number_in_relation=0, groupper="row") + +@pytest.fixture +def row2_table1(table1_doc1: TableEntity) -> TableRowEntity: + return TableRowEntity(id=ROW2_ID, name="r2t1", cells=["c", "d"], owner_id=table1_doc1.id, number_in_relation=1, groupper="row") + + +@pytest.fixture +def all_test_entities( + doc1, doc2, chunk1_doc1, chunk2_doc1, chunk3_doc1, chunk4_doc2, + table1_doc1, row1_table1, row2_table1 +) -> list[LinkerEntity]: + # Собираем все созданные сущности + return [ + doc1, doc2, chunk1_doc1, chunk2_doc1, chunk3_doc1, chunk4_doc2, + table1_doc1, row1_table1, row2_table1 + ] + +@pytest.fixture +def serialized_entities(all_test_entities: list[LinkerEntity]) -> list[LinkerEntity]: + # Сериализуем, как они хранятся в репозитории + return [e.serialize() for e in all_test_entities] + +@pytest.fixture +def test_repository(serialized_entities: list[LinkerEntity]) -> InMemoryEntityRepository: + # Репозиторий для тестов InjectionBuilder + return InMemoryEntityRepository(serialized_entities) + +# --- Моки --- # + +@pytest.fixture +def mock_strategy_class() -> MagicMock: + """Мок класса стратегии чанкинга.""" + mock_cls = MagicMock(spec=FixedSizeChunkingStrategy) + # Мокируем classmethod dechunk + mock_cls.dechunk.return_value = "[Dechunked Text]" + return mock_cls + +@pytest.fixture +def mock_tables_processor_build() -> MagicMock: + """Мок статического метода TablesProcessor.build.""" + with patch.object(TablesProcessor, 'build', return_value="[Built Tables]") as mock_build: + yield mock_build + +@pytest.fixture(autouse=True) +def setup_mocks(mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock): + """Автоматически применяет моки для реестра и процессора таблиц.""" + # Регистрируем мок-стратегию в реестре + with patch.dict(chunking_registry._chunking_strategies, {FIXED_SIZE: mock_strategy_class}, clear=True): + yield # Позволяем тестам выполниться с моками + + +# --- Тесты --- # +class TestInjectionBuilder: + """Тесты для InjectionBuilder.""" + + def test_init_with_repository(self, test_repository: InMemoryEntityRepository): + """Тест инициализации с репозиторием.""" + builder = InjectionBuilder(repository=test_repository) + assert builder.repository is test_repository + + def test_init_with_entities(self, serialized_entities: list[LinkerEntity]): + """Тест инициализации со списком сущностей.""" + builder = InjectionBuilder(entities=serialized_entities) + assert isinstance(builder.repository, InMemoryEntityRepository) + assert builder.repository.entities == serialized_entities + + def test_init_errors(self, test_repository: InMemoryEntityRepository, serialized_entities: list[LinkerEntity]): + """Тест ошибок при инициализации.""" + with pytest.raises(ValueError, match="Необходимо указать либо repository, либо entities"): + InjectionBuilder() + with pytest.raises(ValueError, match="Использование одновременно repository и entities не допускается"): + InjectionBuilder(repository=test_repository, entities=serialized_entities) + + def test_build_simple(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, chunk2_doc1: Chunk): + """Тест простого сценария сборки (только чанки одного документа).""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id, chunk2_doc1.id] + + # Ожидаем, что dechunk будет вызван с правильными аргументами + mock_strategy_class.dechunk.return_value = "Chunk 1 text. Chunk 2 is here." + + result = builder.build(selected_ids, include_tables=False) + + # Проверка вызова get_entities_by_ids + # (Не можем легко проверить без мокирования самого репозитория, но проверим косвенно через вызовы dechunk/build) + + # Проверка вызова dechunk + mock_strategy_class.dechunk.assert_called_once() + call_args, _ = mock_strategy_class.dechunk.call_args + assert call_args[0] is test_repository + # Переданные сущности должны быть десериализованы и отсортированы + passed_entities = call_args[1] + assert len(passed_entities) == 2 + assert all(isinstance(e, Chunk) for e in passed_entities) + assert passed_entities[0].id == chunk1_doc1.id + assert passed_entities[1].id == chunk2_doc1.id + + # Проверка вызова TablesProcessor.build (не должен вызываться) + mock_tables_processor_build.assert_not_called() + + # Проверка результата + expected = ( + "## [Источник] - Document 1\n\n" + "### Текст\nChunk 1 text. Chunk 2 is here.\n\n" + ) + assert result == expected + + def test_build_with_tables(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, row1_table1: TableRowEntity): + """Тест сборки с чанками и таблицами.""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id, row1_table1.id] + + mock_strategy_class.dechunk.return_value = "Chunk 1 text." + mock_tables_processor_build.return_value = "Table Row 1: [a, b]" + + result = builder.build(selected_ids, include_tables=True) + + mock_strategy_class.dechunk.assert_called_once() + # dechunk вызывается только с чанками + dechunk_args, _ = mock_strategy_class.dechunk.call_args + assert len(dechunk_args[1]) == 1 + assert dechunk_args[1][0].id == chunk1_doc1.id + + mock_tables_processor_build.assert_called_once() + # build вызывается со всеми сущностями группы + build_args, _ = mock_tables_processor_build.call_args + assert build_args[0] is test_repository + assert len(build_args[1]) == 2 + build_entity_ids = {e.id for e in build_args[1]} + assert chunk1_doc1.id in build_entity_ids + assert row1_table1.id in build_entity_ids + + expected = ( + "## [Источник] - Document 1\n\n" + "### Текст\nChunk 1 text.\n\n" + "### Таблицы\nTable Row 1: [a, b]\n\n" + ) + assert result == expected + + def test_build_include_tables_false(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, row1_table1: TableRowEntity): + """Тест сборки с include_tables=False.""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id, row1_table1.id] # Передаем строку таблицы + + mock_strategy_class.dechunk.return_value = "Chunk 1 text." + + result = builder.build(selected_ids, include_tables=False) + + # dechunk вызывается только с чанком + mock_strategy_class.dechunk.assert_called_once() + dechunk_args, _ = mock_strategy_class.dechunk.call_args + assert len(dechunk_args[1]) == 1 + assert dechunk_args[1][0].id == chunk1_doc1.id + + # TablesProcessor.build не должен вызываться + mock_tables_processor_build.assert_not_called() + + expected = ( + "## [Источник] - Document 1\n\n" + "### Текст\nChunk 1 text.\n\n" + # Секции таблиц нет + ) + assert result == expected + + def test_build_with_neighbors(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk, chunk2_doc1: Chunk): + """Тест сборки с добавлением соседей.""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id] + neighbors_distance = 1 + + # Мокируем get_neighboring_entities, чтобы убедиться, что он вызывается + with patch.object(test_repository, 'get_neighboring_entities', wraps=test_repository.get_neighboring_entities) as mock_get_neighbors: + mock_strategy_class.dechunk.return_value = "Chunk 1 text. Chunk 2 is here." + result = builder.build(selected_ids, neighbors_max_distance=neighbors_distance) + + mock_get_neighbors.assert_called_once() + call_args, _ = mock_get_neighbors.call_args + # Первым аргументом должны быть десериализованные сущности из selected_ids + assert len(call_args[0]) == 1 + assert isinstance(call_args[0][0], Chunk) + assert call_args[0][0].id == chunk1_doc1.id + # Второй аргумент - max_distance + assert call_args[1] == neighbors_distance + + # Проверяем, что dechunk вызван с chunk1 и его соседом chunk2 + mock_strategy_class.dechunk.assert_called_once() + dechunk_args, _ = mock_strategy_class.dechunk.call_args + assert len(dechunk_args[1]) == 2 + dechunk_ids = {e.id for e in dechunk_args[1]} + assert chunk1_doc1.id in dechunk_ids + assert chunk2_doc1.id in dechunk_ids + + expected = ( + "## [Источник] - Document 1\n\n" + "### Текст\nChunk 1 text. Chunk 2 is here.\n\n" + ) + assert result == expected + + def test_build_multiple_docs_no_limit(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, chunk4_doc2: Chunk): + """Тест сборки сущностей из разных документов без лимита.""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id, chunk4_doc2.id] + + # Настроим возвращаемые значения для dechunk (вызывается дважды) + mock_strategy_class.dechunk.side_effect = [ + "Chunk 1 text.", # Для doc1 + "Chunk 4 from doc 2." # Для doc2 + ] + + result = builder.build(selected_ids, include_tables=False) + + assert mock_strategy_class.dechunk.call_count == 2 + # TablesProcessor.build не должен вызываться + mock_tables_processor_build.assert_not_called() + + # Порядок документов определяется дефолтными скорами (по убыванию индекса) + # chunk4_doc2 (score=2.0) > chunk1_doc1 (score=1.0) + expected = ( + "## [Источник] - Document 2\n\n" + "### Текст\nChunk 4 from doc 2.\n\n" + "\n\n" + "## [Источник] - Document 1\n\n" + "### Текст\nChunk 1 text.\n\n" + ) + assert result == expected + + def test_build_multiple_docs_with_scores(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk, chunk4_doc2: Chunk): + """Тест сборки сущностей из разных документов с заданными скорами.""" + builder = InjectionBuilder(repository=test_repository) + selected_entities = [ + test_repository.entities_by_id[chunk4_doc2.id], # doc2 + test_repository.entities_by_id[chunk1_doc1.id] # doc1 + ] + scores = [0.5, 0.9] # doc1 > doc2 + + mock_strategy_class.dechunk.side_effect = [ + "Chunk 1 text.", # doc1 + "Chunk 4 from doc 2." # doc2 + ] + + result = builder.build(selected_entities, scores=scores, include_tables=False) + + # Проверяем порядок документов в результате (doc1 должен быть первым) + expected = ( + "## [Источник] - Document 1\n\n" + "### Текст\nChunk 1 text.\n\n" + "\n\n" + "## [Источник] - Document 2\n\n" + "### Текст\nChunk 4 from doc 2.\n\n" + ) + assert result == expected + + def test_build_max_documents(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk, chunk4_doc2: Chunk): + """Тест сборки с ограничением max_documents.""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id, chunk4_doc2.id] + + # doc2 (score 2.0) > doc1 (score 1.0) + mock_strategy_class.dechunk.return_value = "Chunk 4 from doc 2." + + result = builder.build(selected_ids, max_documents=1, include_tables=False) + + # Должен быть вызван dechunk только один раз для документа с наивысшим скором (doc2) + mock_strategy_class.dechunk.assert_called_once() + + expected = ( + "## [Источник] - Document 2\n\n" + "### Текст\nChunk 4 from doc 2.\n\n" + ) + assert result == expected + + def test_build_custom_prefix(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk): + """Тест сборки с кастомным префиксом документа.""" + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id] + custom_prefix = "Source Doc: " + + mock_strategy_class.dechunk.return_value = "Chunk 1 text." + + result = builder.build(selected_ids, document_prefix=custom_prefix, include_tables=False) + + expected = ( + f"## {custom_prefix}Document 1\n\n" + "### Текст\nChunk 1 text.\n\n" + ) + assert result == expected + + def test_build_empty_entities(self, test_repository: InMemoryEntityRepository): + """Тест сборки с пустым списком сущностей.""" + builder = InjectionBuilder(repository=test_repository) + result = builder.build([]) + assert result == "" + + def test_build_unknown_ids(self, test_repository: InMemoryEntityRepository): + """Тест сборки с неизвестными ID.""" + builder = InjectionBuilder(repository=test_repository) + result = builder.build([uuid4(), uuid4()]) # Передаем несуществующие ID + assert result == "" + + def test_build_no_strategy_for_doc(self, test_repository: InMemoryEntityRepository, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk): + """Тест сборки, если у документа нет chunking_strategy_ref.""" + # Убираем ссылку на стратегию у документа + doc1_entity = test_repository.entities_by_id[chunk1_doc1.owner_id] + original_ref = doc1_entity.chunking_strategy_ref + doc1_entity.chunking_strategy_ref = None + + builder = InjectionBuilder(repository=test_repository) + selected_ids = [chunk1_doc1.id] + + mock_tables_processor_build.return_value = "[Tables]" + + # dechunk не должен вызываться + with patch.object(chunking_registry[FIXED_SIZE], 'dechunk') as mock_dechunk: + result = builder.build(selected_ids, include_tables=True) + mock_dechunk.assert_not_called() + + # build для таблиц должен вызваться + mock_tables_processor_build.assert_called_once() + + expected = ( + "## [Источник] - Document 1\n\n" + # Секции Текст нет + "### Таблицы\n[Tables]\n\n" + ) + assert result == expected + + # Восстанавливаем ссылку + doc1_entity.chunking_strategy_ref = original_ref \ No newline at end of file diff --git a/lib/extractor/tests/custom_entity.py b/lib/extractor/tests/custom_entity.py index ca96e042f843e368436f54c1aef5297b94ee6d75..328b769f27313e83c34a97ce26cd210cbc25c695 100644 --- a/lib/extractor/tests/custom_entity.py +++ b/lib/extractor/tests/custom_entity.py @@ -2,105 +2,3 @@ from uuid import UUID from ntr_text_fragmentation.models.linker_entity import (LinkerEntity, register_entity) - - -@register_entity -class CustomEntity(LinkerEntity): - """Пользовательский класс-наследник LinkerEntity для тестирования сериализации и десериализации.""" - - def __init__( - self, - id: UUID, - name: str, - text: str, - metadata: dict, - custom_field1: str, - custom_field2: int, - in_search_text: str | None = None, - source_id: UUID | None = None, - target_id: UUID | None = None, - number_in_relation: int | None = None, - type: str = "CustomEntity" - ): - super().__init__( - id=id, - name=name, - text=text, - metadata=metadata, - in_search_text=in_search_text, - source_id=source_id, - target_id=target_id, - number_in_relation=number_in_relation, - type=type - ) - self.custom_field1 = custom_field1 - self.custom_field2 = custom_field2 - - def deserialize(self, entity: LinkerEntity) -> 'CustomEntity': - """Реализация метода десериализации для кастомного класса.""" - custom_field1 = entity.metadata.get('_custom_field1', '') - custom_field2 = entity.metadata.get('_custom_field2', 0) - - # Создаем чистые метаданные без служебных полей - clean_metadata = {k: v for k, v in entity.metadata.items() - if not k.startswith('_')} - - return CustomEntity( - id=entity.id, - name=entity.name, - text=entity.text, - in_search_text=entity.in_search_text, - metadata=clean_metadata, - source_id=entity.source_id, - target_id=entity.target_id, - number_in_relation=entity.number_in_relation, - custom_field1=custom_field1, - custom_field2=custom_field2 - ) - - @classmethod - def deserialize(cls, entity: LinkerEntity) -> 'CustomEntity': - """ - Классовый метод для десериализации. - Необходим для работы с реестром классов. - - Args: - entity: Сериализованная сущность - - Returns: - Десериализованный экземпляр CustomEntity - """ - custom_field1 = entity.metadata.get('_custom_field1', '') - custom_field2 = entity.metadata.get('_custom_field2', 0) - - # Создаем чистые метаданные без служебных полей - clean_metadata = {k: v for k, v in entity.metadata.items() - if not k.startswith('_')} - - return CustomEntity( - id=entity.id, - name=entity.name, - text=entity.text, - in_search_text=entity.in_search_text, - metadata=clean_metadata, - source_id=entity.source_id, - target_id=entity.target_id, - number_in_relation=entity.number_in_relation, - custom_field1=custom_field1, - custom_field2=custom_field2 - ) - - def __eq__(self, other): - """Переопределяем метод сравнения для проверки равенства объектов.""" - if not isinstance(other, CustomEntity): - return False - - # Используем базовое сравнение из LinkerEntity, которое уже учитывает поля связи - base_equality = super().__eq__(other) - - # Дополнительно проверяем кастомные поля - return ( - base_equality - and self.custom_field1 == other.custom_field1 - and self.custom_field2 == other.custom_field2 - ) \ No newline at end of file diff --git a/lib/extractor/tests/models/test_linker_entity.py b/lib/extractor/tests/models/test_linker_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..07166f6b4fb8d06cc5e246f4d4ae32c5f2789b8e --- /dev/null +++ b/lib/extractor/tests/models/test_linker_entity.py @@ -0,0 +1,251 @@ +""" +Unit-тесты для базового класса LinkerEntity и его механизма сериализации/десериализации. +""" + +import uuid +from dataclasses import dataclass, field +from uuid import UUID, uuid4 + +import pytest +from ntr_text_fragmentation.models import LinkerEntity, register_entity +from tests.custom_entity import \ + CustomEntity # Используем существующий кастомный класс + + +# Фикстуры +@pytest.fixture +def base_entity() -> LinkerEntity: + """Фикстура для базовой сущности.""" + return LinkerEntity(id=uuid4(), name="Base Name", text="Base Text") + + +@pytest.fixture +def link_entity() -> LinkerEntity: + """Фикстура для сущности-связи.""" + return LinkerEntity( + id=uuid4(), + name="Link Name", + source_id=uuid4(), + target_id=uuid4(), + number_in_relation=1, + ) + + +@pytest.fixture +def custom_entity_instance() -> CustomEntity: + """Фикстура для кастомной сущности.""" + return CustomEntity( + id=uuid4(), + name="Custom Name", + text="Custom Text", + custom_field="custom_value", + metadata={"existing_meta": "meta_value"}, + ) + + +@pytest.fixture +def serialized_custom_entity( + custom_entity_instance: CustomEntity, +) -> LinkerEntity: + """Фикстура для сериализованной кастомной сущности.""" + return custom_entity_instance.serialize() + + +# Тесты +class TestLinkerEntity: + """Тесты для класса LinkerEntity.""" + + def test_initialization_defaults(self): + """Тест инициализации с значениями по умолчанию.""" + entity = LinkerEntity() + assert isinstance(entity.id, UUID) + assert entity.name == "" + assert entity.text == "" + assert entity.metadata == {} + assert entity.in_search_text is None + assert entity.source_id is None + assert entity.target_id is None + assert entity.number_in_relation is None + assert entity.groupper is None + assert entity.type == "LinkerEntity" # Имя класса по умолчанию + + def test_initialization_with_values(self, base_entity: LinkerEntity): + """Тест инициализации с заданными значениями.""" + entity_id = base_entity.id + assert base_entity.name == "Base Name" + assert base_entity.text == "Base Text" + assert base_entity.id == entity_id + + def test_is_link(self, base_entity: LinkerEntity, link_entity: LinkerEntity): + """Тест метода is_link().""" + assert not base_entity.is_link() + assert link_entity.is_link() + + def test_owner_id_property(self, base_entity: LinkerEntity, link_entity: LinkerEntity): + """Тест свойства owner_id.""" + # У обычной сущности owner_id это target_id + owner_uuid = uuid4() + base_entity.target_id = owner_uuid + assert base_entity.owner_id == owner_uuid + + # У связи нет owner_id + assert link_entity.owner_id is None + + # Попытка установить owner_id для связи должна вызвать ошибку + with pytest.raises(ValueError, match="Связь не может иметь владельца"): + link_entity.owner_id = uuid4() + + # Установка owner_id для обычной сущности + new_owner_id = uuid4() + base_entity.owner_id = new_owner_id + assert base_entity.target_id == new_owner_id + + def test_str_representation(self, base_entity: LinkerEntity): + """Тест строкового представления __str__.""" + assert str(base_entity) == "Base Name: Base Text" + + base_entity.in_search_text = "Search text representation" + assert str(base_entity) == "Search text representation" + + def test_equality(self, base_entity: LinkerEntity): + """Тест сравнения __eq__.""" + entity_copy = LinkerEntity( + id=base_entity.id, name="Base Name", text="Base Text" + ) + different_entity = LinkerEntity(name="Different Name") + + assert base_entity == entity_copy + assert base_entity != different_entity + assert base_entity != "not an entity" + + def test_equality_links(self, link_entity: LinkerEntity): + """Тест сравнения связей.""" + link_copy = LinkerEntity( + id=link_entity.id, + name="Link Name", + source_id=link_entity.source_id, + target_id=link_entity.target_id, + number_in_relation=1, + ) + different_link = LinkerEntity( + id=link_entity.id, + name="Link Name", + source_id=uuid4(), # Другой source_id + target_id=link_entity.target_id, + number_in_relation=1, + ) + non_link = LinkerEntity(id=link_entity.id) + + assert link_entity == link_copy + assert link_entity != different_link + assert link_entity != non_link + + # --- Тесты сериализации/десериализации --- + + def test_serialize_base_entity(self, base_entity: LinkerEntity): + """Тест сериализации базовой сущности.""" + serialized = base_entity.serialize() + assert isinstance(serialized, LinkerEntity) + # Проверяем, что это не тот же самый объект, а копия базового типа + assert serialized is not base_entity + assert type(serialized) is LinkerEntity + assert serialized.id == base_entity.id + assert serialized.name == base_entity.name + assert serialized.text == base_entity.text + assert serialized.metadata == {} # Нет доп. полей + assert serialized.type == "LinkerEntity" # Сохраняем тип + + def test_serialize_custom_entity( + self, + custom_entity_instance: CustomEntity, + serialized_custom_entity: LinkerEntity, + ): + """Тест сериализации кастомной сущности.""" + serialized = serialized_custom_entity # Используем фикстуру + + assert isinstance(serialized, LinkerEntity) + assert type(serialized) is LinkerEntity + assert serialized.id == custom_entity_instance.id + assert serialized.name == custom_entity_instance.name + assert serialized.text == custom_entity_instance.text + # Проверяем, что кастомное поле и исходные метаданные попали в metadata + assert "_custom_field" in serialized.metadata + assert serialized.metadata["_custom_field"] == "custom_value" + assert "existing_meta" in serialized.metadata + assert serialized.metadata["existing_meta"] == "meta_value" + # Тип должен быть именем кастомного класса + assert serialized.type == "CustomEntity" + + def test_deserialize_custom_entity( + self, serialized_custom_entity: LinkerEntity + ): + """Тест десериализации в кастомный тип.""" + # Используем класс CustomEntity для десериализации, так как он зарегистрирован + deserialized = LinkerEntity._deserialize(serialized_custom_entity) + + assert isinstance(deserialized, CustomEntity) + assert deserialized.id == serialized_custom_entity.id + assert deserialized.name == serialized_custom_entity.name + assert deserialized.text == serialized_custom_entity.text + # Проверяем восстановление кастомного поля + assert deserialized.custom_field == "custom_value" + # Проверяем восстановление исходных метаданных + assert "existing_meta" in deserialized.metadata + assert deserialized.metadata["existing_meta"] == "meta_value" + assert deserialized.type == "CustomEntity" # Тип сохраняется + + def test_deserialize_base_entity(self, base_entity: LinkerEntity): + """Тест десериализации базовой сущности (должна вернуться сама).""" + serialized = base_entity.serialize() # Сериализуем базовую + deserialized = LinkerEntity._deserialize(serialized) + assert deserialized is serialized # Возвращается исходный объект LinkerEntity + assert type(deserialized) is LinkerEntity + + def test_deserialize_unregistered_type(self): + """Тест десериализации незарегистрированного типа (должен вернуться исходный объект).""" + unregistered_entity = LinkerEntity(id=uuid4(), type="UnregisteredType") + deserialized = LinkerEntity._deserialize(unregistered_entity) + assert deserialized is unregistered_entity + assert deserialized.type == "UnregisteredType" + + def test_deserialize_to_me_on_custom_class( + self, serialized_custom_entity: LinkerEntity + ): + """Тест прямого вызова _deserialize_to_me на кастомном классе.""" + # Вызываем метод десериализации непосредственно у CustomEntity + deserialized = CustomEntity._deserialize_to_me(serialized_custom_entity) + + assert isinstance(deserialized, CustomEntity) + assert deserialized.id == serialized_custom_entity.id + assert deserialized.custom_field == "custom_value" + assert deserialized.metadata["existing_meta"] == "meta_value" + + def test_deserialize_to_me_type_error(self): + """Тест ошибки TypeError в _deserialize_to_me при неверном типе данных.""" + with pytest.raises(TypeError): + # Пытаемся десериализовать не LinkerEntity + CustomEntity._deserialize_to_me("not_an_entity") # type: ignore + + def test_register_entity_decorator(self): + """Тест работы декоратора @register_entity.""" + + @register_entity + @dataclass + class TempEntity(LinkerEntity): + temp_field: str = "temp" + type: str = "Temporary" # Явно указываем тип для регистрации + + assert "Temporary" in LinkerEntity._entity_classes + assert LinkerEntity._entity_classes["Temporary"] is TempEntity + + # Проверяем, что он десериализуется + instance = TempEntity(id=uuid4(), name="Temp instance", temp_field="value") + serialized = instance.serialize() + assert serialized.type == "Temporary" + deserialized = LinkerEntity._deserialize(serialized) + assert isinstance(deserialized, TempEntity) + assert deserialized.temp_field == "value" + + # Удаляем временный класс из реестра, чтобы не влиять на другие тесты + del LinkerEntity._entity_classes["Temporary"] + assert "Temporary" not in LinkerEntity._entity_classes \ No newline at end of file diff --git a/routes/dataset.py b/routes/dataset.py index ff5cb2c39aa57beacb5685413c36df3003ea6016..87bf0da3d43c617fe2dff40f099904b6166243af 100644 --- a/routes/dataset.py +++ b/routes/dataset.py @@ -54,8 +54,7 @@ def try_create_default_dataset(dataset_service: DatasetService): else: dataset_service.create_dataset_from_directory( is_default=True, - directory_with_documents=dataset_service.config.db_config.files.xmls_path_default, - directory_with_ready_dataset=dataset_service.config.db_config.files.start_path, + directory_with_documents=dataset_service.config.db_config.files.documents_path, ) @router.get('/try_init_default_dataset') diff --git a/routes/entity.py b/routes/entity.py index e667cc661ccdbcb482e50e40a16140fdb6bb85d8..03d73998bc125d2bdb9367b39a97819f4997b439 100644 --- a/routes/entity.py +++ b/routes/entity.py @@ -1,18 +1,23 @@ from typing import Annotated +from uuid import UUID import numpy as np from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from common import auth import common.dependencies as DI +from common import auth from components.dbo.chunk_repository import ChunkRepository from components.services.entity import EntityService -from schemas.entity import (EntityNeighborsRequest, EntityNeighborsResponse, - EntitySearchRequest, EntitySearchResponse, - EntitySearchWithTextRequest, - EntitySearchWithTextResponse, EntityTextRequest, - EntityTextResponse) +from schemas.entity import ( + ChunkInfo, + EntitySearchRequest, + EntitySearchResponse, + EntitySearchWithTextRequest, + EntitySearchWithTextResponse, + EntityTextRequest, + EntityTextResponse, +) router = APIRouter(prefix="/entity", tags=["Entity"]) @@ -21,30 +26,30 @@ router = APIRouter(prefix="/entity", tags=["Entity"]) async def search_entities( request: EntitySearchRequest, entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], - current_user: Annotated[any, Depends(auth.get_current_user)] + current_user: Annotated[any, Depends(auth.get_current_user)], ) -> EntitySearchResponse: """ Поиск похожих сущностей по векторному сходству (только ID). - + Args: request: Параметры поиска entity_service: Сервис для работы с сущностями - + Returns: Результаты поиска (ID и оценки), отсортированные по убыванию сходства """ try: - _, scores, ids = entity_service.search_similar( + _, scores, ids = entity_service.search_similar_old( request.query, request.dataset_id, ) - + # Проверяем, что scores и ids - корректные numpy массивы if not isinstance(scores, np.ndarray): scores = np.array(scores) if not isinstance(ids, np.ndarray): ids = np.array(ids) - + # Сортируем результаты по убыванию оценок # Проверим, что массивы не пустые if len(scores) > 0: @@ -56,15 +61,14 @@ async def search_entities( else: sorted_scores = [] sorted_ids = [] - + return EntitySearchResponse( scores=sorted_scores, entity_ids=sorted_ids, ) except Exception as e: raise HTTPException( - status_code=500, - detail=f"Error during entity search: {str(e)}" + status_code=500, detail=f"Error during entity search: {str(e)}" ) @@ -72,60 +76,60 @@ async def search_entities( async def search_entities_with_text( request: EntitySearchWithTextRequest, entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], - current_user: Annotated[any, Depends(auth.get_current_user)] + current_user: Annotated[any, Depends(auth.get_current_user)], ) -> EntitySearchWithTextResponse: """ Поиск похожих сущностей по векторному сходству с возвратом текстов. - + Args: request: Параметры поиска entity_service: Сервис для работы с сущностями - + Returns: Результаты поиска с текстами чанков, отсортированные по убыванию сходства """ try: # Получаем результаты поиска - _, scores, entity_ids = entity_service.search_similar( - request.query, - request.dataset_id + _, scores, entity_ids = entity_service.search_similar_old( + request.query, request.dataset_id ) - + # Проверяем, что scores и entity_ids - корректные numpy массивы if not isinstance(scores, np.ndarray): scores = np.array(scores) if not isinstance(entity_ids, np.ndarray): entity_ids = np.array(entity_ids) - + # Сортируем результаты по убыванию оценок # Проверим, что массивы не пустые if len(scores) > 0: # Преобразуем индексы в список, чтобы избежать проблем с индексацией sorted_indices = scores.argsort()[::-1].tolist() sorted_scores = [float(scores[i]) for i in sorted_indices] - sorted_ids = [str(entity_ids[i]) for i in sorted_indices] # Преобразуем в строки - + sorted_ids = [UUID(entity_ids[i]) for i in sorted_indices] + # Получаем тексты чанков - chunks = entity_service.chunk_repository.get_chunks_by_ids(sorted_ids) - + chunks = entity_service.chunk_repository.get_entities_by_ids(sorted_ids) + # Формируем ответ return EntitySearchWithTextResponse( chunks=[ - { - "id": str(chunk.id), # Преобразуем UUID в строку - "text": chunk.text, - "score": score - } + ChunkInfo( + id=str(chunk.id), # Преобразуем UUID в строку + text=chunk.text, + score=score, + type=chunk.type, + in_search_text=chunk.in_search_text, + ) for chunk, score in zip(chunks, sorted_scores) ] ) else: return EntitySearchWithTextResponse(chunks=[]) - + except Exception as e: raise HTTPException( - status_code=500, - detail=f"Error during entity search with text: {str(e)}" + status_code=500, detail=f"Error during entity search with text: {str(e)}" ) @@ -133,84 +137,36 @@ async def search_entities_with_text( async def build_entity_text( request: EntityTextRequest, entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], - current_user: Annotated[any, Depends(auth.get_current_user)] + current_user: Annotated[any, Depends(auth.get_current_user)], ) -> EntityTextResponse: """ Сборка текста из сущностей. - + Args: request: Параметры сборки текста entity_service: Сервис для работы с сущностями - + Returns: Собранный текст """ try: - # Получаем объекты LinkerEntity по ID - entities = entity_service.chunk_repository.get_chunks_by_ids(request.entities) - - if not entities: + if not request.entities: raise HTTPException( - status_code=404, - detail="No entities found with provided IDs" + status_code=404, detail="No entities found with provided IDs" ) - + # Собираем текст text = entity_service.build_text( - entities=entities, + entities=request.entities, chunk_scores=request.chunk_scores, include_tables=request.include_tables, max_documents=request.max_documents, ) - - return EntityTextResponse(text=text) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error building entity text: {str(e)}" - ) - -@router.post("/neighbors", response_model=EntityNeighborsResponse) -async def get_neighboring_chunks( - request: EntityNeighborsRequest, - entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], - current_user: Annotated[any, Depends(auth.get_current_user)] -) -> EntityNeighborsResponse: - """ - Получение соседних чанков для заданных сущностей. - - Args: - request: Параметры запроса соседей - entity_service: Сервис для работы с сущностями - - Returns: - Список сущностей с соседями - """ - try: - # Получаем объекты LinkerEntity по ID - entities = entity_service.chunk_repository.get_chunks_by_ids(request.entities) - - if not entities: - raise HTTPException( - status_code=404, - detail="No entities found with provided IDs" - ) - - # Получаем соседние чанки - entities_with_neighbors = entity_service.add_neighboring_chunks( - entities, - max_distance=request.max_distance, - ) - - # Преобразуем LinkerEntity в строки - return EntityNeighborsResponse( - entities=[str(entity.id) for entity in entities_with_neighbors] - ) + return EntityTextResponse(text=text) except Exception as e: raise HTTPException( - status_code=500, - detail=f"Error getting neighboring chunks: {str(e)}" + status_code=500, detail=f"Error building entity text: {str(e)}" ) @@ -218,7 +174,7 @@ async def get_neighboring_chunks( async def get_entity_info( dataset_id: int, db: Annotated[Session, Depends(DI.get_db)], - current_user: Annotated[any, Depends(auth.get_current_user)] + current_user: Annotated[any, Depends(auth.get_current_user)], ) -> dict: """ Получить информацию о сущностях в датасете. @@ -231,40 +187,65 @@ async def get_entity_info( Returns: dict: Информация о сущностях """ + # Создаем репозиторий, передавая sessionmaker chunk_repository = ChunkRepository(db) - entities, embeddings = chunk_repository.get_searching_entities(dataset_id) - - if not entities: - raise HTTPException(status_code=404, detail=f"No entities found for dataset {dataset_id}") - + + # Получаем общее количество сущностей + total_entities_count = chunk_repository.count_entities_by_dataset_id(dataset_id) + + # Получаем сущности, готовые к поиску (с текстом и эмбеддингом) + searchable_entities, searchable_embeddings = ( + chunk_repository.get_searching_entities(dataset_id) + ) + + # Проверка, найдены ли сущности, готовые к поиску + # Можно оставить проверку, чтобы не возвращать пустые примеры, если таких нет, + # но основная ошибка 404 должна базироваться на total_entities_count + if total_entities_count == 0: + raise HTTPException( + status_code=404, detail=f"No entities found for dataset {dataset_id}" + ) + # Собираем статистику stats = { - "total_entities": len(entities), - "entities_with_embeddings": len([e for e in embeddings if e is not None]), - "embedding_shapes": [e.shape if e is not None else None for e in embeddings], - "unique_embedding_shapes": set(str(e.shape) if e is not None else None for e in embeddings), - "entity_types": set(e.type for e in entities), + "total_entities": total_entities_count, # Реальное общее число + "searchable_entities": len( + searchable_entities + ), # Число сущностей с текстом и эмбеддингом + "entities_with_embeddings": len( + [e for e in searchable_embeddings if e is not None] + ), + "embedding_shapes": [ + e.shape if e is not None else None for e in searchable_embeddings + ], + "unique_embedding_shapes": set( + str(e.shape) if e is not None else None for e in searchable_embeddings + ), + # Статистику по типам лучше считать на основе searchable_entities, т.к. для них есть объекты + "entity_types": set(e.type for e in searchable_entities), "entities_per_type": { - t: len([e for e in entities if e.type == t]) - for t in set(e.type for e in entities) - } + t: len([e for e in searchable_entities if e.type == t]) + for t in set(e.type for e in searchable_entities) + }, } - - # Примеры сущностей + + # Примеры сущностей берем из searchable_entities examples = [ { - "id": str(e.id), # Преобразуем UUID в строку + "id": str(e.id), "name": e.name, "type": e.type, - "has_embedding": embeddings[i] is not None, - "embedding_shape": str(embeddings[i].shape) if embeddings[i] is not None else None, + "has_embedding": searchable_embeddings[i] is not None, + "embedding_shape": ( + str(searchable_embeddings[i].shape) + if searchable_embeddings[i] is not None + else None + ), "text_length": len(e.text), - "in_search_text_length": len(e.in_search_text) if e.in_search_text else 0 + "in_search_text_length": len(e.in_search_text) if e.in_search_text else 0, } - for i, e in enumerate(entities[:5]) # Берем только первые 5 для примера + # Берем примеры из сущностей, готовых к поиску + for i, e in enumerate(searchable_entities[:5]) ] - - return { - "stats": stats, - "examples": examples - } \ No newline at end of file + + return {"stats": stats, "examples": examples} diff --git a/routes/llm.py b/routes/llm.py index 7bfea2eb9939dc3ffd156e5e05818f045a6f6903..88d5c77a029cc8f0f44c33ef1c250f558c1e68bf 100644 --- a/routes/llm.py +++ b/routes/llm.py @@ -2,21 +2,20 @@ import json import logging import os from typing import Annotated, AsyncGenerator, List, Optional -from uuid import UUID -from common import auth -from components.services.dialogue import DialogueService -from fastapi.responses import StreamingResponse - -from components.services.dataset import DatasetService -from components.services.entity import EntityService from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse import common.dependencies as DI -from common.configuration import Configuration, Query -from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message +from common import auth +from common.configuration import Configuration +from components.llm.common import (ChatRequest, LlmParams, LlmPredictParams, + Message) from components.llm.deepinfra_api import DeepInfraApi from components.llm.utils import append_llm_response_to_history +from components.services.dataset import DatasetService +from components.services.dialogue import DialogueService +from components.services.entity import EntityService from components.services.llm_config import LLMConfigService from components.services.llm_prompt import LlmPromptService @@ -71,13 +70,16 @@ def insert_search_results_to_message( return False def try_insert_search_results( - chat_request: ChatRequest, search_results: str, entities: List[str] + chat_request: ChatRequest, search_results: List[str], entities: List[List[str]] ) -> bool: + i = 0 for msg in reversed(chat_request.history): if msg.role == "user" and not msg.searchResults: - msg.searchResults = search_results - msg.searchEntities = entities - return True + msg.searchResults = search_results[i] + msg.searchEntities = entities[i] + i += 1 + if i == len(search_results): + return True return False def collapse_history_to_first_message(chat_request: ChatRequest) -> ChatRequest: @@ -132,21 +134,25 @@ async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prom dataset = dataset_service.get_current_dataset() if dataset is None: raise HTTPException(status_code=400, detail="Dataset not found") - _, scores, chunk_ids = entity_service.search_similar(qe_result.search_query, dataset.id) - chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids) - text_chunks = entity_service.build_text(chunks, scores) + previous_entities = [msg.searchEntities for msg in request.history if msg.searchEntities is not None] + previous_entities, chunk_ids, scores = entity_service.search_similar(qe_result.search_query, + dataset.id, previous_entities) + text_chunks = entity_service.build_text(chunk_ids, scores) + all_text_chunks = [text_chunks] + [entity_service.build_text(entities) for entities in previous_entities] + all_entities = [chunk_ids] + previous_entities + search_results_event = { "event": "search_results", "data": { "text": text_chunks, - "ids": chunk_ids.tolist() + "ids": chunk_ids } } yield f"data: {json.dumps(search_results_event, ensure_ascii=False)}\n\n" # new_message = f'\n{text_chunks}\n\n{last_query.content}' - try_insert_search_results(request, text_chunks, chunk_ids.tolist()) + try_insert_search_results(request, all_text_chunks, all_entities) except Exception as e: logger.error(f"Error in SSE chat stream while searching: {str(e)}", stack_info=True) yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" @@ -245,9 +251,12 @@ async def chat( if dataset is None: raise HTTPException(status_code=400, detail="Dataset not found") logger.info(f"qe_result.search_query: {qe_result.search_query}") - _, scores, chunk_ids = entity_service.search_similar(qe_result.search_query, dataset.id) + previous_entities = [msg.searchEntities for msg in request.history] + previous_entities, chunk_ids, scores = entity_service.search_similar( + qe_result.search_query, dataset.id, previous_entities + ) - chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids) + chunks = entity_service.chunk_repository.get_entities_by_ids(chunk_ids) logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}") logger.info(f"scores: {scores[:3]}...{scores[-3:]}") diff --git a/schemas/entity.py b/schemas/entity.py index 3775bd163099dbc3c6a21502c1c383c521b1539f..b37619df266652c9a33b7a259c11e5963b2590b0 100644 --- a/schemas/entity.py +++ b/schemas/entity.py @@ -26,6 +26,8 @@ class ChunkInfo(BaseModel): id: str text: str score: float + type: str + in_search_text: str class EntitySearchWithTextResponse(BaseModel):