from uuid import UUID import numpy as np from ntr_text_fragmentation import LinkerEntity from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository from sqlalchemy import and_, select from sqlalchemy.orm import Session from components.dbo.models.entity import EntityModel class ChunkRepository(SQLAlchemyEntityRepository): def __init__(self, db: Session): super().__init__(db) def _entity_model_class(self): return EntityModel def _map_db_entity_to_linker_entity(self, db_entity: EntityModel): """ Преобразует сущность из базы данных в LinkerEntity. Args: db_entity: Сущность из базы данных Returns: LinkerEntity """ # Преобразуем строковые ID в UUID entity = LinkerEntity( id=UUID(db_entity.uuid), # Преобразуем строку в UUID name=db_entity.name, text=db_entity.text, type=db_entity.entity_type, in_search_text=db_entity.in_search_text, metadata=db_entity.metadata_json, source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID number_in_relation=db_entity.number_in_relation, ) return LinkerEntity.deserialize(entity) def add_entities( self, entities: list[LinkerEntity], dataset_id: int, embeddings: dict[str, np.ndarray], ): """ Добавляет сущности в базу данных. Args: entities: Список сущностей для добавления dataset_id: ID датасета embeddings: Словарь эмбеддингов {entity_id: embedding} """ with self.db() as session: for entity in entities: # Преобразуем UUID в строку для хранения в базе entity_id = str(entity.id) if entity_id in embeddings: embedding = embeddings[entity_id] else: embedding = None session.add( EntityModel( uuid=str(entity.id), # UUID в строку name=entity.name, text=entity.text, entity_type=entity.type, in_search_text=entity.in_search_text, metadata_json=entity.metadata, source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку number_in_relation=entity.number_in_relation, chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index dataset_id=dataset_id, embedding=embedding, ) ) session.commit() def get_searching_entities( self, dataset_id: int, ) -> tuple[list[LinkerEntity], list[np.ndarray]]: with self.db() as session: models = ( session.query(EntityModel) .filter(EntityModel.in_search_text is not None) .filter(EntityModel.dataset_id == dataset_id) .all() ) return ( [self._map_db_entity_to_linker_entity(model) for model in models], [model.embedding for model in models], ) def get_chunks_by_ids( self, chunk_ids: list[str], ) -> list[LinkerEntity]: """ Получение чанков по их ID. Args: chunk_ids: Список ID чанков Returns: Список чанков """ # Преобразуем все ID в строки для единообразия str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] with self.db() as session: models = ( session.query(EntityModel) .filter(EntityModel.uuid.in_(str_chunk_ids)) .all() ) return [self._map_db_entity_to_linker_entity(model) for model in models] def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]: """ Получить сущности по списку идентификаторов. Args: entity_ids: Список идентификаторов сущностей Returns: Список сущностей, соответствующих указанным идентификаторам """ if not entity_ids: return [] # Преобразуем UUID в строки str_entity_ids = [str(entity_id) for entity_id in entity_ids] with self.db() as session: entity_model = self._entity_model_class() db_entities = session.execute( select(entity_model).where(entity_model.uuid.in_(str_entity_ids)) ).scalars().all() return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities] def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]: """ Получить соседние чанки для указанных чанков. Args: chunk_ids: Список идентификаторов чанков max_distance: Максимальное расстояние до соседа Returns: Список соседних чанков """ if not chunk_ids: return [] # Преобразуем UUID в строки str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] with self.db() as session: entity_model = self._entity_model_class() result = [] # Сначала получаем указанные чанки, чтобы узнать их индексы и документы chunks = session.execute( select(entity_model).where( and_( entity_model.uuid.in_(str_chunk_ids), entity_model.entity_type == "Chunk" # Используем entity_type вместо type ) ) ).scalars().all() if not chunks: return [] # Находим документы для чанков через связи doc_ids = set() chunk_indices = {} for chunk in chunks: chunk_indices[chunk.uuid] = chunk.chunk_index # Находим связь от документа к чанку links = session.execute( select(entity_model).where( and_( entity_model.target_id == chunk.uuid, entity_model.name == "document_to_chunk" ) ) ).scalars().all() for link in links: doc_ids.add(link.source_id) if not doc_ids or not any(idx is not None for idx in chunk_indices.values()): return [] # Для каждого документа находим все его чанки for doc_id in doc_ids: # Находим все связи от документа к чанкам links = session.execute( select(entity_model).where( and_( entity_model.source_id == doc_id, entity_model.name == "document_to_chunk" ) ) ).scalars().all() doc_chunk_ids = [link.target_id for link in links] # Получаем все чанки документа doc_chunks = session.execute( select(entity_model).where( and_( entity_model.uuid.in_(doc_chunk_ids), entity_model.entity_type == "Chunk" # Используем entity_type вместо type ) ) ).scalars().all() # Для каждого чанка в документе проверяем, является ли он соседом for doc_chunk in doc_chunks: if doc_chunk.uuid in str_chunk_ids: continue if doc_chunk.chunk_index is None: continue # Проверяем, является ли чанк соседом какого-либо из исходных чанков is_neighbor = False for orig_chunk_id, orig_index in chunk_indices.items(): if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance: is_neighbor = True break if is_neighbor: result.append(self._map_db_entity_to_linker_entity(doc_chunk)) return result