muryshev's picture
init
57cf043
raw
history blame
15.6 kB
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Callable
import numpy as np
from components.embedding_extraction import EmbeddingExtractor
from components.parser.abbreviations.abbreviation import Abbreviation
from components.parser.abbreviations.structures import AbbreviationsCollection
from components.parser.features.dataset_creator import DatasetCreator
from components.parser.features.documents_dataset import DatasetRow, DocumentsDataset
from components.parser.features.hierarchy_parser import Hierarchy, HierarchyParser
from components.parser.paths import DatasetPaths
from components.parser.xml import ParsedXMLs, XMLParser
from components.parser.xml.constants import ACTUAL_STATUSES
logger = logging.getLogger(__name__)
class DatasetCreationPipeline:
"""
Пайплайн для обработки XML файлов со следующими шагами:
1. Парсинг XML файлов из директории
2. Извлечение аббревиатур из распаршенного контента
3. Применение аббревиатур к текстовому и табличному контенту
4. Обработка контента с помощью HierarchyParser
5. Создание и сохранение финального датасета
"""
def __init__(
self,
dataset_id: int,
prepared_abbreviations: list[Abbreviation],
document_ids: list[int],
document_formats: list[str],
datasets_path: Path,
documents_path: Path,
vectorizer: EmbeddingExtractor | None = None,
save_intermediate_files: bool = False,
old_dataset_id: int | None = None,
) -> None:
"""
Инициализация пайплайна.
Args:
dataset_id: Идентификатор датасета
vectorizer: Векторизатор для создания эмбеддингов
prepared_abbreviations: Датафрейм с аббревиатурами, извлечёнными ранее
xml_ids: Список идентификаторов XML файлов
save_intermediate_files: Флаг, указывающий, нужно ли сохранять промежуточные файлы
old_dataset: Старый датасет, если он есть
"""
self.datasets_path = datasets_path
self.documents_path = documents_path
self.dataset_id = dataset_id
self.paths = DatasetPaths(
self.datasets_path / str(dataset_id), save_intermediate_files
)
self.document_ids = document_ids
self.document_formats = document_formats
self.prepared_abbreviations = self._group_abbreviations(prepared_abbreviations)
self.dataset_creator = DatasetCreator()
self.vectorizer = vectorizer
self.xml_parser = XMLParser()
self.hierarchy_parser = HierarchyParser()
self.abbreviations: AbbreviationsCollection | None = None
self.info: ParsedXMLs | None = None
self.dataset: DocumentsDataset | None = None
self.old_paths = (
DatasetPaths(
self.datasets_path / str(old_dataset_id),
save_intermediate_files,
)
if old_dataset_id
else None
)
logger.info(f'DatasetCreationPipeline initialized for {dataset_id}')
def run(
self,
progress_callback: Callable[[int, int], None] | None = None,
) -> DocumentsDataset:
"""
Выполнение полного пайплайна обработки.
Args:
progress_callback: Функция, которая будет вызываться при каждом шаге векторизации.
Принимает два аргумента: current и total.
current - текущий шаг векторизации.
total - общее количество шагов векторизации.
Returns:
DocumentsDataset: Векторизованный датасет.
"""
logger.info(f'Running pipeline for {self.dataset_id}')
# Создание выходной директории
Path(self.paths.root_path).mkdir(parents=True, exist_ok=True)
logger.info('Folder created')
logger.info('Processing XML files...')
# Парсинг XML файлов
parsed_xmls = self.process_xml_files()
logger.info('XML files processed')
logger.info('Saving XML info...')
self.info = [xml.only_info() for xml in parsed_xmls.xmls]
parsed_xmls.to_pandas().to_csv(
self.paths.xml_info,
index=False,
)
logger.info('XML info saved')
logger.info('Saving txt files...')
# Сохранение промежуточных txt файлов
if self.paths.save_intermediate_files:
self._save_txt_files(parsed_xmls)
logger.info('Txt files saved')
logger.info('Processing abbreviations...')
# Обработка аббревиатур
self.abbreviations = self.process_abbreviations(parsed_xmls)
logger.info('Abbreviations processed')
logger.info('Saving abbreviations...')
AbbreviationsCollection(self.abbreviations).to_pandas().to_csv(
self.paths.abbreviations,
index=False,
)
logger.info('Abbreviations saved')
logger.info('Saving txt files with abbreviations...')
# Сохранение промежуточных txt файлов с применением аббревиатур
if self.paths.save_intermediate_files:
self._save_txt_files(parsed_xmls)
logger.info('Txt files with abbreviations saved')
logger.info('Extracting hierarchies...')
hierarchies = self._extract_hierarchies(parsed_xmls)
logger.info('Hierarchies extracted')
logger.info('Saving hierarchies...')
if self.paths.save_intermediate_files:
self._save_hierarchies(hierarchies)
logger.info('Hierarchies saved')
logger.info('Creating dataset...')
dataset = self.create_dataset(parsed_xmls, hierarchies)
if self.vectorizer:
logger.info('Vectorizing dataset...')
dataset.vectorize_with(
self.vectorizer,
progress_callback=progress_callback,
)
logger.info('Dataset vectorized')
logger.info('Saving dataset...')
dataset.to_pickle(self.paths.dataset)
logger.info('Dataset saved')
return dataset
def process_xml_files(self) -> ParsedXMLs:
"""
Парсинг XML файлов из указанной директории.
Возвращает:
ParsedXMLs: Структура с данными из всех XML файлов
"""
parsed_xmls = []
for document_id, document_format in zip(
self.document_ids, self.document_formats
):
parsed_xml = XMLParser.parse(
self.documents_path / f'{document_id}.{document_format}',
include_content=True,
)
if ('состав' in parsed_xml.name.lower()) or (
'составы' in parsed_xml.name.lower()
):
continue
if parsed_xml.status not in ACTUAL_STATUSES:
continue
parsed_xml.id = document_id
parsed_xmls.append(parsed_xml)
return ParsedXMLs(parsed_xmls)
def process_abbreviations(
self,
parsed_xmls: ParsedXMLs,
) -> list[Abbreviation]:
"""
Обработка и применение аббревиатур к контенту документов.
Теперь аббревиатуры уже извлечены во время парсинга, этот метод:
1. Устанавливает document_id для извлеченных аббревиатур
2. Применяет только документно-специфичные аббревиатуры к соответствующим документам
3. Объединяет все аббревиатуры (извлеченные и предварительно подготовленные) для возврата
Args:
parsed_xmls: Структура с данными из всех XML файлов
Returns:
list[Abbreviation]: Список всех аббревиатур для датасета
"""
all_abbreviations = {}
# Итерируем по документам
for xml in parsed_xmls.xmls:
# Устанавливаем document_id для извлеченных аббревиатур, если они есть
doc_specific_abbreviations = []
if xml.abbreviations:
for abbreviation in xml.abbreviations:
abbreviation.document_id = xml.id
doc_specific_abbreviations = xml.abbreviations
# Применяем только аббревиатуры, извлеченные из этого документа
if doc_specific_abbreviations:
# Если есть аббревиатуры из документа, применяем их
xml.apply_abbreviations(doc_specific_abbreviations)
# Получаем подготовленные аббревиатуры для текущего документа
prepared_abbr = self.prepared_abbreviations.get(xml.id, [])
# Объединяем все аббревиатуры для возврата (не для применения)
combined_abbr = (doc_specific_abbreviations or []) + prepared_abbr
# Сохраняем объединенный список в document.abbreviations и в общем словаре
if combined_abbr:
xml.abbreviations = combined_abbr
all_abbreviations[xml.id] = combined_abbr
return self._ungroup_abbreviations(all_abbreviations)
def _get_already_parsed_xmls(
self,
) -> tuple[list[int], list[DatasetRow], list[np.ndarray]]:
if self.old_paths:
self.old_dataset = DocumentsDataset.from_pickle(self.old_paths.dataset)
ids = set([int(row.DocNumber) for row in self.old_dataset.rows])
ids = ids.intersection(self.xml_ids)
rows = [row for row in self.old_dataset.rows if row.DocNumber in ids]
embs = [
emb
for row, emb in zip(rows, self.old_dataset.vectors)
if row.DocNumber in ids
]
return ids, rows, embs
return [], [], []
def _extract_hierarchies(
self,
parsed_xmls: ParsedXMLs,
) -> dict[int, tuple[Hierarchy, Hierarchy]]:
"""
Извлечение иерархических структур из текстового и табличного контента.
Args:
parsed_xmls: Структура с данными из всех XML файлов
Returns:
dict[int, tuple[Hierarchy, Hierarchy]]: Словарь иерархических структур для каждого документа
"""
hierarchies = {}
for xml in parsed_xmls.xmls:
doc_id = xml.id
# Обработка текстового контента
if xml.text:
text_lines = xml.text.to_text().split('\n')
self.hierarchy_parser.parse(text_lines, doc_id, '')
text_hierarchy = self.hierarchy_parser.hierarchy()
else:
text_hierarchy = {}
# Обработка табличного контента
if xml.tables:
table_lines = xml.tables.to_text().split('\n')
self.hierarchy_parser.parse_table(table_lines, doc_id)
table_hierarchy = self.hierarchy_parser.hierarchy()
else:
table_hierarchy = {}
hierarchies[doc_id] = (text_hierarchy, table_hierarchy)
return hierarchies
def create_dataset(
self,
parsed_xmls: ParsedXMLs,
hierarchies: dict[int, tuple[Hierarchy, Hierarchy]],
) -> DocumentsDataset:
"""
Создание финального датасета с векторизацией.
Args:
parsed_xmls: Структура с данными из всех XML файлов
hierarchies: Словарь с иерархической структурой документов
Returns:
DocumentsDataset: Датасет с векторизованными текстами
"""
xmls = {xml.id: xml for xml in parsed_xmls.xmls}
self.dataset = self.dataset_creator.create_dataset(xmls, hierarchies)
return self.dataset
def _group_abbreviations(
self,
abbreviations: list[Abbreviation],
) -> dict[int, list[Abbreviation]]:
"""
Преобразует список аббревиатур в словарь, где ключи - идентификаторы документов, а значения - списки аббревиатур.
"""
doc_to_abbreviations = defaultdict(list)
for abbreviation in abbreviations:
doc_to_abbreviations[abbreviation.document_id].append(abbreviation)
return doc_to_abbreviations
def _ungroup_abbreviations(
self, abbreviations: dict[int, list[Abbreviation]]
) -> list[Abbreviation]:
"""
Преобразует словарь аббревиатур в список аббревиатур.
"""
return sum(abbreviations.values(), [])
def _save_txt_files(self, parsed_xmls: ParsedXMLs) -> None:
"""
Сохранение текстового и табличного контента в текстовые файлы.
"""
self.paths.txt_path.mkdir(parents=True, exist_ok=True)
for xml in parsed_xmls.xmls:
with open(self.paths.txt_path / f'{xml.id}.txt', 'w', encoding='utf-8') as f:
f.write(xml.text.to_text())
if xml.tables:
with open(self.paths.txt_path / f'{xml.id}_table.txt', 'w', encoding='utf-8') as f:
f.write(xml.tables.to_text())
def _save_hierarchies(
self,
hierarchies: dict[int, tuple[Hierarchy, Hierarchy]],
) -> None:
"""
Сохранение иерархий в JSON файлы.
"""
self.paths.jsons_path.mkdir(parents=True, exist_ok=True)
for doc_id, (text_hierarchy, table_hierarchy) in hierarchies.items():
if text_hierarchy:
with open(self.paths.jsons_path / f'{doc_id}.json', 'w', encoding='utf-8') as f:
json.dump(text_hierarchy, f)
if table_hierarchy:
with open(self.paths.jsons_path / f'{doc_id}_table.json', 'w', encoding='utf-8') as f:
json.dump(table_hierarchy, f)