Spaces:
Runtime error
Runtime error
| from uuid import UUID | |
| import numpy as np | |
| from ntr_text_fragmentation import LinkerEntity | |
| from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository | |
| from sqlalchemy import and_, select | |
| from sqlalchemy.orm import Session | |
| from components.dbo.models.entity import EntityModel | |
| class ChunkRepository(SQLAlchemyEntityRepository): | |
| def __init__(self, db: Session): | |
| super().__init__(db) | |
| def _entity_model_class(self): | |
| return EntityModel | |
| def _map_db_entity_to_linker_entity(self, db_entity: EntityModel): | |
| """ | |
| Преобразует сущность из базы данных в LinkerEntity. | |
| Args: | |
| db_entity: Сущность из базы данных | |
| Returns: | |
| LinkerEntity | |
| """ | |
| # Преобразуем строковые ID в UUID | |
| entity = LinkerEntity( | |
| id=UUID(db_entity.uuid), # Преобразуем строку в UUID | |
| name=db_entity.name, | |
| text=db_entity.text, | |
| type=db_entity.entity_type, | |
| in_search_text=db_entity.in_search_text, | |
| metadata=db_entity.metadata_json, | |
| source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID | |
| target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID | |
| number_in_relation=db_entity.number_in_relation, | |
| ) | |
| return LinkerEntity.deserialize(entity) | |
| def add_entities( | |
| self, | |
| entities: list[LinkerEntity], | |
| dataset_id: int, | |
| embeddings: dict[str, np.ndarray], | |
| ): | |
| """ | |
| Добавляет сущности в базу данных. | |
| Args: | |
| entities: Список сущностей для добавления | |
| dataset_id: ID датасета | |
| embeddings: Словарь эмбеддингов {entity_id: embedding} | |
| """ | |
| with self.db() as session: | |
| for entity in entities: | |
| # Преобразуем UUID в строку для хранения в базе | |
| entity_id = str(entity.id) | |
| if entity_id in embeddings: | |
| embedding = embeddings[entity_id] | |
| else: | |
| embedding = None | |
| session.add( | |
| EntityModel( | |
| uuid=str(entity.id), # UUID в строку | |
| name=entity.name, | |
| text=entity.text, | |
| entity_type=entity.type, | |
| in_search_text=entity.in_search_text, | |
| metadata_json=entity.metadata, | |
| source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку | |
| target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку | |
| number_in_relation=entity.number_in_relation, | |
| chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index | |
| dataset_id=dataset_id, | |
| embedding=embedding, | |
| ) | |
| ) | |
| session.commit() | |
| def get_searching_entities( | |
| self, | |
| dataset_id: int, | |
| ) -> tuple[list[LinkerEntity], list[np.ndarray]]: | |
| with self.db() as session: | |
| models = ( | |
| session.query(EntityModel) | |
| .filter(EntityModel.in_search_text is not None) | |
| .filter(EntityModel.dataset_id == dataset_id) | |
| .all() | |
| ) | |
| return ( | |
| [self._map_db_entity_to_linker_entity(model) for model in models], | |
| [model.embedding for model in models], | |
| ) | |
| def get_chunks_by_ids( | |
| self, | |
| chunk_ids: list[str], | |
| ) -> list[LinkerEntity]: | |
| """ | |
| Получение чанков по их ID. | |
| Args: | |
| chunk_ids: Список ID чанков | |
| Returns: | |
| Список чанков | |
| """ | |
| # Преобразуем все ID в строки для единообразия | |
| str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] | |
| with self.db() as session: | |
| models = ( | |
| session.query(EntityModel) | |
| .filter(EntityModel.uuid.in_(str_chunk_ids)) | |
| .all() | |
| ) | |
| return [self._map_db_entity_to_linker_entity(model) for model in models] | |
| def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]: | |
| """ | |
| Получить сущности по списку идентификаторов. | |
| Args: | |
| entity_ids: Список идентификаторов сущностей | |
| Returns: | |
| Список сущностей, соответствующих указанным идентификаторам | |
| """ | |
| if not entity_ids: | |
| return [] | |
| # Преобразуем UUID в строки | |
| str_entity_ids = [str(entity_id) for entity_id in entity_ids] | |
| with self.db() as session: | |
| entity_model = self._entity_model_class() | |
| db_entities = session.execute( | |
| select(entity_model).where(entity_model.uuid.in_(str_entity_ids)) | |
| ).scalars().all() | |
| return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities] | |
| def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]: | |
| """ | |
| Получить соседние чанки для указанных чанков. | |
| Args: | |
| chunk_ids: Список идентификаторов чанков | |
| max_distance: Максимальное расстояние до соседа | |
| Returns: | |
| Список соседних чанков | |
| """ | |
| if not chunk_ids: | |
| return [] | |
| # Преобразуем UUID в строки | |
| str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] | |
| with self.db() as session: | |
| entity_model = self._entity_model_class() | |
| result = [] | |
| # Сначала получаем указанные чанки, чтобы узнать их индексы и документы | |
| chunks = session.execute( | |
| select(entity_model).where( | |
| and_( | |
| entity_model.uuid.in_(str_chunk_ids), | |
| entity_model.entity_type == "Chunk" # Используем entity_type вместо type | |
| ) | |
| ) | |
| ).scalars().all() | |
| if not chunks: | |
| return [] | |
| # Находим документы для чанков через связи | |
| doc_ids = set() | |
| chunk_indices = {} | |
| for chunk in chunks: | |
| chunk_indices[chunk.uuid] = chunk.chunk_index | |
| # Находим связь от документа к чанку | |
| links = session.execute( | |
| select(entity_model).where( | |
| and_( | |
| entity_model.target_id == chunk.uuid, | |
| entity_model.name == "document_to_chunk" | |
| ) | |
| ) | |
| ).scalars().all() | |
| for link in links: | |
| doc_ids.add(link.source_id) | |
| if not doc_ids or not any(idx is not None for idx in chunk_indices.values()): | |
| return [] | |
| # Для каждого документа находим все его чанки | |
| for doc_id in doc_ids: | |
| # Находим все связи от документа к чанкам | |
| links = session.execute( | |
| select(entity_model).where( | |
| and_( | |
| entity_model.source_id == doc_id, | |
| entity_model.name == "document_to_chunk" | |
| ) | |
| ) | |
| ).scalars().all() | |
| doc_chunk_ids = [link.target_id for link in links] | |
| # Получаем все чанки документа | |
| doc_chunks = session.execute( | |
| select(entity_model).where( | |
| and_( | |
| entity_model.uuid.in_(doc_chunk_ids), | |
| entity_model.entity_type == "Chunk" # Используем entity_type вместо type | |
| ) | |
| ) | |
| ).scalars().all() | |
| # Для каждого чанка в документе проверяем, является ли он соседом | |
| for doc_chunk in doc_chunks: | |
| if doc_chunk.uuid in str_chunk_ids: | |
| continue | |
| if doc_chunk.chunk_index is None: | |
| continue | |
| # Проверяем, является ли чанк соседом какого-либо из исходных чанков | |
| is_neighbor = False | |
| for orig_chunk_id, orig_index in chunk_indices.items(): | |
| if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance: | |
| is_neighbor = True | |
| break | |
| if is_neighbor: | |
| result.append(self._map_db_entity_to_linker_entity(doc_chunk)) | |
| return result | |