import asyncio import logging from uuid import UUID import numpy as np from ntr_text_fragmentation import LinkerEntity from ntr_text_fragmentation.integrations.sqlalchemy import \ SQLAlchemyEntityRepository from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker from components.dbo.models.entity import EntityModel logger = logging.getLogger(__name__) class ChunkRepository(SQLAlchemyEntityRepository): """ Репозиторий для работы с сущностями (чанками, документами, связями), хранящимися в базе данных с использованием SQL Alchemy. Наследуется от SQLAlchemyEntityRepository, предоставляя конкретную реализацию для модели EntityModel. """ def __init__(self, db_session_factory: sessionmaker[Session]): """ Инициализация репозитория. Args: db_session_factory: Фабрика сессий SQLAlchemy. """ super().__init__(db_session_factory) @property def _entity_model_class(self): """Возвращает класс модели SQLAlchemy.""" return EntityModel def _map_db_entity_to_linker_entity(self, db_entity: EntityModel) -> LinkerEntity: """ Преобразует объект EntityModel из базы данных в объект LinkerEntity или его соответствующий подкласс. Args: db_entity: Сущность EntityModel из базы данных. Returns: Объект LinkerEntity или его подкласс. """ # Создаем базовый LinkerEntity со всеми данными из БД # Преобразуем строковые UUID обратно в объекты UUID base_data = LinkerEntity( id=UUID(db_entity.uuid), name=db_entity.name, text=db_entity.text, in_search_text=db_entity.in_search_text, metadata=db_entity.metadata_json or {}, source_id=UUID(db_entity.source_id) if db_entity.source_id else None, target_id=UUID(db_entity.target_id) if db_entity.target_id else None, number_in_relation=db_entity.number_in_relation, type=db_entity.entity_type, groupper=db_entity.entity_type, ) # Используем LinkerEntity._deserialize для получения объекта нужного типа # на основе поля 'type', взятого из db_entity.entity_type try: deserialized_entity = base_data.deserialize() return deserialized_entity except Exception as e: logger.error( f"Error deserializing entity {base_data.id} of type {base_data.type}: {e}" ) return base_data def add_entities( self, entities: list[LinkerEntity], dataset_id: int, embeddings: dict[str, np.ndarray] | None = None, ): """ Добавляет список сущностей LinkerEntity в базу данных. Args: entities: Список сущностей LinkerEntity для добавления. dataset_id: ID датасета, к которому принадлежат сущности. embeddings: Словарь эмбеддингов {entity_id_str: embedding}, где entity_id_str - строка UUID. """ embeddings = embeddings or {} with self.db() as session: db_entities_to_add = [] for entity in entities: # Преобразуем UUID в строку для хранения в базе entity_id_str = str(entity.id) embedding = embeddings.get(entity_id_str) db_entity = EntityModel( uuid=entity_id_str, name=entity.name, text=entity.text, entity_type=entity.type, in_search_text=entity.in_search_text, metadata_json=( entity.metadata if isinstance(entity.metadata, dict) else {} ), source_id=str(entity.source_id) if entity.source_id else None, target_id=str(entity.target_id) if entity.target_id else None, number_in_relation=entity.number_in_relation, dataset_id=dataset_id, embedding=embedding, ) db_entities_to_add.append(db_entity) session.add_all(db_entities_to_add) session.commit() async def add_entities_async( self, entities: list[LinkerEntity], dataset_id: int, embeddings: dict[str, np.ndarray] | None = None, ): """Асинхронно добавляет список сущностей LinkerEntity в базу данных.""" # TODO: Реализовать с использованием async-сессии await asyncio.to_thread(self.add_entities, entities, dataset_id, embeddings) def get_searching_entities( self, dataset_id: int, ) -> tuple[list[LinkerEntity], list[np.ndarray]]: """ Получает сущности из указанного датасета, которые имеют текст для поиска (in_search_text не None), вместе с их эмбеддингами. Args: dataset_id: ID датасета. Returns: Кортеж из двух списков: список LinkerEntity и список их эмбеддингов (numpy array). Порядок эмбеддингов соответствует порядку сущностей. """ entity_model = self._entity_model_class linker_entities = [] embeddings_list = [] with self.db() as session: stmt = select(entity_model).where( entity_model.in_search_text.isnot(None), entity_model.dataset_id == dataset_id, entity_model.embedding.isnot(None) ) db_models = session.execute(stmt).scalars().all() # Переносим цикл внутрь сессии for model in db_models: # Теперь маппинг происходит при активной сессии linker_entity = self._map_db_entity_to_linker_entity(model) linker_entities.append(linker_entity) # Извлекаем эмбеддинг. # _map_db_entity_to_linker_entity может поместить его в метаданные. embedding = linker_entity.metadata.get('_embedding') if embedding is None and hasattr(model, 'embedding'): # Fallback embedding = model.embedding # Доступ к model.embedding тоже должен быть внутри сессии if embedding is not None: embeddings_list.append(embedding) else: # Обработка случая отсутствия эмбеддинга print(f"Warning: Entity {model.uuid} has in_search_text but no embedding.") linker_entities.pop() # Возвращаем результаты после закрытия сессии return linker_entities, embeddings_list async def get_searching_entities_async( self, dataset_id: int, ) -> tuple[list[LinkerEntity], list[np.ndarray]]: """Асинхронно получает сущности для поиска вместе с эмбеддингами.""" # TODO: Реализовать с использованием async-сессии return await asyncio.to_thread(self.get_searching_entities, dataset_id) def get_all_entities_for_dataset(self, dataset_id: int) -> list[LinkerEntity]: """ Получает все сущности для указанного датасета. Args: dataset_id: ID датасета. Returns: Список всех LinkerEntity для данного датасета. """ entity_model = self._entity_model_class linker_entities = [] with self.db() as session: stmt = select(entity_model).where( entity_model.dataset_id == dataset_id ) db_models = session.execute(stmt).scalars().all() # Переносим цикл внутрь сессии для маппинга for model in db_models: try: linker_entity = self._map_db_entity_to_linker_entity(model) linker_entities.append(linker_entity) except Exception as e: logger.error(f"Error mapping entity {getattr(model, 'uuid', 'N/A')} in dataset {dataset_id}: {e}") logger.info(f"Loaded {len(linker_entities)} entities for dataset {dataset_id}") return linker_entities async def get_all_entities_for_dataset_async(self, dataset_id: int) -> list[LinkerEntity]: """Асинхронно получает все сущности для указанного датасета.""" # TODO: Реализовать с использованием async-сессии return await asyncio.to_thread(self.get_all_entities_for_dataset, dataset_id) def count_entities_by_dataset_id(self, dataset_id: int) -> int: """ Подсчитывает общее количество сущностей для указанного датасета. Args: dataset_id: ID датасета. Returns: Общее количество сущностей в датасете. """ entity_model = self._entity_model_class id_column = self._get_id_column() # Получаем колонку ID (uuid или id) with self.db() as session: stmt = select(func.count(id_column)).where( entity_model.dataset_id == dataset_id ) count = session.execute(stmt).scalar_one() return count async def count_entities_by_dataset_id_async(self, dataset_id: int) -> int: """Асинхронно подсчитывает общее количество сущностей для датасета.""" # TODO: Реализовать с использованием async-сессии return await asyncio.to_thread(self.count_entities_by_dataset_id, dataset_id) async def get_entities_by_ids_async(self, entity_ids: list[UUID]) -> list[LinkerEntity]: """Асинхронно получить сущности по списку ID.""" # TODO: Реализовать с использованием async-сессии return await asyncio.to_thread(self.get_entities_by_ids, entity_ids)