Spaces:
Sleeping
Sleeping
from uuid import UUID | |
import numpy as np | |
from ntr_text_fragmentation import LinkerEntity | |
from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository | |
from sqlalchemy import and_, select | |
from sqlalchemy.orm import Session | |
from components.dbo.models.entity import EntityModel | |
class ChunkRepository(SQLAlchemyEntityRepository): | |
def __init__(self, db: Session): | |
super().__init__(db) | |
def _entity_model_class(self): | |
return EntityModel | |
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel): | |
""" | |
Преобразует сущность из базы данных в LinkerEntity. | |
Args: | |
db_entity: Сущность из базы данных | |
Returns: | |
LinkerEntity | |
""" | |
# Преобразуем строковые ID в UUID | |
entity = LinkerEntity( | |
id=UUID(db_entity.uuid), # Преобразуем строку в UUID | |
name=db_entity.name, | |
text=db_entity.text, | |
type=db_entity.entity_type, | |
in_search_text=db_entity.in_search_text, | |
metadata=db_entity.metadata_json, | |
source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID | |
target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID | |
number_in_relation=db_entity.number_in_relation, | |
) | |
return LinkerEntity.deserialize(entity) | |
def add_entities( | |
self, | |
entities: list[LinkerEntity], | |
dataset_id: int, | |
embeddings: dict[str, np.ndarray], | |
): | |
""" | |
Добавляет сущности в базу данных. | |
Args: | |
entities: Список сущностей для добавления | |
dataset_id: ID датасета | |
embeddings: Словарь эмбеддингов {entity_id: embedding} | |
""" | |
with self.db() as session: | |
for entity in entities: | |
# Преобразуем UUID в строку для хранения в базе | |
entity_id = str(entity.id) | |
if entity_id in embeddings: | |
embedding = embeddings[entity_id] | |
else: | |
embedding = None | |
session.add( | |
EntityModel( | |
uuid=str(entity.id), # UUID в строку | |
name=entity.name, | |
text=entity.text, | |
entity_type=entity.type, | |
in_search_text=entity.in_search_text, | |
metadata_json=entity.metadata, | |
source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку | |
target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку | |
number_in_relation=entity.number_in_relation, | |
chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index | |
dataset_id=dataset_id, | |
embedding=embedding, | |
) | |
) | |
session.commit() | |
def get_searching_entities( | |
self, | |
dataset_id: int, | |
) -> tuple[list[LinkerEntity], list[np.ndarray]]: | |
with self.db() as session: | |
models = ( | |
session.query(EntityModel) | |
.filter(EntityModel.in_search_text is not None) | |
.filter(EntityModel.dataset_id == dataset_id) | |
.all() | |
) | |
return ( | |
[self._map_db_entity_to_linker_entity(model) for model in models], | |
[model.embedding for model in models], | |
) | |
def get_chunks_by_ids( | |
self, | |
chunk_ids: list[str], | |
) -> list[LinkerEntity]: | |
""" | |
Получение чанков по их ID. | |
Args: | |
chunk_ids: Список ID чанков | |
Returns: | |
Список чанков | |
""" | |
# Преобразуем все ID в строки для единообразия | |
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] | |
with self.db() as session: | |
models = ( | |
session.query(EntityModel) | |
.filter(EntityModel.uuid.in_(str_chunk_ids)) | |
.all() | |
) | |
return [self._map_db_entity_to_linker_entity(model) for model in models] | |
def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]: | |
""" | |
Получить сущности по списку идентификаторов. | |
Args: | |
entity_ids: Список идентификаторов сущностей | |
Returns: | |
Список сущностей, соответствующих указанным идентификаторам | |
""" | |
if not entity_ids: | |
return [] | |
# Преобразуем UUID в строки | |
str_entity_ids = [str(entity_id) for entity_id in entity_ids] | |
with self.db() as session: | |
entity_model = self._entity_model_class() | |
db_entities = session.execute( | |
select(entity_model).where(entity_model.uuid.in_(str_entity_ids)) | |
).scalars().all() | |
return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities] | |
def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]: | |
""" | |
Получить соседние чанки для указанных чанков. | |
Args: | |
chunk_ids: Список идентификаторов чанков | |
max_distance: Максимальное расстояние до соседа | |
Returns: | |
Список соседних чанков | |
""" | |
if not chunk_ids: | |
return [] | |
# Преобразуем UUID в строки | |
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids] | |
with self.db() as session: | |
entity_model = self._entity_model_class() | |
result = [] | |
# Сначала получаем указанные чанки, чтобы узнать их индексы и документы | |
chunks = session.execute( | |
select(entity_model).where( | |
and_( | |
entity_model.uuid.in_(str_chunk_ids), | |
entity_model.entity_type == "Chunk" # Используем entity_type вместо type | |
) | |
) | |
).scalars().all() | |
if not chunks: | |
return [] | |
# Находим документы для чанков через связи | |
doc_ids = set() | |
chunk_indices = {} | |
for chunk in chunks: | |
chunk_indices[chunk.uuid] = chunk.chunk_index | |
# Находим связь от документа к чанку | |
links = session.execute( | |
select(entity_model).where( | |
and_( | |
entity_model.target_id == chunk.uuid, | |
entity_model.name == "document_to_chunk" | |
) | |
) | |
).scalars().all() | |
for link in links: | |
doc_ids.add(link.source_id) | |
if not doc_ids or not any(idx is not None for idx in chunk_indices.values()): | |
return [] | |
# Для каждого документа находим все его чанки | |
for doc_id in doc_ids: | |
# Находим все связи от документа к чанкам | |
links = session.execute( | |
select(entity_model).where( | |
and_( | |
entity_model.source_id == doc_id, | |
entity_model.name == "document_to_chunk" | |
) | |
) | |
).scalars().all() | |
doc_chunk_ids = [link.target_id for link in links] | |
# Получаем все чанки документа | |
doc_chunks = session.execute( | |
select(entity_model).where( | |
and_( | |
entity_model.uuid.in_(doc_chunk_ids), | |
entity_model.entity_type == "Chunk" # Используем entity_type вместо type | |
) | |
) | |
).scalars().all() | |
# Для каждого чанка в документе проверяем, является ли он соседом | |
for doc_chunk in doc_chunks: | |
if doc_chunk.uuid in str_chunk_ids: | |
continue | |
if doc_chunk.chunk_index is None: | |
continue | |
# Проверяем, является ли чанк соседом какого-либо из исходных чанков | |
is_neighbor = False | |
for orig_chunk_id, orig_index in chunk_indices.items(): | |
if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance: | |
is_neighbor = True | |
break | |
if is_neighbor: | |
result.append(self._map_db_entity_to_linker_entity(doc_chunk)) | |
return result | |