muryshev's picture
update
308de05
raw
history blame
4.66 kB
"""
Абстрактный базовый класс для стратегий чанкинга.
"""
import logging
from abc import ABC, abstractmethod
from ntr_fileparser import ParsedDocument
from ..models import DocumentAsEntity, LinkerEntity
from ..repositories import EntityRepository
from .models import Chunk
logger = logging.getLogger(__name__)
class ChunkingStrategy(ABC):
"""Абстрактный класс для стратегий чанкинга."""
@abstractmethod
def chunk(
self,
document: ParsedDocument,
doc_entity: DocumentAsEntity,
) -> list[LinkerEntity]:
"""
Разбивает документ на чанки в соответствии со стратегией.
Args:
document: ParsedDocument для извлечения текста и структуры.
doc_entity: Сущность документа-владельца, к которой будут привязаны чанки.
Returns:
Список сущностей (чанки)
"""
raise NotImplementedError("Стратегия чанкинга должна реализовать метод chunk")
@abstractmethod
async def chunk_async(
self,
document: ParsedDocument,
doc_entity: DocumentAsEntity,
) -> list[LinkerEntity]:
"""
Асинхронно разбивает документ на чанки в соответствии со стратегией.
Args:
document: ParsedDocument для извлечения текста и структуры.
doc_entity: Сущность документа-владельца, к которой будут привязаны чанки.
Returns:
Список сущностей (чанки)
"""
logger.warning(
"Асинхронная стратегия чанкинга не реализована, вызывается синхронная"
)
return self.chunk(document, doc_entity)
@classmethod
def dechunk(
cls,
repository: EntityRepository,
filtered_entities: list[LinkerEntity],
) -> str:
"""
Собирает текст из отфильтрованных чанков к одному документу.
Args:
repository: Репозиторий (может понадобиться для получения доп. информации,
хотя в текущей реализации не используется).
filtered_entities: Список отфильтрованных сущностей (чанков),
относящихся к одному документу.
Returns:
Собранный текст из чанков.
"""
chunks = [e for e in filtered_entities if isinstance(e, Chunk)]
chunks.sort(key=lambda x: x.number_in_relation)
groups: list[list[Chunk]] = []
for chunk in chunks:
if len(groups) == 0:
groups.append([chunk])
continue
last_chunk = groups[-1][-1]
if chunk.number_in_relation == last_chunk.number_in_relation + 1:
groups[-1].append(chunk)
else:
groups.append([chunk])
result = ""
previous_last_index = 0
for group in groups:
if previous_last_index is not None:
missing_chunks = group[0].number_in_relation - previous_last_index - 1
missing_string = f'\n_<...Пропущено {missing_chunks} фрагментов...>_\n'
else:
missing_string = '\n_<...>_\n'
result += missing_string + cls._build_sequenced_chunks(repository, group)
previous_last_index = group[-1].number_in_relation
return result.strip()
@classmethod
def _build_sequenced_chunks(
cls,
repository: EntityRepository,
group: list[Chunk],
) -> str:
"""
Строит текст для последовательных чанков.
Стоит переопределить в конкретной стратегии, если она предполагает сложную логику
"""
return " ".join([cls._build_chunk(chunk) for chunk in group])
@classmethod
def _build_chunk(cls, chunk: Chunk) -> str:
"""Строит текст для одного чанка."""
return chunk.text