muryshev's picture
update
86c402d
raw
history blame
10.5 kB
#!/usr/bin/env python
"""
Скрипт для запуска экспериментов по оценке качества чанкинга с разными моделями и параметрами.
"""
import argparse
import os
import subprocess
import sys
import time
from datetime import datetime
# Конфигурация моделей
MODELS = [
"intfloat/e5-base",
"intfloat/e5-large",
"BAAI/bge-m3",
"deepvk/USER-bge-m3",
"ai-forever/FRIDA"
]
# Параметры чанкинга (отсортированы в запрошенном порядке)
CHUNKING_PARAMS = [
{"words": 50, "overlap": 25, "description": "Маленький чанкинг с нахлёстом 50%"},
{"words": 50, "overlap": 0, "description": "Маленький чанкинг без нахлёста"},
{"words": 20, "overlap": 10, "description": "Очень мелкий чанкинг с нахлёстом 50%"},
{"words": 100, "overlap": 0, "description": "Средний чанкинг без нахлёста"},
{"words": 100, "overlap": 25, "description": "Средний чанкинг с нахлёстом 25%"},
{"words": 150, "overlap": 50, "description": "Крупный чанкинг с нахлёстом 33%"},
{"words": 200, "overlap": 75, "description": "Очень крупный чанкинг с нахлёстом 37.5%"}
]
# Значение порога для нечеткого сравнения
SIMILARITY_THRESHOLD = 0.7
def parse_args():
"""Парсит аргументы командной строки."""
parser = argparse.ArgumentParser(description="Запуск экспериментов для оценки качества чанкинга")
parser.add_argument("--data-folder", type=str, default="data/docs",
help="Путь к папке с документами (по умолчанию: data/docs)")
parser.add_argument("--dataset-path", type=str, default="data/dataset.xlsx",
help="Путь к Excel-датасету с вопросами (по умолчанию: data/dataset.xlsx)")
parser.add_argument("--output-dir", type=str, default="data",
help="Директория для сохранения результатов (по умолчанию: data)")
parser.add_argument("--log-dir", type=str, default="logs",
help="Директория для сохранения логов (по умолчанию: logs)")
parser.add_argument("--skip-existing", action="store_true",
help="Пропускать эксперименты, если файлы результатов уже существуют")
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
parser.add_argument("--model", type=str, default=None,
help="Запустить эксперимент только для указанной модели")
parser.add_argument("--chunking-index", type=int, default=None,
help="Запустить эксперимент только для указанного индекса конфигурации чанкинга (0-6)")
parser.add_argument("--device", type=str, default="cuda:1",
help="Устройство для вычислений (по умолчанию: cuda:1)")
return parser.parse_args()
def run_experiment(model_name, chunking_params, args):
"""
Запускает эксперимент с определенной моделью и параметрами чанкинга.
Args:
model_name: Название модели
chunking_params: Словарь с параметрами чанкинга
args: Аргументы командной строки
"""
words = chunking_params["words"]
overlap = chunking_params["overlap"]
description = chunking_params["description"]
# Формируем имя файла результатов
results_filename = f"results_fixed_size_w{words}_o{overlap}_{model_name.replace('/', '_')}.csv"
results_path = os.path.join(args.output_dir, results_filename)
# Проверяем, существует ли файл результатов
if args.skip_existing and os.path.exists(results_path):
print(f"Пропуск: {results_path} уже существует")
return
# Создаем директорию для логов, если она не существует
os.makedirs(args.log_dir, exist_ok=True)
# Формируем имя файла лога
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"log_{model_name.replace('/', '_')}_w{words}_o{overlap}_{timestamp}.txt"
log_path = os.path.join(args.log_dir, log_filename)
# Используем тот же интерпретатор Python, что и текущий скрипт
python_executable = sys.executable
# Запускаем скрипт evaluate_chunking.py с нужными параметрами
cmd = [
python_executable, "scripts/evaluate_chunking.py",
"--data-folder", args.data_folder,
"--model-name", model_name,
"--dataset-path", args.dataset_path,
"--output-dir", args.output_dir,
"--words-per-chunk", str(words),
"--overlap-words", str(overlap),
"--similarity-threshold", str(args.similarity_threshold),
"--device", args.device,
"--force-recompute" # Принудительно пересчитываем эмбеддинги
]
# Специальная обработка для модели ai-forever/FRIDA
if model_name == "ai-forever/FRIDA":
cmd.append("--use-sentence-transformers") # Добавляем флаг для использования sentence_transformers
print(f"\n{'='*80}")
print(f"Запуск эксперимента:")
print(f" Интерпретатор Python: {python_executable}")
print(f" Модель: {model_name}")
print(f" Чанкинг: {description} (words={words}, overlap={overlap})")
print(f" Порог для нечеткого сравнения: {args.similarity_threshold}")
print(f" Устройство: {args.device}")
print(f" Результаты будут сохранены в: {results_path}")
print(f" Лог: {log_path}")
print(f"{'='*80}\n")
# Запись информации в лог
with open(log_path, "w", encoding="utf-8") as log_file:
log_file.write(f"Эксперимент запущен в: {datetime.now()}\n")
log_file.write(f"Интерпретатор Python: {python_executable}\n")
log_file.write(f"Команда: {' '.join(cmd)}\n\n")
start_time = time.time()
# Запускаем процесс и перенаправляем вывод в файл лога
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1 # Построчная буферизация
)
# Читаем вывод процесса
for line in process.stdout:
print(line, end="") # Выводим в консоль
log_file.write(line) # Записываем в файл лога
# Ждем завершения процесса
process.wait()
end_time = time.time()
duration = end_time - start_time
# Записываем информацию о завершении
log_file.write(f"\nЭксперимент завершен в: {datetime.now()}\n")
log_file.write(f"Длительность: {duration:.2f} секунд ({duration/60:.2f} минут)\n")
log_file.write(f"Код возврата: {process.returncode}\n")
if process.returncode == 0:
print(f"Эксперимент успешно завершен за {duration/60:.2f} минут")
else:
print(f"Эксперимент завершился с ошибкой (код {process.returncode})")
def main():
"""Основная функция скрипта."""
args = parse_args()
# Создаем output_dir, если он не существует
os.makedirs(args.output_dir, exist_ok=True)
# Получаем список моделей для запуска
models_to_run = [args.model] if args.model else MODELS
# Получаем список конфигураций чанкинга для запуска
chunking_configs = [CHUNKING_PARAMS[args.chunking_index]] if args.chunking_index is not None else CHUNKING_PARAMS
start_time_all = time.time()
total_experiments = len(models_to_run) * len(chunking_configs)
completed_experiments = 0
print(f"Запуск {total_experiments} экспериментов...")
# Изменен порядок: сначала идём по стратегиям, затем по моделям
for chunking_config in chunking_configs:
print(f"\n=== Стратегия чанкинга: {chunking_config['description']} (words={chunking_config['words']}, overlap={chunking_config['overlap']}) ===\n")
for model in models_to_run:
# Запускаем эксперимент
run_experiment(model, chunking_config, args)
completed_experiments += 1
remaining_experiments = total_experiments - completed_experiments
if remaining_experiments > 0:
print(f"Завершено {completed_experiments}/{total_experiments} экспериментов. Осталось: {remaining_experiments}")
end_time_all = time.time()
total_duration = end_time_all - start_time_all
print(f"\nВсе эксперименты завершены за {total_duration/60:.2f} минут")
print(f"Результаты сохранены в {args.output_dir}")
print(f"Логи сохранены в {args.log_dir}")
if __name__ == "__main__":
main()