muryshev's picture
update
be03119
raw
history blame
11.4 kB
import asyncio
import logging
from uuid import UUID
import numpy as np
from ntr_text_fragmentation import LinkerEntity
from ntr_text_fragmentation.integrations.sqlalchemy import \
SQLAlchemyEntityRepository
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
from components.dbo.models.entity import EntityModel
logger = logging.getLogger(__name__)
class ChunkRepository(SQLAlchemyEntityRepository):
"""
Репозиторий для работы с сущностями (чанками, документами, связями),
хранящимися в базе данных с использованием SQL Alchemy.
Наследуется от SQLAlchemyEntityRepository, предоставляя конкретную реализацию
для модели EntityModel.
"""
def __init__(self, db_session_factory: sessionmaker[Session]):
"""
Инициализация репозитория.
Args:
db_session_factory: Фабрика сессий SQLAlchemy.
"""
super().__init__(db_session_factory)
@property
def _entity_model_class(self):
"""Возвращает класс модели SQLAlchemy."""
return EntityModel
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel) -> LinkerEntity:
"""
Преобразует объект EntityModel из базы данных в объект LinkerEntity
или его соответствующий подкласс.
Args:
db_entity: Сущность EntityModel из базы данных.
Returns:
Объект LinkerEntity или его подкласс.
"""
# Создаем базовый LinkerEntity со всеми данными из БД
# Преобразуем строковые UUID обратно в объекты UUID
base_data = LinkerEntity(
id=UUID(db_entity.uuid),
name=db_entity.name,
text=db_entity.text,
in_search_text=db_entity.in_search_text,
metadata=db_entity.metadata_json or {},
source_id=UUID(db_entity.source_id) if db_entity.source_id else None,
target_id=UUID(db_entity.target_id) if db_entity.target_id else None,
number_in_relation=db_entity.number_in_relation,
type=db_entity.entity_type,
groupper=db_entity.entity_type,
)
# Используем LinkerEntity._deserialize для получения объекта нужного типа
# на основе поля 'type', взятого из db_entity.entity_type
try:
deserialized_entity = base_data.deserialize()
return deserialized_entity
except Exception as e:
logger.error(
f"Error deserializing entity {base_data.id} of type {base_data.type}: {e}"
)
return base_data
def add_entities(
self,
entities: list[LinkerEntity],
dataset_id: int,
embeddings: dict[str, np.ndarray] | None = None,
):
"""
Добавляет список сущностей LinkerEntity в базу данных.
Args:
entities: Список сущностей LinkerEntity для добавления.
dataset_id: ID датасета, к которому принадлежат сущности.
embeddings: Словарь эмбеддингов {entity_id_str: embedding}, где entity_id_str - строка UUID.
"""
embeddings = embeddings or {}
with self.db() as session:
db_entities_to_add = []
for entity in entities:
# Преобразуем UUID в строку для хранения в базе
entity_id_str = str(entity.id)
embedding = embeddings.get(entity_id_str)
db_entity = EntityModel(
uuid=entity_id_str,
name=entity.name,
text=entity.text,
entity_type=entity.type,
in_search_text=entity.in_search_text,
metadata_json=(
entity.metadata if isinstance(entity.metadata, dict) else {}
),
source_id=str(entity.source_id) if entity.source_id else None,
target_id=str(entity.target_id) if entity.target_id else None,
number_in_relation=entity.number_in_relation,
dataset_id=dataset_id,
embedding=embedding,
)
db_entities_to_add.append(db_entity)
session.add_all(db_entities_to_add)
session.commit()
async def add_entities_async(
self,
entities: list[LinkerEntity],
dataset_id: int,
embeddings: dict[str, np.ndarray] | None = None,
):
"""Асинхронно добавляет список сущностей LinkerEntity в базу данных."""
# TODO: Реализовать с использованием async-сессии
await asyncio.to_thread(self.add_entities, entities, dataset_id, embeddings)
def get_searching_entities(
self,
dataset_id: int,
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
"""
Получает сущности из указанного датасета, которые имеют текст для поиска
(in_search_text не None), вместе с их эмбеддингами.
Args:
dataset_id: ID датасета.
Returns:
Кортеж из двух списков: список LinkerEntity и список их эмбеддингов (numpy array).
Порядок эмбеддингов соответствует порядку сущностей.
"""
entity_model = self._entity_model_class
linker_entities = []
embeddings_list = []
with self.db() as session:
stmt = select(entity_model).where(
entity_model.in_search_text.isnot(None),
entity_model.dataset_id == dataset_id,
entity_model.embedding.isnot(None)
)
db_models = session.execute(stmt).scalars().all()
# Переносим цикл внутрь сессии
for model in db_models:
# Теперь маппинг происходит при активной сессии
linker_entity = self._map_db_entity_to_linker_entity(model)
linker_entities.append(linker_entity)
# Извлекаем эмбеддинг.
# _map_db_entity_to_linker_entity может поместить его в метаданные.
embedding = linker_entity.metadata.get('_embedding')
if embedding is None and hasattr(model, 'embedding'): # Fallback
embedding = model.embedding # Доступ к model.embedding тоже должен быть внутри сессии
if embedding is not None:
embeddings_list.append(embedding)
else:
# Обработка случая отсутствия эмбеддинга
print(f"Warning: Entity {model.uuid} has in_search_text but no embedding.")
linker_entities.pop()
# Возвращаем результаты после закрытия сессии
return linker_entities, embeddings_list
async def get_searching_entities_async(
self,
dataset_id: int,
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
"""Асинхронно получает сущности для поиска вместе с эмбеддингами."""
# TODO: Реализовать с использованием async-сессии
return await asyncio.to_thread(self.get_searching_entities, dataset_id)
def get_all_entities_for_dataset(self, dataset_id: int) -> list[LinkerEntity]:
"""
Получает все сущности для указанного датасета.
Args:
dataset_id: ID датасета.
Returns:
Список всех LinkerEntity для данного датасета.
"""
entity_model = self._entity_model_class
linker_entities = []
with self.db() as session:
stmt = select(entity_model).where(
entity_model.dataset_id == dataset_id
)
db_models = session.execute(stmt).scalars().all()
# Переносим цикл внутрь сессии для маппинга
for model in db_models:
try:
linker_entity = self._map_db_entity_to_linker_entity(model)
linker_entities.append(linker_entity)
except Exception as e:
logger.error(f"Error mapping entity {getattr(model, 'uuid', 'N/A')} in dataset {dataset_id}: {e}")
logger.info(f"Loaded {len(linker_entities)} entities for dataset {dataset_id}")
return linker_entities
async def get_all_entities_for_dataset_async(self, dataset_id: int) -> list[LinkerEntity]:
"""Асинхронно получает все сущности для указанного датасета."""
# TODO: Реализовать с использованием async-сессии
return await asyncio.to_thread(self.get_all_entities_for_dataset, dataset_id)
def count_entities_by_dataset_id(self, dataset_id: int) -> int:
"""
Подсчитывает общее количество сущностей для указанного датасета.
Args:
dataset_id: ID датасета.
Returns:
Общее количество сущностей в датасете.
"""
entity_model = self._entity_model_class
id_column = self._get_id_column() # Получаем колонку ID (uuid или id)
with self.db() as session:
stmt = select(func.count(id_column)).where(
entity_model.dataset_id == dataset_id
)
count = session.execute(stmt).scalar_one()
return count
async def count_entities_by_dataset_id_async(self, dataset_id: int) -> int:
"""Асинхронно подсчитывает общее количество сущностей для датасета."""
# TODO: Реализовать с использованием async-сессии
return await asyncio.to_thread(self.count_entities_by_dataset_id, dataset_id)
async def get_entities_by_ids_async(self, entity_ids: list[UUID]) -> list[LinkerEntity]:
"""Асинхронно получить сущности по списку ID."""
# TODO: Реализовать с использованием async-сессии
return await asyncio.to_thread(self.get_entities_by_ids, entity_ids)