muryshev's picture
update
86c402d
raw
history blame
39.5 kB
#!/usr/bin/env python
"""
Скрипт для оценки качества различных стратегий чанкинга.
Сравнивает стратегии на основе релевантности чанков к вопросам.
"""
import argparse
import json
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from fuzzywuzzy import fuzz
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
# Константы для настройки
DATA_FOLDER = "data/docs" # Путь к папке с документами
MODEL_NAME = "intfloat/e5-base" # Название модели для векторизации
DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами
BATCH_SIZE = 8 # Размер батча для векторизации
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Устройство для вычислений
SIMILARITY_THRESHOLD = 0.7 # Порог для нечеткого сравнения
OUTPUT_DIR = "data" # Директория для сохранения результатов
TOP_CHUNKS_DIR = "data/top_chunks" # Директория для сохранения топ-чанков
TOP_N_VALUES = [5, 10, 20, 30, 50, 70, 100] # Значения N для оценки
# Параметры стратегий чанкинга
FIXED_SIZE_CONFIG = {
"words_per_chunk": 50, # Количество слов в чанке
"overlap_words": 25 # Количество слов перекрытия
}
sys.path.insert(0, str(Path(__file__).parent.parent))
from ntr_fileparser import UniversalParser
from ntr_text_fragmentation import Destructurer
def _average_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Расчёт усредненного эмбеддинга по всем токенам
Args:
last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены
Returns:
torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size)
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def parse_args():
"""
Парсит аргументы командной строки.
Returns:
Аргументы командной строки
"""
parser = argparse.ArgumentParser(description="Скрипт для оценки качества чанкинга")
parser.add_argument("--data-folder", type=str, default=DATA_FOLDER,
help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})")
parser.add_argument("--model-name", type=str, default=MODEL_NAME,
help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})")
parser.add_argument("--dataset-path", type=str, default=DATASET_PATH,
help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
help=f"Размер батча для векторизации (по умолчанию: {BATCH_SIZE})")
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})")
parser.add_argument("--force-recompute", action="store_true",
help="Принудительно пересчитать эмбеддинги, игнорируя сохраненные")
parser.add_argument("--use-sentence-transformers", action="store_true",
help="Использовать библиотеку sentence_transformers для извлечения эмбеддингов (для FRIDA и других моделей)")
parser.add_argument("--device", type=str, default=DEVICE,
help=f"Устройство для вычислений (по умолчанию: {DEVICE})")
# Параметры для fixed_size стратегии
parser.add_argument("--words-per-chunk", type=int, default=FIXED_SIZE_CONFIG["words_per_chunk"],
help=f"Количество слов в чанке для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['words_per_chunk']})")
parser.add_argument("--overlap-words", type=int, default=FIXED_SIZE_CONFIG["overlap_words"],
help=f"Количество слов перекрытия для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['overlap_words']})")
return parser.parse_args()
def read_documents(folder_path: str) -> dict:
"""
Читает все документы из указанной папки.
Args:
folder_path: Путь к папке с документами
Returns:
Словарь {имя_файла: parsed_document}
"""
print(f"Чтение документов из {folder_path}...")
parser = UniversalParser()
documents = {}
for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"):
try:
doc_name = file_path.stem
documents[doc_name] = parser.parse_by_path(str(file_path))
except Exception as e:
print(f"Ошибка при чтении файла {file_path}: {e}")
return documents
def process_documents(documents: dict, fixed_size_config: dict) -> pd.DataFrame:
"""
Обрабатывает документы со стратегией fixed_size для чанкинга.
Args:
documents: Словарь с распарсенными документами
fixed_size_config: Конфигурация для fixed_size стратегии
Returns:
DataFrame с чанками
"""
print("Обработка документов стратегией fixed_size...")
all_data = []
for doc_name, document in tqdm(documents.items(), desc="Применение стратегии fixed_size"):
# Стратегия fixed_size для чанкинга
destructurer = Destructurer(document)
destructurer.configure('fixed_size',
words_per_chunk=fixed_size_config["words_per_chunk"],
overlap_words=fixed_size_config["overlap_words"])
fixed_size_entities, _ = destructurer.destructure()
# Обрабатываем только сущности для поиска
for entity in fixed_size_entities:
if hasattr(entity, 'use_in_search') and entity.use_in_search:
entity_data = {
'id': str(entity.id),
'doc_name': doc_name,
'name': entity.name,
'text': entity.text,
'type': entity.type,
'strategy': 'fixed_size',
'metadata': json.dumps(entity.metadata, ensure_ascii=False)
}
all_data.append(entity_data)
# Создаем DataFrame
df = pd.DataFrame(all_data)
# Фильтруем по типу, исключая Document
df = df[df['type'] != 'Document']
return df
def load_questions_dataset(file_path: str) -> pd.DataFrame:
"""
Загружает датасет с вопросами из Excel-файла.
Args:
file_path: Путь к Excel-файлу
Returns:
DataFrame с вопросами и пунктами
"""
print(f"Загрузка датасета из {file_path}...")
df = pd.read_excel(file_path)
print(f"Загружен датасет со столбцами: {df.columns.tolist()}")
# Преобразуем NaN в пустые строки для текстовых полей
text_columns = ['question', 'text', 'item_type']
for col in text_columns:
if col in df.columns:
df[col] = df[col].fillna('')
return df
def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool = False, device: str = DEVICE):
"""
Инициализирует модель и токенизатор.
Args:
model_name: Название предобученной модели
use_sentence_transformers: Использовать ли библиотеку sentence_transformers
device: Устройство для вычислений
Returns:
Кортеж (модель, токенизатор) или объект SentenceTransformer
"""
print(f"Загрузка модели {model_name} на устройство {device}...")
if use_sentence_transformers:
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, device=device)
return model, None
except ImportError:
print("Библиотека sentence_transformers не установлена. Установите её с помощью pip install sentence-transformers")
raise
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()
return model, tokenizer
def get_embeddings(texts: list[str], model, tokenizer=None, batch_size: int = BATCH_SIZE, use_sentence_transformers: bool = False, device: str = DEVICE) -> np.ndarray:
"""
Получает эмбеддинги для списка текстов с использованием average pooling или sentence_transformers.
Args:
texts: Список текстов
model: Модель для векторизации или SentenceTransformer
tokenizer: Токенизатор (None для sentence_transformers)
batch_size: Размер батча
use_sentence_transformers: Использовать ли библиотеку sentence_transformers
device: Устройство для вычислений
Returns:
Массив эмбеддингов
"""
if use_sentence_transformers:
# Используем sentence_transformers для получения эмбеддингов
all_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов (sentence_transformers)"):
batch_texts = texts[i:i+batch_size]
# Получаем эмбеддинги с помощью sentence_transformers
embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=False)
all_embeddings.append(embeddings)
return np.vstack(all_embeddings)
else:
# Используем стандартный подход с average pooling
all_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов"):
batch_texts = texts[i:i+batch_size]
# Токенизация с обрезкой и padding
encoding = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
# Получаем эмбеддинги с average pooling
with torch.no_grad():
outputs = model(**encoding)
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
all_embeddings.append(embeddings.cpu().numpy())
return np.vstack(all_embeddings)
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
"""
Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio.
Args:
chunk_text: Текст чанка
punct_text: Текст пункта
Returns:
Коэффициент перекрытия от 0 до 1
"""
# Если чанк входит в пункт, возвращаем 1.0 (полное вхождение)
if chunk_text in punct_text:
return 1.0
# Если пункт входит в чанк, возвращаем соотношение длин
if punct_text in chunk_text:
return len(punct_text) / len(chunk_text)
# Используем partial_ratio из fuzzywuzzy, который лучше обрабатывает
# случаи, когда один текст является подстрокой другого, даже с небольшими различиями
partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0
return partial_ratio_score
def save_embeddings_and_data(embeddings: np.ndarray, data: pd.DataFrame, filename: str, output_dir: str):
"""
Сохраняет эмбеддинги и соответствующие данные в файлы.
Args:
embeddings: Массив эмбеддингов
data: DataFrame с данными
filename: Базовое имя файла
output_dir: Директория для сохранения
"""
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
data_path = os.path.join(output_dir, f"{filename}_data.csv")
# Сохраняем эмбеддинги
np.save(embeddings_path, embeddings)
print(f"Эмбеддинги сохранены в {embeddings_path}")
# Сохраняем данные
data.to_csv(data_path, index=False)
print(f"Данные сохранены в {data_path}")
def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]:
"""
Загружает эмбеддинги и соответствующие данные из файлов.
Args:
filename: Базовое имя файла
output_dir: Директория, где хранятся файлы
Returns:
Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены
"""
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
data_path = os.path.join(output_dir, f"{filename}_data.csv")
if os.path.exists(embeddings_path) and os.path.exists(data_path):
print(f"Загрузка данных из {embeddings_path} и {data_path}...")
embeddings = np.load(embeddings_path)
data = pd.read_csv(data_path)
return embeddings, data
return None, None
def save_top_chunks_for_question(
question_id: int,
question_text: str,
question_puncts: list[str],
top_chunks: pd.DataFrame,
similarities: dict,
overlap_data: list,
output_dir: str
):
"""
Сохраняет топ-чанки для конкретного вопроса в JSON-файл.
Args:
question_id: ID вопроса
question_text: Текст вопроса
question_puncts: Список пунктов, относящихся к вопросу
top_chunks: DataFrame с топ-чанками
similarities: Словарь с косинусными схожестями для чанков
overlap_data: Данные о перекрытии чанков с пунктами
output_dir: Директория для сохранения
"""
# Подготавливаем результаты для сохранения
chunks_data = []
for i, (idx, chunk) in enumerate(top_chunks.iterrows()):
# Получаем данные о перекрытии для текущего чанка
chunk_overlaps = overlap_data[i] if i < len(overlap_data) else []
# Преобразуем numpy типы в стандартные типы Python
similarity = float(similarities.get(idx, 0.0))
# Формируем данные чанка
chunk_data = {
'chunk_id': chunk['id'],
'doc_name': chunk['doc_name'],
'text': chunk['text'],
'similarity': similarity,
'overlaps': chunk_overlaps
}
chunks_data.append(chunk_data)
# Преобразуем numpy.int64 в int для question_id
question_id = int(question_id)
# Формируем общий результат
result = {
'question_id': question_id,
'question_text': question_text,
'puncts': question_puncts,
'chunks': chunks_data
}
# Создаем имя файла
filename = f"question_{question_id}_top_chunks.json"
filepath = os.path.join(output_dir, filename)
# Класс для сериализации numpy типов
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
# Сохраняем в JSON с кастомным энкодером
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)
print(f"Топ-чанки для вопроса {question_id} сохранены в {filepath}")
def evaluate_for_top_n_with_mapping(
questions_df: pd.DataFrame,
chunks_df: pd.DataFrame,
question_embeddings: np.ndarray,
chunk_embeddings: np.ndarray,
question_id_to_idx: dict,
top_n: int,
similarity_threshold: float,
top_chunks_dir: str = None
) -> tuple[dict[str, float], pd.DataFrame]:
"""
Оценивает качество чанкинга для заданного значения top_n с использованием маппинга id -> индекс.
Args:
questions_df: DataFrame с вопросами и релевантными пунктами (исходный датасет)
chunks_df: DataFrame с чанками
question_embeddings: Эмбеддинги вопросов
chunk_embeddings: Эмбеддинги чанков
question_id_to_idx: Словарь соответствия id вопроса и его индекса в массиве эмбеддингов
top_n: Количество чанков в топе для каждого вопроса
similarity_threshold: Порог для нечеткого сравнения
top_chunks_dir: Директория для сохранения топ-чанков (если None, то не сохраняем)
Returns:
Кортеж (словарь с усредненными метриками, DataFrame с метриками по отдельным вопросам)
"""
print(f"Оценка для top-{top_n}...")
# Вычисляем косинусную близость между вопросами и чанками
similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
# Счетчики для метрик на основе текста
total_puncts = 0
found_puncts = 0
total_chunks = 0
relevant_chunks = 0
# Счетчики для метрик на основе документов
total_docs_required = 0
found_relevant_docs = 0
total_docs_found = 0
# Для сохранения метрик по отдельным вопросам
question_metrics = []
# Выводим информацию о столбцах для отладки
print(f"Столбцы в исходном датасете: {questions_df.columns.tolist()}")
# Группируем вопросы по id (у нас 20 уникальных вопросов)
for question_id in tqdm(questions_df['id'].unique(), desc=f"Оценка top-{top_n}"):
# Получаем строки для текущего вопроса из исходного датасета
question_rows = questions_df[questions_df['id'] == question_id]
# Проверяем, есть ли вопрос с таким id в нашем маппинге
if question_id not in question_id_to_idx:
print(f"Предупреждение: вопрос с id {question_id} отсутствует в маппинге")
continue
# Если нет строк с таким id, пропускаем
if len(question_rows) == 0:
continue
# Получаем индекс вопроса в массиве эмбеддингов
question_idx = question_id_to_idx[question_id]
# Получаем текст вопроса
question_text = question_rows['question'].iloc[0]
# Получаем все пункты для этого вопроса
puncts = question_rows['text'].tolist()
question_total_puncts = len(puncts)
total_puncts += question_total_puncts
# Получаем связанные документы
relevant_docs = []
if 'filename' in question_rows.columns:
relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)]
question_total_docs_required = len(relevant_docs)
total_docs_required += question_total_docs_required
print(f"Найдено {question_total_docs_required} документов для вопроса {question_id}")
else:
print(f"Столбец 'filename' отсутствует. Используем все документы.")
relevant_docs = chunks_df['doc_name'].unique().tolist()
question_total_docs_required = len(relevant_docs)
total_docs_required += question_total_docs_required
# Если для вопроса нет релевантных документов, пропускаем
if not relevant_docs:
print(f"Для вопроса {question_id} нет связанных документов")
continue
# Флаги для отслеживания найденных пунктов
punct_found = [False] * question_total_puncts
# Для отслеживания найденных документов
docs_found_for_question = set()
# Для хранения всех чанков вопроса для ограничения top_n
all_question_chunks = []
all_question_similarities = []
# Собираем чанки для всех документов по этому вопросу
for filename in relevant_docs:
if not filename or pd.isna(filename):
continue
# Фильтруем чанки по имени файла
doc_chunks = chunks_df[chunks_df['doc_name'] == filename]
if doc_chunks.empty:
print(f"Предупреждение: документ {filename} не содержит чанков")
continue
# Индексы чанков для текущего файла
doc_chunk_indices = doc_chunks.index.tolist()
# Получаем значения близости для чанков текущего файла
doc_similarities = [
similarity_matrix[question_idx, chunks_df.index.get_loc(idx)]
for idx in doc_chunk_indices
]
# Добавляем чанки и их схожести к общему списку для вопроса
for i, idx in enumerate(doc_chunk_indices):
all_question_chunks.append((idx, doc_chunks.iloc[doc_chunks.index.get_indexer([idx])[0]]))
all_question_similarities.append(doc_similarities[i])
# Сортируем все чанки по убыванию схожести и берем top_n
sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1]
top_chunks_indices = [all_question_chunks[i][0] for i in sorted_indices]
top_chunks = [all_question_chunks[i][1] for i in sorted_indices]
# Увеличиваем счетчик общего числа чанков
question_total_chunks = len(top_chunks)
total_chunks += question_total_chunks
# Для сохранения данных топ-чанков
all_top_chunks = pd.DataFrame([chunk for chunk in top_chunks])
all_chunk_similarities = {idx: all_question_similarities[i] for i, idx in enumerate([all_question_chunks[j][0] for j in sorted_indices])}
all_chunk_overlaps = []
# Для каждого чанка проверяем его релевантность к пунктам
question_relevant_chunks = 0
for i, chunk in enumerate(top_chunks):
is_relevant = False
chunk_overlaps = []
# Добавляем документ в найденные
docs_found_for_question.add(chunk['doc_name'])
# Проверяем перекрытие с каждым пунктом
for j, punct in enumerate(puncts):
overlap = calculate_chunk_overlap(chunk['text'], punct)
# Если нужно сохранить топ-чанки и top_n == 20
if top_chunks_dir and top_n == 20:
chunk_overlaps.append({
'punct_index': j,
'punct_text': punct[:100] + '...' if len(punct) > 100 else punct,
'overlap': overlap
})
# Если перекрытие больше порога, чанк релевантен
if overlap >= similarity_threshold:
is_relevant = True
punct_found[j] = True
if is_relevant:
question_relevant_chunks += 1
# Если нужно сохранить топ-чанки и top_n == 20
if top_chunks_dir and top_n == 20:
all_chunk_overlaps.append(chunk_overlaps)
# Если нужно сохранить топ-чанки и top_n == 20
if top_chunks_dir and top_n == 20 and not all_top_chunks.empty:
save_top_chunks_for_question(
question_id,
question_text,
puncts,
all_top_chunks,
all_chunk_similarities,
all_chunk_overlaps,
top_chunks_dir
)
# Подсчитываем метрики для текущего вопроса
question_found_puncts = sum(punct_found)
found_puncts += question_found_puncts
relevant_chunks += question_relevant_chunks
# Обновляем метрики для документов
question_found_relevant_docs = sum(1 for doc in docs_found_for_question if doc in relevant_docs)
found_relevant_docs += question_found_relevant_docs
question_total_docs_found = len(docs_found_for_question)
total_docs_found += question_total_docs_found
# Вычисляем метрики для текущего вопроса
question_text_precision = question_relevant_chunks / question_total_chunks if question_total_chunks > 0 else 0
question_text_recall = question_found_puncts / question_total_puncts if question_total_puncts > 0 else 0
question_text_f1 = 2 * question_text_precision * question_text_recall / (question_text_precision + question_text_recall) if question_text_precision + question_text_recall > 0 else 0
question_doc_precision = question_found_relevant_docs / question_total_docs_found if question_total_docs_found > 0 else 0
question_doc_recall = question_found_relevant_docs / question_total_docs_required if question_total_docs_required > 0 else 0
question_doc_f1 = 2 * question_doc_precision * question_doc_recall / (question_doc_precision + question_doc_recall) if question_doc_precision + question_doc_recall > 0 else 0
# Сохраняем метрики вопроса
question_metrics.append({
'question_id': question_id,
'question_text': question_text,
'top_n': top_n,
'text_precision': question_text_precision,
'text_recall': question_text_recall,
'text_f1': question_text_f1,
'doc_precision': question_doc_precision,
'doc_recall': question_doc_recall,
'doc_f1': question_doc_f1,
'found_puncts': question_found_puncts,
'total_puncts': question_total_puncts,
'relevant_chunks': question_relevant_chunks,
'total_chunks': question_total_chunks,
'found_relevant_docs': question_found_relevant_docs,
'total_docs_required': question_total_docs_required,
'total_docs_found': question_total_docs_found
})
# Вычисляем метрики для текста
text_precision = relevant_chunks / total_chunks if total_chunks > 0 else 0
text_recall = found_puncts / total_puncts if total_puncts > 0 else 0
text_f1 = 2 * text_precision * text_recall / (text_precision + text_recall) if text_precision + text_recall > 0 else 0
# Вычисляем метрики для документов
doc_precision = found_relevant_docs / total_docs_found if total_docs_found > 0 else 0
doc_recall = found_relevant_docs / total_docs_required if total_docs_required > 0 else 0
doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0
aggregated_metrics = {
'top_n': top_n,
'text_precision': text_precision,
'text_recall': text_recall,
'text_f1': text_f1,
'doc_precision': doc_precision,
'doc_recall': doc_recall,
'doc_f1': doc_f1,
'found_puncts': found_puncts,
'total_puncts': total_puncts,
'relevant_chunks': relevant_chunks,
'total_chunks': total_chunks,
'found_relevant_docs': found_relevant_docs,
'total_docs_required': total_docs_required,
'total_docs_found': total_docs_found
}
return aggregated_metrics, pd.DataFrame(question_metrics)
def main():
"""
Основная функция скрипта.
"""
args = parse_args()
# Устанавливаем устройство из аргументов
device = args.device
# Создаем выходной каталог, если его нет
os.makedirs(args.output_dir, exist_ok=True)
# Создаем директорию для топ-чанков
top_chunks_dir = os.path.join(args.output_dir, "top_chunks")
os.makedirs(top_chunks_dir, exist_ok=True)
# Загружаем датасет с вопросами
questions_df = load_questions_dataset(args.dataset_path)
# Формируем уникальное имя для сохраняемых файлов на основе параметров стратегии и модели
strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}"
chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}"
questions_filename = f"questions_{args.model_name.replace('/', '_')}"
# Пытаемся загрузить сохраненные эмбеддинги и данные
chunk_embeddings, chunks_df = None, None
question_embeddings, questions_df_with_embeddings = None, None
if not args.force_recompute:
chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir)
question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir)
# Если не удалось загрузить данные или включен режим принудительного пересчета
if chunk_embeddings is None or chunks_df is None:
# Читаем и обрабатываем документы
documents = read_documents(args.data_folder)
# Формируем конфигурацию для стратегии fixed_size
fixed_size_config = {
"words_per_chunk": args.words_per_chunk,
"overlap_words": args.overlap_words
}
# Получаем DataFrame с чанками
chunks_df = process_documents(documents, fixed_size_config)
# Настраиваем модель и токенизатор
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
# Получаем эмбеддинги для чанков
chunk_embeddings = get_embeddings(chunks_df['text'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
# Сохраняем эмбеддинги и данные
save_embeddings_and_data(chunk_embeddings, chunks_df, chunks_filename, args.output_dir)
# Если не удалось загрузить эмбеддинги вопросов или включен режим принудительного пересчета
if question_embeddings is None or questions_df_with_embeddings is None:
# Получаем уникальные вопросы (по id)
unique_questions = questions_df.drop_duplicates(subset=['id'])[['id', 'question']]
# Настраиваем модель и токенизатор (если еще не настроены)
if 'model' not in locals() or 'tokenizer' not in locals():
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
# Получаем эмбеддинги для вопросов
question_embeddings = get_embeddings(unique_questions['question'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
# Сохраняем эмбеддинги и данные
save_embeddings_and_data(question_embeddings, unique_questions, questions_filename, args.output_dir)
# Устанавливаем questions_df_with_embeddings для дальнейшего использования
questions_df_with_embeddings = unique_questions
# Создаем словарь соответствия id вопроса и его индекса в эмбеддингах
question_id_to_idx = {
row['id']: i
for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows())
}
# Оцениваем стратегию чанкинга для разных значений top_n
aggregated_results = []
all_question_metrics = []
for top_n in TOP_N_VALUES:
metrics, question_metrics = evaluate_for_top_n_with_mapping(
questions_df, # Исходный датасет с связью между вопросами и документами
chunks_df, # Датасет с чанками
question_embeddings, # Эмбеддинги вопросов
chunk_embeddings, # Эмбеддинги чанков
question_id_to_idx, # Маппинг id вопроса к индексу в эмбеддингах
top_n, # Количество чанков в топе
args.similarity_threshold, # Порог для определения перекрытия
top_chunks_dir if top_n == 20 else None # Сохраняем топ-чанки только для top_n=20
)
aggregated_results.append(metrics)
all_question_metrics.append(question_metrics)
# Объединяем все метрики по вопросам
all_question_metrics_df = pd.concat(all_question_metrics)
# Создаем DataFrame с агрегированными результатами
aggregated_results_df = pd.DataFrame(aggregated_results)
# Сохраняем результаты
results_filename = f"results_{strategy_config_str}_{args.model_name.replace('/', '_')}.csv"
results_path = os.path.join(args.output_dir, results_filename)
aggregated_results_df.to_csv(results_path, index=False)
# Сохраняем метрики по вопросам
question_metrics_filename = f"question_metrics_{strategy_config_str}_{args.model_name.replace('/', '_')}.xlsx"
question_metrics_path = os.path.join(args.output_dir, question_metrics_filename)
all_question_metrics_df.to_excel(question_metrics_path, index=False)
print(f"\nРезультаты сохранены в {results_path}")
print(f"Метрики по вопросам сохранены в {question_metrics_path}")
print(f"Топ-20 чанков для каждого вопроса сохранены в {top_chunks_dir}")
print("\nМетрики для различных значений top_n:")
print(aggregated_results_df[['top_n', 'text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']])
if __name__ == "__main__":
main()