Spaces:
Sleeping
Sleeping
File size: 59,047 Bytes
fd485d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Основной пайплайн для оценки качества RAG системы.
Этот скрипт выполняет один прогон оценки для заданных параметров:
- Чтение документов и датасетов вопросов/ответов.
- Чанкинг документов.
- Векторизация вопросов и чанков.
- Оценка релевантности чанков к пунктам из датасета (Chunk-level).
- Сборка контекста из релевантных чанков (Assembly-level).
- Оценка релевантности собранного контекста к эталонным ответам.
- Сохранение детальных метрик для данного прогона.
"""
import argparse
# Add necessary imports for caching
import hashlib
import json
import os
import pickle
import sys
from pathlib import Path
from typing import Any
from uuid import UUID, uuid4
import numpy as np
import pandas as pd
import torch
from fuzzywuzzy import fuzz
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
# --- Константы (могут быть переопределены аргументами) ---
DEFAULT_DATA_FOLDER = "data/input/docs"
DEFAULT_SEARCH_DATASET_PATH = "data/input/search_dataset_texts.xlsx"
DEFAULT_QA_DATASET_PATH = "data/input/question_answering.xlsx"
DEFAULT_MODEL_NAME = "intfloat/e5-base"
DEFAULT_BATCH_SIZE = 8
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEFAULT_SIMILARITY_THRESHOLD = 0.7
DEFAULT_OUTPUT_DIR = "data/intermediate" # Директория для промежуточных результатов
DEFAULT_WORDS_PER_CHUNK = 50
DEFAULT_OVERLAP_WORDS = 25
DEFAULT_TOP_N = 20 # Значение N по умолчанию для топа чанков
# Add chunking strategy constant
DEFAULT_CHUNKING_STRATEGY = "fixed_size"
# Add cache directory constant
DEFAULT_CACHE_DIR = "data/cache"
# --- Добавление путей к библиотекам ---
# Добавляем путь к корневой папке проекта, чтобы можно было импортировать ntr_...
SCRIPT_DIR = Path(__file__).parent.resolve()
PROJECT_ROOT = SCRIPT_DIR.parent.parent # Перейти на два уровня вверх (scripts/testing -> scripts -> project root)
LIB_EXTRACTOR_PATH = PROJECT_ROOT / "lib" / "extractor"
sys.path.insert(0, str(LIB_EXTRACTOR_PATH))
# Добавляем путь к папке с ntr_text_fragmentation
sys.path.insert(0, str(LIB_EXTRACTOR_PATH / "ntr_text_fragmentation"))
# --- Импорты из локальных модулей ---
try:
from ntr_fileparser import ParsedDocument, UniversalParser
from ntr_text_fragmentation import Destructurer
from ntr_text_fragmentation.core.entity_repository import \
InMemoryEntityRepository
from ntr_text_fragmentation.core.injection_builder import InjectionBuilder
from ntr_text_fragmentation.models.chunk import Chunk
from ntr_text_fragmentation.models.document import DocumentAsEntity
from ntr_text_fragmentation.models.linker_entity import LinkerEntity
except ImportError as e:
print(f"Ошибка импорта локальных модулей: {e}")
print(f"Проверьте пути: Project Root: {PROJECT_ROOT}, Extractor Lib: {LIB_EXTRACTOR_PATH}")
sys.exit(1)
# --- Вспомогательные функции (аналогичные evaluate_chunking.py) ---
def _average_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Расчёт усредненного эмбеддинга по всем токенам.
(Копипаста из evaluate_chunking.py)
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
"""
Рассчитывает степень перекрытия между чанком и пунктом.
(Копипаста из evaluate_chunking.py)
"""
if not chunk_text or not punct_text:
return 0.0
# Используем partial_ratio для лучшей обработки подстрок
return fuzz.partial_ratio(chunk_text, punct_text) / 100.0
# --- Функции загрузки и обработки данных ---
def parse_args():
"""Парсит аргументы командной строки."""
parser = argparse.ArgumentParser(description="Пайплайн оценки RAG системы")
# Пути к данным
parser.add_argument("--data-folder", type=str, default=DEFAULT_DATA_FOLDER,
help=f"Папка с документами (по умолчанию: {DEFAULT_DATA_FOLDER})")
parser.add_argument("--search-dataset-path", type=str, default=DEFAULT_SEARCH_DATASET_PATH,
help=f"Путь к датасету для поиска (по умолчанию: {DEFAULT_SEARCH_DATASET_PATH})")
parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
help=f"Папка для сохранения промежуточных результатов (по умолчанию: {DEFAULT_OUTPUT_DIR})")
parser.add_argument("--run-id", type=str, default=f"run_{uuid4()}",
help="Уникальный идентификатор запуска (по умолчанию: генерируется)")
# Параметры модели и векторизации
parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME,
help=f"Название модели для векторизации (по умолчанию: {DEFAULT_MODEL_NAME})")
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
help=f"Размер батча для векторизации (по умолчанию: {DEFAULT_BATCH_SIZE})")
parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, # type: ignore
help=f"Устройство для вычислений (по умолчанию: {DEFAULT_DEVICE})")
parser.add_argument("--use-sentence-transformers", action="store_true",
help="Использовать библиотеку sentence_transformers")
# Параметры чанкинга
parser.add_argument("--chunking-strategy", type=str, default=DEFAULT_CHUNKING_STRATEGY,
choices=list(Destructurer.STRATEGIES.keys()), # Use keys from Destructurer
help=f"Стратегия чанкинга (по умолчанию: {DEFAULT_CHUNKING_STRATEGY})")
parser.add_argument("--strategy-params", type=str, default='{}', # Default to empty JSON object
help="Параметры для стратегии чанкинга в формате JSON строки (например, '{\"words_per_chunk\": 50}')")
parser.add_argument("--no-process-tables", action="store_false", dest="process_tables",
help="Отключить обработку таблиц при чанкинге")
parser.set_defaults(process_tables=True) # Default is to process tables
# Параметры оценки
parser.add_argument("--similarity-threshold", type=float, default=DEFAULT_SIMILARITY_THRESHOLD,
help=f"Порог для нечеткого сравнения чанка и пункта (по умолчанию: {DEFAULT_SIMILARITY_THRESHOLD})")
parser.add_argument("--top-n", type=int, default=DEFAULT_TOP_N,
help=f"Количество топ-чанков для рассмотрения (по умолчанию: {DEFAULT_TOP_N})")
# Add cache directory argument
parser.add_argument("--cache-dir", type=str, default=DEFAULT_CACHE_DIR,
help=f"Директория для кэширования эмбеддингов и матриц схожести (по умолчанию: {DEFAULT_CACHE_DIR})")
# Параметры сборки контекста
parser.add_argument("--use-injection", action="store_true",
help="Выполнять ли сборку контекста и её оценку")
parser.add_argument("--use-qe", action="store_true",
help="Использовать столбец query_expansion вместо question для поиска (если он есть)")
parser.add_argument("--include-neighbors", action="store_true",
help="Включать ли соседние чанки (предыдущий/следующий) при сборке контекста")
# --- Добавляем аргумент для batch_id ---
parser.add_argument("--batch-id", type=str, default="batch_default",
help="Идентификатор серии запусков (передается из run_pipelines.py)")
# TODO: Добавить другие параметры при необходимости (например, параметры InjectionBuilder)
return parser.parse_args()
def read_documents(folder_path: str) -> dict[str, ParsedDocument]:
"""
Читает все документы из указанной папки и создает сущности.
Args:
folder_path: Путь к папке с документами
Returns:
Словарь {имя_файла: объект ParsedDocument}
"""
print(f"Чтение документов из {folder_path}...")
parser = UniversalParser()
documents_map = {}
doc_files = list(Path(folder_path).glob("*.docx"))
if not doc_files:
print(f"ВНИМАНИЕ: В папке {folder_path} не найдено *.docx файлов.")
return {}
for file_path in tqdm(doc_files, desc="Чтение документов"):
try:
doc_name = file_path.stem
# Парсим документ с помощью UniversalParser
parsed_document = parser.parse_by_path(str(file_path))
# Сохраняем распарсенный документ
documents_map[doc_name] = parsed_document
except Exception as e:
print(f"Ошибка при чтении файла {file_path}: {e}")
print(f"Прочитано документов: {len(documents_map)}")
return documents_map
def load_datasets(search_dataset_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Загружает датасет для поиска и готовит данные для векторизации.
Args:
search_dataset_path: Путь к Excel с пунктами для поиска.
Returns:
Кортеж: (полный DataFrame поискового датасета, DataFrame с уникальными вопросами для векторизации).
"""
print(f"Загрузка поискового датасета из {search_dataset_path}...")
try:
search_df = pd.read_excel(search_dataset_path)
print(f"Загружен поисковый датасет: {len(search_df)} строк, столбцы: {search_df.columns.tolist()}")
# Проверяем наличие обязательных столбцов
required_columns = ['id', 'question', 'text', 'filename']
missing_cols = [col for col in required_columns if col not in search_df.columns]
if missing_cols:
print(f"Ошибка: В поисковом датасете отсутствуют обязательные столбцы: {missing_cols}")
sys.exit(1)
# Преобразуем NaN в пустые строки для текстовых полей
# Добавляем 'query_expansion', если он есть, для обработки NaN
text_columns = ['question', 'text', 'item_type', 'filename']
if 'query_expansion' in search_df.columns:
text_columns.append('query_expansion')
for col in text_columns:
if col in search_df.columns:
search_df[col] = search_df[col].fillna('')
# Если необязательный item_type отсутствует, добавляем его пустым
elif col == 'item_type':
print(f"Предупреждение: столбец '{col}' отсутствует в поисковом датасете. Добавлен пустой столбец.")
search_df[col] = ''
# Убедимся, что 'id' имеет целочисленный тип
try:
search_df['id'] = search_df['id'].astype(int)
except ValueError as e:
print(f"Ошибка при приведении типов столбца 'id' в поисковом датасете: {e}. Убедитесь, что ID являются целыми числами.")
sys.exit(1)
except FileNotFoundError:
print(f"Ошибка: Поисковый датасет не найден по пути {search_dataset_path}")
sys.exit(1)
except Exception as e:
print(f"Ошибка при чтении поискового датасета: {e}")
sys.exit(1)
# Готовим DataFrame для векторизации уникальных вопросов
# Включаем query_expansion, если он есть
cols_for_embedding = ['id', 'question']
query_expansion_exists = 'query_expansion' in search_df.columns
if query_expansion_exists:
cols_for_embedding.append('query_expansion')
print("Столбец 'query_expansion' найден в поисковом датасете.")
else:
print("Столбец 'query_expansion' не найден в поисковом датасете.")
questions_to_embed = search_df[cols_for_embedding].drop_duplicates(subset=['id']).copy()
# Если query_expansion не существует, добавляем пустой столбец для единообразия
if not query_expansion_exists:
questions_to_embed['query_expansion'] = ''
print(f"Уникальных вопросов для векторизации: {len(questions_to_embed)}")
# Теперь search_df это и есть наш "объединенный" датасет (так как QA не используется)
return search_df, questions_to_embed
def perform_chunking(
documents_map: dict[str, ParsedDocument],
chunking_strategy: str,
process_tables: bool,
strategy_params_json: str # Expect JSON string
) -> tuple[pd.DataFrame, list[LinkerEntity]]:
"""
Выполняет чанкинг для всех документов.
Args:
documents_map: Словарь {имя_файла: сущность_документа}.
chunking_strategy: Имя используемой стратегии чанкинга.
process_tables: Флаг, указывающий, нужно ли обрабатывать таблицы.
strategy_params_json: Строка JSON с параметрами для стратегии.
Returns:
Кортеж: (DataFrame с чанками для поиска, список всех созданных сущностей LinkerEntity)
"""
print("Выполнение чанкинга...")
searchable_chunks_data = [] # Данные только для чанков с in_search_text
final_entities: list[LinkerEntity] = [] # Список для ВСЕХ сущностей (доки, чанки, связи и т.д.)
# Parse strategy parameters from JSON string
try:
chunking_params = json.loads(strategy_params_json)
print(f"Параметры для стратегии '{chunking_strategy}': {chunking_params}")
except json.JSONDecodeError as e:
print(f"Ошибка парсинга JSON для strategy-params: '{strategy_params_json}'. Используются параметры по умолчанию стратегии. Ошибка: {e}")
chunking_params = {} # Use strategy defaults if JSON is invalid
print(f"Используется стратегия чанкинга: '{chunking_strategy}'")
print(f"Обработка таблиц: {'Включена' if process_tables else 'Отключена'}")
for doc_name, parsed_doc in tqdm(documents_map.items(), desc="Чанкинг документов"):
try:
# Инициализируем Destructurer ВНУТРИ цикла для КАЖДОГО документа
destructurer = Destructurer(
document=parsed_doc,
process_tables=process_tables,
strategy_name=chunking_strategy, # Передаем имя стратегии при инициализации
**chunking_params # И параметры стратегии
)
# Destructure создает DocumentAsEntity, чанки, связи и возвращает их как LinkerEntity
new_entities = destructurer.destructure()
# Добавляем ВСЕ созданные сущности (сериализованные LinkerEntity) в общий список
final_entities.extend(new_entities)
# Собираем данные для DataFrame только из тех сущностей,
# у которых есть поле in_search_text (это наши чанки для поиска)
for entity in new_entities:
# Проверяем наличие атрибута 'in_search_text', а не тип
if hasattr(entity, 'in_search_text') and entity.in_search_text:
entity_data = {
'chunk_id': str(entity.id),
'doc_name': doc_name, # Имя исходного файла
'doc_id': str(entity.source_id), # ID сущности документа (DocumentAsEntity)
'text': entity.in_search_text, # Текст для векторизации и поиска
'type': entity.type, # Тип сущности (например, 'FixedSizeChunk')
'strategy_params': json.dumps(chunking_params, ensure_ascii=False),
}
searchable_chunks_data.append(entity_data)
except Exception as e:
# Логируем ошибку и продолжаем с остальными документами
print(f"\nОшибка при чанкинге документа {doc_name}: {e}")
import traceback
traceback.print_exc() # Печатаем traceback для детальной отладки
# Создаем DataFrame только из чанков, предназначенных для поиска
chunks_df = pd.DataFrame(searchable_chunks_data)
print(f"Создано чанков для поиска: {len(chunks_df)}")
# Возвращаем DataFrame с чанками для поиска и ПОЛНЫЙ список всех LinkerEntity
return chunks_df, final_entities
def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool, device: str):
"""Инициализирует модель и токенизатор."""
print(f"Загрузка модели {model_name} на устройство {device}...")
if use_sentence_transformers:
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, device=device)
tokenizer = None # sentence_transformers не требует отдельного токенизатора
print("Используется SentenceTransformer.")
return model, tokenizer
except ImportError:
print("Ошибка: Библиотека sentence_transformers не установлена. Установите: pip install sentence-transformers")
sys.exit(1)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()
print("Используется AutoModel и AutoTokenizer из transformers.")
return model, tokenizer
except Exception as e:
print(f"Ошибка при загрузке модели {model_name} из transformers: {e}")
sys.exit(1)
def get_embeddings(
texts: list[str],
model,
tokenizer,
batch_size: int,
use_sentence_transformers: bool,
device: str
) -> np.ndarray:
"""Получает эмбеддинги для списка текстов."""
all_embeddings = []
desc = "Векторизация (Sentence Transformers)" if use_sentence_transformers else "Векторизация (Transformers)"
for i in tqdm(range(0, len(texts), batch_size), desc=desc):
batch_texts = texts[i:i+batch_size]
if not batch_texts:
continue
if use_sentence_transformers:
# Эмбеддинги через sentence_transformers
embeddings = model.encode(batch_texts, batch_size=len(batch_texts), show_progress_bar=False)
all_embeddings.append(embeddings)
else:
# Эмбеддинги через transformers с average pooling
try:
encoding = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=512, # Стандартное ограничение для многих моделей
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**encoding)
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
all_embeddings.append(embeddings.cpu().numpy())
except Exception as e:
print(f"Ошибка при векторизации батча (индексы {i} - {i+batch_size}): {e}")
print(f"Тексты батча: {batch_texts[:2]}...")
# Добавляем нулевые векторы, чтобы не сломать vstack
# Определяем размер эмбеддинга
if all_embeddings:
embedding_dim = all_embeddings[0].shape[1]
else:
# Пытаемся получить размер из конфигурации модели
try:
embedding_dim = model.config.hidden_size
except AttributeError:
embedding_dim = 768 # Запасной вариант
print(f"Не удалось определить размер эмбеддинга, используется {embedding_dim}")
print(f"Добавление нулевых эмбеддингов размерности ({len(batch_texts)}, {embedding_dim})")
null_embeddings = np.zeros((len(batch_texts), embedding_dim), dtype=np.float32)
all_embeddings.append(null_embeddings)
if not all_embeddings:
print("ВНИМАНИЕ: Не удалось создать эмбеддинги.")
# Возвращаем пустой массив правильной формы, если возможно
try:
embedding_dim = model.config.hidden_size if not use_sentence_transformers else model.get_sentence_embedding_dimension()
except:
embedding_dim = 768
return np.empty((0, embedding_dim), dtype=np.float32)
# Объединяем эмбеддинги из всех батчей
try:
final_embeddings = np.vstack(all_embeddings)
except ValueError as e:
print(f"Ошибка при объединении эмбеддингов: {e}")
print("Размеры эмбеддингов в батчах:")
for i, emb_batch in enumerate(all_embeddings):
print(f" Батч {i}: {emb_batch.shape}")
# Попробуем определить ожидаемый размер и создать нулевой массив
if all_embeddings:
embedding_dim = all_embeddings[0].shape[1]
print(f"Возвращение нулевого массива размерности ({len(texts)}, {embedding_dim})")
return np.zeros((len(texts), embedding_dim), dtype=np.float32)
else:
return np.empty((0, 768), dtype=np.float32) # Запасной вариант
print(f"Получено эмбеддингов: {final_embeddings.shape}")
return final_embeddings
# --- Caching Helper Functions ---
def _get_params_hash(
model_name: str,
process_tables: bool | None = None,
strategy_params: dict | None = None # Expect the parsed dictionary
) -> str:
"""Создает MD5 хэш из переданных параметров."""
hasher = hashlib.md5()
hasher.update(model_name.encode())
# Add chunking strategy and table processing flag if provided
if process_tables is not None:
hasher.update(str(process_tables).encode())
# Add strategy parameters (sort items to ensure consistent hash)
if strategy_params:
sorted_params = sorted(strategy_params.items())
hasher.update(json.dumps(sorted_params).encode())
return hasher.hexdigest()
def _get_cache_path(cache_dir: Path, hash_str: str, filename: str) -> Path:
"""Формирует путь к файлу кэша, создавая поддиректории."""
# Используем первые 2 символа хэша для распределения по поддиректориям
# Это помогает избежать слишком большого количества файлов в одной директории
cache_subdir = cache_dir / hash_str[:2] / hash_str
cache_subdir.mkdir(parents=True, exist_ok=True)
return cache_subdir / filename
# --- Добавляем функцию для хэша чанкинга ---
def _get_chunking_cache_hash(
data_folder: str,
chunking_strategy: str,
process_tables: bool,
strategy_params: dict # Ожидаем словарь
) -> str:
"""Создает MD5 хэш для параметров чанкинга и папки с данными."""
hasher = hashlib.md5()
hasher.update(data_folder.encode())
hasher.update(chunking_strategy.encode())
hasher.update(str(process_tables).encode())
# Сортируем параметры для консистентности хэша
sorted_params = sorted(strategy_params.items())
hasher.update(json.dumps(sorted_params).encode())
return hasher.hexdigest()
# ---------------------------------------------
# --- Main Evaluation Function ---
def evaluate_run(
search_dataset: pd.DataFrame,
questions_to_embed: pd.DataFrame,
chunks_df: pd.DataFrame,
all_entities: list[LinkerEntity],
model: Any | None, # Принимаем None
tokenizer: Any | None, # Принимаем None
args: argparse.Namespace
) -> pd.DataFrame:
"""
Выполняет основной цикл оценки для одного набора параметров.
Args:
search_dataset: DataFrame поискового датасета.
questions_to_embed: DataFrame с уникальными вопросами для векторизации.
chunks_df: DataFrame с данными по чанкам.
all_entities: Список всех сущностей (документы, чанки, связи).
model: Модель для векторизации.
tokenizer: Токенизатор.
args: Аргументы командной строки.
Returns:
DataFrame с детальными метриками по каждому вопросу для этого запуска.
"""
print("Начало этапа оценки...")
# Переменные для модели и токенизатора, инициализируем None
loaded_model: Any | None = model
loaded_tokenizer: Any | None = tokenizer
# --- Caching Setup ---
print("Настройка кэширования...")
CACHE_DIR_PATH = Path(args.cache_dir)
model_slug = args.model_name.split('/')[-1] # Basic slug for filename clarity
# --- Определяем, какой текст использовать для эмбеддингов вопросов ---
# и устанавливаем флаг qe_active, который будет влиять на кэш
if args.use_qe and 'query_expansion' in questions_to_embed.columns and questions_to_embed['query_expansion'].notna().any(): # Check if column exists and has non-NA values
print("Используется Query Expansion (столбец 'query_expansion') для векторизации вопросов.")
query_texts_to_embed = questions_to_embed['query_expansion'].tolist()
qe_active = True
else:
print("Используется оригинальный текст вопроса (столбец 'question') для векторизации.")
query_texts_to_embed = questions_to_embed['question'].tolist()
qe_active = False
# Cache key for question embeddings (ЗАВИСИТ от модели и флага use_qe)
question_params_for_hash = {
'model_name': args.model_name,
'use_qe': qe_active # Добавляем фактическое использование QE в параметры для хэша
}
question_hash = hashlib.md5(json.dumps(question_params_for_hash, sort_keys=True).encode()).hexdigest()
question_embeddings_cache_path = _get_cache_path(
CACHE_DIR_PATH, question_hash, f"q_embeddings_{model_slug}_qe{qe_active}.npy"
)
# Cache key for chunk embeddings (depends on model and chunking)
chunk_hash = _get_params_hash(
args.model_name,
args.process_tables, # Include table flag
json.loads(args.strategy_params) # Pass parsed params dictionary
)
chunk_embeddings_cache_path = _get_cache_path(
CACHE_DIR_PATH, chunk_hash,
f"c_emb_{model_slug}_s-{args.chunking_strategy}_t{args.process_tables}_ph-{hashlib.md5(args.strategy_params.encode()).hexdigest()[:8]}.npy"
)
# Cache key for similarity matrix (depends on both sets of embeddings)
similarity_hash = f"{question_hash}_{chunk_hash}" # Combine hashes
similarity_cache_path = _get_cache_path(
CACHE_DIR_PATH, similarity_hash,
f"sim_{model_slug}_qe{qe_active}_ph-{hashlib.md5(args.strategy_params.encode()).hexdigest()[:8]}.npy" # Добавляем флаг QE в имя файла
)
# 1. Векторизация вопросов и чанков (с кэшем)
question_embeddings = None
needs_model_load = False # Флаг, указывающий, нужна ли загрузка модели
if question_embeddings_cache_path.exists():
try:
print(f"Загрузка кэшированных эмбеддингов вопросов из: {question_embeddings_cache_path}")
question_embeddings = np.load(question_embeddings_cache_path, allow_pickle=False)
if len(question_embeddings) != len(questions_to_embed):
print(f"Предупреждение: Размер кэша эмбеддингов вопросов не совпадает. Пересчет.")
question_embeddings = None
else:
print("Кэш эмбеддингов вопросов успешно загружен.")
except Exception as e:
print(f"Ошибка загрузки кэша эмбеддингов вопросов: {e}. Пересчет.")
question_embeddings = None
if question_embeddings is None:
needs_model_load = True # Требуется модель для генерации эмбеддингов
print("Векторизация вопросов (потребуется загрузка модели)...")
chunk_embeddings = None
if chunk_embeddings_cache_path.exists():
try:
print(f"Загрузка кэшированных эмбеддингов чанков из: {chunk_embeddings_cache_path}")
chunk_embeddings = np.load(chunk_embeddings_cache_path, allow_pickle=False)
if len(chunk_embeddings) != len(chunks_df):
print(f"Предупреждение: Размер кэша эмбеддингов чанков не совпадает. Пересчет.")
chunk_embeddings = None
else:
print("Кэш эмбеддингов чанков успешно загружен.")
except Exception as e:
print(f"Ошибка загрузки кэша эмбеддингов чанков: {e}. Пересчет.")
chunk_embeddings = None
if chunk_embeddings is None:
needs_model_load = True # Требуется модель для генерации эмбеддингов
print("Векторизация чанков (потребуется загрузка модели)...")
# --- Отложенная загрузка модели, если необходимо ---
if needs_model_load and loaded_model is None:
print("\n--- Загрузка модели и токенизатора (т.к. кэш эмбеддингов отсутствует) ---")
loaded_model, loaded_tokenizer = setup_model_and_tokenizer(
args.model_name, args.use_sentence_transformers, args.device
)
print("--- Модель и токенизатор загружены ---\n")
# --- Повторная генерация эмбеддингов, если они не загрузились из кэша ---
if question_embeddings is None:
if loaded_model is None:
print("Критическая ошибка: Модель не загружена, но требуется для векторизации вопросов!")
# Возвращаем пустой DataFrame или выбрасываем исключение
return pd.DataFrame()
print("Повторная векторизация вопросов...")
question_embeddings = get_embeddings(
query_texts_to_embed,
loaded_model, loaded_tokenizer, args.batch_size, args.use_sentence_transformers, args.device
)
if question_embeddings.shape[0] > 0:
try:
print(f"Сохранение эмбеддингов вопросов в кэш: {question_embeddings_cache_path}")
np.save(question_embeddings_cache_path, question_embeddings, allow_pickle=False)
except Exception as e:
print(f"Не удалось сохранить кэш эмбеддингов вопросов: {e}")
if chunk_embeddings is None:
if loaded_model is None:
print("Критическая ошибка: Модель не загружена, но требуется для векторизации чанков!")
return pd.DataFrame()
print("Повторная векторизация чанков...")
chunk_texts = chunks_df['text'].fillna('').astype(str).tolist()
chunk_embeddings = get_embeddings(
chunk_texts,
loaded_model, loaded_tokenizer, args.batch_size, args.use_sentence_transformers, args.device
)
if chunk_embeddings.shape[0] > 0:
try:
print(f"Сохранение эмбеддингов чанков в кэш: {chunk_embeddings_cache_path}")
np.save(chunk_embeddings_cache_path, chunk_embeddings, allow_pickle=False)
except Exception as e:
print(f"Не удалось сохранить кэш эмбеддингов чанков: {e}")
# Проверка совпадения количества эмбеддингов и данных
if len(question_embeddings) != len(questions_to_embed):
print(f"Ошибка: Количество эмбеддингов вопросов ({len(question_embeddings)}) не совпадает с количеством уникальных вопросов ({len(questions_to_embed)}).")
# Можно либо прервать выполнение, либо попытаться исправить
# Например, взять первые N эмбеддингов, но это может быть некорректно
sys.exit(1)
if len(chunk_embeddings) != len(chunks_df):
print(f"Ошибка: Количество эмбеддингов чанков ({len(chunk_embeddings)}) не совпадает с количеством чанков в DataFrame ({len(chunks_df)}).")
# Попытка исправить (если ошибка небольшая) или выход
if abs(len(chunk_embeddings) - len(chunks_df)) < 5:
print("Попытка обрезать лишние эмбеддинги/данные...")
min_len = min(len(chunk_embeddings), len(chunks_df))
chunk_embeddings = chunk_embeddings[:min_len]
chunks_df = chunks_df.iloc[:min_len]
print(f"Размеры выровнены до {min_len}")
else:
sys.exit(1)
# Создаем маппинг ID вопроса к индексу в эмбеддингах
question_id_to_idx = {
row['id']: i for i, (_, row) in enumerate(questions_to_embed.iterrows())
}
# 2. Расчет косинусной близости
print("Расчет косинусной близости...")
# Проверка на пустые эмбеддинги
if question_embeddings.shape[0] == 0 or chunk_embeddings.shape[0] == 0:
print("Ошибка: Отсутствуют эмбеддинги вопросов или чанков для расчета близости.")
# Возвращаем пустой DataFrame или обрабатываем ошибку иначе
return pd.DataFrame()
similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
# 3. Инициализация InjectionBuilder (если нужно)
injection_builder = None
if args.use_injection:
print("Инициализация InjectionBuilder...")
repository = InMemoryEntityRepository(all_entities)
injection_builder = InjectionBuilder(repository)
# TODO: Зарегистрировать стратегии, если необходимо
# builder.register_strategy(...)
# 4. Цикл по уникальным вопросам для оценки
results = []
print(f"Оценка для {len(questions_to_embed)} уникальных вопросов...")
for question_id_iter, question_data in tqdm(questions_to_embed.iterrows(), total=len(questions_to_embed), desc="Оценка вопросов"): # Renamed loop variable
q_id = question_data['id']
q_text = question_data['question']
# Получаем все строки из исходного датасета для этого вопроса
question_rows = search_dataset[search_dataset['id'] == q_id] # Use search_dataset
if question_rows.empty:
print(f"Предупреждение: Нет данных в search_dataset для вопроса ID={q_id}")
continue
# Получаем пункты (relevant items)
puncts = question_rows['text'].tolist()
# reference_answer больше не используется и не извлекается
# Получаем индекс вопроса в матрице близости
if q_id not in question_id_to_idx:
print(f"Предупреждение: Вопрос ID={q_id} не найден в маппинге эмбеддингов.")
continue
question_idx = question_id_to_idx[q_id]
# --- Оценка на уровне чанков (Chunk-level) ---
chunk_level_metrics = evaluate_chunk_relevance(
q_id, question_idx, puncts,
similarity_matrix, chunks_df, args.top_n, args.similarity_threshold
)
# --- Оценка на уровне сборки (Assembly-level) ---
# Удаляем assembly_relevance, основанный на reference_answer
assembly_level_metrics = {} # Start with an empty dict for assembly metrics
assembled_context = ""
top_chunk_indices = chunk_level_metrics.get("top_chunk_ids", []) # Get indices first
neighbors_included = False # Flag to log
if args.use_injection and injection_builder and top_chunk_indices:
try:
# Преобразуем ID строк обратно в UUID чанков
top_chunk_uuids = [UUID(chunks_df.iloc[idx]['chunk_id']) for idx in top_chunk_indices]
final_chunk_uuids_for_assembly = set(top_chunk_uuids) # Start with top chunks
# --- Добавляем соседей, если нужно ---
if args.include_neighbors:
neighbors_included = True
# --- Убираем логирование индексов ---
neighbor_chunks = repository.get_neighboring_chunks(chunk_ids=top_chunk_uuids, max_distance=1)
neighbor_ids = {neighbor.id for neighbor in neighbor_chunks}
# --- Логирование до/после добавления ID соседей ---
print(f" [DEBUG QID {q_id}] Кол-во ID до добавления соседей: {len(final_chunk_uuids_for_assembly)}")
print(f" [DEBUG QID {q_id}] Кол-во найденных ID соседей: {len(neighbor_ids)}")
final_chunk_uuids_for_assembly.update(neighbor_ids)
print(f" [DEBUG QID {q_id}] Кол-во ID после добавления соседей: {len(final_chunk_uuids_for_assembly)}")
# --- Конец логирования ---
# --- Убираем логирование индексов ---
else:
# --- Убираем логирование индексов ---
pass # Ничего не делаем, если соседи не включены
# Собираем контекст
# Передаем финальный набор UUID (уникальный)
assembled_context = injection_builder.build(
filtered_entities=list(final_chunk_uuids_for_assembly),
# chunk_scores= {chunks_df.loc[idx, 'chunk_id']: sim for idx, sim in zip(top_chunk_ids_for_assembly, chunk_level_metrics.get('top_chunk_similarities',[]))} # Можно добавить веса
)
# --- Новая метрика: Assembly Punct Recall ---
# Оцениваем, сколько пунктов из датасета найдено в собранном контексте
# (по вашей идее: пункт считается найденным, если хотя бы одна его часть,
# разделенная переносом строки, найдена в контексте)
assembly_found_puncts = 0
assembly_total_puncts = len(puncts)
if assembly_total_puncts > 0 and assembled_context:
# Итерируемся по каждому исходному пункту
for punct_text in puncts:
# Разбиваем пункт на части по переносу строки
# Убираем пустые строки, которые могут появиться из-за двойных переносов
punct_parts = [part for part in punct_text.split('\n') if part.strip()]
# Если пункт пустой или состоит только из пробельных символов после разбивки,
# пропускаем его (не считаем ни найденным, ни не найденным в контексте recall)
if not punct_parts:
assembly_total_puncts -= 1 # Уменьшаем общее число пунктов для расчета recall
continue
is_punct_found = False
# Итерируемся по частям пункта
for part_text in punct_parts:
# Сравниваем КАЖДУЮ ЧАСТЬ пункта с собранным контекстом
if calculate_chunk_overlap(assembled_context, part_text.strip()) >= args.similarity_threshold:
# Если ХОТЯ БЫ ОДНА часть найдена, считаем ВЕСЬ пункт найденным
is_punct_found = True
break # Дальше части этого пункта можно не проверять
# Если флаг is_punct_found стал True, увеличиваем счетчик найденных пунктов
if is_punct_found:
assembly_found_puncts += 1
# Рассчитываем recall, только если были валидные пункты для проверки
if assembly_total_puncts > 0:
assembly_level_metrics["assembly_punct_recall"] = assembly_found_puncts / assembly_total_puncts
else:
assembly_level_metrics["assembly_punct_recall"] = 0.0 # Или можно None, если нет валидных пунктов
else:
assembly_level_metrics["assembly_punct_recall"] = 0.0
# Добавляем сам текст сборки для возможного анализа (усеченный)
assembly_level_metrics["assembled_context_preview"] = assembled_context[:500] + ("..." if len(assembled_context) > 500 else "")
except Exception as e:
print(f"Ошибка при сборке/оценке контекста для вопроса ID={q_id}: {e}")
# Записываем None или 0, чтобы не прерывать процесс
assembly_level_metrics["assembly_punct_recall"] = None # Indicate error
assembly_level_metrics["assembled_context_preview"] = f"Error during assembly: {e}"
# Собираем все метрики для вопроса
question_result = {
"run_id": args.run_id,
"batch_id": args.batch_id, # --- Добавляем batch_id в результаты ---
"question_id": q_id,
"question_text": q_text,
# Параметры запуска
"model_name": args.model_name,
"chunking_strategy": args.chunking_strategy, # Log strategy
"process_tables": args.process_tables, # Log table flag
"strategy_params": args.strategy_params, # Log JSON string
"top_n": args.top_n,
"use_injection": args.use_injection,
"use_qe": qe_active, # Log QE status
"neighbors_included": neighbors_included, # Log neighbor flag
"similarity_threshold": args.similarity_threshold,
# Метрики Chunk-level
**chunk_level_metrics,
# Метрики Assembly-level (теперь с recall по пунктам)
**assembly_level_metrics,
# Тексты для отладки (эталонный ответ удален, сборка добавлена выше)
# "assembled_context": assembled_context[:500] + "..." if assembled_context else "",
}
results.append(question_result)
print("Оценка завершена.")
return pd.DataFrame(results)
def evaluate_chunk_relevance(
question_id: int,
question_idx: int,
puncts: list[str],
similarity_matrix: np.ndarray,
chunks_df: pd.DataFrame,
top_n: int,
similarity_threshold: float
) -> dict:
"""
Оценивает релевантность чанков для одного вопроса.
(Адаптировано из evaluate_for_top_n_with_mapping в evaluate_chunking.py)
Возвращает словарь с метриками для этого вопроса.
"""
metrics = {
"chunk_text_precision": 0.0,
"chunk_text_recall": 0.0,
"chunk_text_f1": 0.0,
"found_puncts": 0,
"total_puncts": len(puncts),
"relevant_chunks": 0,
"total_chunks_in_top_n": 0,
"top_chunk_ids": [], # Индексы строк в chunks_df
"top_chunk_similarities": [],
}
if chunks_df.empty or similarity_matrix.shape[1] == 0:
print(f"Предупреждение (QID {question_id}): Нет чанков для оценки.")
return metrics
# Получаем схожести всех чанков с текущим вопросом
question_similarities = similarity_matrix[question_idx, :]
# Сортируем чанки по схожести и берем top_n
# argsort возвращает индексы элементов, которые бы отсортировали массив
# Берем последние N индексов (-top_n:) и разворачиваем ([::-1]) для убывания
# Добавляем проверку на случай если top_n > количества чанков
if top_n >= similarity_matrix.shape[1]:
sorted_chunk_indices = np.argsort(question_similarities)[::-1] # Берем все, сортируем по убыванию
else:
sorted_chunk_indices = np.argsort(question_similarities)[-top_n:][::-1]
# Ограничиваем top_n, если чанков меньше (это должно быть сделано выше, но дублируем для надежности)
actual_top_n = min(top_n, len(sorted_chunk_indices))
top_chunk_indices = sorted_chunk_indices[:actual_top_n]
# Сохраняем ID и схожести топ-чанков
metrics["top_chunk_ids"] = top_chunk_indices.tolist()
metrics["top_chunk_similarities"] = question_similarities[top_chunk_indices].tolist()
# Отбираем данные топ-чанков
top_chunks_df = chunks_df.iloc[top_chunk_indices]
metrics["total_chunks_in_top_n"] = len(top_chunks_df)
if metrics["total_chunks_in_top_n"] == 0:
return metrics # Если нет топ-чанков, метрики остаются нулевыми
# Оценка на основе текста (пунктов)
punct_found = [False] * metrics["total_puncts"]
question_relevant_chunks = 0
for i, (idx, chunk_row) in enumerate(top_chunks_df.iterrows()):
chunk_text = chunk_row['text']
is_relevant_to_punct = False
for j, punct_text in enumerate(puncts):
overlap = calculate_chunk_overlap(chunk_text, punct_text)
if overlap >= similarity_threshold:
is_relevant_to_punct = True
punct_found[j] = True
if is_relevant_to_punct:
question_relevant_chunks += 1
metrics["found_puncts"] = sum(punct_found)
metrics["relevant_chunks"] = question_relevant_chunks
if metrics["total_chunks_in_top_n"] > 0:
metrics["chunk_text_precision"] = metrics["relevant_chunks"] / metrics["total_chunks_in_top_n"]
if metrics["total_puncts"] > 0:
metrics["chunk_text_recall"] = metrics["found_puncts"] / metrics["total_puncts"]
if metrics["chunk_text_precision"] + metrics["chunk_text_recall"] > 0:
metrics["chunk_text_f1"] = (2 * metrics["chunk_text_precision"] * metrics["chunk_text_recall"] /
(metrics["chunk_text_precision"] + metrics["chunk_text_recall"]))
return metrics
# --- Основная функция ---
def main():
"""Основная функция скрипта."""
args = parse_args()
print(f"Запуск оценки с ID: {args.run_id}")
print(f"Параметры: {vars(args)}")
# --- Кэширование Чанкинга ---
CACHE_DIR_PATH = Path(args.cache_dir)
try:
# Парсим параметры стратегии один раз
parsed_strategy_params = json.loads(args.strategy_params)
except json.JSONDecodeError:
print(f"Предупреждение: Невалидный JSON в strategy_params: '{args.strategy_params}'. Используются параметры по умолчанию для хэша кэша.")
parsed_strategy_params = {}
chunking_hash = _get_chunking_cache_hash(
args.data_folder,
args.chunking_strategy,
args.process_tables,
parsed_strategy_params
)
chunks_df_cache_path = _get_cache_path(CACHE_DIR_PATH, chunking_hash, "chunks_df.parquet")
entities_cache_path = _get_cache_path(CACHE_DIR_PATH, chunking_hash, "final_entities.pkl")
chunks_df = None
all_entities = None
if chunks_df_cache_path.exists() and entities_cache_path.exists():
print(f"Найден кэш чанкинга (hash: {chunking_hash}). Загрузка...")
try:
chunks_df = pd.read_parquet(chunks_df_cache_path)
with open(entities_cache_path, 'rb') as f:
all_entities = pickle.load(f)
print(f"Кэш чанкинга успешно загружен: {len(chunks_df)} чанков, {len(all_entities)} сущностей.")
except Exception as e:
print(f"Ошибка загрузки кэша чанкинга: {e}. Выполняем чанкинг заново.")
chunks_df = None
all_entities = None
if chunks_df is None or all_entities is None:
print("Кэш чанкинга не найден или поврежден. Выполнение чтения документов и чанкинга...")
# 1. Загрузка данных
documents_map = read_documents(args.data_folder)
if not documents_map:
print("Нет документов для обработки. Завершение.")
return
# 2. Чанкинг
chunks_df, all_entities = perform_chunking(
documents_map,
args.chunking_strategy, # Pass strategy
args.process_tables, # Pass table flag
args.strategy_params # Pass JSON string parameters
)
if chunks_df.empty:
print("После чанкинга не осталось чанков для обработки. Завершение.")
return
# Сохраняем результаты чанкинга в кэш
try:
print(f"Сохранение результатов чанкинга в кэш (hash: {chunking_hash})...")
# Убедимся, что директория кэша существует (на всякий случай)
chunks_df_cache_path.parent.mkdir(parents=True, exist_ok=True)
entities_cache_path.parent.mkdir(parents=True, exist_ok=True)
chunks_df.to_parquet(chunks_df_cache_path)
with open(entities_cache_path, 'wb') as f:
pickle.dump(all_entities, f)
print("Результаты чанкинга сохранены в кэш.")
except Exception as e:
print(f"Ошибка сохранения кэша чанкинга: {e}")
# --- Конец Кэширования Чанкинга ---
# Загружаем поисковый датасет (это нужно делать всегда, т.к. он не кэшируется здесь)
search_df, questions_to_embed = load_datasets(args.search_dataset_path)
# 3. Выполнение оценки (передаем загруженные или свежесгенерированные chunks_df и all_entities)
results_df = evaluate_run(
search_df, questions_to_embed, chunks_df, all_entities,
None, None, args # Передаем None для model и tokenizer
)
# 5. Сохранение результатов
if not results_df.empty:
os.makedirs(args.output_dir, exist_ok=True)
# output_filename = f"results_{args.run_id}.csv"
# Добавляем batch_id в имя файла для лучшей группировки
output_filename = f"results_{args.batch_id}_{args.run_id}.csv"
output_path = os.path.join(args.output_dir, output_filename)
try:
results_df.to_csv(output_path, index=False, encoding='utf-8')
print(f"Детальные результаты сохранены в: {output_path}")
except Exception as e:
print(f"Ошибка при сохранении результатов в {output_path}: {e}")
else:
print("Нет результатов для сохранения.")
if __name__ == "__main__":
main()
|