muryshev's picture
update
86c402d
raw
history blame
10.1 kB
from uuid import UUID
import numpy as np
from ntr_text_fragmentation import LinkerEntity
from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository
from sqlalchemy import and_, select
from sqlalchemy.orm import Session
from components.dbo.models.entity import EntityModel
class ChunkRepository(SQLAlchemyEntityRepository):
def __init__(self, db: Session):
super().__init__(db)
def _entity_model_class(self):
return EntityModel
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel):
"""
Преобразует сущность из базы данных в LinkerEntity.
Args:
db_entity: Сущность из базы данных
Returns:
LinkerEntity
"""
# Преобразуем строковые ID в UUID
entity = LinkerEntity(
id=UUID(db_entity.uuid), # Преобразуем строку в UUID
name=db_entity.name,
text=db_entity.text,
type=db_entity.entity_type,
in_search_text=db_entity.in_search_text,
metadata=db_entity.metadata_json,
source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID
target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID
number_in_relation=db_entity.number_in_relation,
)
return LinkerEntity.deserialize(entity)
def add_entities(
self,
entities: list[LinkerEntity],
dataset_id: int,
embeddings: dict[str, np.ndarray],
):
"""
Добавляет сущности в базу данных.
Args:
entities: Список сущностей для добавления
dataset_id: ID датасета
embeddings: Словарь эмбеддингов {entity_id: embedding}
"""
with self.db() as session:
for entity in entities:
# Преобразуем UUID в строку для хранения в базе
entity_id = str(entity.id)
if entity_id in embeddings:
embedding = embeddings[entity_id]
else:
embedding = None
session.add(
EntityModel(
uuid=str(entity.id), # UUID в строку
name=entity.name,
text=entity.text,
entity_type=entity.type,
in_search_text=entity.in_search_text,
metadata_json=entity.metadata,
source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку
target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку
number_in_relation=entity.number_in_relation,
chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index
dataset_id=dataset_id,
embedding=embedding,
)
)
session.commit()
def get_searching_entities(
self,
dataset_id: int,
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
with self.db() as session:
models = (
session.query(EntityModel)
.filter(EntityModel.in_search_text is not None)
.filter(EntityModel.dataset_id == dataset_id)
.all()
)
return (
[self._map_db_entity_to_linker_entity(model) for model in models],
[model.embedding for model in models],
)
def get_chunks_by_ids(
self,
chunk_ids: list[str],
) -> list[LinkerEntity]:
"""
Получение чанков по их ID.
Args:
chunk_ids: Список ID чанков
Returns:
Список чанков
"""
# Преобразуем все ID в строки для единообразия
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
with self.db() as session:
models = (
session.query(EntityModel)
.filter(EntityModel.uuid.in_(str_chunk_ids))
.all()
)
return [self._map_db_entity_to_linker_entity(model) for model in models]
def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]:
"""
Получить сущности по списку идентификаторов.
Args:
entity_ids: Список идентификаторов сущностей
Returns:
Список сущностей, соответствующих указанным идентификаторам
"""
if not entity_ids:
return []
# Преобразуем UUID в строки
str_entity_ids = [str(entity_id) for entity_id in entity_ids]
with self.db() as session:
entity_model = self._entity_model_class()
db_entities = session.execute(
select(entity_model).where(entity_model.uuid.in_(str_entity_ids))
).scalars().all()
return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]:
"""
Получить соседние чанки для указанных чанков.
Args:
chunk_ids: Список идентификаторов чанков
max_distance: Максимальное расстояние до соседа
Returns:
Список соседних чанков
"""
if not chunk_ids:
return []
# Преобразуем UUID в строки
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
with self.db() as session:
entity_model = self._entity_model_class()
result = []
# Сначала получаем указанные чанки, чтобы узнать их индексы и документы
chunks = session.execute(
select(entity_model).where(
and_(
entity_model.uuid.in_(str_chunk_ids),
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
)
)
).scalars().all()
if not chunks:
return []
# Находим документы для чанков через связи
doc_ids = set()
chunk_indices = {}
for chunk in chunks:
chunk_indices[chunk.uuid] = chunk.chunk_index
# Находим связь от документа к чанку
links = session.execute(
select(entity_model).where(
and_(
entity_model.target_id == chunk.uuid,
entity_model.name == "document_to_chunk"
)
)
).scalars().all()
for link in links:
doc_ids.add(link.source_id)
if not doc_ids or not any(idx is not None for idx in chunk_indices.values()):
return []
# Для каждого документа находим все его чанки
for doc_id in doc_ids:
# Находим все связи от документа к чанкам
links = session.execute(
select(entity_model).where(
and_(
entity_model.source_id == doc_id,
entity_model.name == "document_to_chunk"
)
)
).scalars().all()
doc_chunk_ids = [link.target_id for link in links]
# Получаем все чанки документа
doc_chunks = session.execute(
select(entity_model).where(
and_(
entity_model.uuid.in_(doc_chunk_ids),
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
)
)
).scalars().all()
# Для каждого чанка в документе проверяем, является ли он соседом
for doc_chunk in doc_chunks:
if doc_chunk.uuid in str_chunk_ids:
continue
if doc_chunk.chunk_index is None:
continue
# Проверяем, является ли чанк соседом какого-либо из исходных чанков
is_neighbor = False
for orig_chunk_id, orig_index in chunk_indices.items():
if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance:
is_neighbor = True
break
if is_neighbor:
result.append(self._map_db_entity_to_linker_entity(doc_chunk))
return result