import json import logging from collections import defaultdict from pathlib import Path from typing import Callable import numpy as np from components.embedding_extraction import EmbeddingExtractor from components.parser.abbreviations.abbreviation import Abbreviation from components.parser.abbreviations.structures import AbbreviationsCollection from components.parser.features.dataset_creator import DatasetCreator from components.parser.features.documents_dataset import DatasetRow, DocumentsDataset from components.parser.features.hierarchy_parser import Hierarchy, HierarchyParser from components.parser.paths import DatasetPaths from components.parser.xml import ParsedXMLs, XMLParser from components.parser.xml.constants import ACTUAL_STATUSES logger = logging.getLogger(__name__) class DatasetCreationPipeline: """ Пайплайн для обработки XML файлов со следующими шагами: 1. Парсинг XML файлов из директории 2. Извлечение аббревиатур из распаршенного контента 3. Применение аббревиатур к текстовому и табличному контенту 4. Обработка контента с помощью HierarchyParser 5. Создание и сохранение финального датасета """ def __init__( self, dataset_id: int, prepared_abbreviations: list[Abbreviation], document_ids: list[int], document_formats: list[str], datasets_path: Path, documents_path: Path, vectorizer: EmbeddingExtractor | None = None, save_intermediate_files: bool = False, old_dataset_id: int | None = None, ) -> None: """ Инициализация пайплайна. Args: dataset_id: Идентификатор датасета vectorizer: Векторизатор для создания эмбеддингов prepared_abbreviations: Датафрейм с аббревиатурами, извлечёнными ранее xml_ids: Список идентификаторов XML файлов save_intermediate_files: Флаг, указывающий, нужно ли сохранять промежуточные файлы old_dataset: Старый датасет, если он есть """ self.datasets_path = datasets_path self.documents_path = documents_path self.dataset_id = dataset_id self.paths = DatasetPaths( self.datasets_path / str(dataset_id), save_intermediate_files ) self.document_ids = document_ids self.document_formats = document_formats self.prepared_abbreviations = self._group_abbreviations(prepared_abbreviations) self.dataset_creator = DatasetCreator() self.vectorizer = vectorizer self.xml_parser = XMLParser() self.hierarchy_parser = HierarchyParser() self.abbreviations: AbbreviationsCollection | None = None self.info: ParsedXMLs | None = None self.dataset: DocumentsDataset | None = None self.old_paths = ( DatasetPaths( self.datasets_path / str(old_dataset_id), save_intermediate_files, ) if old_dataset_id else None ) logger.info(f'DatasetCreationPipeline initialized for {dataset_id}') def run( self, progress_callback: Callable[[int, int], None] | None = None, ) -> DocumentsDataset: """ Выполнение полного пайплайна обработки. Args: progress_callback: Функция, которая будет вызываться при каждом шаге векторизации. Принимает два аргумента: current и total. current - текущий шаг векторизации. total - общее количество шагов векторизации. Returns: DocumentsDataset: Векторизованный датасет. """ logger.info(f'Running pipeline for {self.dataset_id}') # Создание выходной директории Path(self.paths.root_path).mkdir(parents=True, exist_ok=True) logger.info('Folder created') logger.info('Processing XML files...') # Парсинг XML файлов parsed_xmls = self.process_xml_files() logger.info('XML files processed') logger.info('Saving XML info...') self.info = [xml.only_info() for xml in parsed_xmls.xmls] parsed_xmls.to_pandas().to_csv( self.paths.xml_info, index=False, ) logger.info('XML info saved') logger.info('Saving txt files...') # Сохранение промежуточных txt файлов if self.paths.save_intermediate_files: self._save_txt_files(parsed_xmls) logger.info('Txt files saved') logger.info('Processing abbreviations...') # Обработка аббревиатур self.abbreviations = self.process_abbreviations(parsed_xmls) logger.info('Abbreviations processed') logger.info('Saving abbreviations...') AbbreviationsCollection(self.abbreviations).to_pandas().to_csv( self.paths.abbreviations, index=False, ) logger.info('Abbreviations saved') logger.info('Saving txt files with abbreviations...') # Сохранение промежуточных txt файлов с применением аббревиатур if self.paths.save_intermediate_files: self._save_txt_files(parsed_xmls) logger.info('Txt files with abbreviations saved') logger.info('Extracting hierarchies...') hierarchies = self._extract_hierarchies(parsed_xmls) logger.info('Hierarchies extracted') logger.info('Saving hierarchies...') if self.paths.save_intermediate_files: self._save_hierarchies(hierarchies) logger.info('Hierarchies saved') logger.info('Creating dataset...') dataset = self.create_dataset(parsed_xmls, hierarchies) if self.vectorizer: logger.info('Vectorizing dataset...') dataset.vectorize_with( self.vectorizer, progress_callback=progress_callback, ) logger.info('Dataset vectorized') logger.info('Saving dataset...') dataset.to_pickle(self.paths.dataset) logger.info('Dataset saved') return dataset def process_xml_files(self) -> ParsedXMLs: """ Парсинг XML файлов из указанной директории. Возвращает: ParsedXMLs: Структура с данными из всех XML файлов """ parsed_xmls = [] for document_id, document_format in zip( self.document_ids, self.document_formats ): parsed_xml = XMLParser.parse( self.documents_path / f'{document_id}.{document_format}', include_content=True, ) if ('состав' in parsed_xml.name.lower()) or ( 'составы' in parsed_xml.name.lower() ): continue if parsed_xml.status not in ACTUAL_STATUSES: continue parsed_xml.id = document_id parsed_xmls.append(parsed_xml) return ParsedXMLs(parsed_xmls) def process_abbreviations( self, parsed_xmls: ParsedXMLs, ) -> list[Abbreviation]: """ Обработка и применение аббревиатур к контенту документов. Теперь аббревиатуры уже извлечены во время парсинга, этот метод: 1. Устанавливает document_id для извлеченных аббревиатур 2. Применяет только документно-специфичные аббревиатуры к соответствующим документам 3. Объединяет все аббревиатуры (извлеченные и предварительно подготовленные) для возврата Args: parsed_xmls: Структура с данными из всех XML файлов Returns: list[Abbreviation]: Список всех аббревиатур для датасета """ all_abbreviations = {} # Итерируем по документам for xml in parsed_xmls.xmls: # Устанавливаем document_id для извлеченных аббревиатур, если они есть doc_specific_abbreviations = [] if xml.abbreviations: for abbreviation in xml.abbreviations: abbreviation.document_id = xml.id doc_specific_abbreviations = xml.abbreviations # Применяем только аббревиатуры, извлеченные из этого документа if doc_specific_abbreviations: # Если есть аббревиатуры из документа, применяем их xml.apply_abbreviations(doc_specific_abbreviations) # Получаем подготовленные аббревиатуры для текущего документа prepared_abbr = self.prepared_abbreviations.get(xml.id, []) # Объединяем все аббревиатуры для возврата (не для применения) combined_abbr = (doc_specific_abbreviations or []) + prepared_abbr # Сохраняем объединенный список в document.abbreviations и в общем словаре if combined_abbr: xml.abbreviations = combined_abbr all_abbreviations[xml.id] = combined_abbr return self._ungroup_abbreviations(all_abbreviations) def _get_already_parsed_xmls( self, ) -> tuple[list[int], list[DatasetRow], list[np.ndarray]]: if self.old_paths: self.old_dataset = DocumentsDataset.from_pickle(self.old_paths.dataset) ids = set([int(row.DocNumber) for row in self.old_dataset.rows]) ids = ids.intersection(self.xml_ids) rows = [row for row in self.old_dataset.rows if row.DocNumber in ids] embs = [ emb for row, emb in zip(rows, self.old_dataset.vectors) if row.DocNumber in ids ] return ids, rows, embs return [], [], [] def _extract_hierarchies( self, parsed_xmls: ParsedXMLs, ) -> dict[int, tuple[Hierarchy, Hierarchy]]: """ Извлечение иерархических структур из текстового и табличного контента. Args: parsed_xmls: Структура с данными из всех XML файлов Returns: dict[int, tuple[Hierarchy, Hierarchy]]: Словарь иерархических структур для каждого документа """ hierarchies = {} for xml in parsed_xmls.xmls: doc_id = xml.id # Обработка текстового контента if xml.text: text_lines = xml.text.to_text().split('\n') self.hierarchy_parser.parse(text_lines, doc_id, '') text_hierarchy = self.hierarchy_parser.hierarchy() else: text_hierarchy = {} # Обработка табличного контента if xml.tables: table_lines = xml.tables.to_text().split('\n') self.hierarchy_parser.parse_table(table_lines, doc_id) table_hierarchy = self.hierarchy_parser.hierarchy() else: table_hierarchy = {} hierarchies[doc_id] = (text_hierarchy, table_hierarchy) return hierarchies def create_dataset( self, parsed_xmls: ParsedXMLs, hierarchies: dict[int, tuple[Hierarchy, Hierarchy]], ) -> DocumentsDataset: """ Создание финального датасета с векторизацией. Args: parsed_xmls: Структура с данными из всех XML файлов hierarchies: Словарь с иерархической структурой документов Returns: DocumentsDataset: Датасет с векторизованными текстами """ xmls = {xml.id: xml for xml in parsed_xmls.xmls} self.dataset = self.dataset_creator.create_dataset(xmls, hierarchies) return self.dataset def _group_abbreviations( self, abbreviations: list[Abbreviation], ) -> dict[int, list[Abbreviation]]: """ Преобразует список аббревиатур в словарь, где ключи - идентификаторы документов, а значения - списки аббревиатур. """ doc_to_abbreviations = defaultdict(list) for abbreviation in abbreviations: doc_to_abbreviations[abbreviation.document_id].append(abbreviation) return doc_to_abbreviations def _ungroup_abbreviations( self, abbreviations: dict[int, list[Abbreviation]] ) -> list[Abbreviation]: """ Преобразует словарь аббревиатур в список аббревиатур. """ return sum(abbreviations.values(), []) def _save_txt_files(self, parsed_xmls: ParsedXMLs) -> None: """ Сохранение текстового и табличного контента в текстовые файлы. """ self.paths.txt_path.mkdir(parents=True, exist_ok=True) for xml in parsed_xmls.xmls: with open(self.paths.txt_path / f'{xml.id}.txt', 'w', encoding='utf-8') as f: f.write(xml.text.to_text()) if xml.tables: with open(self.paths.txt_path / f'{xml.id}_table.txt', 'w', encoding='utf-8') as f: f.write(xml.tables.to_text()) def _save_hierarchies( self, hierarchies: dict[int, tuple[Hierarchy, Hierarchy]], ) -> None: """ Сохранение иерархий в JSON файлы. """ self.paths.jsons_path.mkdir(parents=True, exist_ok=True) for doc_id, (text_hierarchy, table_hierarchy) in hierarchies.items(): if text_hierarchy: with open(self.paths.jsons_path / f'{doc_id}.json', 'w', encoding='utf-8') as f: json.dump(text_hierarchy, f) if table_hierarchy: with open(self.paths.jsons_path / f'{doc_id}_table.json', 'w', encoding='utf-8') as f: json.dump(table_hierarchy, f)