Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
""" | |
Скрипт для оценки качества различных стратегий чанкинга. | |
Сравнивает стратегии на основе релевантности чанков к вопросам. | |
""" | |
import argparse | |
import json | |
import os | |
import sys | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import torch | |
from fuzzywuzzy import fuzz | |
from sklearn.metrics.pairwise import cosine_similarity | |
from tqdm import tqdm | |
from transformers import AutoModel, AutoTokenizer | |
# Константы для настройки | |
DATA_FOLDER = "data/docs" # Путь к папке с документами | |
MODEL_NAME = "intfloat/e5-base" # Название модели для векторизации | |
DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами | |
BATCH_SIZE = 8 # Размер батча для векторизации | |
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Устройство для вычислений | |
SIMILARITY_THRESHOLD = 0.7 # Порог для нечеткого сравнения | |
OUTPUT_DIR = "data" # Директория для сохранения результатов | |
TOP_CHUNKS_DIR = "data/top_chunks" # Директория для сохранения топ-чанков | |
TOP_N_VALUES = [5, 10, 20, 30, 50, 70, 100] # Значения N для оценки | |
# Параметры стратегий чанкинга | |
FIXED_SIZE_CONFIG = { | |
"words_per_chunk": 50, # Количество слов в чанке | |
"overlap_words": 25 # Количество слов перекрытия | |
} | |
sys.path.insert(0, str(Path(__file__).parent.parent)) | |
from ntr_fileparser import UniversalParser | |
from ntr_text_fragmentation import Destructurer | |
def _average_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Расчёт усредненного эмбеддинга по всем токенам | |
Args: | |
last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой | |
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены | |
Returns: | |
torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size) | |
""" | |
last_hidden = last_hidden_states.masked_fill( | |
~attention_mask[..., None].bool(), 0.0 | |
) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def parse_args(): | |
""" | |
Парсит аргументы командной строки. | |
Returns: | |
Аргументы командной строки | |
""" | |
parser = argparse.ArgumentParser(description="Скрипт для оценки качества чанкинга") | |
parser.add_argument("--data-folder", type=str, default=DATA_FOLDER, | |
help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})") | |
parser.add_argument("--model-name", type=str, default=MODEL_NAME, | |
help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})") | |
parser.add_argument("--dataset-path", type=str, default=DATASET_PATH, | |
help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})") | |
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, | |
help=f"Размер батча для векторизации (по умолчанию: {BATCH_SIZE})") | |
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD, | |
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})") | |
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR, | |
help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})") | |
parser.add_argument("--force-recompute", action="store_true", | |
help="Принудительно пересчитать эмбеддинги, игнорируя сохраненные") | |
parser.add_argument("--use-sentence-transformers", action="store_true", | |
help="Использовать библиотеку sentence_transformers для извлечения эмбеддингов (для FRIDA и других моделей)") | |
parser.add_argument("--device", type=str, default=DEVICE, | |
help=f"Устройство для вычислений (по умолчанию: {DEVICE})") | |
# Параметры для fixed_size стратегии | |
parser.add_argument("--words-per-chunk", type=int, default=FIXED_SIZE_CONFIG["words_per_chunk"], | |
help=f"Количество слов в чанке для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['words_per_chunk']})") | |
parser.add_argument("--overlap-words", type=int, default=FIXED_SIZE_CONFIG["overlap_words"], | |
help=f"Количество слов перекрытия для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['overlap_words']})") | |
return parser.parse_args() | |
def read_documents(folder_path: str) -> dict: | |
""" | |
Читает все документы из указанной папки. | |
Args: | |
folder_path: Путь к папке с документами | |
Returns: | |
Словарь {имя_файла: parsed_document} | |
""" | |
print(f"Чтение документов из {folder_path}...") | |
parser = UniversalParser() | |
documents = {} | |
for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"): | |
try: | |
doc_name = file_path.stem | |
documents[doc_name] = parser.parse_by_path(str(file_path)) | |
except Exception as e: | |
print(f"Ошибка при чтении файла {file_path}: {e}") | |
return documents | |
def process_documents(documents: dict, fixed_size_config: dict) -> pd.DataFrame: | |
""" | |
Обрабатывает документы со стратегией fixed_size для чанкинга. | |
Args: | |
documents: Словарь с распарсенными документами | |
fixed_size_config: Конфигурация для fixed_size стратегии | |
Returns: | |
DataFrame с чанками | |
""" | |
print("Обработка документов стратегией fixed_size...") | |
all_data = [] | |
for doc_name, document in tqdm(documents.items(), desc="Применение стратегии fixed_size"): | |
# Стратегия fixed_size для чанкинга | |
destructurer = Destructurer(document) | |
destructurer.configure('fixed_size', | |
words_per_chunk=fixed_size_config["words_per_chunk"], | |
overlap_words=fixed_size_config["overlap_words"]) | |
fixed_size_entities, _ = destructurer.destructure() | |
# Обрабатываем только сущности для поиска | |
for entity in fixed_size_entities: | |
if hasattr(entity, 'use_in_search') and entity.use_in_search: | |
entity_data = { | |
'id': str(entity.id), | |
'doc_name': doc_name, | |
'name': entity.name, | |
'text': entity.text, | |
'type': entity.type, | |
'strategy': 'fixed_size', | |
'metadata': json.dumps(entity.metadata, ensure_ascii=False) | |
} | |
all_data.append(entity_data) | |
# Создаем DataFrame | |
df = pd.DataFrame(all_data) | |
# Фильтруем по типу, исключая Document | |
df = df[df['type'] != 'Document'] | |
return df | |
def load_questions_dataset(file_path: str) -> pd.DataFrame: | |
""" | |
Загружает датасет с вопросами из Excel-файла. | |
Args: | |
file_path: Путь к Excel-файлу | |
Returns: | |
DataFrame с вопросами и пунктами | |
""" | |
print(f"Загрузка датасета из {file_path}...") | |
df = pd.read_excel(file_path) | |
print(f"Загружен датасет со столбцами: {df.columns.tolist()}") | |
# Преобразуем NaN в пустые строки для текстовых полей | |
text_columns = ['question', 'text', 'item_type'] | |
for col in text_columns: | |
if col in df.columns: | |
df[col] = df[col].fillna('') | |
return df | |
def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool = False, device: str = DEVICE): | |
""" | |
Инициализирует модель и токенизатор. | |
Args: | |
model_name: Название предобученной модели | |
use_sentence_transformers: Использовать ли библиотеку sentence_transformers | |
device: Устройство для вычислений | |
Returns: | |
Кортеж (модель, токенизатор) или объект SentenceTransformer | |
""" | |
print(f"Загрузка модели {model_name} на устройство {device}...") | |
if use_sentence_transformers: | |
try: | |
from sentence_transformers import SentenceTransformer | |
model = SentenceTransformer(model_name, device=device) | |
return model, None | |
except ImportError: | |
print("Библиотека sentence_transformers не установлена. Установите её с помощью pip install sentence-transformers") | |
raise | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name).to(device) | |
model.eval() | |
return model, tokenizer | |
def get_embeddings(texts: list[str], model, tokenizer=None, batch_size: int = BATCH_SIZE, use_sentence_transformers: bool = False, device: str = DEVICE) -> np.ndarray: | |
""" | |
Получает эмбеддинги для списка текстов с использованием average pooling или sentence_transformers. | |
Args: | |
texts: Список текстов | |
model: Модель для векторизации или SentenceTransformer | |
tokenizer: Токенизатор (None для sentence_transformers) | |
batch_size: Размер батча | |
use_sentence_transformers: Использовать ли библиотеку sentence_transformers | |
device: Устройство для вычислений | |
Returns: | |
Массив эмбеддингов | |
""" | |
if use_sentence_transformers: | |
# Используем sentence_transformers для получения эмбеддингов | |
all_embeddings = [] | |
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов (sentence_transformers)"): | |
batch_texts = texts[i:i+batch_size] | |
# Получаем эмбеддинги с помощью sentence_transformers | |
embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=False) | |
all_embeddings.append(embeddings) | |
return np.vstack(all_embeddings) | |
else: | |
# Используем стандартный подход с average pooling | |
all_embeddings = [] | |
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов"): | |
batch_texts = texts[i:i+batch_size] | |
# Токенизация с обрезкой и padding | |
encoding = tokenizer( | |
batch_texts, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(device) | |
# Получаем эмбеддинги с average pooling | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"]) | |
all_embeddings.append(embeddings.cpu().numpy()) | |
return np.vstack(all_embeddings) | |
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float: | |
""" | |
Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio. | |
Args: | |
chunk_text: Текст чанка | |
punct_text: Текст пункта | |
Returns: | |
Коэффициент перекрытия от 0 до 1 | |
""" | |
# Если чанк входит в пункт, возвращаем 1.0 (полное вхождение) | |
if chunk_text in punct_text: | |
return 1.0 | |
# Если пункт входит в чанк, возвращаем соотношение длин | |
if punct_text in chunk_text: | |
return len(punct_text) / len(chunk_text) | |
# Используем partial_ratio из fuzzywuzzy, который лучше обрабатывает | |
# случаи, когда один текст является подстрокой другого, даже с небольшими различиями | |
partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0 | |
return partial_ratio_score | |
def save_embeddings_and_data(embeddings: np.ndarray, data: pd.DataFrame, filename: str, output_dir: str): | |
""" | |
Сохраняет эмбеддинги и соответствующие данные в файлы. | |
Args: | |
embeddings: Массив эмбеддингов | |
data: DataFrame с данными | |
filename: Базовое имя файла | |
output_dir: Директория для сохранения | |
""" | |
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy") | |
data_path = os.path.join(output_dir, f"{filename}_data.csv") | |
# Сохраняем эмбеддинги | |
np.save(embeddings_path, embeddings) | |
print(f"Эмбеддинги сохранены в {embeddings_path}") | |
# Сохраняем данные | |
data.to_csv(data_path, index=False) | |
print(f"Данные сохранены в {data_path}") | |
def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]: | |
""" | |
Загружает эмбеддинги и соответствующие данные из файлов. | |
Args: | |
filename: Базовое имя файла | |
output_dir: Директория, где хранятся файлы | |
Returns: | |
Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены | |
""" | |
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy") | |
data_path = os.path.join(output_dir, f"{filename}_data.csv") | |
if os.path.exists(embeddings_path) and os.path.exists(data_path): | |
print(f"Загрузка данных из {embeddings_path} и {data_path}...") | |
embeddings = np.load(embeddings_path) | |
data = pd.read_csv(data_path) | |
return embeddings, data | |
return None, None | |
def save_top_chunks_for_question( | |
question_id: int, | |
question_text: str, | |
question_puncts: list[str], | |
top_chunks: pd.DataFrame, | |
similarities: dict, | |
overlap_data: list, | |
output_dir: str | |
): | |
""" | |
Сохраняет топ-чанки для конкретного вопроса в JSON-файл. | |
Args: | |
question_id: ID вопроса | |
question_text: Текст вопроса | |
question_puncts: Список пунктов, относящихся к вопросу | |
top_chunks: DataFrame с топ-чанками | |
similarities: Словарь с косинусными схожестями для чанков | |
overlap_data: Данные о перекрытии чанков с пунктами | |
output_dir: Директория для сохранения | |
""" | |
# Подготавливаем результаты для сохранения | |
chunks_data = [] | |
for i, (idx, chunk) in enumerate(top_chunks.iterrows()): | |
# Получаем данные о перекрытии для текущего чанка | |
chunk_overlaps = overlap_data[i] if i < len(overlap_data) else [] | |
# Преобразуем numpy типы в стандартные типы Python | |
similarity = float(similarities.get(idx, 0.0)) | |
# Формируем данные чанка | |
chunk_data = { | |
'chunk_id': chunk['id'], | |
'doc_name': chunk['doc_name'], | |
'text': chunk['text'], | |
'similarity': similarity, | |
'overlaps': chunk_overlaps | |
} | |
chunks_data.append(chunk_data) | |
# Преобразуем numpy.int64 в int для question_id | |
question_id = int(question_id) | |
# Формируем общий результат | |
result = { | |
'question_id': question_id, | |
'question_text': question_text, | |
'puncts': question_puncts, | |
'chunks': chunks_data | |
} | |
# Создаем имя файла | |
filename = f"question_{question_id}_top_chunks.json" | |
filepath = os.path.join(output_dir, filename) | |
# Класс для сериализации numpy типов | |
class NumpyEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, np.integer): | |
return int(obj) | |
if isinstance(obj, np.floating): | |
return float(obj) | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
return super().default(obj) | |
# Сохраняем в JSON с кастомным энкодером | |
with open(filepath, 'w', encoding='utf-8') as f: | |
json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder) | |
print(f"Топ-чанки для вопроса {question_id} сохранены в {filepath}") | |
def evaluate_for_top_n_with_mapping( | |
questions_df: pd.DataFrame, | |
chunks_df: pd.DataFrame, | |
question_embeddings: np.ndarray, | |
chunk_embeddings: np.ndarray, | |
question_id_to_idx: dict, | |
top_n: int, | |
similarity_threshold: float, | |
top_chunks_dir: str = None | |
) -> tuple[dict[str, float], pd.DataFrame]: | |
""" | |
Оценивает качество чанкинга для заданного значения top_n с использованием маппинга id -> индекс. | |
Args: | |
questions_df: DataFrame с вопросами и релевантными пунктами (исходный датасет) | |
chunks_df: DataFrame с чанками | |
question_embeddings: Эмбеддинги вопросов | |
chunk_embeddings: Эмбеддинги чанков | |
question_id_to_idx: Словарь соответствия id вопроса и его индекса в массиве эмбеддингов | |
top_n: Количество чанков в топе для каждого вопроса | |
similarity_threshold: Порог для нечеткого сравнения | |
top_chunks_dir: Директория для сохранения топ-чанков (если None, то не сохраняем) | |
Returns: | |
Кортеж (словарь с усредненными метриками, DataFrame с метриками по отдельным вопросам) | |
""" | |
print(f"Оценка для top-{top_n}...") | |
# Вычисляем косинусную близость между вопросами и чанками | |
similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings) | |
# Счетчики для метрик на основе текста | |
total_puncts = 0 | |
found_puncts = 0 | |
total_chunks = 0 | |
relevant_chunks = 0 | |
# Счетчики для метрик на основе документов | |
total_docs_required = 0 | |
found_relevant_docs = 0 | |
total_docs_found = 0 | |
# Для сохранения метрик по отдельным вопросам | |
question_metrics = [] | |
# Выводим информацию о столбцах для отладки | |
print(f"Столбцы в исходном датасете: {questions_df.columns.tolist()}") | |
# Группируем вопросы по id (у нас 20 уникальных вопросов) | |
for question_id in tqdm(questions_df['id'].unique(), desc=f"Оценка top-{top_n}"): | |
# Получаем строки для текущего вопроса из исходного датасета | |
question_rows = questions_df[questions_df['id'] == question_id] | |
# Проверяем, есть ли вопрос с таким id в нашем маппинге | |
if question_id not in question_id_to_idx: | |
print(f"Предупреждение: вопрос с id {question_id} отсутствует в маппинге") | |
continue | |
# Если нет строк с таким id, пропускаем | |
if len(question_rows) == 0: | |
continue | |
# Получаем индекс вопроса в массиве эмбеддингов | |
question_idx = question_id_to_idx[question_id] | |
# Получаем текст вопроса | |
question_text = question_rows['question'].iloc[0] | |
# Получаем все пункты для этого вопроса | |
puncts = question_rows['text'].tolist() | |
question_total_puncts = len(puncts) | |
total_puncts += question_total_puncts | |
# Получаем связанные документы | |
relevant_docs = [] | |
if 'filename' in question_rows.columns: | |
relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)] | |
question_total_docs_required = len(relevant_docs) | |
total_docs_required += question_total_docs_required | |
print(f"Найдено {question_total_docs_required} документов для вопроса {question_id}") | |
else: | |
print(f"Столбец 'filename' отсутствует. Используем все документы.") | |
relevant_docs = chunks_df['doc_name'].unique().tolist() | |
question_total_docs_required = len(relevant_docs) | |
total_docs_required += question_total_docs_required | |
# Если для вопроса нет релевантных документов, пропускаем | |
if not relevant_docs: | |
print(f"Для вопроса {question_id} нет связанных документов") | |
continue | |
# Флаги для отслеживания найденных пунктов | |
punct_found = [False] * question_total_puncts | |
# Для отслеживания найденных документов | |
docs_found_for_question = set() | |
# Для хранения всех чанков вопроса для ограничения top_n | |
all_question_chunks = [] | |
all_question_similarities = [] | |
# Собираем чанки для всех документов по этому вопросу | |
for filename in relevant_docs: | |
if not filename or pd.isna(filename): | |
continue | |
# Фильтруем чанки по имени файла | |
doc_chunks = chunks_df[chunks_df['doc_name'] == filename] | |
if doc_chunks.empty: | |
print(f"Предупреждение: документ {filename} не содержит чанков") | |
continue | |
# Индексы чанков для текущего файла | |
doc_chunk_indices = doc_chunks.index.tolist() | |
# Получаем значения близости для чанков текущего файла | |
doc_similarities = [ | |
similarity_matrix[question_idx, chunks_df.index.get_loc(idx)] | |
for idx in doc_chunk_indices | |
] | |
# Добавляем чанки и их схожести к общему списку для вопроса | |
for i, idx in enumerate(doc_chunk_indices): | |
all_question_chunks.append((idx, doc_chunks.iloc[doc_chunks.index.get_indexer([idx])[0]])) | |
all_question_similarities.append(doc_similarities[i]) | |
# Сортируем все чанки по убыванию схожести и берем top_n | |
sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1] | |
top_chunks_indices = [all_question_chunks[i][0] for i in sorted_indices] | |
top_chunks = [all_question_chunks[i][1] for i in sorted_indices] | |
# Увеличиваем счетчик общего числа чанков | |
question_total_chunks = len(top_chunks) | |
total_chunks += question_total_chunks | |
# Для сохранения данных топ-чанков | |
all_top_chunks = pd.DataFrame([chunk for chunk in top_chunks]) | |
all_chunk_similarities = {idx: all_question_similarities[i] for i, idx in enumerate([all_question_chunks[j][0] for j in sorted_indices])} | |
all_chunk_overlaps = [] | |
# Для каждого чанка проверяем его релевантность к пунктам | |
question_relevant_chunks = 0 | |
for i, chunk in enumerate(top_chunks): | |
is_relevant = False | |
chunk_overlaps = [] | |
# Добавляем документ в найденные | |
docs_found_for_question.add(chunk['doc_name']) | |
# Проверяем перекрытие с каждым пунктом | |
for j, punct in enumerate(puncts): | |
overlap = calculate_chunk_overlap(chunk['text'], punct) | |
# Если нужно сохранить топ-чанки и top_n == 20 | |
if top_chunks_dir and top_n == 20: | |
chunk_overlaps.append({ | |
'punct_index': j, | |
'punct_text': punct[:100] + '...' if len(punct) > 100 else punct, | |
'overlap': overlap | |
}) | |
# Если перекрытие больше порога, чанк релевантен | |
if overlap >= similarity_threshold: | |
is_relevant = True | |
punct_found[j] = True | |
if is_relevant: | |
question_relevant_chunks += 1 | |
# Если нужно сохранить топ-чанки и top_n == 20 | |
if top_chunks_dir and top_n == 20: | |
all_chunk_overlaps.append(chunk_overlaps) | |
# Если нужно сохранить топ-чанки и top_n == 20 | |
if top_chunks_dir and top_n == 20 and not all_top_chunks.empty: | |
save_top_chunks_for_question( | |
question_id, | |
question_text, | |
puncts, | |
all_top_chunks, | |
all_chunk_similarities, | |
all_chunk_overlaps, | |
top_chunks_dir | |
) | |
# Подсчитываем метрики для текущего вопроса | |
question_found_puncts = sum(punct_found) | |
found_puncts += question_found_puncts | |
relevant_chunks += question_relevant_chunks | |
# Обновляем метрики для документов | |
question_found_relevant_docs = sum(1 for doc in docs_found_for_question if doc in relevant_docs) | |
found_relevant_docs += question_found_relevant_docs | |
question_total_docs_found = len(docs_found_for_question) | |
total_docs_found += question_total_docs_found | |
# Вычисляем метрики для текущего вопроса | |
question_text_precision = question_relevant_chunks / question_total_chunks if question_total_chunks > 0 else 0 | |
question_text_recall = question_found_puncts / question_total_puncts if question_total_puncts > 0 else 0 | |
question_text_f1 = 2 * question_text_precision * question_text_recall / (question_text_precision + question_text_recall) if question_text_precision + question_text_recall > 0 else 0 | |
question_doc_precision = question_found_relevant_docs / question_total_docs_found if question_total_docs_found > 0 else 0 | |
question_doc_recall = question_found_relevant_docs / question_total_docs_required if question_total_docs_required > 0 else 0 | |
question_doc_f1 = 2 * question_doc_precision * question_doc_recall / (question_doc_precision + question_doc_recall) if question_doc_precision + question_doc_recall > 0 else 0 | |
# Сохраняем метрики вопроса | |
question_metrics.append({ | |
'question_id': question_id, | |
'question_text': question_text, | |
'top_n': top_n, | |
'text_precision': question_text_precision, | |
'text_recall': question_text_recall, | |
'text_f1': question_text_f1, | |
'doc_precision': question_doc_precision, | |
'doc_recall': question_doc_recall, | |
'doc_f1': question_doc_f1, | |
'found_puncts': question_found_puncts, | |
'total_puncts': question_total_puncts, | |
'relevant_chunks': question_relevant_chunks, | |
'total_chunks': question_total_chunks, | |
'found_relevant_docs': question_found_relevant_docs, | |
'total_docs_required': question_total_docs_required, | |
'total_docs_found': question_total_docs_found | |
}) | |
# Вычисляем метрики для текста | |
text_precision = relevant_chunks / total_chunks if total_chunks > 0 else 0 | |
text_recall = found_puncts / total_puncts if total_puncts > 0 else 0 | |
text_f1 = 2 * text_precision * text_recall / (text_precision + text_recall) if text_precision + text_recall > 0 else 0 | |
# Вычисляем метрики для документов | |
doc_precision = found_relevant_docs / total_docs_found if total_docs_found > 0 else 0 | |
doc_recall = found_relevant_docs / total_docs_required if total_docs_required > 0 else 0 | |
doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0 | |
aggregated_metrics = { | |
'top_n': top_n, | |
'text_precision': text_precision, | |
'text_recall': text_recall, | |
'text_f1': text_f1, | |
'doc_precision': doc_precision, | |
'doc_recall': doc_recall, | |
'doc_f1': doc_f1, | |
'found_puncts': found_puncts, | |
'total_puncts': total_puncts, | |
'relevant_chunks': relevant_chunks, | |
'total_chunks': total_chunks, | |
'found_relevant_docs': found_relevant_docs, | |
'total_docs_required': total_docs_required, | |
'total_docs_found': total_docs_found | |
} | |
return aggregated_metrics, pd.DataFrame(question_metrics) | |
def main(): | |
""" | |
Основная функция скрипта. | |
""" | |
args = parse_args() | |
# Устанавливаем устройство из аргументов | |
device = args.device | |
# Создаем выходной каталог, если его нет | |
os.makedirs(args.output_dir, exist_ok=True) | |
# Создаем директорию для топ-чанков | |
top_chunks_dir = os.path.join(args.output_dir, "top_chunks") | |
os.makedirs(top_chunks_dir, exist_ok=True) | |
# Загружаем датасет с вопросами | |
questions_df = load_questions_dataset(args.dataset_path) | |
# Формируем уникальное имя для сохраняемых файлов на основе параметров стратегии и модели | |
strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}" | |
chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}" | |
questions_filename = f"questions_{args.model_name.replace('/', '_')}" | |
# Пытаемся загрузить сохраненные эмбеддинги и данные | |
chunk_embeddings, chunks_df = None, None | |
question_embeddings, questions_df_with_embeddings = None, None | |
if not args.force_recompute: | |
chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir) | |
question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir) | |
# Если не удалось загрузить данные или включен режим принудительного пересчета | |
if chunk_embeddings is None or chunks_df is None: | |
# Читаем и обрабатываем документы | |
documents = read_documents(args.data_folder) | |
# Формируем конфигурацию для стратегии fixed_size | |
fixed_size_config = { | |
"words_per_chunk": args.words_per_chunk, | |
"overlap_words": args.overlap_words | |
} | |
# Получаем DataFrame с чанками | |
chunks_df = process_documents(documents, fixed_size_config) | |
# Настраиваем модель и токенизатор | |
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device) | |
# Получаем эмбеддинги для чанков | |
chunk_embeddings = get_embeddings(chunks_df['text'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device) | |
# Сохраняем эмбеддинги и данные | |
save_embeddings_and_data(chunk_embeddings, chunks_df, chunks_filename, args.output_dir) | |
# Если не удалось загрузить эмбеддинги вопросов или включен режим принудительного пересчета | |
if question_embeddings is None or questions_df_with_embeddings is None: | |
# Получаем уникальные вопросы (по id) | |
unique_questions = questions_df.drop_duplicates(subset=['id'])[['id', 'question']] | |
# Настраиваем модель и токенизатор (если еще не настроены) | |
if 'model' not in locals() or 'tokenizer' not in locals(): | |
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device) | |
# Получаем эмбеддинги для вопросов | |
question_embeddings = get_embeddings(unique_questions['question'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device) | |
# Сохраняем эмбеддинги и данные | |
save_embeddings_and_data(question_embeddings, unique_questions, questions_filename, args.output_dir) | |
# Устанавливаем questions_df_with_embeddings для дальнейшего использования | |
questions_df_with_embeddings = unique_questions | |
# Создаем словарь соответствия id вопроса и его индекса в эмбеддингах | |
question_id_to_idx = { | |
row['id']: i | |
for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows()) | |
} | |
# Оцениваем стратегию чанкинга для разных значений top_n | |
aggregated_results = [] | |
all_question_metrics = [] | |
for top_n in TOP_N_VALUES: | |
metrics, question_metrics = evaluate_for_top_n_with_mapping( | |
questions_df, # Исходный датасет с связью между вопросами и документами | |
chunks_df, # Датасет с чанками | |
question_embeddings, # Эмбеддинги вопросов | |
chunk_embeddings, # Эмбеддинги чанков | |
question_id_to_idx, # Маппинг id вопроса к индексу в эмбеддингах | |
top_n, # Количество чанков в топе | |
args.similarity_threshold, # Порог для определения перекрытия | |
top_chunks_dir if top_n == 20 else None # Сохраняем топ-чанки только для top_n=20 | |
) | |
aggregated_results.append(metrics) | |
all_question_metrics.append(question_metrics) | |
# Объединяем все метрики по вопросам | |
all_question_metrics_df = pd.concat(all_question_metrics) | |
# Создаем DataFrame с агрегированными результатами | |
aggregated_results_df = pd.DataFrame(aggregated_results) | |
# Сохраняем результаты | |
results_filename = f"results_{strategy_config_str}_{args.model_name.replace('/', '_')}.csv" | |
results_path = os.path.join(args.output_dir, results_filename) | |
aggregated_results_df.to_csv(results_path, index=False) | |
# Сохраняем метрики по вопросам | |
question_metrics_filename = f"question_metrics_{strategy_config_str}_{args.model_name.replace('/', '_')}.xlsx" | |
question_metrics_path = os.path.join(args.output_dir, question_metrics_filename) | |
all_question_metrics_df.to_excel(question_metrics_path, index=False) | |
print(f"\nРезультаты сохранены в {results_path}") | |
print(f"Метрики по вопросам сохранены в {question_metrics_path}") | |
print(f"Топ-20 чанков для каждого вопроса сохранены в {top_chunks_dir}") | |
print("\nМетрики для различных значений top_n:") | |
print(aggregated_results_df[['top_n', 'text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']]) | |
if __name__ == "__main__": | |
main() |