Spaces:
Sleeping
Sleeping
File size: 39,523 Bytes
86c402d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 |
#!/usr/bin/env python
"""
Скрипт для оценки качества различных стратегий чанкинга.
Сравнивает стратегии на основе релевантности чанков к вопросам.
"""
import argparse
import json
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from fuzzywuzzy import fuzz
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
# Константы для настройки
DATA_FOLDER = "data/docs" # Путь к папке с документами
MODEL_NAME = "intfloat/e5-base" # Название модели для векторизации
DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами
BATCH_SIZE = 8 # Размер батча для векторизации
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Устройство для вычислений
SIMILARITY_THRESHOLD = 0.7 # Порог для нечеткого сравнения
OUTPUT_DIR = "data" # Директория для сохранения результатов
TOP_CHUNKS_DIR = "data/top_chunks" # Директория для сохранения топ-чанков
TOP_N_VALUES = [5, 10, 20, 30, 50, 70, 100] # Значения N для оценки
# Параметры стратегий чанкинга
FIXED_SIZE_CONFIG = {
"words_per_chunk": 50, # Количество слов в чанке
"overlap_words": 25 # Количество слов перекрытия
}
sys.path.insert(0, str(Path(__file__).parent.parent))
from ntr_fileparser import UniversalParser
from ntr_text_fragmentation import Destructurer
def _average_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Расчёт усредненного эмбеддинга по всем токенам
Args:
last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены
Returns:
torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size)
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def parse_args():
"""
Парсит аргументы командной строки.
Returns:
Аргументы командной строки
"""
parser = argparse.ArgumentParser(description="Скрипт для оценки качества чанкинга")
parser.add_argument("--data-folder", type=str, default=DATA_FOLDER,
help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})")
parser.add_argument("--model-name", type=str, default=MODEL_NAME,
help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})")
parser.add_argument("--dataset-path", type=str, default=DATASET_PATH,
help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help=f"Размер батча для векторизации (по умолчанию: {BATCH_SIZE})")
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})")
parser.add_argument("--force-recompute", action="store_true",
help="Принудительно пересчитать эмбеддинги, игнорируя сохраненные")
parser.add_argument("--use-sentence-transformers", action="store_true",
help="Использовать библиотеку sentence_transformers для извлечения эмбеддингов (для FRIDA и других моделей)")
parser.add_argument("--device", type=str, default=DEVICE,
help=f"Устройство для вычислений (по умолчанию: {DEVICE})")
# Параметры для fixed_size стратегии
parser.add_argument("--words-per-chunk", type=int, default=FIXED_SIZE_CONFIG["words_per_chunk"],
help=f"Количество слов в чанке для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['words_per_chunk']})")
parser.add_argument("--overlap-words", type=int, default=FIXED_SIZE_CONFIG["overlap_words"],
help=f"Количество слов перекрытия для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['overlap_words']})")
return parser.parse_args()
def read_documents(folder_path: str) -> dict:
"""
Читает все документы из указанной папки.
Args:
folder_path: Путь к папке с документами
Returns:
Словарь {имя_файла: parsed_document}
"""
print(f"Чтение документов из {folder_path}...")
parser = UniversalParser()
documents = {}
for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"):
try:
doc_name = file_path.stem
documents[doc_name] = parser.parse_by_path(str(file_path))
except Exception as e:
print(f"Ошибка при чтении файла {file_path}: {e}")
return documents
def process_documents(documents: dict, fixed_size_config: dict) -> pd.DataFrame:
"""
Обрабатывает документы со стратегией fixed_size для чанкинга.
Args:
documents: Словарь с распарсенными документами
fixed_size_config: Конфигурация для fixed_size стратегии
Returns:
DataFrame с чанками
"""
print("Обработка документов стратегией fixed_size...")
all_data = []
for doc_name, document in tqdm(documents.items(), desc="Применение стратегии fixed_size"):
# Стратегия fixed_size для чанкинга
destructurer = Destructurer(document)
destructurer.configure('fixed_size',
words_per_chunk=fixed_size_config["words_per_chunk"],
overlap_words=fixed_size_config["overlap_words"])
fixed_size_entities, _ = destructurer.destructure()
# Обрабатываем только сущности для поиска
for entity in fixed_size_entities:
if hasattr(entity, 'use_in_search') and entity.use_in_search:
entity_data = {
'id': str(entity.id),
'doc_name': doc_name,
'name': entity.name,
'text': entity.text,
'type': entity.type,
'strategy': 'fixed_size',
'metadata': json.dumps(entity.metadata, ensure_ascii=False)
}
all_data.append(entity_data)
# Создаем DataFrame
df = pd.DataFrame(all_data)
# Фильтруем по типу, исключая Document
df = df[df['type'] != 'Document']
return df
def load_questions_dataset(file_path: str) -> pd.DataFrame:
"""
Загружает датасет с вопросами из Excel-файла.
Args:
file_path: Путь к Excel-файлу
Returns:
DataFrame с вопросами и пунктами
"""
print(f"Загрузка датасета из {file_path}...")
df = pd.read_excel(file_path)
print(f"Загружен датасет со столбцами: {df.columns.tolist()}")
# Преобразуем NaN в пустые строки для текстовых полей
text_columns = ['question', 'text', 'item_type']
for col in text_columns:
if col in df.columns:
df[col] = df[col].fillna('')
return df
def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool = False, device: str = DEVICE):
"""
Инициализирует модель и токенизатор.
Args:
model_name: Название предобученной модели
use_sentence_transformers: Использовать ли библиотеку sentence_transformers
device: Устройство для вычислений
Returns:
Кортеж (модель, токенизатор) или объект SentenceTransformer
"""
print(f"Загрузка модели {model_name} на устройство {device}...")
if use_sentence_transformers:
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, device=device)
return model, None
except ImportError:
print("Библиотека sentence_transformers не установлена. Установите её с помощью pip install sentence-transformers")
raise
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()
return model, tokenizer
def get_embeddings(texts: list[str], model, tokenizer=None, batch_size: int = BATCH_SIZE, use_sentence_transformers: bool = False, device: str = DEVICE) -> np.ndarray:
"""
Получает эмбеддинги для списка текстов с использованием average pooling или sentence_transformers.
Args:
texts: Список текстов
model: Модель для векторизации или SentenceTransformer
tokenizer: Токенизатор (None для sentence_transformers)
batch_size: Размер батча
use_sentence_transformers: Использовать ли библиотеку sentence_transformers
device: Устройство для вычислений
Returns:
Массив эмбеддингов
"""
if use_sentence_transformers:
# Используем sentence_transformers для получения эмбеддингов
all_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов (sentence_transformers)"):
batch_texts = texts[i:i+batch_size]
# Получаем эмбеддинги с помощью sentence_transformers
embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=False)
all_embeddings.append(embeddings)
return np.vstack(all_embeddings)
else:
# Используем стандартный подход с average pooling
all_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов"):
batch_texts = texts[i:i+batch_size]
# Токенизация с обрезкой и padding
encoding = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
# Получаем эмбеддинги с average pooling
with torch.no_grad():
outputs = model(**encoding)
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
all_embeddings.append(embeddings.cpu().numpy())
return np.vstack(all_embeddings)
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
"""
Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio.
Args:
chunk_text: Текст чанка
punct_text: Текст пункта
Returns:
Коэффициент перекрытия от 0 до 1
"""
# Если чанк входит в пункт, возвращаем 1.0 (полное вхождение)
if chunk_text in punct_text:
return 1.0
# Если пункт входит в чанк, возвращаем соотношение длин
if punct_text in chunk_text:
return len(punct_text) / len(chunk_text)
# Используем partial_ratio из fuzzywuzzy, который лучше обрабатывает
# случаи, когда один текст является подстрокой другого, даже с небольшими различиями
partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0
return partial_ratio_score
def save_embeddings_and_data(embeddings: np.ndarray, data: pd.DataFrame, filename: str, output_dir: str):
"""
Сохраняет эмбеддинги и соответствующие данные в файлы.
Args:
embeddings: Массив эмбеддингов
data: DataFrame с данными
filename: Базовое имя файла
output_dir: Директория для сохранения
"""
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
data_path = os.path.join(output_dir, f"{filename}_data.csv")
# Сохраняем эмбеддинги
np.save(embeddings_path, embeddings)
print(f"Эмбеддинги сохранены в {embeddings_path}")
# Сохраняем данные
data.to_csv(data_path, index=False)
print(f"Данные сохранены в {data_path}")
def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]:
"""
Загружает эмбеддинги и соответствующие данные из файлов.
Args:
filename: Базовое имя файла
output_dir: Директория, где хранятся файлы
Returns:
Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены
"""
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
data_path = os.path.join(output_dir, f"{filename}_data.csv")
if os.path.exists(embeddings_path) and os.path.exists(data_path):
print(f"Загрузка данных из {embeddings_path} и {data_path}...")
embeddings = np.load(embeddings_path)
data = pd.read_csv(data_path)
return embeddings, data
return None, None
def save_top_chunks_for_question(
question_id: int,
question_text: str,
question_puncts: list[str],
top_chunks: pd.DataFrame,
similarities: dict,
overlap_data: list,
output_dir: str
):
"""
Сохраняет топ-чанки для конкретного вопроса в JSON-файл.
Args:
question_id: ID вопроса
question_text: Текст вопроса
question_puncts: Список пунктов, относящихся к вопросу
top_chunks: DataFrame с топ-чанками
similarities: Словарь с косинусными схожестями для чанков
overlap_data: Данные о перекрытии чанков с пунктами
output_dir: Директория для сохранения
"""
# Подготавливаем результаты для сохранения
chunks_data = []
for i, (idx, chunk) in enumerate(top_chunks.iterrows()):
# Получаем данные о перекрытии для текущего чанка
chunk_overlaps = overlap_data[i] if i < len(overlap_data) else []
# Преобразуем numpy типы в стандартные типы Python
similarity = float(similarities.get(idx, 0.0))
# Формируем данные чанка
chunk_data = {
'chunk_id': chunk['id'],
'doc_name': chunk['doc_name'],
'text': chunk['text'],
'similarity': similarity,
'overlaps': chunk_overlaps
}
chunks_data.append(chunk_data)
# Преобразуем numpy.int64 в int для question_id
question_id = int(question_id)
# Формируем общий результат
result = {
'question_id': question_id,
'question_text': question_text,
'puncts': question_puncts,
'chunks': chunks_data
}
# Создаем имя файла
filename = f"question_{question_id}_top_chunks.json"
filepath = os.path.join(output_dir, filename)
# Класс для сериализации numpy типов
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
# Сохраняем в JSON с кастомным энкодером
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)
print(f"Топ-чанки для вопроса {question_id} сохранены в {filepath}")
def evaluate_for_top_n_with_mapping(
questions_df: pd.DataFrame,
chunks_df: pd.DataFrame,
question_embeddings: np.ndarray,
chunk_embeddings: np.ndarray,
question_id_to_idx: dict,
top_n: int,
similarity_threshold: float,
top_chunks_dir: str = None
) -> tuple[dict[str, float], pd.DataFrame]:
"""
Оценивает качество чанкинга для заданного значения top_n с использованием маппинга id -> индекс.
Args:
questions_df: DataFrame с вопросами и релевантными пунктами (исходный датасет)
chunks_df: DataFrame с чанками
question_embeddings: Эмбеддинги вопросов
chunk_embeddings: Эмбеддинги чанков
question_id_to_idx: Словарь соответствия id вопроса и его индекса в массиве эмбеддингов
top_n: Количество чанков в топе для каждого вопроса
similarity_threshold: Порог для нечеткого сравнения
top_chunks_dir: Директория для сохранения топ-чанков (если None, то не сохраняем)
Returns:
Кортеж (словарь с усредненными метриками, DataFrame с метриками по отдельным вопросам)
"""
print(f"Оценка для top-{top_n}...")
# Вычисляем косинусную близость между вопросами и чанками
similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
# Счетчики для метрик на основе текста
total_puncts = 0
found_puncts = 0
total_chunks = 0
relevant_chunks = 0
# Счетчики для метрик на основе документов
total_docs_required = 0
found_relevant_docs = 0
total_docs_found = 0
# Для сохранения метрик по отдельным вопросам
question_metrics = []
# Выводим информацию о столбцах для отладки
print(f"Столбцы в исходном датасете: {questions_df.columns.tolist()}")
# Группируем вопросы по id (у нас 20 уникальных вопросов)
for question_id in tqdm(questions_df['id'].unique(), desc=f"Оценка top-{top_n}"):
# Получаем строки для текущего вопроса из исходного датасета
question_rows = questions_df[questions_df['id'] == question_id]
# Проверяем, есть ли вопрос с таким id в нашем маппинге
if question_id not in question_id_to_idx:
print(f"Предупреждение: вопрос с id {question_id} отсутствует в маппинге")
continue
# Если нет строк с таким id, пропускаем
if len(question_rows) == 0:
continue
# Получаем индекс вопроса в массиве эмбеддингов
question_idx = question_id_to_idx[question_id]
# Получаем текст вопроса
question_text = question_rows['question'].iloc[0]
# Получаем все пункты для этого вопроса
puncts = question_rows['text'].tolist()
question_total_puncts = len(puncts)
total_puncts += question_total_puncts
# Получаем связанные документы
relevant_docs = []
if 'filename' in question_rows.columns:
relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)]
question_total_docs_required = len(relevant_docs)
total_docs_required += question_total_docs_required
print(f"Найдено {question_total_docs_required} документов для вопроса {question_id}")
else:
print(f"Столбец 'filename' отсутствует. Используем все документы.")
relevant_docs = chunks_df['doc_name'].unique().tolist()
question_total_docs_required = len(relevant_docs)
total_docs_required += question_total_docs_required
# Если для вопроса нет релевантных документов, пропускаем
if not relevant_docs:
print(f"Для вопроса {question_id} нет связанных документов")
continue
# Флаги для отслеживания найденных пунктов
punct_found = [False] * question_total_puncts
# Для отслеживания найденных документов
docs_found_for_question = set()
# Для хранения всех чанков вопроса для ограничения top_n
all_question_chunks = []
all_question_similarities = []
# Собираем чанки для всех документов по этому вопросу
for filename in relevant_docs:
if not filename or pd.isna(filename):
continue
# Фильтруем чанки по имени файла
doc_chunks = chunks_df[chunks_df['doc_name'] == filename]
if doc_chunks.empty:
print(f"Предупреждение: документ {filename} не содержит чанков")
continue
# Индексы чанков для текущего файла
doc_chunk_indices = doc_chunks.index.tolist()
# Получаем значения близости для чанков текущего файла
doc_similarities = [
similarity_matrix[question_idx, chunks_df.index.get_loc(idx)]
for idx in doc_chunk_indices
]
# Добавляем чанки и их схожести к общему списку для вопроса
for i, idx in enumerate(doc_chunk_indices):
all_question_chunks.append((idx, doc_chunks.iloc[doc_chunks.index.get_indexer([idx])[0]]))
all_question_similarities.append(doc_similarities[i])
# Сортируем все чанки по убыванию схожести и берем top_n
sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1]
top_chunks_indices = [all_question_chunks[i][0] for i in sorted_indices]
top_chunks = [all_question_chunks[i][1] for i in sorted_indices]
# Увеличиваем счетчик общего числа чанков
question_total_chunks = len(top_chunks)
total_chunks += question_total_chunks
# Для сохранения данных топ-чанков
all_top_chunks = pd.DataFrame([chunk for chunk in top_chunks])
all_chunk_similarities = {idx: all_question_similarities[i] for i, idx in enumerate([all_question_chunks[j][0] for j in sorted_indices])}
all_chunk_overlaps = []
# Для каждого чанка проверяем его релевантность к пунктам
question_relevant_chunks = 0
for i, chunk in enumerate(top_chunks):
is_relevant = False
chunk_overlaps = []
# Добавляем документ в найденные
docs_found_for_question.add(chunk['doc_name'])
# Проверяем перекрытие с каждым пунктом
for j, punct in enumerate(puncts):
overlap = calculate_chunk_overlap(chunk['text'], punct)
# Если нужно сохранить топ-чанки и top_n == 20
if top_chunks_dir and top_n == 20:
chunk_overlaps.append({
'punct_index': j,
'punct_text': punct[:100] + '...' if len(punct) > 100 else punct,
'overlap': overlap
})
# Если перекрытие больше порога, чанк релевантен
if overlap >= similarity_threshold:
is_relevant = True
punct_found[j] = True
if is_relevant:
question_relevant_chunks += 1
# Если нужно сохранить топ-чанки и top_n == 20
if top_chunks_dir and top_n == 20:
all_chunk_overlaps.append(chunk_overlaps)
# Если нужно сохранить топ-чанки и top_n == 20
if top_chunks_dir and top_n == 20 and not all_top_chunks.empty:
save_top_chunks_for_question(
question_id,
question_text,
puncts,
all_top_chunks,
all_chunk_similarities,
all_chunk_overlaps,
top_chunks_dir
)
# Подсчитываем метрики для текущего вопроса
question_found_puncts = sum(punct_found)
found_puncts += question_found_puncts
relevant_chunks += question_relevant_chunks
# Обновляем метрики для документов
question_found_relevant_docs = sum(1 for doc in docs_found_for_question if doc in relevant_docs)
found_relevant_docs += question_found_relevant_docs
question_total_docs_found = len(docs_found_for_question)
total_docs_found += question_total_docs_found
# Вычисляем метрики для текущего вопроса
question_text_precision = question_relevant_chunks / question_total_chunks if question_total_chunks > 0 else 0
question_text_recall = question_found_puncts / question_total_puncts if question_total_puncts > 0 else 0
question_text_f1 = 2 * question_text_precision * question_text_recall / (question_text_precision + question_text_recall) if question_text_precision + question_text_recall > 0 else 0
question_doc_precision = question_found_relevant_docs / question_total_docs_found if question_total_docs_found > 0 else 0
question_doc_recall = question_found_relevant_docs / question_total_docs_required if question_total_docs_required > 0 else 0
question_doc_f1 = 2 * question_doc_precision * question_doc_recall / (question_doc_precision + question_doc_recall) if question_doc_precision + question_doc_recall > 0 else 0
# Сохраняем метрики вопроса
question_metrics.append({
'question_id': question_id,
'question_text': question_text,
'top_n': top_n,
'text_precision': question_text_precision,
'text_recall': question_text_recall,
'text_f1': question_text_f1,
'doc_precision': question_doc_precision,
'doc_recall': question_doc_recall,
'doc_f1': question_doc_f1,
'found_puncts': question_found_puncts,
'total_puncts': question_total_puncts,
'relevant_chunks': question_relevant_chunks,
'total_chunks': question_total_chunks,
'found_relevant_docs': question_found_relevant_docs,
'total_docs_required': question_total_docs_required,
'total_docs_found': question_total_docs_found
})
# Вычисляем метрики для текста
text_precision = relevant_chunks / total_chunks if total_chunks > 0 else 0
text_recall = found_puncts / total_puncts if total_puncts > 0 else 0
text_f1 = 2 * text_precision * text_recall / (text_precision + text_recall) if text_precision + text_recall > 0 else 0
# Вычисляем метрики для документов
doc_precision = found_relevant_docs / total_docs_found if total_docs_found > 0 else 0
doc_recall = found_relevant_docs / total_docs_required if total_docs_required > 0 else 0
doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0
aggregated_metrics = {
'top_n': top_n,
'text_precision': text_precision,
'text_recall': text_recall,
'text_f1': text_f1,
'doc_precision': doc_precision,
'doc_recall': doc_recall,
'doc_f1': doc_f1,
'found_puncts': found_puncts,
'total_puncts': total_puncts,
'relevant_chunks': relevant_chunks,
'total_chunks': total_chunks,
'found_relevant_docs': found_relevant_docs,
'total_docs_required': total_docs_required,
'total_docs_found': total_docs_found
}
return aggregated_metrics, pd.DataFrame(question_metrics)
def main():
"""
Основная функция скрипта.
"""
args = parse_args()
# Устанавливаем устройство из аргументов
device = args.device
# Создаем выходной каталог, если его нет
os.makedirs(args.output_dir, exist_ok=True)
# Создаем директорию для топ-чанков
top_chunks_dir = os.path.join(args.output_dir, "top_chunks")
os.makedirs(top_chunks_dir, exist_ok=True)
# Загружаем датасет с вопросами
questions_df = load_questions_dataset(args.dataset_path)
# Формируем уникальное имя для сохраняемых файлов на основе параметров стратегии и модели
strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}"
chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}"
questions_filename = f"questions_{args.model_name.replace('/', '_')}"
# Пытаемся загрузить сохраненные эмбеддинги и данные
chunk_embeddings, chunks_df = None, None
question_embeddings, questions_df_with_embeddings = None, None
if not args.force_recompute:
chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir)
question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir)
# Если не удалось загрузить данные или включен режим принудительного пересчета
if chunk_embeddings is None or chunks_df is None:
# Читаем и обрабатываем документы
documents = read_documents(args.data_folder)
# Формируем конфигурацию для стратегии fixed_size
fixed_size_config = {
"words_per_chunk": args.words_per_chunk,
"overlap_words": args.overlap_words
}
# Получаем DataFrame с чанками
chunks_df = process_documents(documents, fixed_size_config)
# Настраиваем модель и токенизатор
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
# Получаем эмбеддинги для чанков
chunk_embeddings = get_embeddings(chunks_df['text'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
# Сохраняем эмбеддинги и данные
save_embeddings_and_data(chunk_embeddings, chunks_df, chunks_filename, args.output_dir)
# Если не удалось загрузить эмбеддинги вопросов или включен режим принудительного пересчета
if question_embeddings is None or questions_df_with_embeddings is None:
# Получаем уникальные вопросы (по id)
unique_questions = questions_df.drop_duplicates(subset=['id'])[['id', 'question']]
# Настраиваем модель и токенизатор (если еще не настроены)
if 'model' not in locals() or 'tokenizer' not in locals():
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
# Получаем эмбеддинги для вопросов
question_embeddings = get_embeddings(unique_questions['question'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
# Сохраняем эмбеддинги и данные
save_embeddings_and_data(question_embeddings, unique_questions, questions_filename, args.output_dir)
# Устанавливаем questions_df_with_embeddings для дальнейшего использования
questions_df_with_embeddings = unique_questions
# Создаем словарь соответствия id вопроса и его индекса в эмбеддингах
question_id_to_idx = {
row['id']: i
for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows())
}
# Оцениваем стратегию чанкинга для разных значений top_n
aggregated_results = []
all_question_metrics = []
for top_n in TOP_N_VALUES:
metrics, question_metrics = evaluate_for_top_n_with_mapping(
questions_df, # Исходный датасет с связью между вопросами и документами
chunks_df, # Датасет с чанками
question_embeddings, # Эмбеддинги вопросов
chunk_embeddings, # Эмбеддинги чанков
question_id_to_idx, # Маппинг id вопроса к индексу в эмбеддингах
top_n, # Количество чанков в топе
args.similarity_threshold, # Порог для определения перекрытия
top_chunks_dir if top_n == 20 else None # Сохраняем топ-чанки только для top_n=20
)
aggregated_results.append(metrics)
all_question_metrics.append(question_metrics)
# Объединяем все метрики по вопросам
all_question_metrics_df = pd.concat(all_question_metrics)
# Создаем DataFrame с агрегированными результатами
aggregated_results_df = pd.DataFrame(aggregated_results)
# Сохраняем результаты
results_filename = f"results_{strategy_config_str}_{args.model_name.replace('/', '_')}.csv"
results_path = os.path.join(args.output_dir, results_filename)
aggregated_results_df.to_csv(results_path, index=False)
# Сохраняем метрики по вопросам
question_metrics_filename = f"question_metrics_{strategy_config_str}_{args.model_name.replace('/', '_')}.xlsx"
question_metrics_path = os.path.join(args.output_dir, question_metrics_filename)
all_question_metrics_df.to_excel(question_metrics_path, index=False)
print(f"\nРезультаты сохранены в {results_path}")
print(f"Метрики по вопросам сохранены в {question_metrics_path}")
print(f"Топ-20 чанков для каждого вопроса сохранены в {top_chunks_dir}")
print("\nМетрики для различных значений top_n:")
print(aggregated_results_df[['top_n', 'text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']])
if __name__ == "__main__":
main() |