Spaces:
Sleeping
Sleeping
import asyncio # Добавляем импорт | |
import io # Для работы с UploadFile как с файлом | |
import logging | |
import re # Добавляем re | |
from pathlib import Path # Добавляем Path | |
from typing import Any | |
from uuid import UUID | |
import pandas as pd | |
from fastapi import HTTPException, UploadFile | |
from fuzzywuzzy import fuzz | |
from common.configuration import Configuration | |
from components.llm.common import Message | |
from components.services.dialogue import DialogueService | |
from components.services.entity import EntityService | |
logger = logging.getLogger(__name__) | |
# Константа для сравнения имен файлов | |
FILENAME_SIMILARITY_THRESHOLD = 40 # Считаем имена файлов одинаковыми, если partial_ratio >= 90 | |
class SearchMetricsService: | |
"""Сервис для расчета метрик поиска по загруженному файлу. | |
Attributes: | |
entity_service: Сервис для работы с сущностями. | |
config: Конфигурация приложения. | |
dialogue_service: Сервис для работы с диалогами. | |
""" | |
def __init__( | |
self, | |
entity_service: EntityService, | |
config: Configuration, | |
dialogue_service: DialogueService, | |
): | |
"""Инициализирует сервис. | |
Args: | |
entity_service: Сервис для работы с сущностями. | |
config: Конфигурация приложения. | |
dialogue_service: Сервис для работы с диалогами. | |
""" | |
self.entity_service = entity_service | |
self.config = config | |
self.dialogue_service = dialogue_service | |
# --- Вспомогательная функция для очистки имени файла --- | |
def _clean_filename(self, filename: str | None) -> str: | |
"""Удаляет расширение и приводит к нижнему регистру.""" | |
if not filename: | |
return "" | |
return Path(str(filename)).stem.lower() | |
async def _load_evaluation_data(self, file: UploadFile) -> list[dict[str, Any]]: | |
""" | |
Загружает, валидирует и ГРУППИРУЕТ данные из XLSX файла по уникальным вопросам. | |
Сохраняет список эталонных текстов, SET ожидаемых имен файлов и эталонный ответ. | |
""" | |
if not file.filename.endswith(".xlsx"): | |
raise HTTPException( | |
status_code=400, | |
detail="Invalid file format. Please upload an XLSX file.", | |
) | |
try: | |
contents = await file.read() | |
data = io.BytesIO(contents) | |
# +++ Добавляем answer в dtype +++ | |
df = pd.read_excel(data, dtype={'id': str, 'question': str, 'text': str, 'filename': str, 'answer': str}) | |
except Exception as e: | |
logger.error(f"Error reading Excel file: {e}", exc_info=True) | |
raise HTTPException( | |
status_code=400, detail=f"Error reading Excel file: {e}" | |
) | |
finally: | |
await file.close() | |
# +++ Добавляем answer в required_columns +++ | |
required_columns = ["id", "question", "text", "filename", "answer"] | |
missing_cols = [col for col in required_columns if col not in df.columns] | |
if missing_cols: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Missing required columns in XLSX file: {missing_cols}. Expected: 'id', 'question', 'text', 'filename', 'answer'", | |
) | |
grouped_data = [] | |
for question_id, group in df.groupby('id'): | |
first_valid_question = group['question'].dropna().iloc[0] if not group['question'].dropna().empty else None | |
all_texts_raw = group['text'].dropna().tolist() | |
all_filenames_raw = group['filename'].dropna().tolist() | |
expected_filenames_cleaned = {self._clean_filename(fn) for fn in all_filenames_raw if self._clean_filename(fn)} | |
# +++ Извлекаем первый валидный answer +++ | |
first_valid_answer = group['answer'].dropna().iloc[0] if not group['answer'].dropna().empty else None | |
# +++ ИСПРАВЛЕНИЕ: Сохраняем тексты ячеек как есть, без дробления +++ | |
ground_truth_texts_raw = [str(text_block) for text_block in all_texts_raw if str(text_block).strip()] # Список оригинальных текстов ячеек (не пустых) | |
# --- Обновляем проверку на пропуск группы, используя ground_truth_texts_raw --- (включая проверку на пустой список текстов) | |
if pd.isna(question_id) or not first_valid_question or not ground_truth_texts_raw or not expected_filenames_cleaned or first_valid_answer is None: | |
logger.warning(f"Skipping group for question_id '{question_id}' due to missing question, 'text', 'filename', or 'answer' data within the group, or empty 'text' cells.") | |
continue | |
# +++ КОНЕЦ ИСПРАВЛЕНИЯ +++ | |
grouped_data.append({ | |
"question_id": str(question_id), | |
"question": str(first_valid_question), | |
"ground_truth_texts": ground_truth_texts_raw, # Сохраняем список оригинальных текстов ячеек | |
"expected_filenames": expected_filenames_cleaned, | |
"reference_answer": str(first_valid_answer) # Добавляем эталонный ответ | |
}) | |
if not grouped_data: | |
raise HTTPException( | |
status_code=400, | |
detail="No valid data groups found in the uploaded file after processing and grouping by 'id'." | |
) | |
logger.info(f"Successfully loaded and grouped {len(grouped_data)} unique questions from file.") | |
return grouped_data | |
# --- Убираем логи из _calculate_relevance_metrics --- | |
def _calculate_relevance_metrics( | |
self, | |
retrieved_chunks: list[str], | |
ground_truth_texts: list[str], | |
similarity_threshold: float, | |
question_id_for_log: str = "unknown" # ID можно оставить для warning/error | |
) -> tuple[float, float, float, int, int, int, int, list[int]]: | |
num_retrieved = len(retrieved_chunks) | |
total_ground_truth = len(ground_truth_texts) | |
if total_ground_truth == 0: return 0.0, 0.0, 0.0, 0, 0, 0, num_retrieved, [] | |
if num_retrieved == 0: return 0.0, 0.0, 0.0, 0, total_ground_truth, 0, 0, list(range(total_ground_truth)) | |
ground_truth_found = [False] * total_ground_truth | |
relevant_chunks_count = 0 | |
fuzzy_threshold_int = similarity_threshold * 100 | |
for chunk_text in retrieved_chunks: | |
is_chunk_relevant = False | |
for i, gt_text in enumerate(ground_truth_texts): | |
overlap_score = fuzz.partial_ratio(chunk_text, gt_text) | |
if overlap_score >= fuzzy_threshold_int: | |
is_chunk_relevant = True | |
ground_truth_found[i] = True | |
# Не обязательно break, чанк может быть релевантен нескольким пунктам | |
if is_chunk_relevant: | |
relevant_chunks_count += 1 | |
# logger.debug(...) # <--- УДАЛЕНО | |
# else: | |
# logger.debug(...) # <--- УДАЛЕНО | |
found_puncts_count = sum(ground_truth_found) | |
precision = relevant_chunks_count / num_retrieved | |
recall = found_puncts_count / total_ground_truth | |
f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 | |
missed_gt_indices = [i for i, found in enumerate(ground_truth_found) if not found] | |
# logger.debug(...) # <--- УДАЛЕНО | |
return precision, recall, f1, found_puncts_count, total_ground_truth, relevant_chunks_count, num_retrieved, missed_gt_indices | |
# --- Убираем логи из _calculate_assembly_punct_recall --- | |
def _calculate_assembly_punct_recall( | |
self, | |
assembled_context: str, | |
ground_truth_texts: list[str], | |
similarity_threshold: float, | |
question_id_for_log: str = "unknown" # ID можно оставить для warning/error | |
) -> tuple[float, int, int]: | |
# ... (расчеты как были) ... | |
if not ground_truth_texts or not assembled_context: return 0.0, 0, 0 | |
assembly_found_puncts = 0 | |
valid_ground_truth_count = 0 | |
fuzzy_threshold_int = similarity_threshold * 100 | |
for i, punct_text in enumerate(ground_truth_texts): | |
punct_parts = [part.strip() for part in punct_text.split('\n') if part.strip()] | |
if not punct_parts: continue | |
valid_ground_truth_count += 1 | |
is_punct_found = False | |
for j, part_text in enumerate(punct_parts): | |
score = fuzz.partial_ratio(assembled_context, part_text) | |
if score >= fuzzy_threshold_int: | |
# logger.debug(...) # <--- УДАЛЕНО | |
is_punct_found = True | |
break | |
if is_punct_found: | |
assembly_found_puncts += 1 | |
# else: | |
# logger.debug(...) # <--- УДАЛЕНО | |
assembly_recall = assembly_found_puncts / valid_ground_truth_count if valid_ground_truth_count > 0 else 0.0 | |
# logger.debug(...) # <--- УДАЛЕНО | |
return assembly_recall, assembly_found_puncts, valid_ground_truth_count | |
# --- Убираем логи из _extract_and_compare_documents --- | |
def _extract_and_compare_documents( | |
self, | |
assembled_context: str, | |
expected_filenames_cleaned: set[str] | |
) -> tuple[float, int]: | |
# ... (расчеты как были) ... | |
if not assembled_context or not expected_filenames_cleaned: return 0.0, 0 | |
pattern = r"#\s*\[Источник\]\s*-\s*(.*?)(?:\n|$)" | |
found_filenames_raw = re.findall(pattern, assembled_context) | |
found_filenames_cleaned = {self._clean_filename(fn) for fn in found_filenames_raw if self._clean_filename(fn)} | |
# logger.debug(...) # <--- УДАЛЕНО | |
if not found_filenames_cleaned: return 0.0, 0 | |
found_expected_count = 0 | |
spurious_count = 0 | |
matched_expected = set() | |
for found_clean in found_filenames_cleaned: | |
is_spurious = True | |
for expected_clean in expected_filenames_cleaned: | |
score = fuzz.partial_ratio(found_clean, expected_clean) | |
if score >= FILENAME_SIMILARITY_THRESHOLD: | |
if expected_clean not in matched_expected: | |
found_expected_count += 1 | |
matched_expected.add(expected_clean) | |
is_spurious = False | |
# Не обязательно break | |
# +++ Логирование убрано +++ | |
if is_spurious: | |
spurious_count += 1 | |
doc_recall = found_expected_count / len(expected_filenames_cleaned) | |
# logger.debug(...) # <--- УДАЛЕНО | |
return doc_recall, spurious_count | |
async def _call_qe_safe(self, original_question: str) -> str | None: | |
""" | |
Безопасно вызывает QE сервис для одного вопроса. | |
Args: | |
original_question: Исходный текст вопроса. | |
Returns: | |
Строку с новым запросом от QE, если он успешен и релевантен, | |
иначе None. | |
""" | |
try: | |
fake_history = [Message(role="user", content=original_question, searchResults="")] | |
qe_result = await self.dialogue_service.get_qe_result(fake_history) | |
logger.debug(f"QE result for '{original_question[:50]}...': {qe_result}") | |
if qe_result.use_search and qe_result.search_query: | |
return qe_result.search_query | |
# QE решил не искать или вернул пустой результат | |
return None | |
except Exception as e: | |
logger.error(f"Error during single QE call for question '{original_question[:50]}...': {e}", exc_info=True) | |
# В случае ошибки возвращаем None, чтобы использовать оригинальный вопрос | |
return None | |
async def evaluate_from_file( | |
self, | |
file: UploadFile, | |
dataset_id: int, | |
similarity_threshold: float, | |
top_n_values: list[int], | |
use_query_expansion: bool, | |
top_worst_k: int = 5, | |
) -> dict[str, Any]: | |
""" | |
Выполняет оценку по файлу, группируя строки по вопросам и считая метрики сборки. | |
""" | |
logger.info(f"Starting evaluation for dataset_id={dataset_id}, top_n={top_n_values}, threshold={similarity_threshold}, use_query_expansion={use_query_expansion} (Grouped by question_id)") | |
evaluation_data = await self._load_evaluation_data(file) | |
results: dict[int, dict[str, Any]] = { | |
n: { | |
'precision_list': [], 'recall_list': [], 'f1_list': [], # Для Macro/Weighted | |
'assembly_punct_recall_list': [], | |
'doc_recall_list': [], | |
'spurious_docs_list': [], | |
} for n in top_n_values | |
} | |
question_performance: dict[str, dict[str, Any | None]] = {} | |
max_top_n = max(top_n_values) if top_n_values else 0 | |
if not max_top_n: raise HTTPException(status_code=400, detail="top_n_values list cannot be empty.") | |
# +++ Инициализация НОВЫХ общих счетчиков Micro (по n) +++ | |
overall_micro_counters = { | |
n: {'found': 0, 'gt': 0, 'relevant': 0, 'retrieved': 0} | |
for n in top_n_values | |
} | |
# --- Счетчики для Micro Assembly Recall остаются --- | |
overall_assembly_found_puncts = 0 | |
overall_valid_gt_for_assembly = 0 | |
# --- Этап 2: Подготовка запросов (QE) --- (Добавляем reference_answer) | |
processed_items = [] | |
if use_query_expansion and evaluation_data: | |
logger.info(f"Starting asynchronous QE for {len(evaluation_data)} unique questions...") | |
tasks = [self._call_qe_safe(item['question']) for item in evaluation_data] | |
qe_results_or_errors = await asyncio.gather(*tasks, return_exceptions=True) | |
logger.info("Asynchronous QE calls finished for unique questions.") | |
for i, item in enumerate(evaluation_data): | |
query_for_search = item['question'] | |
qe_result = qe_results_or_errors[i] | |
if isinstance(qe_result, str): query_for_search = qe_result | |
processed_items.append({ | |
'question_id': item['question_id'], | |
'question': item['question'], | |
'query_for_search': query_for_search, | |
'ground_truth_texts': item['ground_truth_texts'], | |
'expected_filenames': item['expected_filenames'], | |
'reference_answer': item['reference_answer'] # Добавляем | |
}) | |
else: | |
logger.info("QE disabled or no data. Preparing items without QE.") | |
for item in evaluation_data: | |
processed_items.append({ | |
'question_id': item['question_id'], | |
'question': item['question'], | |
'query_for_search': item['question'], | |
'ground_truth_texts': item['ground_truth_texts'], | |
'expected_filenames': item['expected_filenames'], | |
'reference_answer': item['reference_answer'] # Добавляем | |
}) | |
# --- Этап 3: Цикл по УНИКАЛЬНЫМ вопросам --- | |
for item in processed_items: | |
question_id = item['question_id'] | |
original_question_text = item['question'] | |
reference_answer = item['reference_answer'] # Извлекаем | |
ground_truth_texts = item['ground_truth_texts'] | |
expected_filenames = item['expected_filenames'] | |
total_gt_count = len(ground_truth_texts) | |
query_for_search = item['query_for_search'] | |
# --- Инициализируем question_performance с новыми полями --- | |
if question_id not in question_performance: | |
question_performance[question_id] = { | |
'f1': None, | |
'assembly_recall_for_worst': None, # Новое поле для сортировки | |
'question_text': original_question_text, | |
'reference_answer': reference_answer, | |
'missed_gt_indices': None | |
} | |
logger.debug(f"Processing unique QID={question_id} with {total_gt_count} ground truths. Query: \"{query_for_search}\"") | |
try: | |
# --- Поиск (Один раз для max_top_n) --- | |
logger.info(f"Searching for QID={question_id} with k={max_top_n}...") # Оставим INFO | |
_, scores, ids = self.entity_service.search_similar_old( | |
query=query_for_search, dataset_id=dataset_id, k=max_top_n | |
) | |
# Важно: 'ids' это список СТРОК UUID | |
# --- !!! Удаляем ненужное извлечение текстов здесь !!! --- | |
# all_retrieved_chunk_texts = [] | |
# ... | |
# --- Цикл по top_n --- | |
for n in top_n_values: | |
current_top_n = min(n, len(ids)) | |
# +++ Получаем ID чанков для текущего n +++ | |
chunk_ids_for_n = ids[:current_top_n] | |
retrieved_count_for_n = len(chunk_ids_for_n) | |
# +++ Получаем тексты чанков для расчета метрик chunk/punct +++ | |
retrieved_chunks_texts_for_n = [] | |
if chunk_ids_for_n.size > 0: | |
# Используем асинхронный вызов | |
chunks_for_n = await self.entity_service.chunk_repository.get_entities_by_ids_async( | |
[UUID(ch_id) for ch_id in chunk_ids_for_n] | |
) | |
chunk_map_for_n = {str(ch.id): ch for ch in chunks_for_n} | |
retrieved_chunks_texts_for_n = [ | |
chunk_map_for_n[ch_id].in_search_text | |
for ch_id in chunk_ids_for_n | |
if ch_id in chunk_map_for_n and hasattr(chunk_map_for_n[ch_id], 'in_search_text') and chunk_map_for_n[ch_id].in_search_text | |
] | |
# --- Метрики Chunk/Punct --- | |
( | |
precision, recall, f1, | |
found_count, total_gt, | |
relevant_count, retrieved_count_calc, # retrieved_count_calc == retrieved_count_for_n | |
missed_indices | |
) = self._calculate_relevance_metrics( | |
retrieved_chunks_texts_for_n, # Используем тексты для n | |
ground_truth_texts, | |
similarity_threshold, | |
question_id_for_log=question_id | |
) | |
# Агрегация для Macro/Weighted | |
results[n]['precision_list'].append((precision, retrieved_count_for_n)) # Вес = retrieved_count_for_n | |
results[n]['recall_list'].append((recall, total_gt)) | |
results[n]['f1_list'].append((f1, total_gt)) | |
# Агрегация для Micro | |
overall_micro_counters[n]['found'] += found_count | |
overall_micro_counters[n]['gt'] += total_gt | |
overall_micro_counters[n]['relevant'] += relevant_count | |
overall_micro_counters[n]['retrieved'] += retrieved_count_for_n # Используем кол-во для n | |
# --- Метрики Сборки --- | |
# +++ Правильная сборка контекста с помощью build_text +++ | |
logger.info(f"Building context for QID={question_id}, n={n} using {len(chunk_ids_for_n)} chunk IDs...") | |
# Используем асинхронный вызов и передаем dataset_id | |
assembled_context_for_n = await self.entity_service.build_text_async( | |
entities=chunk_ids_for_n.tolist(), # Преобразуем numpy array в list[str] | |
dataset_id=dataset_id, # Передаем ID датасета | |
# chunk_scores можно передать, если они нужны для сборки, иначе None | |
# include_tables=True, # По умолчанию | |
# max_documents=None, # По умолчанию | |
) | |
assembly_recall, single_q_assembly_found, single_q_valid_gt = self._calculate_assembly_punct_recall( | |
assembled_context_for_n, | |
ground_truth_texts, | |
similarity_threshold, | |
question_id_for_log=question_id | |
) | |
results[n]['assembly_punct_recall_list'].append(assembly_recall) | |
if n == max_top_n: | |
overall_assembly_found_puncts += single_q_assembly_found | |
overall_valid_gt_for_assembly += single_q_valid_gt | |
# --- Метрики Документов --- | |
doc_recall, spurious_docs = self._extract_and_compare_documents( | |
assembled_context_for_n, # Используем корректный контекст | |
expected_filenames | |
) | |
results[n]['doc_recall_list'].append(doc_recall) | |
results[n]['spurious_docs_list'].append(spurious_docs) | |
# --- Сохраняем показатели для худших --- | |
if n == max_top_n: | |
question_performance[question_id]['f1'] = f1 | |
question_performance[question_id]['assembly_recall_for_worst'] = assembly_recall | |
question_performance[question_id]['missed_gt_indices'] = missed_indices | |
except HTTPException as http_exc: | |
logger.error(f"HTTP Error processing QID={question_id}: {http_exc.detail}") | |
if question_id in question_performance: | |
# +++ Устанавливаем F1 в 0.0 при ошибке +++ | |
question_performance[question_id]['f1'] = 0.0 | |
question_performance[question_id]['assembly_recall_for_worst'] = 0.0 # Худший recall | |
question_performance[question_id]['missed_gt_indices'] = list(range(total_gt_count)) | |
for n_err in top_n_values: | |
results[n_err]['precision_list'].append((0.0, 0)) | |
results[n_err]['recall_list'].append((0.0, total_gt_count)) | |
results[n_err]['f1_list'].append((0.0, total_gt_count)) | |
results[n_err]['assembly_punct_recall_list'].append(0.0) | |
results[n_err]['doc_recall_list'].append(0.0) | |
results[n_err]['spurious_docs_list'].append(0) | |
# +++ Обновляем общий счетчик GT для Micro при ошибке +++ | |
overall_micro_counters[n_err]['gt'] += total_gt_count | |
except Exception as e: | |
logger.error(f"General Error processing QID={question_id}: {e}", exc_info=True) | |
if question_id in question_performance: | |
# +++ Устанавливаем F1 в 0.0 при ошибке +++ | |
question_performance[question_id]['f1'] = 0.0 | |
question_performance[question_id]['assembly_recall_for_worst'] = 0.0 | |
question_performance[question_id]['missed_gt_indices'] = list(range(total_gt_count)) | |
for n_err in top_n_values: | |
results[n_err]['precision_list'].append((0.0, 0)) | |
results[n_err]['recall_list'].append((0.0, total_gt_count)) | |
results[n_err]['f1_list'].append((0.0, total_gt_count)) | |
results[n_err]['assembly_punct_recall_list'].append(0.0) | |
results[n_err]['doc_recall_list'].append(0.0) | |
results[n_err]['spurious_docs_list'].append(0) | |
# +++ Обновляем общий счетчик GT для Micro при ошибке +++ | |
overall_micro_counters[n_err]['gt'] += total_gt_count | |
# --- Этап 4: Расчет итоговых метрик --- | |
final_metrics_results: dict[int, dict[str, float | None]] = {} | |
# !!! УДАЛЯЕМ ПОВТОРНУЮ ИНИЦИАЛИЗАЦИЮ СЧЕТЧИКОВ !!! | |
# overall_micro_counters = { ... } | |
# overall_assembly_found_puncts = 0 | |
# overall_valid_gt_for_assembly = 0 | |
# +++ Лог перед финальным расчетом +++ (Оставляем на всякий случай) | |
logger.debug(f"Data before final calculation: results={results}") | |
logger.debug(f"Overall micro counters before final calc: {overall_micro_counters}") | |
logger.debug(f"Overall assembly counters before final calc: found={overall_assembly_found_puncts}, valid_gt={overall_valid_gt_for_assembly}") | |
# ... | |
for n in top_n_values: | |
# Извлекаем списки | |
prec_list = results[n]['precision_list'] | |
rec_list = results[n]['recall_list'] | |
f1_list = results[n]['f1_list'] | |
assembly_recall_list = results[n]['assembly_punct_recall_list'] | |
doc_recall_list = results[n]['doc_recall_list'] | |
spurious_docs_list = results[n]['spurious_docs_list'] | |
# --- Расчет Macro (с явной проверкой) --- | |
macro_precision = sum(p for p, w in prec_list) / len(prec_list) if prec_list else None | |
macro_recall = sum(r for r, w in rec_list) / len(rec_list) if rec_list else None | |
macro_f1 = sum(f for f, w in f1_list) / len(f1_list) if f1_list else None | |
# --- Расчет Weighted (с явной проверкой на пустой список) --- | |
weighted_precision = None | |
if prec_list: | |
weighted_precision_num = sum(p * w for p, w in prec_list) | |
weighted_precision_den = sum(w for p, w in prec_list) | |
weighted_precision = weighted_precision_num / weighted_precision_den if weighted_precision_den > 0 else 0.0 | |
weighted_recall = None | |
if rec_list: | |
weighted_recall_num = sum(r * w for r, w in rec_list) | |
weighted_recall_den = sum(w for r, w in rec_list) | |
weighted_recall = weighted_recall_num / weighted_recall_den if weighted_recall_den > 0 else 0.0 | |
weighted_f1 = None | |
if f1_list: | |
weighted_f1_num = sum(f * w for f, w in f1_list) | |
weighted_f1_den = sum(w for f, w in f1_list) | |
weighted_f1 = weighted_f1_num / weighted_f1_den if weighted_f1_den > 0 else 0.0 | |
# --- Расчет Micro (теперь использует накопленные значения) --- | |
total_found = overall_micro_counters[n]['found'] | |
total_gt = overall_micro_counters[n]['gt'] | |
total_relevant = overall_micro_counters[n]['relevant'] | |
total_retrieved = overall_micro_counters[n]['retrieved'] | |
micro_precision = total_relevant / total_retrieved if total_retrieved > 0 else 0.0 | |
micro_recall = total_found / total_gt if total_gt > 0 else 0.0 | |
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0 | |
# --- Новые Macro метрики (с явной проверкой) --- | |
assembly_punct_recall_macro = sum(assembly_recall_list) / len(assembly_recall_list) if assembly_recall_list else None | |
doc_recall_macro = sum(doc_recall_list) / len(doc_recall_list) if doc_recall_list else None | |
avg_spurious_docs = sum(spurious_docs_list) / len(spurious_docs_list) if spurious_docs_list else None | |
# Заполняем результат (без изменений) | |
final_metrics_results[n] = { | |
'macro_precision': macro_precision, | |
'macro_recall': macro_recall, | |
'macro_f1': macro_f1, | |
'weighted_precision': weighted_precision, | |
'weighted_recall': weighted_recall, | |
'weighted_f1': weighted_f1, | |
'micro_precision': micro_precision, | |
'micro_recall': micro_recall, | |
'micro_f1': micro_f1, | |
'assembly_punct_recall_macro': assembly_punct_recall_macro, | |
'doc_recall_macro': doc_recall_macro, | |
'avg_spurious_docs': avg_spurious_docs, | |
} | |
logger.info(f"Final metrics for top_n={n}: {final_metrics_results[n]}\n") | |
# --- Расчет Micro Assembly Punct Recall (теперь использует накопленные значения) --- | |
micro_assembly_punct_recall = ( | |
overall_assembly_found_puncts / overall_valid_gt_for_assembly | |
if overall_valid_gt_for_assembly > 0 else 0.0 | |
) | |
# --- Поиск худших вопросов (по Assembly Recall) --- | |
qid_to_ground_truths = {item['question_id']: item['ground_truth_texts'] for item in processed_items} | |
worst_questions_processed = [] | |
logger.debug(f"Debugging worst questions: question_performance = {question_performance}") | |
# +++ Сортируем по assembly_recall_for_worst +++ | |
sorted_performance = sorted( | |
[ | |
(qid, data) for qid, data in question_performance.items() | |
# !!! КЛЮЧЕВОЙ ФИЛЬТР !!! Убедимся, что assembly_recall_for_worst не None | |
if data.get('assembly_recall_for_worst') is not None | |
], | |
key=lambda item: item[1]['assembly_recall_for_worst'] # Сортируем по recall ПО ВОЗРАСТАНИЮ | |
) | |
# +++ ДОБАВЛЯЕМ ЛОГ ПОСЛЕ СОРТИРОВКИ +++ | |
logger.debug(f"Debugging worst questions: sorted_performance (top {top_worst_k}) = {sorted_performance[:top_worst_k]}") | |
# +++ КОНЕЦ ЛОГА +++ | |
# +++ ДОБАВЛЯЕМ ЛОГИ ВНУТРИ ЦИКЛА +++ | |
for qid, perf_data in sorted_performance[:top_worst_k]: | |
logger.debug(f"Processing worst question: QID={qid}, Data={perf_data}") | |
try: | |
missed_indices = perf_data.get('missed_gt_indices', []) | |
logger.debug(f"QID={qid}: Got missed_indices: {missed_indices}") | |
missed_texts = [] | |
if missed_indices is not None and qid in qid_to_ground_truths: | |
original_gts = qid_to_ground_truths[qid] | |
missed_texts = [original_gts[i] for i in missed_indices if i < len(original_gts)] | |
logger.debug(f"QID={qid}: Found {len(missed_texts)} missed texts from {len(original_gts)} original GTs.") | |
elif qid not in qid_to_ground_truths: | |
logger.warning(f"QID={qid} not found in qid_to_ground_truths when processing worst questions.") | |
# Формируем словарь перед добавлением | |
worst_entry = { | |
'id': qid, | |
'f1': perf_data.get('f1'), # Используем .get() для безопасности | |
'assembly_recall': perf_data.get('assembly_recall_for_worst'), | |
'text': perf_data.get('question_text'), | |
'reference_answer': perf_data.get('reference_answer'), | |
'missed_ground_truths': missed_texts | |
} | |
logger.debug(f"QID={qid}: Appending entry: {worst_entry}") | |
worst_questions_processed.append(worst_entry) | |
except Exception as e: | |
logger.error(f"Error processing worst question QID={qid}: {e}", exc_info=True) | |
# Не прерываем цикл, но логируем ошибку | |
# +++ КОНЕЦ ЛОГОВ ВНУТРИ ЦИКЛА +++ | |
# --- Формируем финальный ответ --- | |
metrics_for_max_n = final_metrics_results.get(max_top_n, {}) | |
overall_total_found_micro = overall_micro_counters[max_top_n]['found'] | |
overall_total_gt_micro = overall_micro_counters[max_top_n]['gt'] | |
# --- Логирование перед ответом (Оставляем) --- | |
logger.debug(f"Final Response Prep: max_top_n={max_top_n}") | |
logger.debug(f"Final Response Prep: metrics_for_max_n={metrics_for_max_n}") | |
logger.debug(f"Final Response Prep: overall_micro_counters={overall_micro_counters}") | |
logger.debug(f"Final Response Prep: micro_recall_for_human_readable = {metrics_for_max_n.get('micro_recall')}") | |
# --- Конец лога --- | |
# +++ Перестраиваем структуру ответа с РУССКИМИ КЛЮЧАМИ +++ | |
final_response = { | |
# --- Человекочитаемые метрики --- (Вверху) | |
"Найдено пунктов (всего)": overall_total_found_micro, | |
"Всего пунктов (эталон)": overall_total_gt_micro, | |
"% найденных пунктов (чанк присутствует в пункте)": metrics_for_max_n.get('micro_recall'), # Micro Recall | |
"% пунктов были найдены в собранной версии": micro_assembly_punct_recall, # Micro Assembly Recall | |
"В среднем для каждого вопроса найден такой % пунктов": metrics_for_max_n.get('macro_recall'), # Macro Recall | |
"В среднем для каждого вопроса найден такой % документов": metrics_for_max_n.get('doc_recall_macro'), # Macro Doc Recall | |
"В среднем для каждого вопроса найдено N лишних документов, N": metrics_for_max_n.get('avg_spurious_docs'), # Avg Spurious Docs | |
# --- Результаты по top_n --- (В середине) | |
"results": final_metrics_results, | |
# --- Худшие вопросы --- (Внизу) | |
"worst_performing_questions": worst_questions_processed, | |
} | |
return final_response | |