Spaces:
Sleeping
Sleeping
import json | |
import logging | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import Callable | |
import numpy as np | |
from components.embedding_extraction import EmbeddingExtractor | |
from components.parser.abbreviations.abbreviation import Abbreviation | |
from components.parser.abbreviations.structures import AbbreviationsCollection | |
from components.parser.features.dataset_creator import DatasetCreator | |
from components.parser.features.documents_dataset import DatasetRow, DocumentsDataset | |
from components.parser.features.hierarchy_parser import Hierarchy, HierarchyParser | |
from components.parser.paths import DatasetPaths | |
from components.parser.xml import ParsedXMLs, XMLParser | |
from components.parser.xml.constants import ACTUAL_STATUSES | |
logger = logging.getLogger(__name__) | |
class DatasetCreationPipeline: | |
""" | |
Пайплайн для обработки XML файлов со следующими шагами: | |
1. Парсинг XML файлов из директории | |
2. Извлечение аббревиатур из распаршенного контента | |
3. Применение аббревиатур к текстовому и табличному контенту | |
4. Обработка контента с помощью HierarchyParser | |
5. Создание и сохранение финального датасета | |
""" | |
def __init__( | |
self, | |
dataset_id: int, | |
prepared_abbreviations: list[Abbreviation], | |
document_ids: list[int], | |
document_formats: list[str], | |
datasets_path: Path, | |
documents_path: Path, | |
vectorizer: EmbeddingExtractor | None = None, | |
save_intermediate_files: bool = False, | |
old_dataset_id: int | None = None, | |
) -> None: | |
""" | |
Инициализация пайплайна. | |
Args: | |
dataset_id: Идентификатор датасета | |
vectorizer: Векторизатор для создания эмбеддингов | |
prepared_abbreviations: Датафрейм с аббревиатурами, извлечёнными ранее | |
xml_ids: Список идентификаторов XML файлов | |
save_intermediate_files: Флаг, указывающий, нужно ли сохранять промежуточные файлы | |
old_dataset: Старый датасет, если он есть | |
""" | |
self.datasets_path = datasets_path | |
self.documents_path = documents_path | |
self.dataset_id = dataset_id | |
self.paths = DatasetPaths( | |
self.datasets_path / str(dataset_id), save_intermediate_files | |
) | |
self.document_ids = document_ids | |
self.document_formats = document_formats | |
self.prepared_abbreviations = self._group_abbreviations(prepared_abbreviations) | |
self.dataset_creator = DatasetCreator() | |
self.vectorizer = vectorizer | |
self.xml_parser = XMLParser() | |
self.hierarchy_parser = HierarchyParser() | |
self.abbreviations: AbbreviationsCollection | None = None | |
self.info: ParsedXMLs | None = None | |
self.dataset: DocumentsDataset | None = None | |
self.old_paths = ( | |
DatasetPaths( | |
self.datasets_path / str(old_dataset_id), | |
save_intermediate_files, | |
) | |
if old_dataset_id | |
else None | |
) | |
logger.info(f'DatasetCreationPipeline initialized for {dataset_id}') | |
def run( | |
self, | |
progress_callback: Callable[[int, int], None] | None = None, | |
) -> DocumentsDataset: | |
""" | |
Выполнение полного пайплайна обработки. | |
Args: | |
progress_callback: Функция, которая будет вызываться при каждом шаге векторизации. | |
Принимает два аргумента: current и total. | |
current - текущий шаг векторизации. | |
total - общее количество шагов векторизации. | |
Returns: | |
DocumentsDataset: Векторизованный датасет. | |
""" | |
logger.info(f'Running pipeline for {self.dataset_id}') | |
# Создание выходной директории | |
Path(self.paths.root_path).mkdir(parents=True, exist_ok=True) | |
logger.info('Folder created') | |
logger.info('Processing XML files...') | |
# Парсинг XML файлов | |
parsed_xmls = self.process_xml_files() | |
logger.info('XML files processed') | |
logger.info('Saving XML info...') | |
self.info = [xml.only_info() for xml in parsed_xmls.xmls] | |
parsed_xmls.to_pandas().to_csv( | |
self.paths.xml_info, | |
index=False, | |
) | |
logger.info('XML info saved') | |
logger.info('Saving txt files...') | |
# Сохранение промежуточных txt файлов | |
if self.paths.save_intermediate_files: | |
self._save_txt_files(parsed_xmls) | |
logger.info('Txt files saved') | |
logger.info('Processing abbreviations...') | |
# Обработка аббревиатур | |
self.abbreviations = self.process_abbreviations(parsed_xmls) | |
logger.info('Abbreviations processed') | |
logger.info('Saving abbreviations...') | |
AbbreviationsCollection(self.abbreviations).to_pandas().to_csv( | |
self.paths.abbreviations, | |
index=False, | |
) | |
logger.info('Abbreviations saved') | |
logger.info('Saving txt files with abbreviations...') | |
# Сохранение промежуточных txt файлов с применением аббревиатур | |
if self.paths.save_intermediate_files: | |
self._save_txt_files(parsed_xmls) | |
logger.info('Txt files with abbreviations saved') | |
logger.info('Extracting hierarchies...') | |
hierarchies = self._extract_hierarchies(parsed_xmls) | |
logger.info('Hierarchies extracted') | |
logger.info('Saving hierarchies...') | |
if self.paths.save_intermediate_files: | |
self._save_hierarchies(hierarchies) | |
logger.info('Hierarchies saved') | |
logger.info('Creating dataset...') | |
dataset = self.create_dataset(parsed_xmls, hierarchies) | |
if self.vectorizer: | |
logger.info('Vectorizing dataset...') | |
dataset.vectorize_with( | |
self.vectorizer, | |
progress_callback=progress_callback, | |
) | |
logger.info('Dataset vectorized') | |
logger.info('Saving dataset...') | |
dataset.to_pickle(self.paths.dataset) | |
logger.info('Dataset saved') | |
return dataset | |
def process_xml_files(self) -> ParsedXMLs: | |
""" | |
Парсинг XML файлов из указанной директории. | |
Возвращает: | |
ParsedXMLs: Структура с данными из всех XML файлов | |
""" | |
parsed_xmls = [] | |
for document_id, document_format in zip( | |
self.document_ids, self.document_formats | |
): | |
parsed_xml = XMLParser.parse( | |
self.documents_path / f'{document_id}.{document_format}', | |
include_content=True, | |
) | |
if ('состав' in parsed_xml.name.lower()) or ( | |
'составы' in parsed_xml.name.lower() | |
): | |
continue | |
if parsed_xml.status not in ACTUAL_STATUSES: | |
continue | |
parsed_xml.id = document_id | |
parsed_xmls.append(parsed_xml) | |
return ParsedXMLs(parsed_xmls) | |
def process_abbreviations( | |
self, | |
parsed_xmls: ParsedXMLs, | |
) -> list[Abbreviation]: | |
""" | |
Обработка и применение аббревиатур к контенту документов. | |
Теперь аббревиатуры уже извлечены во время парсинга, этот метод: | |
1. Устанавливает document_id для извлеченных аббревиатур | |
2. Применяет только документно-специфичные аббревиатуры к соответствующим документам | |
3. Объединяет все аббревиатуры (извлеченные и предварительно подготовленные) для возврата | |
Args: | |
parsed_xmls: Структура с данными из всех XML файлов | |
Returns: | |
list[Abbreviation]: Список всех аббревиатур для датасета | |
""" | |
all_abbreviations = {} | |
# Итерируем по документам | |
for xml in parsed_xmls.xmls: | |
# Устанавливаем document_id для извлеченных аббревиатур, если они есть | |
doc_specific_abbreviations = [] | |
if xml.abbreviations: | |
for abbreviation in xml.abbreviations: | |
abbreviation.document_id = xml.id | |
doc_specific_abbreviations = xml.abbreviations | |
# Применяем только аббревиатуры, извлеченные из этого документа | |
if doc_specific_abbreviations: | |
# Если есть аббревиатуры из документа, применяем их | |
xml.apply_abbreviations(doc_specific_abbreviations) | |
# Получаем подготовленные аббревиатуры для текущего документа | |
prepared_abbr = self.prepared_abbreviations.get(xml.id, []) | |
# Объединяем все аббревиатуры для возврата (не для применения) | |
combined_abbr = (doc_specific_abbreviations or []) + prepared_abbr | |
# Сохраняем объединенный список в document.abbreviations и в общем словаре | |
if combined_abbr: | |
xml.abbreviations = combined_abbr | |
all_abbreviations[xml.id] = combined_abbr | |
return self._ungroup_abbreviations(all_abbreviations) | |
def _get_already_parsed_xmls( | |
self, | |
) -> tuple[list[int], list[DatasetRow], list[np.ndarray]]: | |
if self.old_paths: | |
self.old_dataset = DocumentsDataset.from_pickle(self.old_paths.dataset) | |
ids = set([int(row.DocNumber) for row in self.old_dataset.rows]) | |
ids = ids.intersection(self.xml_ids) | |
rows = [row for row in self.old_dataset.rows if row.DocNumber in ids] | |
embs = [ | |
emb | |
for row, emb in zip(rows, self.old_dataset.vectors) | |
if row.DocNumber in ids | |
] | |
return ids, rows, embs | |
return [], [], [] | |
def _extract_hierarchies( | |
self, | |
parsed_xmls: ParsedXMLs, | |
) -> dict[int, tuple[Hierarchy, Hierarchy]]: | |
""" | |
Извлечение иерархических структур из текстового и табличного контента. | |
Args: | |
parsed_xmls: Структура с данными из всех XML файлов | |
Returns: | |
dict[int, tuple[Hierarchy, Hierarchy]]: Словарь иерархических структур для каждого документа | |
""" | |
hierarchies = {} | |
for xml in parsed_xmls.xmls: | |
doc_id = xml.id | |
# Обработка текстового контента | |
if xml.text: | |
text_lines = xml.text.to_text().split('\n') | |
self.hierarchy_parser.parse(text_lines, doc_id, '') | |
text_hierarchy = self.hierarchy_parser.hierarchy() | |
else: | |
text_hierarchy = {} | |
# Обработка табличного контента | |
if xml.tables: | |
table_lines = xml.tables.to_text().split('\n') | |
self.hierarchy_parser.parse_table(table_lines, doc_id) | |
table_hierarchy = self.hierarchy_parser.hierarchy() | |
else: | |
table_hierarchy = {} | |
hierarchies[doc_id] = (text_hierarchy, table_hierarchy) | |
return hierarchies | |
def create_dataset( | |
self, | |
parsed_xmls: ParsedXMLs, | |
hierarchies: dict[int, tuple[Hierarchy, Hierarchy]], | |
) -> DocumentsDataset: | |
""" | |
Создание финального датасета с векторизацией. | |
Args: | |
parsed_xmls: Структура с данными из всех XML файлов | |
hierarchies: Словарь с иерархической структурой документов | |
Returns: | |
DocumentsDataset: Датасет с векторизованными текстами | |
""" | |
xmls = {xml.id: xml for xml in parsed_xmls.xmls} | |
self.dataset = self.dataset_creator.create_dataset(xmls, hierarchies) | |
return self.dataset | |
def _group_abbreviations( | |
self, | |
abbreviations: list[Abbreviation], | |
) -> dict[int, list[Abbreviation]]: | |
""" | |
Преобразует список аббревиатур в словарь, где ключи - идентификаторы документов, а значения - списки аббревиатур. | |
""" | |
doc_to_abbreviations = defaultdict(list) | |
for abbreviation in abbreviations: | |
doc_to_abbreviations[abbreviation.document_id].append(abbreviation) | |
return doc_to_abbreviations | |
def _ungroup_abbreviations( | |
self, abbreviations: dict[int, list[Abbreviation]] | |
) -> list[Abbreviation]: | |
""" | |
Преобразует словарь аббревиатур в список аббревиатур. | |
""" | |
return sum(abbreviations.values(), []) | |
def _save_txt_files(self, parsed_xmls: ParsedXMLs) -> None: | |
""" | |
Сохранение текстового и табличного контента в текстовые файлы. | |
""" | |
self.paths.txt_path.mkdir(parents=True, exist_ok=True) | |
for xml in parsed_xmls.xmls: | |
with open(self.paths.txt_path / f'{xml.id}.txt', 'w', encoding='utf-8') as f: | |
f.write(xml.text.to_text()) | |
if xml.tables: | |
with open(self.paths.txt_path / f'{xml.id}_table.txt', 'w', encoding='utf-8') as f: | |
f.write(xml.tables.to_text()) | |
def _save_hierarchies( | |
self, | |
hierarchies: dict[int, tuple[Hierarchy, Hierarchy]], | |
) -> None: | |
""" | |
Сохранение иерархий в JSON файлы. | |
""" | |
self.paths.jsons_path.mkdir(parents=True, exist_ok=True) | |
for doc_id, (text_hierarchy, table_hierarchy) in hierarchies.items(): | |
if text_hierarchy: | |
with open(self.paths.jsons_path / f'{doc_id}.json', 'w', encoding='utf-8') as f: | |
json.dump(text_hierarchy, f) | |
if table_hierarchy: | |
with open(self.paths.jsons_path / f'{doc_id}_table.json', 'w', encoding='utf-8') as f: | |
json.dump(table_hierarchy, f) | |