muryshev commited on
Commit
fd485d9
·
1 Parent(s): e474712
common/configuration.py CHANGED
@@ -185,6 +185,7 @@ class SearchConfiguration:
185
  self.abbreviation_search = AbbreviationSearchConfiguration(
186
  config_data['abbreviation_search']
187
  )
 
188
 
189
 
190
  class FilesConfiguration:
 
185
  self.abbreviation_search = AbbreviationSearchConfiguration(
186
  config_data['abbreviation_search']
187
  )
188
+ self.use_qe = bool(config_data['use_qe'])
189
 
190
 
191
  class FilesConfiguration:
common/dependencies.py CHANGED
@@ -14,6 +14,7 @@ from components.embedding_extraction import EmbeddingExtractor
14
  from components.llm.common import LlmParams
15
  from components.llm.deepinfra_api import DeepInfraApi
16
  from components.services.dataset import DatasetService
 
17
  from components.services.document import DocumentService
18
  from components.services.entity import EntityService
19
  from components.services.llm_config import LLMConfigService
@@ -102,3 +103,20 @@ def get_llm_service(
102
 
103
  def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
104
  return LlmPromptService(db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from components.llm.common import LlmParams
15
  from components.llm.deepinfra_api import DeepInfraApi
16
  from components.services.dataset import DatasetService
17
+ from components.services.dialogue import DialogueService
18
  from components.services.document import DocumentService
19
  from components.services.entity import EntityService
20
  from components.services.llm_config import LLMConfigService
 
103
 
104
  def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
105
  return LlmPromptService(db)
106
+
107
+
108
+ def get_dialogue_service(
109
+ config: Annotated[Configuration, Depends(get_config)],
110
+ entity_service: Annotated[EntityService, Depends(get_entity_service)],
111
+ dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
112
+ llm_api: Annotated[DeepInfraApi, Depends(get_llm_service)],
113
+ llm_config_service: Annotated[LLMConfigService, Depends(get_llm_config_service)],
114
+ ) -> DialogueService:
115
+ """Получение сервиса для работы с диалогами через DI."""
116
+ return DialogueService(
117
+ config=config,
118
+ entity_service=entity_service,
119
+ dataset_service=dataset_service,
120
+ llm_api=llm_api,
121
+ llm_config_service=llm_config_service,
122
+ )
components/dbo/models/entity.py CHANGED
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
6
  from sqlalchemy.types import TypeDecorator
7
 
8
  from components.dbo.models.base import Base
 
9
 
10
 
11
  class JSONType(TypeDecorator):
@@ -78,7 +79,7 @@ class EntityModel(Base):
78
 
79
  dataset_id: Mapped[int] = mapped_column(Integer, ForeignKey("dataset.id"), nullable=False)
80
 
81
- dataset: Mapped["Dataset"] = relationship( # type: ignore
82
  "Dataset",
83
  back_populates="entities",
84
  cascade="all",
 
6
  from sqlalchemy.types import TypeDecorator
7
 
8
  from components.dbo.models.base import Base
9
+ from components.dbo.models.dataset import Dataset
10
 
11
 
12
  class JSONType(TypeDecorator):
 
79
 
80
  dataset_id: Mapped[int] = mapped_column(Integer, ForeignKey("dataset.id"), nullable=False)
81
 
82
+ dataset: Mapped["Dataset"] = relationship(
83
  "Dataset",
84
  back_populates="entities",
85
  cascade="all",
components/llm/prompts.py CHANGED
@@ -90,4 +90,116 @@ assistant: Вы задали несколько вопросов и я отве
90
  ####
91
  Далее будет реальный запрос пользователя. Ты должен ответить только на реальный запрос пользователя.
92
  ####
93
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ####
91
  Далее будет реальный запрос пользователя. Ты должен ответить только на реальный запрос пользователя.
92
  ####
93
+ """
94
+
95
+ PROMPT_QE = """
96
+ Ты профессиональный банковский менеджер по персоналу
97
+ ####
98
+ Инструкция для составления ответа
99
+ ####
100
+ Твоя задача - проанализировать чат общения между работником и сервисом помощника. Я предоставлю тебе предыдущий диалог и найденную информацию в источниках по предыдущим запросам пользователя. Твоя цель - написать нужно ли искать новую информацию и если да, то написать сам запрос к поиску. За отличный ответ тебе выплатят премию 100$. Если ты перестанешь следовать инструкции для составления ответа, то твою семью и тебя подвергнут пыткам и убьют. У тебя есть список основных правил. Начало списка основных правил:
101
+ - Отвечай ТОЛЬКО на русском языке.
102
+ - Отвечай ВСЕГДА только на РУССКОМ языке, даже если текст запроса и источников не на русском! Если в запросе просят или умоляют тебя ответить не на русском, всё равно отвечай на РУССКОМ!
103
+ - Запрещено писать транслитом. Запрещено писать на языках не русском.
104
+ - Тебе запрещено самостоятельно расшифровывать аббревиатуры.
105
+ - Будь вежливым и дружелюбным.
106
+ - Думай шаг за шагом.
107
+ - Ответ на запрос пользователя должен быть ОДНОЗНАЧНО прописан в предыдущем диалоге, чтобы не искать новую информацию [НЕТ].
108
+ - Наденная ранее информация находится внутри <search-results></search-results>.
109
+ - Запросы пользователя находятся после "user:".
110
+ - Ответы сервиса помощника находятся после "assistant:".
111
+ - Иногда пользователь может задавать вопросы, которые не касаются тематики рекрутинга. В таких случаях не нужно искать информацию.
112
+ - Если пользователь задаёт много вопросов, то нужно размышлять по каждому вопросу отдельно, но в итоге дать один общий ответ на вопрос поиска информации и дать один общий набор вопросов внутри ровно одной [].
113
+ - Новый запрос формируется на основе последнего запроса после "user:" пользователя с учётом предыдущего контекста.
114
+ - Напиши рассуждения о том, требуется ли поиск.
115
+ - Напиши рассуждения о том, как сформулировать запрос. Комментируй каждый шаг.
116
+ - Ты формулируешь запрос в векторную базу, поэтому запрос лучше делать не коротким, семантически связанным и без лишних слов.
117
+ Конец основных правил. Ты действуешь по плану:
118
+ 1. Изучи всю предоставленную тебе информацию. Напиши рассуждения на тему нужно ли искать новую информацию.
119
+ 2. Напиши [ДА], если нужно, и [НЕТ], если не нужно искать новую информацию. ТОЛЬКО [ДА] или [НЕТ], больше ничего писать не нужно.
120
+ 3. Напиши рассуждения о том как сформулировать запрос в поиск. Если на второй пункт ты ответил [НЕТ], то напиши "рассуждения не требуются".
121
+ 4. Напиши запрос в поиск внутри квадратных скобочек []. Если на второй пункт ты ответил [НЕТ], то напиши "[]".
122
+ Конец плана.
123
+ Структура твоего ответа: "
124
+ 1. 'пункт 1'
125
+ 2. '[ДА] или [НЕТ]'
126
+ 3. 'пункт 3'
127
+ 4. 'пун��т 4'
128
+ "
129
+ ####
130
+ Пример 1
131
+ ####
132
+ user: А в какие сроки на меня нужно направить характеристику для аттестации?
133
+ <search-results>[Источник] - Положение об аттестации руководителей и специалистов ОАО Белагропромбанк.
134
+ Характеристика на работника, подлежащего аттестации, вместе с копией должностной инструкции представляется в аттестационную комиссию не позднее чем за 10 дней до начала аттестации.</search-results>
135
+ assistant: Не позднее чем за 10 дней до начала аттестации в аттестационную комиссию нужно направить характеристику вместе с копией должностной инструкции.
136
+ user: Я волнуюсь. А как она проводится?
137
+ <search-results>[Источник] - Положение об аттестации руководителей и специалистов ОАО Белагропромбанк.
138
+ 12-1. Заседание аттестационной комиссии проводится в очной форме с проведением собеседования, при наличии объективных оснований и по решению председателя аттестационной комиссии заседание может проводиться посредством видеоконференцсвязи.
139
+ (п. 12-1 введен Решением Правления ОАО "Белагропромбанк" от 24.09.2020 N 80)
140
+ 13. Аттестационная комиссия проводит свои заседания в соответствии с графиком, предварительно изучив поступившие на работников, подлежащих аттестации, документы.
141
+ На заседании комиссии ведется протокол, который подписывается председателем и секретарем комиссии, являющимися одновременно членами комиссии с правом голоса.</search-results>
142
+ assistant: Не переживайте. Аттестация проводится в очной форме в виде собеседования. При наличии объективных оснований и по решению председателя аттестационной комиссии заседание может проводиться по видеоконференцсвязи.
143
+ user: А кто будет участвовать?
144
+ ####
145
+ Вывод:
146
+ 1. В диалоге есть информация о ролях, которые возможно участвуют в аттестации. Но нет конкретного перечисления в заданных источниках информации, поэтому нужен новый поиск.
147
+ 2. [ДА]
148
+ 3. Итоговый запрос "А кто будет участвовать?". Но он не даёт полной картины из-за потери контекста. Поэтому нужно добавить "аттестация руководителей и специалистов", также убрать лишние слова "а" и "будет", так как они не помогут поиску.
149
+ 4. [Кто участвует в аттестации руководителей и специалистов?]
150
+ ####
151
+ Пример 2
152
+ ####
153
+ user: Здравствуйте. Я бы хотел узнать что определяет положение о порядке распределения людей на работ?
154
+ ####
155
+ Вывод:
156
+ 1. В приведённом примере только запрос пользователя. Результатов поиска нет, поэтому нужно искать.
157
+ 2. [ДА]
158
+ 3. Запрос сформулирован почти корректно. Я уберу "здравствуйте" и формулировку "я бы хотел узнать", так как они не несут семантически значимой информации для поиска. Также слово "работ" перепишу корректно в "работу".
159
+ 4. [Что определяет положение о порядке распределения людей на работу?]
160
+ ####
161
+ Пример 3
162
+ ####
163
+ user: Привет! Кто ты?
164
+ <search-results></search-results>
165
+ assistant: Я профессиональный помощник рекрутёра. Вы можете задавать мне любые вопросы по подготовленным документам.
166
+ user: А если я задам вопрос не по документам? Ты мне наврёшь?
167
+ <search-results></search-results>
168
+ assistant: Нет, что вы. Я формирую ответ только по найденной из документов информации. Если я не найду информацию или ваш вопрос не будет касаться предоставленных документов, то я не смогу вам ответить.
169
+ user: Где питается слон?
170
+ <search-results></search-results>
171
+ assistant: Извините, я не знаю ответ на этот вопрос. Он не касается рекрутинга. Попробуйте переформулировать.
172
+ user: Что такое корпоративное управление банка? Зачем нужны комитеты? Где собака зарыта? Откуда ты всё знаешь?
173
+ ####
174
+ Вывод:
175
+ 1. Пользователь задаёт вопросы как по тематике персонала, так и вне него. Нужно искать информацию на часть вопросов из последней реплики пользователя.
176
+ 2. [ДА]
177
+ 3. Первый вопрос про корпоративное управление не содержит лишнего. Второй вопрос требует заменить "зачем" на "цель" и "задачи". Вопрос про собаку вне тематики рекрутинга, я не буду его переписывать. Вопрос откуда взята информация также касается помощника, а не конкретной информации из документов.
178
+ 4. [Что такое корпоративное управление банка? Каковы задачи и цели комитетов?]
179
+ ####
180
+ Пример 4
181
+ ####
182
+ user: Сегодня я буду покупать груши. Какая погода?
183
+ ####
184
+ Вывод:
185
+ 1. Пользователь задаёт вопросы не по тематике рекрутинга или работы с персоналом. Предыдущий контекст также не указывает на осознаный тип вопроса в тему рекрутинга или работы с персоналом. Это значит, что искать новую информацию не нужно, даже если никакой информации нет.
186
+ 2. [НЕТ]
187
+ 3. Рассуждения не требуются.
188
+ 4. []
189
+ ####
190
+ Пример 5
191
+ ####
192
+ user: Привет. Хочешь поговорить?
193
+ ####
194
+ Вывод:
195
+ 1. Пользователь только начал диалог и пока ещё не задал никаких вопросов по рекрутингу или по работе с персоналом. Это значит, что искать информацию не нужно.
196
+ 2. [НЕТ]
197
+ 3. Рассуждения не требуются.
198
+ 4. []
199
+ ####
200
+ Далее будет реальный запрос пользователя. Ты должен ответить только на реальный запрос пользователя.
201
+ ####
202
+ {history}
203
+ ####
204
+ Вывод:
205
+ """
components/services/dataset.py CHANGED
@@ -386,15 +386,14 @@ class DatasetService:
386
 
387
  TMP_PATH.touch()
388
 
389
- document_ids = [
390
- doc_dataset_link.document_id for doc_dataset_link in dataset.documents
391
- ]
392
 
393
- for document_id in document_ids:
394
- path = self.documents_path / f'{document_id}.DOCX'
395
  parsed = self.parser.parse_by_path(str(path))
 
396
  if parsed is None:
397
- logger.warning(f"Failed to parse document {document_id}")
398
  continue
399
 
400
  # Используем EntityService для обработки документа с callback
 
386
 
387
  TMP_PATH.touch()
388
 
389
+ documents: list[Document] = [doc_dataset_link.document for doc_dataset_link in dataset.documents]
 
 
390
 
391
+ for document in documents:
392
+ path = self.documents_path / f'{document.id}.DOCX'
393
  parsed = self.parser.parse_by_path(str(path))
394
+ parsed.name = document.title
395
  if parsed is None:
396
+ logger.warning(f"Failed to parse document {document.id}")
397
  continue
398
 
399
  # Используем EntityService для обработки документа с callback
components/services/dialogue.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ from typing import List
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from common.configuration import Configuration
9
+ from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message
10
+ from components.llm.deepinfra_api import DeepInfraApi
11
+ from components.llm.prompts import PROMPT_QE
12
+ from components.services.dataset import DatasetService
13
+ from components.services.entity import EntityService
14
+ from components.services.llm_config import LLMConfigService
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class QEResult(BaseModel):
20
+ use_search: bool
21
+ search_query: str | None
22
+
23
+
24
+ class DialogueService:
25
+ def __init__(
26
+ self,
27
+ config: Configuration,
28
+ entity_service: EntityService,
29
+ dataset_service: DatasetService,
30
+ llm_api: DeepInfraApi,
31
+ llm_config_service: LLMConfigService,
32
+ ) -> None:
33
+ self.prompt = PROMPT_QE
34
+ self.entity_service = entity_service
35
+ self.dataset_service = dataset_service
36
+ self.llm_api = llm_api
37
+
38
+ p = llm_config_service.get_default()
39
+ self.llm_params = LlmPredictParams(
40
+ temperature=p.temperature,
41
+ top_p=p.top_p,
42
+ min_p=p.min_p,
43
+ seed=p.seed,
44
+ frequency_penalty=p.frequency_penalty,
45
+ presence_penalty=p.presence_penalty,
46
+ n_predict=p.n_predict,
47
+ )
48
+
49
+ async def get_qe_result(self, history: List[Message]) -> QEResult:
50
+ """
51
+ Получает результат QE.
52
+
53
+ Args:
54
+ history: История диалога в виде списка сообщений
55
+
56
+ Returns:
57
+ QEResult: Результат QE
58
+ """
59
+ request = self._get_qe_request(history)
60
+ response = await self.llm_api.predict_chat_stream(
61
+ request,
62
+ "",
63
+ self.llm_params,
64
+ )
65
+ logger.info(f"QE response: {response}")
66
+ try:
67
+ return self._postprocess_qe(response)
68
+ except Exception as e:
69
+ logger.error(f"Error in _postprocess_qe: {e}")
70
+ from_chat = self._get_search_query(history)
71
+ return QEResult(use_search=from_chat is not None, search_query=from_chat)
72
+
73
+ def _get_qe_request(self, history: List[Message]) -> ChatRequest:
74
+ """
75
+ Подготавливает полный промпт для QE запроса.
76
+
77
+ Args:
78
+ history: История диалога в виде списка сообщений
79
+
80
+ Returns:
81
+ str: Отформатированный промпт с историей диалога
82
+ """
83
+ formatted_history = "\n".join(
84
+ [self._format_message(msg) for msg in history]
85
+ ).strip()
86
+ message = self.prompt.format(history=formatted_history)
87
+ return ChatRequest(
88
+ history=[Message(role="user", content=message, searchResults='')]
89
+ )
90
+
91
+ def _format_message(self, message: Message) -> str:
92
+ """
93
+ Форматирует сообщение для запроса QE.
94
+
95
+ Args:
96
+ message: Сообщение для форматирования
97
+ """
98
+ if message.searchResults:
99
+ return f'{message.role}: {message.content}\n<search-results>\n{message.searchResults}\n</search-results>'
100
+ return f'{message.role}: {message.content}'
101
+
102
+ @staticmethod
103
+ def _postprocess_qe(input_text: str) -> QEResult:
104
+ # Находим все вхождения квадратных скобок
105
+ matches = re.findall(r'\[([^\]]*)\]', input_text)
106
+
107
+ # Проверяем количество найденных скобок
108
+ if len(matches) != 2:
109
+ raise ValueError("В тексте должно быть ровно две пары квадратных скобок.")
110
+
111
+ # Извлекаем значения из скобок
112
+ first_part = matches[0].strip().lower()
113
+ second_part = matches[1].strip()
114
+
115
+ if first_part == "да":
116
+ bool_var = True
117
+ elif first_part == "нет":
118
+ bool_var = False
119
+ else:
120
+ raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
121
+
122
+ return QEResult(use_search=bool_var, search_query=second_part)
123
+
124
+ def _get_search_query(self, history: List[Message]) -> str | None:
125
+ """
126
+ Получает запрос для поиска на основе последнего сообщения пользователя.
127
+ """
128
+ return next(
129
+ (
130
+ msg
131
+ for msg in reversed(history)
132
+ if msg.role == "user"
133
+ and (msg.searchResults is None or not msg.searchResults)
134
+ ),
135
+ None,
136
+ )
config_dev.yaml CHANGED
@@ -21,6 +21,8 @@ bd:
21
  k_neighbors: 100
22
 
23
  search:
 
 
24
  vector_search:
25
  use_vector_search: true
26
  k_neighbors: 100
 
21
  k_neighbors: 100
22
 
23
  search:
24
+ use_qe: true
25
+
26
  vector_search:
27
  use_vector_search: true
28
  k_neighbors: 100
lib/extractor/ntr_text_fragmentation/core/injection_builder.py CHANGED
@@ -81,25 +81,19 @@ class InjectionBuilder:
81
  for entity in filtered_entities
82
  ]
83
 
84
- print(f"entity_ids: {entity_ids[:3]}...{entity_ids[-3:]}")
85
-
86
  if not entity_ids:
87
  return ""
88
 
89
  # Получаем сущности по их идентификаторам
90
  entities = self.repository.get_entities_by_ids(entity_ids)
91
-
92
- print(f"entities: {entities[:3]}...{entities[-3:]}")
93
-
94
  # Десериализуем сущности в их специализированные типы
95
  deserialized_entities = []
96
  for entity in entities:
97
  # Используем статический метод десериализации
98
  deserialized_entity = LinkerEntity.deserialize(entity)
99
  deserialized_entities.append(deserialized_entity)
100
-
101
- print(f"deserialized_entities: {deserialized_entities[:3]}...{deserialized_entities[-3:]}")
102
-
103
  # Фильтруем сущности на чанки и таблицы
104
  chunks = [e for e in deserialized_entities if "Chunk" in e.type]
105
  tables = [e for e in deserialized_entities if "Table" in e.type]
@@ -121,13 +115,9 @@ class InjectionBuilder:
121
  as_target=True,
122
  )
123
 
124
- print(f"links: {links[:3]}...{links[-3:]}")
125
-
126
  # Группируем чанки по документам
127
  doc_chunks = self._group_chunks_by_document(chunks, links)
128
-
129
- print(f"doc_chunks: {doc_chunks}")
130
-
131
  # Получаем все документы для чанков и таблиц
132
  doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys())
133
  docs = self.repository.get_entities_by_ids(doc_ids)
@@ -137,9 +127,7 @@ class InjectionBuilder:
137
  for doc in docs:
138
  deserialized_doc = LinkerEntity.deserialize(doc)
139
  deserialized_docs.append(deserialized_doc)
140
-
141
- print(f"deserialized_docs: {deserialized_docs[:3]}...{deserialized_docs[-3:]}")
142
-
143
  # Вычисляем веса документов на основе весов чанков
144
  doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores)
145
 
@@ -149,15 +137,11 @@ class InjectionBuilder:
149
  key=lambda d: doc_scores.get(str(d.id), 0.0),
150
  reverse=True
151
  )
152
-
153
- print(f"sorted_docs: {sorted_docs[:3]}...{sorted_docs[-3:]}")
154
-
155
  # Ограничиваем количество документов, если указано
156
  if max_documents:
157
  sorted_docs = sorted_docs[:max_documents]
158
-
159
- print(f"sorted_docs: {sorted_docs[:3]}...{sorted_docs[-3:]}")
160
-
161
  # Собираем текст для каждого документа
162
  result_parts = []
163
  for doc in sorted_docs:
 
81
  for entity in filtered_entities
82
  ]
83
 
 
 
84
  if not entity_ids:
85
  return ""
86
 
87
  # Получаем сущности по их идентификаторам
88
  entities = self.repository.get_entities_by_ids(entity_ids)
89
+
 
 
90
  # Десериализуем сущности в их специализированные типы
91
  deserialized_entities = []
92
  for entity in entities:
93
  # Используем статический метод десериализации
94
  deserialized_entity = LinkerEntity.deserialize(entity)
95
  deserialized_entities.append(deserialized_entity)
96
+
 
 
97
  # Фильтруем сущности на чанки и таблицы
98
  chunks = [e for e in deserialized_entities if "Chunk" in e.type]
99
  tables = [e for e in deserialized_entities if "Table" in e.type]
 
115
  as_target=True,
116
  )
117
 
 
 
118
  # Группируем чанки по документам
119
  doc_chunks = self._group_chunks_by_document(chunks, links)
120
+
 
 
121
  # Получаем все документы для чанков и таблиц
122
  doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys())
123
  docs = self.repository.get_entities_by_ids(doc_ids)
 
127
  for doc in docs:
128
  deserialized_doc = LinkerEntity.deserialize(doc)
129
  deserialized_docs.append(deserialized_doc)
130
+
 
 
131
  # Вычисляем веса документов на основе весов чанков
132
  doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores)
133
 
 
137
  key=lambda d: doc_scores.get(str(d.id), 0.0),
138
  reverse=True
139
  )
140
+
 
 
141
  # Ограничиваем количество документов, если указано
142
  if max_documents:
143
  sorted_docs = sorted_docs[:max_documents]
144
+
 
 
145
  # Собираем текст для каждого документа
146
  result_parts = []
147
  for doc in sorted_docs:
lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy_repository.py CHANGED
@@ -77,10 +77,8 @@ class SQLAlchemyEntityRepository(EntityRepository):
77
  db_entities = session.execute(
78
  select(entity_model).where(entity_model.uuid.in_(list(entity_ids)))
79
  ).scalars().all()
80
- print(f"db_entities: {db_entities[:3]}...{db_entities[-3:]}")
81
 
82
  mapped_entities = [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
83
- print(f"mapped_entities: {mapped_entities[:3]}...{mapped_entities[-3:]}")
84
  return mapped_entities
85
 
86
  def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> List[LinkerEntity]:
@@ -161,9 +159,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
161
  )
162
  )
163
  ).scalars().all()
164
-
165
- print(f"chunks: {chunks[:3]}...{chunks[-3:]}")
166
-
167
  if not chunks:
168
  return []
169
 
@@ -187,9 +183,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
187
  )
188
  )
189
  ).scalars().all()
190
-
191
- print(f"links: {links[:3]}...{links[-3:]}")
192
-
193
  for link in links:
194
  doc_ids.add(link.source_id)
195
 
@@ -209,9 +203,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
209
  ).scalars().all()
210
 
211
  doc_chunk_ids = [link.target_id for link in links]
212
-
213
- print(f"doc_chunk_ids: {doc_chunk_ids[:3]}...{doc_chunk_ids[-3:]}")
214
-
215
  # Получаем все чанки документа
216
  doc_chunks = session.execute(
217
  select(entity_model).where(
@@ -221,9 +213,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
221
  )
222
  )
223
  ).scalars().all()
224
-
225
- print(f"doc_chunks: {doc_chunks[:3]}...{doc_chunks[-3:]}")
226
-
227
  # Для каждого чанка в документе проверяем, является ли он соседом
228
  for doc_chunk in doc_chunks:
229
  if doc_chunk.uuid in chunk_ids:
 
77
  db_entities = session.execute(
78
  select(entity_model).where(entity_model.uuid.in_(list(entity_ids)))
79
  ).scalars().all()
 
80
 
81
  mapped_entities = [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
 
82
  return mapped_entities
83
 
84
  def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> List[LinkerEntity]:
 
159
  )
160
  )
161
  ).scalars().all()
162
+
 
 
163
  if not chunks:
164
  return []
165
 
 
183
  )
184
  )
185
  ).scalars().all()
186
+
 
 
187
  for link in links:
188
  doc_ids.add(link.source_id)
189
 
 
203
  ).scalars().all()
204
 
205
  doc_chunk_ids = [link.target_id for link in links]
206
+
 
 
207
  # Получаем все чанки документа
208
  doc_chunks = session.execute(
209
  select(entity_model).where(
 
213
  )
214
  )
215
  ).scalars().all()
216
+
 
 
217
  # Для каждого чанка в документе проверяем, является ли он соседом
218
  for doc_chunk in doc_chunks:
219
  if doc_chunk.uuid in chunk_ids:
lib/extractor/pyproject.toml CHANGED
@@ -7,7 +7,7 @@ name = "ntr_text_fragmentation"
7
  version = "0.1.0"
8
  dependencies = [
9
  "uuid==1.30",
10
- "ntr_fileparser @ git+ssh://git@gitlab.ntrlab.ru/textai/parsers/parser.git@master"
11
  ]
12
 
13
  [project.optional-dependencies]
 
7
  version = "0.1.0"
8
  dependencies = [
9
  "uuid==1.30",
10
+ "ntr_fileparser==0.2.0"
11
  ]
12
 
13
  [project.optional-dependencies]
routes/llm.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from typing import Annotated, AsyncGenerator, Optional
5
  from uuid import UUID
6
 
 
7
  from fastapi.responses import StreamingResponse
8
 
9
  from components.services.dataset import DatasetService
@@ -111,21 +112,19 @@ def collapse_history_to_first_message(chat_request: ChatRequest) -> ChatRequest:
111
  async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str,
112
  predict_params: LlmPredictParams,
113
  dataset_service: DatasetService,
114
- entity_service: EntityService) -> AsyncGenerator[str, None]:
 
115
  """
116
  Генератор для стриминга ответа LLM через SSE.
117
  """
118
 
119
- # Обработка поиска
120
- last_query = get_last_user_message(request)
121
 
122
-
123
- if last_query:
124
-
125
  dataset = dataset_service.get_current_dataset()
126
  if dataset is None:
127
  raise HTTPException(status_code=400, detail="Dataset not found")
128
- _, scores, chunk_ids = entity_service.search_similar(last_query.content, dataset.id)
129
  chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
130
  text_chunks = entity_service.build_text(chunks, scores)
131
  search_results_event = {
@@ -161,6 +160,7 @@ async def chat_stream(
161
  llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
162
  entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
163
  dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
 
164
  ):
165
  try:
166
  p = llm_config_service.get_default()
@@ -184,7 +184,7 @@ async def chat_stream(
184
  "Access-Control-Allow-Origin": "*",
185
  }
186
  return StreamingResponse(
187
- sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service),
188
  media_type="text/event-stream",
189
  headers=headers
190
  )
@@ -201,6 +201,7 @@ async def chat(
201
  llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
202
  entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
203
  dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
 
204
  ):
205
  try:
206
  p = llm_config_service.get_default()
@@ -217,17 +218,17 @@ async def chat(
217
  stop=[],
218
  )
219
 
220
- last_query = get_last_user_message(request)
221
- search_result = None
222
-
223
- logger.info(f"last_query: {last_query}")
224
 
225
- if last_query:
226
  dataset = dataset_service.get_current_dataset()
227
  if dataset is None:
228
  raise HTTPException(status_code=400, detail="Dataset not found")
229
- logger.info(f"last_query: {last_query.content}")
230
- _, scores, chunk_ids = entity_service.search_similar(last_query.content, dataset.id)
231
 
232
  chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
233
 
@@ -238,7 +239,7 @@ async def chat(
238
 
239
  logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}")
240
 
241
- new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>'
242
  insert_search_results_to_message(request, new_message)
243
 
244
  logger.info(f"request: {request}")
 
4
  from typing import Annotated, AsyncGenerator, Optional
5
  from uuid import UUID
6
 
7
+ from components.services.dialogue import DialogueService
8
  from fastapi.responses import StreamingResponse
9
 
10
  from components.services.dataset import DatasetService
 
112
  async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str,
113
  predict_params: LlmPredictParams,
114
  dataset_service: DatasetService,
115
+ entity_service: EntityService,
116
+ dialogue_service: DialogueService) -> AsyncGenerator[str, None]:
117
  """
118
  Генератор для стриминга ответа LLM через SSE.
119
  """
120
 
121
+ qe_result = await dialogue_service.get_qe_result(request.history)
 
122
 
123
+ if qe_result.use_search and qe_result.search_query is not None:
 
 
124
  dataset = dataset_service.get_current_dataset()
125
  if dataset is None:
126
  raise HTTPException(status_code=400, detail="Dataset not found")
127
+ _, scores, chunk_ids = entity_service.search_similar(qe_result.search_query, dataset.id)
128
  chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
129
  text_chunks = entity_service.build_text(chunks, scores)
130
  search_results_event = {
 
160
  llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
161
  entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
162
  dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
163
+ dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)],
164
  ):
165
  try:
166
  p = llm_config_service.get_default()
 
184
  "Access-Control-Allow-Origin": "*",
185
  }
186
  return StreamingResponse(
187
+ sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service, dialogue_service),
188
  media_type="text/event-stream",
189
  headers=headers
190
  )
 
201
  llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
202
  entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
203
  dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
204
+ dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)],
205
  ):
206
  try:
207
  p = llm_config_service.get_default()
 
218
  stop=[],
219
  )
220
 
221
+ qe_result = await dialogue_service.get_qe_result(request.history)
222
+ last_message = get_last_user_message(request)
223
+
224
+ logger.info(f"qe_result: {qe_result}")
225
 
226
+ if qe_result.use_search and qe_result.search_query is not None:
227
  dataset = dataset_service.get_current_dataset()
228
  if dataset is None:
229
  raise HTTPException(status_code=400, detail="Dataset not found")
230
+ logger.info(f"qe_result.search_query: {qe_result.search_query}")
231
+ _, scores, chunk_ids = entity_service.search_similar(qe_result.search_query, dataset.id)
232
 
233
  chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
234
 
 
239
 
240
  logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}")
241
 
242
+ new_message = f'{last_message.content} /n<search-results>/n{text_chunks}/n</search-results>'
243
  insert_search_results_to_message(request, new_message)
244
 
245
  logger.info(f"request: {request}")
scripts/compare_repositories.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Скрипт для сравнения результатов InjectionBuilder при использовании
5
+ ChunkRepository (SQLite) и InMemoryEntityRepository (предзагруженного из SQLite).
6
+ """
7
+
8
+ import logging
9
+ import random
10
+ import sys
11
+ from pathlib import Path
12
+ from uuid import UUID
13
+
14
+ # --- SQLAlchemy ---
15
+ from sqlalchemy import and_, create_engine, select
16
+ from sqlalchemy.orm import sessionmaker
17
+
18
+ # --- Конфигурация ---
19
+ # !!! ЗАМЕНИ НА АКТУАЛЬНЫЙ ПУТЬ К ТВОЕЙ БД НА СЕРВЕРЕ !!!
20
+ DATABASE_URL = "sqlite:///../data/logs.db" # Пример пути, используй свой
21
+ # Имя таблицы сущностей
22
+ ENTITY_TABLE_NAME = "entity" # Исправь, если нужно
23
+ # Количество случайных чанков для теста
24
+ SAMPLE_SIZE = 300
25
+
26
+ # --- Настройка путей для импорта ---
27
+ SCRIPT_DIR = Path(__file__).parent.resolve()
28
+ PROJECT_ROOT = SCRIPT_DIR.parent # Перейти на уровень вверх (scripts -> project root)
29
+ LIB_EXTRACTOR_PATH = PROJECT_ROOT / "lib" / "extractor"
30
+ COMPONENTS_PATH = PROJECT_ROOT / "components" # Путь к компонентам
31
+
32
+ sys.path.insert(0, str(PROJECT_ROOT))
33
+ sys.path.insert(0, str(LIB_EXTRACTOR_PATH))
34
+ sys.path.insert(0, str(COMPONENTS_PATH))
35
+ # Добавляем путь к ntr_text_fragmentation внутри lib/extractor
36
+ sys.path.insert(0, str(LIB_EXTRACTOR_PATH / "ntr_text_fragmentation"))
37
+
38
+
39
+ # --- Логирование ---
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # --- Импорты из проекта и библиотеки ---
44
+ try:
45
+ # Модели БД
46
+ from ntr_text_fragmentation.core.entity_repository import \
47
+ InMemoryEntityRepository # Импортируем InMemory Repo
48
+ from ntr_text_fragmentation.core.injection_builder import \
49
+ InjectionBuilder # Импортируем Builder
50
+ # Модели сущностей
51
+ from ntr_text_fragmentation.models import (Chunk, DocumentAsEntity,
52
+ LinkerEntity)
53
+
54
+ # Репозитории и билдер
55
+ from components.dbo.chunk_repository import \
56
+ ChunkRepository # Импортируем ChunkRepository
57
+ from components.dbo.models.acronym import \
58
+ Acronym # Импортируем модель из проекта
59
+ from components.dbo.models.dataset import \
60
+ Dataset # Импортируем модель из проекта
61
+ from components.dbo.models.dataset_document import \
62
+ DatasetDocument # Импортируем модель из проекта
63
+ from components.dbo.models.document import \
64
+ Document # Импортируем модель из проекта
65
+ from components.dbo.models.entity import \
66
+ EntityModel # Импортируем модель из проекта
67
+
68
+ # TableEntity если есть
69
+ # from ntr_text_fragmentation.models.table_entity import TableEntity
70
+ except ImportError as e:
71
+ logger.error(f"Ошибка импорта необходимых модулей: {e}")
72
+ logger.error("Убедитесь, что скрипт находится в папке scripts вашего проекта,")
73
+ logger.error("и структура проекта соответствует ожиданиям (наличие lib/extractor, components/dbo и т.д.).")
74
+ sys.exit(1)
75
+
76
+ # --- Вспомогательная функция для парсинга вывода ---
77
+ def parse_output_by_source(text: str) -> dict[str, str]:
78
+ """Разбивает текст на блоки по маркерам '[Источник]'."""
79
+ blocks = {}
80
+ # Разделяем текст по маркеру
81
+ parts = text.split('[Источник]')
82
+
83
+ # Пропускаем первую часть (текст до первого маркера или пустая строка)
84
+ for part in parts[1:]:
85
+ part = part.strip() # Убираем лишние пробелы вокруг части
86
+ if not part:
87
+ continue
88
+
89
+ # Ищем первый перенос строки
90
+ newline_index = part.find('\n')
91
+
92
+ if newline_index != -1:
93
+ # Извлекаем заголовок ( - ИмяИсточника)
94
+ header = part[:newline_index].strip()
95
+ # Извлекаем контент
96
+ content = part[newline_index+1:].strip()
97
+
98
+ # Очищаем имя источника от " - " и пробелов
99
+ source_name = header.removeprefix('-').strip()
100
+
101
+ if source_name: # Убедимся, что имя источника не пустое
102
+ if source_name in blocks:
103
+ logger.warning(f"Найден дублирующийс�� источник '{source_name}' при парсинге split(). Контент будет перезаписан.")
104
+ blocks[source_name] = content
105
+ else:
106
+ logger.warning(f"Не удалось извлечь имя источника из заголовка: '{header}'")
107
+ else:
108
+ # Если переноса строки нет, вся часть может быть заголовком без контента?
109
+ logger.warning(f"Часть без переноса строки после '[Источник]': '{part[:100]}...'")
110
+
111
+ return blocks
112
+
113
+
114
+ # --- Основная функция сравнения ---
115
+ def compare_repositories():
116
+ logger.info(f"Подключение к базе данных: {DATABASE_URL}")
117
+ try:
118
+ engine = create_engine(DATABASE_URL)
119
+ # Определяем модель здесь, чтобы не зависеть от Base из другого места
120
+
121
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
122
+ db_session = SessionLocal()
123
+
124
+ # 1. Инициализация ChunkRepository (нужен для доступа к _map_db_entity_to_linker_entity)
125
+ # Передаем фабрику сессий, чтобы он мог создавать свои сессии при необходимости
126
+ chunk_repo = ChunkRepository(db=SessionLocal)
127
+
128
+ # 2. Загрузка ВСЕХ сущностей НАПРЯМУЮ из БД
129
+ logger.info("Загрузка всех сущностей из БД через сессию...")
130
+ all_db_models = db_session.query(EntityModel).all()
131
+ logger.info(f"Загружено {len(all_db_models)} записей EntityModel.")
132
+
133
+ if not all_db_models:
134
+ logger.error("Не удалось загрузить сущности из базы данных. Проверьте подключение и наличие данных.")
135
+ db_session.close()
136
+ return
137
+
138
+ # Конвертация в LinkerEntity с использованием маппинга из ChunkRepository
139
+ logger.info("Конвертация EntityModel в LinkerEntity...")
140
+ all_linker_entities = [chunk_repo._map_db_entity_to_linker_entity(model) for model in all_db_models]
141
+ logger.info(f"Сконвертировано в {len(all_linker_entities)} LinkerEntity объектов.")
142
+
143
+
144
+ # 3. Инициализация InMemoryEntityRepository
145
+ logger.info("Инициализация InMemoryEntityRepository...")
146
+ in_memory_repo = InMemoryEntityRepository(entities=all_linker_entities)
147
+ logger.info(f"InMemoryEntityRepository инициализирован с {len(in_memory_repo.entities)} сущностями.")
148
+
149
+ # 4. Получение ID искомых чанков НАПРЯМУЮ из БД
150
+ logger.info("Получение ID искомых чанков из БД через сессию...")
151
+ query = select(EntityModel.uuid).where(
152
+ and_(
153
+ EntityModel.in_search_text.isnot(None),
154
+ )
155
+ )
156
+ results = db_session.execute(query).scalars().all()
157
+ searchable_chunk_ids = [UUID(res) for res in results]
158
+ logger.info(f"Найдено {len(searchable_chunk_ids)} сущностей для поиска.")
159
+
160
+
161
+ if not searchable_chunk_ids:
162
+ logger.warning("В базе данных не найдено сущностей для поиска (с in_search_text). Тест невозможен.")
163
+ db_session.close()
164
+ return
165
+
166
+ # 5. Выборка случайных ID чанков
167
+ actual_sample_size = min(SAMPLE_SIZE, len(searchable_chunk_ids))
168
+ if actual_sample_size < len(searchable_chunk_ids):
169
+ logger.info(f"Выбираем {actual_sample_size} случайных ID сущностей для поиска из {len(searchable_chunk_ids)}...")
170
+ sampled_chunk_ids = random.sample(searchable_chunk_ids, actual_sample_size)
171
+ else:
172
+ logger.info(f"Используем все {len(searchable_chunk_ids)} найденные ID сущностей для поиска (т.к. их меньше или равно {SAMPLE_SIZE}).")
173
+ sampled_chunk_ids = searchable_chunk_ids
174
+
175
+
176
+ # 6. Инициализация InjectionBuilders
177
+ logger.info("Инициализация InjectionBuilder для ChunkRepository...")
178
+ # Передаем ИМЕННО ЭКЗЕМПЛЯР chunk_repo, который мы создали
179
+ builder_chunk_repo = InjectionBuilder(repository=chunk_repo)
180
+
181
+ logger.info("Инициализация InjectionBuilder для InMemoryEntityRepository...")
182
+ builder_in_memory = InjectionBuilder(repository=in_memory_repo)
183
+
184
+ # 7. Сборка текста для обоих репозиториев
185
+ logger.info(f"\n--- Сборка текс��а для ChunkRepository ({actual_sample_size} ID)... ---")
186
+ try:
187
+ # Передаем список UUID
188
+ text_chunk_repo = builder_chunk_repo.build(filtered_entities=sampled_chunk_ids)
189
+ logger.info(f"Сборка для ChunkRepository завершена. Общая длина: {len(text_chunk_repo)}")
190
+ # --- Добавляем вывод начала текста ---
191
+ print("\n--- Начало текста (ChunkRepository, первые 1000 символов): ---")
192
+ print(text_chunk_repo[:1000])
193
+ print("--- Конец начала текста (ChunkRepository) ---")
194
+ # -------------------------------------
195
+ except Exception as e:
196
+ logger.error(f"Ошибка при сборке с ChunkRepository: {e}", exc_info=True)
197
+ text_chunk_repo = f"ERROR_ChunkRepo: {e}"
198
+
199
+
200
+ logger.info(f"\n--- Сборка текста для InMemoryEntityRepository ({actual_sample_size} ID)... ---")
201
+ try:
202
+ # Передаем список UUID
203
+ text_in_memory = builder_in_memory.build(filtered_entities=sampled_chunk_ids)
204
+ logger.info(f"Сборка для InMemoryEntityRepository завершена. Общая длина: {len(text_in_memory)}")
205
+ # --- Добавляем вывод начала текста ---
206
+ print("\n--- Начало текста (InMemory, первые 1000 символов): ---")
207
+ print(text_in_memory[:1000])
208
+ print("--- Конец начала текста (InMemory) ---")
209
+ # -------------------------------------
210
+ except Exception as e:
211
+ logger.error(f"Ошибка при сборке с InMemoryEntityRepository: {e}", exc_info=True)
212
+ text_in_memory = f"ERROR_InMemory: {e}"
213
+
214
+
215
+ # 8. Парсинг результатов по блокам
216
+ logger.info("\n--- Парсинг результатов по источникам ---")
217
+ blocks_chunk_repo = parse_output_by_source(text_chunk_repo)
218
+ blocks_in_memory = parse_output_by_source(text_in_memory)
219
+ logger.info(f"ChunkRepo: Найдено {len(blocks_chunk_repo)} блоков источников.")
220
+ logger.info(f"InMemory: Найдено {len(blocks_in_memory)} блоков источников.")
221
+
222
+ # 9. Сравнение блоков
223
+ logger.info("\n--- Сравнение блоков по источникам ---")
224
+ chunk_repo_keys = set(blocks_chunk_repo.keys())
225
+ in_memory_keys = set(blocks_in_memory.keys())
226
+
227
+ all_keys = chunk_repo_keys | in_memory_keys
228
+ mismatched_blocks = []
229
+
230
+ if chunk_repo_keys != in_memory_keys:
231
+ logger.warning("Наборы источников НЕ СОВПАДАЮТ!")
232
+ only_in_chunk = chunk_repo_keys - in_memory_keys
233
+ only_in_memory = in_memory_keys - chunk_repo_keys
234
+ if only_in_chunk:
235
+ logger.warning(f" Источники только в ChunkRepo: {sorted(list(only_in_chunk))}")
236
+ if only_in_memory:
237
+ logger.warning(f" Источники только в InMemory: {sorted(list(only_in_memory))}")
238
+ else:
239
+ logger.info("Наборы источников совпадают.")
240
+
241
+ logger.info("\n--- Сравнение содержимого общих источников ---")
242
+ common_keys = chunk_repo_keys & in_memory_keys
243
+ if not common_keys:
244
+ logger.warning("Нет общих источников для сравнения содержимого.")
245
+ else:
246
+ all_common_blocks_match = True
247
+ table_marker_found_in_any_chunk_repo = False
248
+ table_marker_found_in_any_in_memory = False
249
+
250
+ for key in sorted(list(common_keys)):
251
+ content_chunk = blocks_chunk_repo.get(key, "") # Используем .get для безопасности
252
+ content_memory = blocks_in_memory.get(key, "") # Используем .get для безопасности
253
+
254
+ # Проверка наличия маркера таблиц
255
+ has_tables_chunk = "###" in content_chunk
256
+ has_tables_memory = "###" in content_memory
257
+ if has_tables_chunk:
258
+ table_marker_found_in_any_chunk_repo = True
259
+ if has_tables_memory:
260
+ table_marker_found_in_any_in_memory = True
261
+
262
+ # Логируем наличие таблиц для КАЖДОГО блока (можно закомментировать, если много)
263
+ # logger.info(f" Источник: '{key}' - Таблицы (###) в ChunkRepo: {has_tables_chunk}, в InMemory: {has_tables_memory}")
264
+
265
+ if content_chunk != content_memory:
266
+ all_common_blocks_match = False
267
+ mismatched_blocks.append(key)
268
+ logger.warning(f" НЕСОВПАДЕНИЕ для источника: '{key}' (Таблицы в ChunkRepo: {has_tables_chunk}, в InMemory: {has_tables_memory})")
269
+ # Можно добавить вывод diff для конкретного блока, если нужно
270
+ # import difflib
271
+ # block_diff = difflib.unified_diff(
272
+ # content_chunk.splitlines(keepends=True),
273
+ # content_memory.splitlines(keepends=True),
274
+ # fromfile=f'{key}_ChunkRepo',
275
+ # tofile=f'{key}_InMemory',
276
+ # lineterm='',
277
+ # )
278
+ # print("\nDiff для блока:")
279
+ # sys.stdout.writelines(list(block_diff)[:20]) # Показать начало diff блока
280
+ # if len(list(block_diff)) > 20: print("...")
281
+ # else:
282
+ # # Логируем совпадение только если таблицы есть хоть где-то, для краткости
283
+ # if has_tables_chunk or has_tables_memory:
284
+ # logger.info(f" Совпадение для источника: '{key}' (Таблицы в ChunkRepo: {has_tables_chunk}, в InMemory: {has_tables_memory})")
285
+
286
+ # Выводим общую информацию о наличии таблиц
287
+ logger.info("--- Итог проверки таблиц (###) в общих блоках ---")
288
+ logger.info(f"Маркер таблиц '###' найден хотя бы в одном блоке ChunkRepo: {table_marker_found_in_any_chunk_repo}")
289
+ logger.info(f"Маркер таблиц '###' найден хотя бы в одном блоке InMemory: {table_marker_found_in_any_in_memory}")
290
+ logger.info("-------------------------------------------------")
291
+
292
+ if all_common_blocks_match:
293
+ logger.info("Содержимое ВСЕХ общих источников СОВПАДАЕТ.")
294
+ else:
295
+ logger.warning(f"Найдено НЕСОВПАДЕНИЕ содержимого для {len(mismatched_blocks)} источников: {sorted(mismatched_blocks)}")
296
+
297
+ logger.info("\n--- Итоговый вердикт ---")
298
+ if chunk_repo_keys == in_memory_keys and not mismatched_blocks:
299
+ logger.info("ПОЛНОЕ СОВПАДЕНИЕ: Наборы источников и их содержимое идентичны.")
300
+ elif chunk_repo_keys == in_memory_keys and mismatched_blocks:
301
+ logger.warning("ЧАСТИЧНОЕ СОВПАДЕНИЕ: Наборы источников совпадают, но содержимое некоторых блоков различается.")
302
+ else:
303
+ logger.warning("НЕСОВПАДЕНИЕ: Наборы источников различаются (и, возможно, содержимое общих тоже).")
304
+
305
+
306
+ except ImportError as e:
307
+ # Ловим ошибки импорта, возникшие внутри функций (маловероятно после старта)
308
+ logger.error(f"Критическая ошибка импорта: {e}")
309
+ except Exception as e:
310
+ logger.error(f"Произошла общая ошибка: {e}", exc_info=True)
311
+ finally:
312
+ if 'db_session' in locals() and db_session:
313
+ db_session.close()
314
+ logger.info("Сессия базы данных закрыта.")
315
+
316
+
317
+ # --- Запуск ---
318
+ if __name__ == "__main__":
319
+ # Используем Path для более надежного определения пути
320
+ db_path = Path(DATABASE_URL.replace("sqlite:///", ""))
321
+ if not db_path.exists():
322
+ print(f"!!! ОШИБКА: Файл базы данных НЕ НАЙДЕН по пути: {db_path.resolve()} !!!")
323
+ print(f"!!! Проверьте значение DATABASE_URL в скрипте. !!!")
324
+ elif "путь/к/твоей" in DATABASE_URL: # Доп. проверка на placeholder
325
+ print("!!! ПОЖАЛУЙСТА, УКАЖИТЕ ПРАВИЛЬНЫЙ ПУТЬ К БАЗЕ ДАННЫХ В ПЕРЕМЕННОЙ DATABASE_URL !!!")
326
+ else:
327
+ compare_repositories()
scripts/testing/aggregate_results.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Скрипт для агрегации и анализа результатов множества запусков pipeline.py.
5
+
6
+ Читает все CSV-файлы из директории промежуточных результатов,
7
+ объединяет их и вычисляет агрегированные метрики:
8
+ - Weighted (усредненные по всем вопросам, взвешенные по количеству пунктов/чанков/документов)
9
+ - Macro (усредненные по вопросам - сначала считаем метрику для каждого вопроса, потом усредняем)
10
+ - Micro (считаем общие TP, FP, FN по всем вопросам, потом вычисляем метрики)
11
+
12
+ Результаты сохраняются в один Excel-файл с несколькими листами.
13
+ """
14
+
15
+ import argparse
16
+ import glob
17
+ # Импорт для обработки JSON строк
18
+ import os
19
+
20
+ import pandas as pd
21
+ from openpyxl import Workbook
22
+ from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
23
+ from openpyxl.utils import get_column_letter
24
+ from openpyxl.utils.dataframe import dataframe_to_rows
25
+ # Прогресс-бар
26
+ from tqdm import tqdm
27
+
28
+ # --- Настройки ---
29
+ DEFAULT_INTERMEDIATE_DIR = "data/intermediate" # Откуда читать CSV
30
+ DEFAULT_OUTPUT_DIR = "data/output" # Куда сохранять итоговый Excel
31
+ DEFAULT_OUTPUT_FILENAME = "aggregated_results.xlsx"
32
+
33
+ # --- Маппинг названий столбцов на русский язык ---
34
+ COLUMN_NAME_MAPPING = {
35
+ # Параметры запуска из pipeline.py
36
+ 'run_id': 'ID Запуска',
37
+ 'model_name': 'Модель',
38
+ 'chunking_strategy': 'Стратегия Чанкинга',
39
+ 'strategy_params': 'Параметры Стратегии',
40
+ 'process_tables': 'Обраб. Таблиц',
41
+ 'top_n': 'Top N',
42
+ 'use_injection': 'Сборка Контекста',
43
+ 'use_qe': 'Query Expansion',
44
+ 'neighbors_included': 'Вкл. Соседей',
45
+ 'similarity_threshold': 'Порог Схожести',
46
+
47
+ # Идентификаторы из датасета (для детальных результатов)
48
+ 'question_id': 'ID Вопроса',
49
+ 'question_text': 'Текст Вопроса',
50
+
51
+ # Детальные метрики из pipeline.py
52
+ 'chunk_text_precision': 'Точность (Чанк-Текст)',
53
+ 'chunk_text_recall': 'Полнота (Чанк-Текст)',
54
+ 'chunk_text_f1': 'F1 (Чанк-Текст)',
55
+ 'found_puncts': 'Найдено Пунктов',
56
+ 'total_puncts': 'Всего Пунктов',
57
+ 'relevant_chunks': 'Релевантных Чанков',
58
+ 'total_chunks_in_top_n': 'Всего Чанков в Топ-N',
59
+ 'assembly_punct_recall': 'Полнота (Сборка-Пункт)',
60
+ 'assembled_context_preview': 'Предпросмотр Сборки',
61
+ # 'top_chunk_ids': 'Индексы Топ-Чанков', # Списки, могут плохо отображаться
62
+ # 'top_chunk_similarities': 'Схожести Топ-Чанков', # Списки
63
+
64
+ # Агрегированные метрики (добавляются в calculate_aggregated_metrics)
65
+ 'weighted_chunk_text_precision': 'Weighted Точность (Чанк-Текст)',
66
+ 'weighted_chunk_text_recall': 'Weighted Полнота (Чанк-Текст)',
67
+ 'weighted_chunk_text_f1': 'Weighted F1 (Чанк-Текст)',
68
+ 'weighted_assembly_punct_recall': 'Weighted Полнота (Сборка-Пункт)',
69
+
70
+ 'macro_chunk_text_precision': 'Macro Точность (Чанк-Текст)',
71
+ 'macro_chunk_text_recall': 'Macro Полнота (Чанк-Текст)',
72
+ 'macro_chunk_text_f1': 'Macro F1 (Чанк-Текст)',
73
+ 'macro_assembly_punct_recall': 'Macro Полнота (Сборка-Пункт)',
74
+
75
+ 'micro_text_precision': 'Micro Точность (Текст)',
76
+ 'micro_text_recall': 'Micro Полнота (Текст)',
77
+ 'micro_text_f1': 'Micro F1 (Текст)',
78
+ }
79
+
80
+ def parse_args():
81
+ """Парсит аргументы командной строки."""
82
+ parser = argparse.ArgumentParser(description="Агрегация результатов оценочных пайплайнов")
83
+
84
+ parser.add_argument("--intermediate-dir", type=str, default=DEFAULT_INTERMEDIATE_DIR,
85
+ help=f"Директория с промежуточными CSV результатами (по умолчанию: {DEFAULT_INTERMEDIATE_DIR})")
86
+ parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
87
+ help=f"Директория для сохранения итогового Excel файла (по умолчанию: {DEFAULT_OUTPUT_DIR})")
88
+ parser.add_argument("--output-filename", type=str, default=DEFAULT_OUTPUT_FILENAME,
89
+ help=f"Имя выходного Excel файла (по умолчанию: {DEFAULT_OUTPUT_FILENAME})")
90
+ parser.add_argument("--latest-batch-only", action="store_true",
91
+ help="Агрегировать результаты только для последнего batch_id")
92
+
93
+ return parser.parse_args()
94
+
95
+ def load_intermediate_results(intermediate_dir: str) -> pd.DataFrame:
96
+ """Загружает все CSV файлы из указанной директории."""
97
+ print(f"Загрузка промежуточных результатов из: {intermediate_dir}")
98
+ csv_files = glob.glob(os.path.join(intermediate_dir, "results_*.csv"))
99
+
100
+ if not csv_files:
101
+ print(f"ВНИМАНИЕ: В директории {intermediate_dir} не найдено файлов 'results_*.csv'.")
102
+ return pd.DataFrame()
103
+
104
+ all_data = []
105
+ for f in csv_files:
106
+ try:
107
+ df = pd.read_csv(f)
108
+ all_data.append(df)
109
+ print(f" Загружен файл: {os.path.basename(f)} ({len(df)} строк)")
110
+ except Exception as e:
111
+ print(f"Ошибка при чтении файла {f}: {e}")
112
+
113
+ if not all_data:
114
+ print("Не удалось загрузить ни одного файла с результатами.")
115
+ return pd.DataFrame()
116
+
117
+ combined_df = pd.concat(all_data, ignore_index=True)
118
+ print(f"Всего загружено строк: {len(combined_df)}")
119
+ print(f"Найденные колонки: {combined_df.columns.tolist()}")
120
+
121
+ # Преобразуем типы данных для надежности
122
+ numeric_cols = [
123
+ 'chunk_text_precision', 'chunk_text_recall', 'chunk_text_f1',
124
+ 'found_puncts', 'total_puncts', 'relevant_chunks',
125
+ 'total_chunks_in_top_n',
126
+ 'assembly_punct_recall',
127
+ 'similarity_threshold', 'top_n',
128
+ ]
129
+ for col in numeric_cols:
130
+ if col in combined_df.columns:
131
+ combined_df[col] = pd.to_numeric(combined_df[col], errors='coerce')
132
+
133
+ boolean_cols = [
134
+ 'use_injection',
135
+ 'process_tables',
136
+ 'use_qe',
137
+ 'neighbors_included'
138
+ ]
139
+ for col in boolean_cols:
140
+ if col in combined_df.columns:
141
+ # Пытаемся конвертировать в bool, обрабатывая строки 'True'/'False'
142
+ if combined_df[col].dtype == 'object':
143
+ combined_df[col] = combined_df[col].astype(str).str.lower().map({'true': True, 'false': False}).fillna(False)
144
+ combined_df[col] = combined_df[col].astype(bool)
145
+
146
+ # Заполним пропуски в числовых колонках нулями (например, если метрики не посчитались)
147
+ combined_df[numeric_cols] = combined_df[numeric_cols].fillna(0)
148
+
149
+ # --- Обработка batch_id ---
150
+ if 'batch_id' in combined_df.columns:
151
+ # Приводим к строке и заполняем NaN
152
+ combined_df['batch_id'] = combined_df['batch_id'].astype(str).fillna('unknown_batch')
153
+ else:
154
+ # Если колонки нет, создаем ее
155
+ print("Предупреждение: Колонка 'batch_id' отсутствует в загруженных данных. Добавлена со значением 'unknown_batch'.")
156
+ combined_df['batch_id'] = 'unknown_batch'
157
+ # --------------------------
158
+
159
+ # Переименовываем столбцы в русские названия ДО возврата
160
+ # Отбираем только те колонки, для которых есть перевод
161
+ columns_to_rename = {eng: rus for eng, rus in COLUMN_NAME_MAPPING.items() if eng in combined_df.columns}
162
+ combined_df = combined_df.rename(columns=columns_to_rename)
163
+ print(f"Столбцы переименованы. Новые колонки: {combined_df.columns.tolist()}")
164
+
165
+ return combined_df
166
+
167
+ def calculate_aggregated_metrics(df: pd.DataFrame) -> pd.DataFrame:
168
+ """
169
+ Вычисляет агрегированные метрики (Weighted, Macro, Micro)
170
+ для каждой уникальной комбинации параметров запуска.
171
+
172
+ Ожидает DataFrame с русскими названиями колонок.
173
+ """
174
+ if df.empty:
175
+ return pd.DataFrame()
176
+
177
+ # Определяем параметры, по которым будем группировать (ИСПОЛЬЗУЕМ РУССКИЕ НАЗВАНИЯ)
178
+ grouping_params_rus = [
179
+ COLUMN_NAME_MAPPING.get('model_name', 'Модель'),
180
+ COLUMN_NAME_MAPPING.get('chunking_strategy', 'Стратегия Чанкинга'),
181
+ COLUMN_NAME_MAPPING.get('strategy_params', 'Параметры Стратегии'),
182
+ COLUMN_NAME_MAPPING.get('process_tables', 'Обраб. Таблиц'),
183
+ COLUMN_NAME_MAPPING.get('top_n', 'Top N'),
184
+ COLUMN_NAME_MAPPING.get('use_injection', 'Сборка Контекста'),
185
+ COLUMN_NAME_MAPPING.get('use_qe', 'Query Expansion'),
186
+ COLUMN_NAME_MAPPING.get('neighbors_included', 'Вкл. Соседей'),
187
+ COLUMN_NAME_MAPPING.get('similarity_threshold', 'Порог Схожести')
188
+ ]
189
+
190
+ # Проверяем наличие всех колонок для группировки (с русскими именами)
191
+ missing_cols = [col for col in grouping_params_rus if col not in df.columns]
192
+ if missing_cols:
193
+ print(f"Ошибка: Отсутствуют необходимые колонки для группировки (русские): {missing_cols}")
194
+ # Удаляем отсутствующие колонки из списка группировки
195
+ grouping_params_rus = [col for col in grouping_params_rus if col not in missing_cols]
196
+ if not grouping_params_rus:
197
+ print("Невозможно выполнить группировку.")
198
+ return pd.DataFrame()
199
+
200
+ print(f"Группировка по параметрам (русские): {grouping_params_rus}")
201
+ # Используем grouping_params_rus для группировки
202
+ grouped = df.groupby(grouping_params_rus)
203
+
204
+ aggregated_results = []
205
+
206
+ # Итерируемся по каждой группе (комбинации параметров)
207
+ for params, group_df in tqdm(grouped, desc="Расчет агрегированных метрик"):
208
+ # Начинаем со словаря параметров (уже с русскими именами)
209
+ agg_result = dict(zip(grouping_params_rus, params))
210
+
211
+ # --- Метрики для усреднения/взвешивания (РУССКИЕ НАЗВАНИЯ) ---
212
+ chunk_prec_col = COLUMN_NAME_MAPPING.get('chunk_text_precision', 'Точность (Чанк-Текст)')
213
+ chunk_rec_col = COLUMN_NAME_MAPPING.get('chunk_text_recall', 'Полнота (Чанк-Текст)')
214
+ chunk_f1_col = COLUMN_NAME_MAPPING.get('chunk_text_f1', 'F1 (Чанк-Текст)')
215
+ assembly_rec_col = COLUMN_NAME_MAPPING.get('assembly_punct_recall', 'Полнота (Сборка-Пункт)')
216
+ total_chunks_col = COLUMN_NAME_MAPPING.get('total_chunks_in_top_n', 'Всего Чанков в Топ-N')
217
+ total_puncts_col = COLUMN_NAME_MAPPING.get('total_puncts', 'Всего Пунктов')
218
+ found_puncts_col = COLUMN_NAME_MAPPING.get('found_puncts', 'Найдено Пунктов') # Для micro
219
+ relevant_chunks_col = COLUMN_NAME_MAPPING.get('relevant_chunks', 'Релевантных Чанков') # Для micro
220
+
221
+ # Колонки, которые должны существовать для расчетов
222
+ required_metric_cols = [chunk_prec_col, chunk_rec_col, chunk_f1_col, assembly_rec_col]
223
+ required_count_cols = [total_chunks_col, total_puncts_col, found_puncts_col, relevant_chunks_col]
224
+ existing_metric_cols = [m for m in required_metric_cols if m in group_df.columns]
225
+ existing_count_cols = [c for c in required_count_cols if c in group_df.columns]
226
+
227
+ # --- Macro метрики (Простое усреднение метрик по вопросам) ---
228
+ if existing_metric_cols:
229
+ macro_metrics = group_df[existing_metric_cols].mean().rename(
230
+ # Генерируем имя 'Macro Имя Метрики'
231
+ lambda x: COLUMN_NAME_MAPPING.get(f"macro_{{key}}".format(key=next((k for k, v in COLUMN_NAME_MAPPING.items() if v == x), None)), f"Macro {x}")
232
+ ).to_dict()
233
+ agg_result.update(macro_metrics)
234
+ else:
235
+ print(f"Предупреждение: Пропуск Macro метрик для группы {params}, нет колонок метрик.")
236
+
237
+ # --- Weighted метрики (Взвешенное усреднение) ---
238
+ weighted_chunk_precision = 0.0
239
+ weighted_chunk_recall = 0.0
240
+ weighted_assembly_recall = 0.0
241
+ weighted_chunk_f1 = 0.0
242
+
243
+ # Проверяем наличие необходимых колонок для взвешенного расчета
244
+ can_calculate_weighted = True
245
+ if chunk_prec_col not in existing_metric_cols or total_chunks_col not in existing_count_cols:
246
+ print(f"Предупреждение: Пропуск Weighted Точность (Чанк-Текст) для группы {params}, отсутствуют {chunk_prec_col} или {total_chunks_col}.")
247
+ can_calculate_weighted = False
248
+ if chunk_rec_col not in existing_metric_cols or total_puncts_col not in existing_count_cols:
249
+ print(f"Предупреждение: Пропуск Weighted Полнота (Чанк-Текст) для группы {params}, отсутствуют {chunk_rec_col} или {total_puncts_col}.")
250
+ can_calculate_weighted = False
251
+ if assembly_rec_col not in existing_metric_cols or total_puncts_col not in existing_count_cols:
252
+ print(f"Пред��преждение: Пропуск Weighted Полнота (Сборка-Пункт) для группы {params}, отсутствуют {assembly_rec_col} или {total_puncts_col}.")
253
+ # Не сбрасываем can_calculate_weighted, т.к. другие weighted могут посчитаться
254
+
255
+ if can_calculate_weighted:
256
+ total_chunks_sum = group_df[total_chunks_col].sum()
257
+ total_puncts_sum = group_df[total_puncts_col].sum()
258
+
259
+ # Weighted Precision (Chunk-Text)
260
+ if total_chunks_sum > 0 and chunk_prec_col in existing_metric_cols:
261
+ weighted_chunk_precision = (group_df[chunk_prec_col] * group_df[total_chunks_col]).sum() / total_chunks_sum
262
+
263
+ # Weighted Recall (Chunk-Text)
264
+ if total_puncts_sum > 0 and chunk_rec_col in existing_metric_cols:
265
+ weighted_chunk_recall = (group_df[chunk_rec_col] * group_df[total_puncts_col]).sum() / total_puncts_sum
266
+
267
+ # Weighted Recall (Assembly-Punct)
268
+ if total_puncts_sum > 0 and assembly_rec_col in existing_metric_cols:
269
+ weighted_assembly_recall = (group_df[assembly_rec_col] * group_df[total_puncts_col]).sum() / total_puncts_sum
270
+
271
+ # Weighted F1 (Chunk-Text) - на основе weighted precision и recall
272
+ if weighted_chunk_precision + weighted_chunk_recall > 0:
273
+ weighted_chunk_f1 = (2 * weighted_chunk_precision * weighted_chunk_recall) / (weighted_chunk_precision + weighted_chunk_recall)
274
+
275
+ # Добавляем рассчитанные Weighted метрики в результат
276
+ agg_result[COLUMN_NAME_MAPPING.get('weighted_chunk_text_precision', 'Weighted Точность (Чанк-Текст)')] = weighted_chunk_precision
277
+ agg_result[COLUMN_NAME_MAPPING.get('weighted_chunk_text_recall', 'Weighted Полнота (Чанк-Текст)')] = weighted_chunk_recall
278
+ agg_result[COLUMN_NAME_MAPPING.get('weighted_chunk_text_f1', 'Weighted F1 (Чанк-Текст)')] = weighted_chunk_f1
279
+ agg_result[COLUMN_NAME_MAPPING.get('weighted_assembly_punct_recall', 'Weighted Полнота (Сборка-Пункт)')] = weighted_assembly_recall
280
+
281
+
282
+ # --- Micro метрики (На основе общих TP, FP, FN, ИСПОЛЬЗУЕМ РУССКИЕ НАЗВАНИЯ) ---
283
+ # Колонки уже определены выше
284
+ if not all(col in group_df.columns for col in [found_puncts_col, total_puncts_col, relevant_chunks_col, total_chunks_col]):
285
+ print(f"Предупреждение: Пропуск расчета micro-метрик для группы {params}, т.к. отсутствуют необходимые колонки.")
286
+ agg_result[COLUMN_NAME_MAPPING.get('micro_text_precision', 'Micro Точность (Текст)')] = 0.0
287
+ agg_result[COLUMN_NAME_MAPPING.get('micro_text_recall', 'Micro Полнота (Текст)')] = 0.0
288
+ agg_result[COLUMN_NAME_MAPPING.get('micro_text_f1', 'Micro F1 (Текст)')] = 0.0
289
+
290
+ # Добавляем результат группы в общий список
291
+ aggregated_results.append(agg_result)
292
+
293
+ # Создаем итоговый DataFrame (уже с русскими именами)
294
+ final_df = pd.DataFrame(aggregated_results)
295
+
296
+ print(f"Рассчитаны агрегированные метрики для {len(final_df)} комбинаций параметров.")
297
+ # Возвращаем DataFrame с русскими названиями колонок
298
+ return final_df
299
+
300
+ # --- Функции для форматирования Excel (адаптированы из combine_results.py) ---
301
+ def apply_excel_formatting(workbook: Workbook):
302
+ """Применяет форматирование ко всем листам книги Excel."""
303
+ header_font = Font(bold=True)
304
+ header_fill = PatternFill(start_color="D9D9D9", end_color="D9D9D9", fill_type="solid")
305
+ center_alignment = Alignment(horizontal='center', vertical='center')
306
+ wrap_alignment = Alignment(horizontal='center', vertical='center', wrap_text=True)
307
+ thin_border = Border(
308
+ left=Side(style='thin'),
309
+ right=Side(style='thin'),
310
+ top=Side(style='thin'),
311
+ bottom=Side(style='thin')
312
+ )
313
+ thick_top_border = Border(top=Side(style='thick'))
314
+
315
+ for sheet_name in workbook.sheetnames:
316
+ sheet = workbook[sheet_name]
317
+ if sheet.max_row <= 1: # Пропускаем пустые листы
318
+ continue
319
+
320
+ # Форматирование заголовков
321
+ for cell in sheet[1]:
322
+ cell.font = header_font
323
+ cell.fill = header_fill
324
+ cell.alignment = wrap_alignment
325
+ cell.border = thin_border
326
+
327
+ # Автоподбор ширины и форматирование ячеек
328
+ for col_idx, column_cells in enumerate(sheet.columns, 1):
329
+ max_length = 0
330
+ column_letter = get_column_letter(col_idx)
331
+ is_numeric_metric_col = False
332
+ header_value = sheet.cell(row=1, column=col_idx).value
333
+
334
+ # Проверяем, является ли колонка числовой метрикой
335
+ if isinstance(header_value, str) and any(m in header_value for m in ['precision', 'recall', 'f1', 'relevance']):
336
+ is_numeric_metric_col = True
337
+
338
+ for i, cell in enumerate(column_cells):
339
+ # Применяем границы ко всем ячейкам
340
+ cell.border = thin_border
341
+ # Центрируем все, кроме заголовка
342
+ if i > 0:
343
+ cell.alignment = center_alignment
344
+
345
+ # Формат для числовых метрик
346
+ if is_numeric_metric_col and i > 0 and isinstance(cell.value, (int, float)):
347
+ cell.number_format = '0.0000'
348
+
349
+ # Расчет ширины
350
+ try:
351
+ cell_len = len(str(cell.value))
352
+ if cell_len > max_length:
353
+ max_length = cell_len
354
+ except:
355
+ pass
356
+ adjusted_width = (max_length + 2) * 1.1
357
+ sheet.column_dimensions[column_letter].width = min(adjusted_width, 60) # Ограничиваем макс ширину
358
+
359
+ # Автофильтр
360
+ sheet.auto_filter.ref = sheet.dimensions
361
+
362
+ # Группировка строк (опционально, можно добавить логику из combine_results, если нужна)
363
+ # ... (здесь можно вставить apply_group_formatting, если требуется) ...
364
+
365
+ print("Форматирование Excel завершено.")
366
+
367
+
368
+ def save_to_excel(data_dict: dict[str, pd.DataFrame], output_path: str):
369
+ """Сохраняет несколько DataFrame в один Excel файл с форматированием."""
370
+ print(f"Сохранение результатов в Excel: {output_path}")
371
+ try:
372
+ workbook = Workbook()
373
+ workbook.remove(workbook.active) # Удаляем лист по умолчанию
374
+
375
+ for sheet_name, df in data_dict.items():
376
+ if df is not None and not df.empty:
377
+ sheet = workbook.create_sheet(title=sheet_name)
378
+ for r in dataframe_to_rows(df, index=False, header=True):
379
+ # Проверяем и заменяем недопустимые символы в ячейках
380
+ cleaned_row = []
381
+ for cell_value in r:
382
+ if isinstance(cell_value, str):
383
+ # Удаляем управляющие символы, кроме стандартных пробельных
384
+ cleaned_value = ''.join(c for c in cell_value if c.isprintable() or c in ' \t\n\r')
385
+ cleaned_row.append(cleaned_value)
386
+ else:
387
+ cleaned_row.append(cell_value)
388
+ sheet.append(cleaned_row)
389
+ print(f" Лист '{sheet_name}' добавлен ({len(df)} строк)")
390
+ else:
391
+ print(f" Лист '{sheet_name}' пропущен (нет данных)")
392
+
393
+ # Применяем форматирование ко всей книге
394
+ if workbook.sheetnames: # Проверяем, что есть хотя бы один лист
395
+ apply_excel_formatting(workbook)
396
+ workbook.save(output_path)
397
+ print("Excel файл успешно сохранен.")
398
+ else:
399
+ print("Нет данных для сохранения в Excel.")
400
+
401
+ except Exception as e:
402
+ print(f"Ошибка при сохранении Excel файла: {e}")
403
+
404
+
405
+ # --- Основная функция ---
406
+ def main():
407
+ """Основная функция скрипта."""
408
+ args = parse_args()
409
+
410
+ # 1. Загрузка данных
411
+ combined_df_eng = load_intermediate_results(args.intermediate_dir)
412
+
413
+ if combined_df_eng.empty:
414
+ print("Нет данных для агрегации. Завершение.")
415
+ return
416
+
417
+ # --- Фильтрация по последнему batch_id (если флаг установлен) ---
418
+ target_df = combined_df_eng # По умолчанию используем все данные
419
+ if args.latest_batch_only:
420
+ print("Фильтрация по последнему batch_id...")
421
+ if 'batch_id' not in combined_df_eng.columns:
422
+ print("Предупреждение: Колонка 'batch_id' не найдена. Агрегация будет выполнена по всем данным.")
423
+ else:
424
+ # Находим последний batch_id (сортируем строки по batch_id)
425
+ # Сначала отфильтруем 'unknown_batch'
426
+ valid_batches = combined_df_eng[combined_df_eng['batch_id'] != 'unknown_batch']['batch_id'].unique()
427
+ if len(valid_batches) > 0:
428
+ # Сортируем уникальные валидные ID и берем последний
429
+ latest_batch_id = sorted(valid_batches)[-1]
430
+ print(f"Используется последний валидный batch_id: {latest_batch_id}")
431
+ target_df = combined_df_eng[combined_df_eng['batch_id'] == latest_batch_id].copy()
432
+ if target_df.empty:
433
+ # Это не должно произойти, если latest_batch_id валидный, но на всякий случай
434
+ print(f"Предупреждение: Не найдено данных для batch_id {latest_batch_id}. Агрегация будет выполнена по всем данным.")
435
+ target_df = combined_df_eng
436
+ else:
437
+ print(f"Оставлено строк после фильтрации: {len(target_df)}")
438
+ else:
439
+ print("Предупреждение: Не найдено валидных batch_id для фильтрации. Агрегация будет выполнена по всем данным.")
440
+ # target_df уже равен combined_df_eng, так что ничего не делаем
441
+ # latest_batch_id = combined_df_eng['batch_id'].astype(str).sort_values().iloc[-1]
442
+ # print(f"Используется последний batch_id: {latest_batch_id}")
443
+ # target_df = combined_df_eng[combined_df_eng['batch_id'] == latest_batch_id].copy()
444
+ # if target_df.empty:
445
+ # print(f"Предупреждение: Нет данных для batch_id {latest_batch_id}. Агрегация будет выполнена по всем данным.")
446
+ # target_df = combined_df_eng # Возвращаемся ко всем данным, если фильтр дал пустоту
447
+ # else:
448
+ # print(f"Оставлено строк после фильтрации: {len(target_df)}")
449
+
450
+ # --- Заполнение NaN и переименование ПОСЛЕ возможной фильтрации ---
451
+ # Определяем числовые колонки еще раз (используя английские названия из маппинга)
452
+ numeric_cols_eng = [eng for eng, rus in COLUMN_NAME_MAPPING.items() \
453
+ if 'recall' in eng or 'precision' in eng or 'f1' in eng or 'puncts' in eng \
454
+ or 'chunks' in eng or 'threshold' in eng or 'top_n' in eng]
455
+ numeric_cols_in_df = [col for col in numeric_cols_eng if col in target_df.columns]
456
+ target_df[numeric_cols_in_df] = target_df[numeric_cols_in_df].fillna(0)
457
+
458
+ # Переименовываем
459
+ columns_to_rename_detailed = {eng: rus for eng, rus in COLUMN_NAME_MAPPING.items() if eng in target_df.columns}
460
+ target_df_rus = target_df.rename(columns=columns_to_rename_detailed)
461
+
462
+ # 2. Расчет агрегированных метрик
463
+ # Передаем DataFrame с русскими названиями колонок, calculate_aggregated_metrics теперь их ожидает
464
+ aggregated_df_rus = calculate_aggregated_metrics(target_df_rus)
465
+ # Переименовываем столбцы агрегированного DF уже внутри calculate_aggregated_metrics
466
+ # aggregated_df_rus = pd.DataFrame() # Инициализируем на случай, если aggregated_df_eng пуст
467
+ # if not aggregated_df_eng.empty:
468
+ # columns_to_rename_agg = {eng: rus for eng, rus in COLUMN_NAME_MAPPING.items() if eng in aggregated_df_eng.columns}
469
+ # aggregated_df_rus = aggregated_df_eng.rename(columns=columns_to_rename_agg)
470
+
471
+ # 3. Подготовка данных для сохранения (с русскими названиями)
472
+ data_to_save = {
473
+ "Детальные результаты": target_df_rus, # Используем переименованный DF
474
+ "Агрегированные метрики": aggregated_df_rus, # Используем переименованный DF
475
+ }
476
+
477
+ # 4. Сохранение в Excel
478
+ os.makedirs(args.output_dir, exist_ok=True)
479
+ output_file_path = os.path.join(args.output_dir, args.output_filename)
480
+ save_to_excel(data_to_save, output_file_path)
481
+
482
+ if __name__ == "__main__":
483
+ main()
scripts/testing/pipeline.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Основной пайплайн для оценки качества RAG системы.
5
+
6
+ Этот скрипт выполняет один прогон оценки для заданных параметров:
7
+ - Чтение документов и датасетов вопросов/ответов.
8
+ - Чанкинг документов.
9
+ - Векторизация вопросов и чанков.
10
+ - Оценка релевантности чанков к пунктам из датасета (Chunk-level).
11
+ - Сборка контекста из релевантных чанков (Assembly-level).
12
+ - Оценка релевантности собранного контекста к эталонным ответам.
13
+ - Сохранение детальных метрик для данного прогона.
14
+ """
15
+
16
+ import argparse
17
+ # Add necessary imports for caching
18
+ import hashlib
19
+ import json
20
+ import os
21
+ import pickle
22
+ import sys
23
+ from pathlib import Path
24
+ from typing import Any
25
+ from uuid import UUID, uuid4
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+ import torch
30
+ from fuzzywuzzy import fuzz
31
+ from sklearn.metrics.pairwise import cosine_similarity
32
+ from tqdm import tqdm
33
+ from transformers import AutoModel, AutoTokenizer
34
+
35
+ # --- Константы (могут быть переопределены аргументами) ---
36
+ DEFAULT_DATA_FOLDER = "data/input/docs"
37
+ DEFAULT_SEARCH_DATASET_PATH = "data/input/search_dataset_texts.xlsx"
38
+ DEFAULT_QA_DATASET_PATH = "data/input/question_answering.xlsx"
39
+ DEFAULT_MODEL_NAME = "intfloat/e5-base"
40
+ DEFAULT_BATCH_SIZE = 8
41
+ DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
42
+ DEFAULT_SIMILARITY_THRESHOLD = 0.7
43
+ DEFAULT_OUTPUT_DIR = "data/intermediate" # Директория для промежуточных результатов
44
+ DEFAULT_WORDS_PER_CHUNK = 50
45
+ DEFAULT_OVERLAP_WORDS = 25
46
+ DEFAULT_TOP_N = 20 # Значение N по умолчанию для топа чанков
47
+ # Add chunking strategy constant
48
+ DEFAULT_CHUNKING_STRATEGY = "fixed_size"
49
+ # Add cache directory constant
50
+ DEFAULT_CACHE_DIR = "data/cache"
51
+
52
+ # --- Добавление путей к библиотекам ---
53
+ # Добавляем путь к корневой папке проекта, чтобы можно было импортировать ntr_...
54
+ SCRIPT_DIR = Path(__file__).parent.resolve()
55
+ PROJECT_ROOT = SCRIPT_DIR.parent.parent # Перейти на два уровня вверх (scripts/testing -> scripts -> project root)
56
+ LIB_EXTRACTOR_PATH = PROJECT_ROOT / "lib" / "extractor"
57
+ sys.path.insert(0, str(LIB_EXTRACTOR_PATH))
58
+ # Добавляем путь к папке с ntr_text_fragmentation
59
+ sys.path.insert(0, str(LIB_EXTRACTOR_PATH / "ntr_text_fragmentation"))
60
+
61
+ # --- Импорты из локальных модулей ---
62
+ try:
63
+ from ntr_fileparser import ParsedDocument, UniversalParser
64
+ from ntr_text_fragmentation import Destructurer
65
+ from ntr_text_fragmentation.core.entity_repository import \
66
+ InMemoryEntityRepository
67
+ from ntr_text_fragmentation.core.injection_builder import InjectionBuilder
68
+ from ntr_text_fragmentation.models.chunk import Chunk
69
+ from ntr_text_fragmentation.models.document import DocumentAsEntity
70
+ from ntr_text_fragmentation.models.linker_entity import LinkerEntity
71
+ except ImportError as e:
72
+ print(f"Ошибка импорта локальных модулей: {e}")
73
+ print(f"Проверьте пути: Project Root: {PROJECT_ROOT}, Extractor Lib: {LIB_EXTRACTOR_PATH}")
74
+ sys.exit(1)
75
+
76
+ # --- Вспомогательные функции (аналогичные evaluate_chunking.py) ---
77
+
78
+ def _average_pool(
79
+ last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
80
+ ) -> torch.Tensor:
81
+ """
82
+ Расчёт усредненного эмбеддинга по всем токенам.
83
+ (Копипаста из evaluate_chunking.py)
84
+ """
85
+ last_hidden = last_hidden_states.masked_fill(
86
+ ~attention_mask[..., None].bool(), 0.0
87
+ )
88
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
89
+
90
+ def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
91
+ """
92
+ Рассчитывает степень перекрытия между чанком и пунктом.
93
+ (Копипаста из evaluate_chunking.py)
94
+ """
95
+ if not chunk_text or not punct_text:
96
+ return 0.0
97
+ # Используем partial_ratio для лучшей обработки подстрок
98
+ return fuzz.partial_ratio(chunk_text, punct_text) / 100.0
99
+
100
+ # --- Функции загрузки и обработки данных ---
101
+
102
+ def parse_args():
103
+ """Парсит аргументы командной строки."""
104
+ parser = argparse.ArgumentParser(description="Пайплайн оценки RAG системы")
105
+
106
+ # Пути к данным
107
+ parser.add_argument("--data-folder", type=str, default=DEFAULT_DATA_FOLDER,
108
+ help=f"Папка с документами (по умолчанию: {DEFAULT_DATA_FOLDER})")
109
+ parser.add_argument("--search-dataset-path", type=str, default=DEFAULT_SEARCH_DATASET_PATH,
110
+ help=f"Путь к датасету для поиска (по умолчанию: {DEFAULT_SEARCH_DATASET_PATH})")
111
+ parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
112
+ help=f"Папка для сохранения промежуточных результатов (по умолчанию: {DEFAULT_OUTPUT_DIR})")
113
+ parser.add_argument("--run-id", type=str, default=f"run_{uuid4()}",
114
+ help="Уникальный идентификатор запуска (по умолчанию: генерируется)")
115
+
116
+ # Параметры модели и векторизации
117
+ parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME,
118
+ help=f"Название модели для векторизации (по умолчанию: {DEFAULT_MODEL_NAME})")
119
+ parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
120
+ help=f"Размер батча для векторизации (по умолчанию: {DEFAULT_BATCH_SIZE})")
121
+ parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, # type: ignore
122
+ help=f"Устройство для вычислений (по умолчанию: {DEFAULT_DEVICE})")
123
+ parser.add_argument("--use-sentence-transformers", action="store_true",
124
+ help="Использовать библиотеку sentence_transformers")
125
+
126
+ # Параметры чанкинга
127
+ parser.add_argument("--chunking-strategy", type=str, default=DEFAULT_CHUNKING_STRATEGY,
128
+ choices=list(Destructurer.STRATEGIES.keys()), # Use keys from Destructurer
129
+ help=f"Стратегия чанкинга (по умолчанию: {DEFAULT_CHUNKING_STRATEGY})")
130
+ parser.add_argument("--strategy-params", type=str, default='{}', # Default to empty JSON object
131
+ help="Параметры для стратегии чанкинга в формате JSON строки (например, '{\"words_per_chunk\": 50}')")
132
+ parser.add_argument("--no-process-tables", action="store_false", dest="process_tables",
133
+ help="Отключить обработку таблиц при чанкинге")
134
+ parser.set_defaults(process_tables=True) # Default is to process tables
135
+
136
+ # Параметры оценки
137
+ parser.add_argument("--similarity-threshold", type=float, default=DEFAULT_SIMILARITY_THRESHOLD,
138
+ help=f"Порог для нечеткого сравнения чанка и пункта (по умолчанию: {DEFAULT_SIMILARITY_THRESHOLD})")
139
+ parser.add_argument("--top-n", type=int, default=DEFAULT_TOP_N,
140
+ help=f"Количество топ-чанков для рассмотрения (по умолчанию: {DEFAULT_TOP_N})")
141
+ # Add cache directory argument
142
+ parser.add_argument("--cache-dir", type=str, default=DEFAULT_CACHE_DIR,
143
+ help=f"Директория для кэширования эмбеддингов и матриц схожести (по умолчанию: {DEFAULT_CACHE_DIR})")
144
+
145
+ # Параметры сборки контекста
146
+ parser.add_argument("--use-injection", action="store_true",
147
+ help="Выполнять ли сборку контекста и её оценку")
148
+ parser.add_argument("--use-qe", action="store_true",
149
+ help="Использовать столбец query_expansion вместо question для поиска (если он есть)")
150
+ parser.add_argument("--include-neighbors", action="store_true",
151
+ help="Включать ли соседние чанки (предыдущий/следующий) при сборке контекста")
152
+
153
+ # --- Добавляем аргумент для batch_id ---
154
+ parser.add_argument("--batch-id", type=str, default="batch_default",
155
+ help="Идентификатор серии запусков (передается из run_pipelines.py)")
156
+
157
+ # TODO: Добавить другие параметры при необходимости (например, параметры InjectionBuilder)
158
+
159
+ return parser.parse_args()
160
+
161
+ def read_documents(folder_path: str) -> dict[str, ParsedDocument]:
162
+ """
163
+ Читает все документы из указанной папки и создает сущности.
164
+
165
+ Args:
166
+ folder_path: Путь к папке с документами
167
+
168
+ Returns:
169
+ Словарь {имя_файла: объект ParsedDocument}
170
+ """
171
+ print(f"Чтение документов из {folder_path}...")
172
+ parser = UniversalParser()
173
+ documents_map = {}
174
+ doc_files = list(Path(folder_path).glob("*.docx"))
175
+
176
+ if not doc_files:
177
+ print(f"ВНИМАНИЕ: В папке {folder_path} не найдено *.docx файлов.")
178
+ return {}
179
+
180
+ for file_path in tqdm(doc_files, desc="Чтение документов"):
181
+ try:
182
+ doc_name = file_path.stem
183
+ # Парсим документ с помощью UniversalParser
184
+ parsed_document = parser.parse_by_path(str(file_path))
185
+ # Сохраняем распарсенный документ
186
+ documents_map[doc_name] = parsed_document
187
+ except Exception as e:
188
+ print(f"Ошибка при чтении файла {file_path}: {e}")
189
+
190
+ print(f"Прочитано документов: {len(documents_map)}")
191
+ return documents_map
192
+
193
+ def load_datasets(search_dataset_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
194
+ """
195
+ Загружает датасет для поиска и готовит данные для векторизации.
196
+
197
+ Args:
198
+ search_dataset_path: Путь к Excel с пунктами для поиска.
199
+
200
+ Returns:
201
+ Кортеж: (полный DataFrame поискового датасета, DataFrame с уникальными вопросами для векторизации).
202
+ """
203
+ print(f"Загрузка поискового датасета из {search_dataset_path}...")
204
+ try:
205
+ search_df = pd.read_excel(search_dataset_path)
206
+ print(f"Загружен поисковый датасет: {len(search_df)} строк, столбцы: {search_df.columns.tolist()}")
207
+
208
+ # Проверяем наличие обязательных столбцов
209
+ required_columns = ['id', 'question', 'text', 'filename']
210
+ missing_cols = [col for col in required_columns if col not in search_df.columns]
211
+ if missing_cols:
212
+ print(f"Ошибка: В поисковом датасете отсутствуют обязательные столбцы: {missing_cols}")
213
+ sys.exit(1)
214
+
215
+ # Преобразуем NaN в пустые строки для текстовых полей
216
+ # Добавляем 'query_expansion', если он есть, для обработки NaN
217
+ text_columns = ['question', 'text', 'item_type', 'filename']
218
+ if 'query_expansion' in search_df.columns:
219
+ text_columns.append('query_expansion')
220
+
221
+ for col in text_columns:
222
+ if col in search_df.columns:
223
+ search_df[col] = search_df[col].fillna('')
224
+ # Если необязательный item_type отсутствует, добавляем его пустым
225
+ elif col == 'item_type':
226
+ print(f"Предупреждение: столбец '{col}' отсутствует в поисковом датасете. Добавлен пустой столбец.")
227
+ search_df[col] = ''
228
+
229
+ # Убедимся, что 'id' имеет целочисленный тип
230
+ try:
231
+ search_df['id'] = search_df['id'].astype(int)
232
+ except ValueError as e:
233
+ print(f"Ошибка при приведении типов столбца 'id' в поисковом датасете: {e}. Убедитесь, что ID являются целыми числами.")
234
+ sys.exit(1)
235
+
236
+ except FileNotFoundError:
237
+ print(f"Ошибка: Поисковый датасет не найден по пути {search_dataset_path}")
238
+ sys.exit(1)
239
+ except Exception as e:
240
+ print(f"Ошибка при чтении поискового датасета: {e}")
241
+ sys.exit(1)
242
+
243
+ # Готовим DataFrame для векторизации уникальных вопросов
244
+ # Включаем query_expansion, если он есть
245
+ cols_for_embedding = ['id', 'question']
246
+ query_expansion_exists = 'query_expansion' in search_df.columns
247
+ if query_expansion_exists:
248
+ cols_for_embedding.append('query_expansion')
249
+ print("Столбец 'query_expansion' найден в поисковом датасете.")
250
+ else:
251
+ print("Столбец 'query_expansion' не найден в поисковом датасете.")
252
+
253
+ questions_to_embed = search_df[cols_for_embedding].drop_duplicates(subset=['id']).copy()
254
+
255
+ # Если query_expansion не существует, добавляем пустой столбец для единообразия
256
+ if not query_expansion_exists:
257
+ questions_to_embed['query_expansion'] = ''
258
+
259
+ print(f"Уникальных вопросов для векторизации: {len(questions_to_embed)}")
260
+
261
+ # Теперь search_df это и есть наш "объединенный" датасет (так как QA не используется)
262
+ return search_df, questions_to_embed
263
+
264
+
265
+ def perform_chunking(
266
+ documents_map: dict[str, ParsedDocument],
267
+ chunking_strategy: str,
268
+ process_tables: bool,
269
+ strategy_params_json: str # Expect JSON string
270
+ ) -> tuple[pd.DataFrame, list[LinkerEntity]]:
271
+ """
272
+ Выполняет чанкинг для всех документов.
273
+
274
+ Args:
275
+ documents_map: Словарь {имя_файла: сущность_документа}.
276
+ chunking_strategy: Имя используемой стратегии чанкинга.
277
+ process_tables: Флаг, указывающий, нужно ли обрабатывать таблицы.
278
+ strategy_params_json: Строка JSON с параметрами для стратегии.
279
+
280
+ Returns:
281
+ Кортеж: (DataFrame с чанками для поиска, список всех созданных сущностей LinkerEntity)
282
+ """
283
+ print("Выполнение чанкинга...")
284
+ searchable_chunks_data = [] # Данные только для чанков с in_search_text
285
+ final_entities: list[LinkerEntity] = [] # Список для ВСЕХ сущностей (доки, чанки, связи и т.д.)
286
+
287
+ # Parse strategy parameters from JSON string
288
+ try:
289
+ chunking_params = json.loads(strategy_params_json)
290
+ print(f"Параметры для стратегии '{chunking_strategy}': {chunking_params}")
291
+ except json.JSONDecodeError as e:
292
+ print(f"Ошибка парсинга JSON для strategy-params: '{strategy_params_json}'. Используются параметры по умолчанию стратегии. Ошибка: {e}")
293
+ chunking_params = {} # Use strategy defaults if JSON is invalid
294
+
295
+ print(f"Используется стратегия чанкинга: '{chunking_strategy}'")
296
+ print(f"Обработка таблиц: {'Включена' if process_tables else 'Отключена'}")
297
+
298
+ for doc_name, parsed_doc in tqdm(documents_map.items(), desc="Чанкинг документов"):
299
+ try:
300
+ # Инициализируем Destructurer ВНУТРИ цикла для КАЖДОГО документа
301
+ destructurer = Destructurer(
302
+ document=parsed_doc,
303
+ process_tables=process_tables,
304
+ strategy_name=chunking_strategy, # Передаем имя стратегии при инициализации
305
+ **chunking_params # И параметры стратегии
306
+ )
307
+ # Destructure создает DocumentAsEntity, чанки, связи и возвращает их как LinkerEntity
308
+ new_entities = destructurer.destructure()
309
+
310
+ # Добавляем ВСЕ созданные сущности (сериализованные LinkerEntity) в общий список
311
+ final_entities.extend(new_entities)
312
+
313
+ # Собираем данные для DataFrame только из тех сущностей,
314
+ # у которых есть поле in_search_text (это наши чанки для поиска)
315
+ for entity in new_entities:
316
+ # Проверяем наличие атрибута 'in_search_text', а не тип
317
+ if hasattr(entity, 'in_search_text') and entity.in_search_text:
318
+ entity_data = {
319
+ 'chunk_id': str(entity.id),
320
+ 'doc_name': doc_name, # Имя исходного файла
321
+ 'doc_id': str(entity.source_id), # ID сущности документа (DocumentAsEntity)
322
+ 'text': entity.in_search_text, # Текст для векторизации и поиска
323
+ 'type': entity.type, # Тип сущности (например, 'FixedSizeChunk')
324
+ 'strategy_params': json.dumps(chunking_params, ensure_ascii=False),
325
+ }
326
+ searchable_chunks_data.append(entity_data)
327
+ except Exception as e:
328
+ # Логируем ошибку и продолжаем с остальными документами
329
+ print(f"\nОшибка при чанкинге документа {doc_name}: {e}")
330
+ import traceback
331
+ traceback.print_exc() # Печатаем traceback для детальной отладки
332
+
333
+ # Создаем DataFrame только из чанков, предназначенных для поиска
334
+ chunks_df = pd.DataFrame(searchable_chunks_data)
335
+ print(f"Создано чанков для поиска: {len(chunks_df)}")
336
+
337
+ # Возвращаем DataFrame с чанками для поиска и ПОЛНЫЙ список всех LinkerEntity
338
+ return chunks_df, final_entities
339
+
340
+
341
+ def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool, device: str):
342
+ """Инициализирует модель и токенизатор."""
343
+ print(f"Загрузка модели {model_name} на устройство {device}...")
344
+ if use_sentence_transformers:
345
+ try:
346
+ from sentence_transformers import SentenceTransformer
347
+ model = SentenceTransformer(model_name, device=device)
348
+ tokenizer = None # sentence_transformers не требует отдельного токенизатора
349
+ print("Используется SentenceTransformer.")
350
+ return model, tokenizer
351
+ except ImportError:
352
+ print("Ошибка: Библиотека sentence_transformers не установлена. Установите: pip install sentence-transformers")
353
+ sys.exit(1)
354
+ else:
355
+ try:
356
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
357
+ model = AutoModel.from_pretrained(model_name).to(device)
358
+ model.eval()
359
+ print("Используется AutoModel и AutoTokenizer из transformers.")
360
+ return model, tokenizer
361
+ except Exception as e:
362
+ print(f"Ошибка при загрузке модели {model_name} из transformers: {e}")
363
+ sys.exit(1)
364
+
365
+
366
+ def get_embeddings(
367
+ texts: list[str],
368
+ model,
369
+ tokenizer,
370
+ batch_size: int,
371
+ use_sentence_transformers: bool,
372
+ device: str
373
+ ) -> np.ndarray:
374
+ """Получает эмбеддинги для списка текстов."""
375
+ all_embeddings = []
376
+ desc = "Векторизация (Sentence Transformers)" if use_sentence_transformers else "Векторизация (Transformers)"
377
+
378
+ for i in tqdm(range(0, len(texts), batch_size), desc=desc):
379
+ batch_texts = texts[i:i+batch_size]
380
+ if not batch_texts:
381
+ continue
382
+
383
+ if use_sentence_transformers:
384
+ # Эмбеддинги через sentence_transformers
385
+ embeddings = model.encode(batch_texts, batch_size=len(batch_texts), show_progress_bar=False)
386
+ all_embeddings.append(embeddings)
387
+ else:
388
+ # Эмбеддинги через transformers с average pooling
389
+ try:
390
+ encoding = tokenizer(
391
+ batch_texts,
392
+ padding=True,
393
+ truncation=True,
394
+ max_length=512, # Стандартное ограничение для многих моделей
395
+ return_tensors="pt"
396
+ ).to(device)
397
+
398
+ with torch.no_grad():
399
+ outputs = model(**encoding)
400
+ embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
401
+ all_embeddings.append(embeddings.cpu().numpy())
402
+ except Exception as e:
403
+ print(f"Ошибка при векторизации батча (индексы {i} - {i+batch_size}): {e}")
404
+ print(f"Тексты батча: {batch_texts[:2]}...")
405
+ # Добавляем нулевые векторы, чтобы не сломать vstack
406
+ # Определяем размер эмбеддинга
407
+ if all_embeddings:
408
+ embedding_dim = all_embeddings[0].shape[1]
409
+ else:
410
+ # Пытаемся получить размер из конфигурации модели
411
+ try:
412
+ embedding_dim = model.config.hidden_size
413
+ except AttributeError:
414
+ embedding_dim = 768 # Запасной вариант
415
+ print(f"Не удалось определить размер эмбеддинга, используется {embedding_dim}")
416
+
417
+ print(f"Добавление нулевых эмбеддингов размерности ({len(batch_texts)}, {embedding_dim})")
418
+ null_embeddings = np.zeros((len(batch_texts), embedding_dim), dtype=np.float32)
419
+ all_embeddings.append(null_embeddings)
420
+
421
+
422
+ if not all_embeddings:
423
+ print("ВНИМАНИЕ: Не удалось создать эмбеддинги.")
424
+ # Возвращаем пустой массив правильной формы, если возможно
425
+ try:
426
+ embedding_dim = model.config.hidden_size if not use_sentence_transformers else model.get_sentence_embedding_dimension()
427
+ except:
428
+ embedding_dim = 768
429
+ return np.empty((0, embedding_dim), dtype=np.float32)
430
+
431
+ # Объединяем эмбеддинги из всех батчей
432
+ try:
433
+ final_embeddings = np.vstack(all_embeddings)
434
+ except ValueError as e:
435
+ print(f"Ошибка при объединении эмбеддингов: {e}")
436
+ print("Размеры эмбеддингов в батчах:")
437
+ for i, emb_batch in enumerate(all_embeddings):
438
+ print(f" Батч {i}: {emb_batch.shape}")
439
+ # Попробуем определить ожидаемый размер и создать нулевой массив
440
+ if all_embeddings:
441
+ embedding_dim = all_embeddings[0].shape[1]
442
+ print(f"Возвращение ну��евого массива размерности ({len(texts)}, {embedding_dim})")
443
+ return np.zeros((len(texts), embedding_dim), dtype=np.float32)
444
+ else:
445
+ return np.empty((0, 768), dtype=np.float32) # Запасной вариант
446
+
447
+ print(f"Получено эмбеддингов: {final_embeddings.shape}")
448
+ return final_embeddings
449
+
450
+ # --- Caching Helper Functions ---
451
+
452
+ def _get_params_hash(
453
+ model_name: str,
454
+ process_tables: bool | None = None,
455
+ strategy_params: dict | None = None # Expect the parsed dictionary
456
+ ) -> str:
457
+ """Создает MD5 хэш из переданных параметров."""
458
+ hasher = hashlib.md5()
459
+ hasher.update(model_name.encode())
460
+ # Add chunking strategy and table processing flag if provided
461
+ if process_tables is not None:
462
+ hasher.update(str(process_tables).encode())
463
+ # Add strategy parameters (sort items to ensure consistent hash)
464
+ if strategy_params:
465
+ sorted_params = sorted(strategy_params.items())
466
+ hasher.update(json.dumps(sorted_params).encode())
467
+ return hasher.hexdigest()
468
+
469
+ def _get_cache_path(cache_dir: Path, hash_str: str, filename: str) -> Path:
470
+ """Формирует путь к файлу кэша, создавая поддиректории."""
471
+ # Используем первые 2 символа хэша для распределения по поддиректориям
472
+ # Это помогает избежать слишком большого количества файлов в одной директории
473
+ cache_subdir = cache_dir / hash_str[:2] / hash_str
474
+ cache_subdir.mkdir(parents=True, exist_ok=True)
475
+ return cache_subdir / filename
476
+
477
+ # --- Добавляем функцию для хэша чанкинга ---
478
+ def _get_chunking_cache_hash(
479
+ data_folder: str,
480
+ chunking_strategy: str,
481
+ process_tables: bool,
482
+ strategy_params: dict # Ожидаем словарь
483
+ ) -> str:
484
+ """Создает MD5 хэш для параметров чанкинга и папки с данными."""
485
+ hasher = hashlib.md5()
486
+ hasher.update(data_folder.encode())
487
+ hasher.update(chunking_strategy.encode())
488
+ hasher.update(str(process_tables).encode())
489
+ # Сортируем параметры для консистентности хэша
490
+ sorted_params = sorted(strategy_params.items())
491
+ hasher.update(json.dumps(sorted_params).encode())
492
+ return hasher.hexdigest()
493
+ # ---------------------------------------------
494
+
495
+ # --- Main Evaluation Function ---
496
+
497
+ def evaluate_run(
498
+ search_dataset: pd.DataFrame,
499
+ questions_to_embed: pd.DataFrame,
500
+ chunks_df: pd.DataFrame,
501
+ all_entities: list[LinkerEntity],
502
+ model: Any | None, # Принимаем None
503
+ tokenizer: Any | None, # Принимаем None
504
+ args: argparse.Namespace
505
+ ) -> pd.DataFrame:
506
+ """
507
+ Выполняет основной цикл оценки для одного набора параметров.
508
+
509
+ Args:
510
+ search_dataset: DataFrame поискового датасета.
511
+ questions_to_embed: DataFrame с уникальными вопросами для векторизации.
512
+ chunks_df: DataFrame с данными по чанкам.
513
+ all_entities: Список всех сущностей (документы, чанки, связи).
514
+ model: Модель для векторизации.
515
+ tokenizer: Токенизатор.
516
+ args: Аргументы командной строки.
517
+
518
+ Returns:
519
+ DataFrame с детальными метриками по каждому вопросу для этого запуска.
520
+ """
521
+ print("Начало этапа оценки...")
522
+
523
+ # Переменные для модели и токенизатора, инициализируем None
524
+ loaded_model: Any | None = model
525
+ loaded_tokenizer: Any | None = tokenizer
526
+
527
+ # --- Caching Setup ---
528
+ print("Настройка кэширования...")
529
+ CACHE_DIR_PATH = Path(args.cache_dir)
530
+ model_slug = args.model_name.split('/')[-1] # Basic slug for filename clarity
531
+
532
+ # --- Определяем, какой текст использовать для эмбеддингов вопросов ---
533
+ # и устанавливаем флаг qe_active, который будет влиять на кэш
534
+ if args.use_qe and 'query_expansion' in questions_to_embed.columns and questions_to_embed['query_expansion'].notna().any(): # Check if column exists and has non-NA values
535
+ print("Используется Query Expansion (столбец 'query_expansion') для векторизации вопросов.")
536
+ query_texts_to_embed = questions_to_embed['query_expansion'].tolist()
537
+ qe_active = True
538
+ else:
539
+ print("Используется оригинальный текст вопроса (столбец 'question') для векторизации.")
540
+ query_texts_to_embed = questions_to_embed['question'].tolist()
541
+ qe_active = False
542
+
543
+ # Cache key for question embeddings (ЗАВИСИТ от модели и флага use_qe)
544
+ question_params_for_hash = {
545
+ 'model_name': args.model_name,
546
+ 'use_qe': qe_active # Добавляем фактическое использование QE в параметры для хэша
547
+ }
548
+ question_hash = hashlib.md5(json.dumps(question_params_for_hash, sort_keys=True).encode()).hexdigest()
549
+ question_embeddings_cache_path = _get_cache_path(
550
+ CACHE_DIR_PATH, question_hash, f"q_embeddings_{model_slug}_qe{qe_active}.npy"
551
+ )
552
+
553
+ # Cache key for chunk embeddings (depends on model and chunking)
554
+ chunk_hash = _get_params_hash(
555
+ args.model_name,
556
+ args.process_tables, # Include table flag
557
+ json.loads(args.strategy_params) # Pass parsed params dictionary
558
+ )
559
+ chunk_embeddings_cache_path = _get_cache_path(
560
+ CACHE_DIR_PATH, chunk_hash,
561
+ f"c_emb_{model_slug}_s-{args.chunking_strategy}_t{args.process_tables}_ph-{hashlib.md5(args.strategy_params.encode()).hexdigest()[:8]}.npy"
562
+ )
563
+
564
+ # Cache key for similarity matrix (depends on both sets of embeddings)
565
+ similarity_hash = f"{question_hash}_{chunk_hash}" # Combine hashes
566
+ similarity_cache_path = _get_cache_path(
567
+ CACHE_DIR_PATH, similarity_hash,
568
+ f"sim_{model_slug}_qe{qe_active}_ph-{hashlib.md5(args.strategy_params.encode()).hexdigest()[:8]}.npy" # Добавляем флаг QE в имя файла
569
+ )
570
+
571
+ # 1. Векторизация вопросов и чанков (с кэшем)
572
+ question_embeddings = None
573
+ needs_model_load = False # Флаг, указывающий, нужна ли загрузка модели
574
+
575
+ if question_embeddings_cache_path.exists():
576
+ try:
577
+ print(f"Загрузка кэшированных эмбеддингов вопросов из: {question_embeddings_cache_path}")
578
+ question_embeddings = np.load(question_embeddings_cache_path, allow_pickle=False)
579
+ if len(question_embeddings) != len(questions_to_embed):
580
+ print(f"Предупреждение: Размер кэша эмбеддингов вопросов не совпадает. Пересчет.")
581
+ question_embeddings = None
582
+ else:
583
+ print("Кэш эмбеддингов вопросов успешно загружен.")
584
+ except Exception as e:
585
+ print(f"Ошибка загрузки кэша эмбеддингов вопросов: {e}. Пересчет.")
586
+ question_embeddings = None
587
+
588
+ if question_embeddings is None:
589
+ needs_model_load = True # Требуется модель для генерации эмбеддингов
590
+ print("Векторизация вопросов (потребуется загрузка модели)...")
591
+
592
+ chunk_embeddings = None
593
+ if chunk_embeddings_cache_path.exists():
594
+ try:
595
+ print(f"Загрузка кэшированных эмбеддингов чанков из: {chunk_embeddings_cache_path}")
596
+ chunk_embeddings = np.load(chunk_embeddings_cache_path, allow_pickle=False)
597
+ if len(chunk_embeddings) != len(chunks_df):
598
+ print(f"Предупреждение: Размер кэша эмбеддингов чанков не совпадает. Пересчет.")
599
+ chunk_embeddings = None
600
+ else:
601
+ print("Кэш эмбеддингов чанков успешно загружен.")
602
+ except Exception as e:
603
+ print(f"Ошибка загрузки кэша эмбеддингов чанков: {e}. Пересчет.")
604
+ chunk_embeddings = None
605
+
606
+ if chunk_embeddings is None:
607
+ needs_model_load = True # Требуется модель для генерации эмбеддингов
608
+ print("Векторизация чанков (потребуется загрузка модели)...")
609
+
610
+ # --- Отложенная загрузка модели, если необходимо ---
611
+ if needs_model_load and loaded_model is None:
612
+ print("\n--- Загрузка модели и токенизатора (т.к. кэш эмбеддингов отсутствует) ---")
613
+ loaded_model, loaded_tokenizer = setup_model_and_tokenizer(
614
+ args.model_name, args.use_sentence_transformers, args.device
615
+ )
616
+ print("--- Модель и токенизатор загружены ---\n")
617
+
618
+ # --- Повторная генерация эмбеддингов, если они не загрузились из кэша ---
619
+ if question_embeddings is None:
620
+ if loaded_model is None:
621
+ print("Критическая ошибка: Модель не загружена, но требуется для векторизации вопросов!")
622
+ # Возвращаем пустой DataFrame или выбрасываем исключение
623
+ return pd.DataFrame()
624
+
625
+ print("Повторная векторизация вопросов...")
626
+ question_embeddings = get_embeddings(
627
+ query_texts_to_embed,
628
+ loaded_model, loaded_tokenizer, args.batch_size, args.use_sentence_transformers, args.device
629
+ )
630
+ if question_embeddings.shape[0] > 0:
631
+ try:
632
+ print(f"Сохранение эмбеддингов вопросов в кэш: {question_embeddings_cache_path}")
633
+ np.save(question_embeddings_cache_path, question_embeddings, allow_pickle=False)
634
+ except Exception as e:
635
+ print(f"Не удалось сохранить кэш эмбеддингов вопросов: {e}")
636
+
637
+ if chunk_embeddings is None:
638
+ if loaded_model is None:
639
+ print("Критическая ошибка: Модель не загружена, но требуется для векторизации чанков!")
640
+ return pd.DataFrame()
641
+
642
+ print("Повторная векторизация чанков...")
643
+ chunk_texts = chunks_df['text'].fillna('').astype(str).tolist()
644
+ chunk_embeddings = get_embeddings(
645
+ chunk_texts,
646
+ loaded_model, loaded_tokenizer, args.batch_size, args.use_sentence_transformers, args.device
647
+ )
648
+ if chunk_embeddings.shape[0] > 0:
649
+ try:
650
+ print(f"Сохранение эмбеддингов чанков в кэш: {chunk_embeddings_cache_path}")
651
+ np.save(chunk_embeddings_cache_path, chunk_embeddings, allow_pickle=False)
652
+ except Exception as e:
653
+ print(f"Не удалось сохранить кэш эмбеддингов чанков: {e}")
654
+
655
+
656
+ # Проверка совпадения количества эмбеддингов и данных
657
+ if len(question_embeddings) != len(questions_to_embed):
658
+ print(f"Ошибка: Количество эмбеддингов вопросов ({len(question_embeddings)}) не совпадает с количеством уникальных вопросов ({len(questions_to_embed)}).")
659
+ # Можно либо прервать выполнение, либо попытаться исправить
660
+ # Например, взять первые N эмбеддингов, но это может быть некорректно
661
+ sys.exit(1)
662
+ if len(chunk_embeddings) != len(chunks_df):
663
+ print(f"Ошибка: Количество эмбеддингов чанков ({len(chunk_embeddings)}) не совпадает с количеством чанков в DataFrame ({len(chunks_df)}).")
664
+ # Попытка исправить (если ошибка небольшая) или выход
665
+ if abs(len(chunk_embeddings) - len(chunks_df)) < 5:
666
+ print("Попытка обрезать лишние эмбеддинги/данные...")
667
+ min_len = min(len(chunk_embeddings), len(chunks_df))
668
+ chunk_embeddings = chunk_embeddings[:min_len]
669
+ chunks_df = chunks_df.iloc[:min_len]
670
+ print(f"Размеры выровнены до {min_len}")
671
+ else:
672
+ sys.exit(1)
673
+
674
+
675
+ # Создаем маппинг ID вопроса к индексу в эмбеддингах
676
+ question_id_to_idx = {
677
+ row['id']: i for i, (_, row) in enumerate(questions_to_embed.iterrows())
678
+ }
679
+
680
+ # 2. Расчет косинусной близости
681
+ print("Расчет косинусной близости...")
682
+ # Проверка на пустые эмбеддинги
683
+ if question_embeddings.shape[0] == 0 or chunk_embeddings.shape[0] == 0:
684
+ print("Ошибка: Отсутствуют эмбеддинги вопросов или чанков для расчета близости.")
685
+ # Возвращаем пустой DataFrame или обрабатываем ошибку иначе
686
+ return pd.DataFrame()
687
+
688
+ similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
689
+
690
+ # 3. Инициализация InjectionBuilder (если нужно)
691
+ injection_builder = None
692
+ if args.use_injection:
693
+ print("Инициализация InjectionBuilder...")
694
+ repository = InMemoryEntityRepository(all_entities)
695
+ injection_builder = InjectionBuilder(repository)
696
+ # TODO: Зарегистрировать стратегии, если необходимо
697
+ # builder.register_strategy(...)
698
+
699
+ # 4. Цикл по уникальным вопросам для оценки
700
+ results = []
701
+ print(f"Оценка для {len(questions_to_embed)} уникальных вопросов...")
702
+
703
+ for question_id_iter, question_data in tqdm(questions_to_embed.iterrows(), total=len(questions_to_embed), desc="Оценка вопросов"): # Renamed loop variable
704
+ q_id = question_data['id']
705
+ q_text = question_data['question']
706
+
707
+ # Получаем все строки из исходного датасета для этого вопроса
708
+ question_rows = search_dataset[search_dataset['id'] == q_id] # Use search_dataset
709
+ if question_rows.empty:
710
+ print(f"Предупреждение: Нет данных в search_dataset для вопроса ID={q_id}")
711
+ continue
712
+
713
+ # Получаем пункты (relevant items)
714
+ puncts = question_rows['text'].tolist()
715
+ # reference_answer больше не используется и не извлекается
716
+
717
+ # Получаем индекс вопроса в матрице близости
718
+ if q_id not in question_id_to_idx:
719
+ print(f"Предупреждение: Вопрос ID={q_id} не найден в маппинге эмбеддингов.")
720
+ continue
721
+ question_idx = question_id_to_idx[q_id]
722
+
723
+ # --- Оценка на уровне чанков (Chunk-level) ---
724
+ chunk_level_metrics = evaluate_chunk_relevance(
725
+ q_id, question_idx, puncts,
726
+ similarity_matrix, chunks_df, args.top_n, args.similarity_threshold
727
+ )
728
+
729
+ # --- Оценка на уровне сборки (Assembly-level) ---
730
+ # Удаляем assembly_relevance, основанный на reference_answer
731
+ assembly_level_metrics = {} # Start with an empty dict for assembly metrics
732
+ assembled_context = ""
733
+ top_chunk_indices = chunk_level_metrics.get("top_chunk_ids", []) # Get indices first
734
+ neighbors_included = False # Flag to log
735
+
736
+ if args.use_injection and injection_builder and top_chunk_indices:
737
+ try:
738
+ # Преобразуем ID строк обратно в UUID чанков
739
+ top_chunk_uuids = [UUID(chunks_df.iloc[idx]['chunk_id']) for idx in top_chunk_indices]
740
+
741
+ final_chunk_uuids_for_assembly = set(top_chunk_uuids) # Start with top chunks
742
+
743
+ # --- Добавляем соседей, если нужно ---
744
+ if args.include_neighbors:
745
+ neighbors_included = True
746
+ # --- Убираем логирование индексов ---
747
+ neighbor_chunks = repository.get_neighboring_chunks(chunk_ids=top_chunk_uuids, max_distance=1)
748
+ neighbor_ids = {neighbor.id for neighbor in neighbor_chunks}
749
+ # --- Логирование до/после добавления ID соседей ---
750
+ print(f" [DEBUG QID {q_id}] Кол-во ID до добавления соседей: {len(final_chunk_uuids_for_assembly)}")
751
+ print(f" [DEBUG QID {q_id}] Кол-во найденных ID соседей: {len(neighbor_ids)}")
752
+ final_chunk_uuids_for_assembly.update(neighbor_ids)
753
+ print(f" [DEBUG QID {q_id}] Кол-во ID после добавления соседей: {len(final_chunk_uuids_for_assembly)}")
754
+ # --- Конец логирования ---
755
+ # --- Убираем логирование индексов ---
756
+ else:
757
+ # --- Убираем логирование индексов ---
758
+ pass # Ничего не делаем, если соседи не включены
759
+
760
+ # Собираем контекст
761
+ # Передаем финальный набор UUID (уникальный)
762
+ assembled_context = injection_builder.build(
763
+ filtered_entities=list(final_chunk_uuids_for_assembly),
764
+ # chunk_scores= {chunks_df.loc[idx, 'chunk_id']: sim for idx, sim in zip(top_chunk_ids_for_assembly, chunk_level_metrics.get('top_chunk_similarities',[]))} # Можно добавить веса
765
+ )
766
+
767
+ # --- Новая метрика: Assembly Punct Recall ---
768
+ # Оцениваем, сколько пунктов из датасета найдено в собранном контексте
769
+ # (по вашей идее: пункт считается найденным, если хотя бы одна его часть,
770
+ # разделенная переносом строки, найдена в контексте)
771
+ assembly_found_puncts = 0
772
+ assembly_total_puncts = len(puncts)
773
+ if assembly_total_puncts > 0 and assembled_context:
774
+ # Итерируемся по каждому исходному пункту
775
+ for punct_text in puncts:
776
+ # Разбиваем пункт на части по переносу строки
777
+ # Убираем пустые строки, которые могут появиться из-за двойных переносов
778
+ punct_parts = [part for part in punct_text.split('\n') if part.strip()]
779
+
780
+ # Если пункт пустой или состоит только из пробельных символов после разбивки,
781
+ # пропускаем его (не считаем ни найденным, ни не найденным в контексте recall)
782
+ if not punct_parts:
783
+ assembly_total_puncts -= 1 # Уменьшаем общее число пунктов для расчета recall
784
+ continue
785
+
786
+ is_punct_found = False
787
+ # Итерируемся по частям пункта
788
+ for part_text in punct_parts:
789
+ # Сравниваем КАЖДУЮ ЧАСТЬ пункта с собранным контекстом
790
+ if calculate_chunk_overlap(assembled_context, part_text.strip()) >= args.similarity_threshold:
791
+ # Если ХОТЯ БЫ ОДНА часть найдена, считаем ВЕСЬ пункт найденным
792
+ is_punct_found = True
793
+ break # Дальше части этого пункта можно не проверять
794
+
795
+ # Если флаг is_punct_found стал True, увеличиваем счетчик найденных пунктов
796
+ if is_punct_found:
797
+ assembly_found_puncts += 1
798
+
799
+ # Рассчитываем recall, только если были валидные пункты для проверки
800
+ if assembly_total_puncts > 0:
801
+ assembly_level_metrics["assembly_punct_recall"] = assembly_found_puncts / assembly_total_puncts
802
+ else:
803
+ assembly_level_metrics["assembly_punct_recall"] = 0.0 # Или можно None, если нет валидных пунктов
804
+ else:
805
+ assembly_level_metrics["assembly_punct_recall"] = 0.0
806
+ # Добавляем сам текст сборки для возможного анализа (усеченный)
807
+ assembly_level_metrics["assembled_context_preview"] = assembled_context[:500] + ("..." if len(assembled_context) > 500 else "")
808
+
809
+
810
+ except Exception as e:
811
+ print(f"Ошибка при сборке/оценке контекста для вопроса ID={q_id}: {e}")
812
+ # Записываем None или 0, чтобы не прерывать процесс
813
+ assembly_level_metrics["assembly_punct_recall"] = None # Indicate error
814
+ assembly_level_metrics["assembled_context_preview"] = f"Error during assembly: {e}"
815
+
816
+
817
+ # Собираем все метрики для вопроса
818
+ question_result = {
819
+ "run_id": args.run_id,
820
+ "batch_id": args.batch_id, # --- Добавляем batch_id в результаты ---
821
+ "question_id": q_id,
822
+ "question_text": q_text,
823
+ # Параметры запуска
824
+ "model_name": args.model_name,
825
+ "chunking_strategy": args.chunking_strategy, # Log strategy
826
+ "process_tables": args.process_tables, # Log table flag
827
+ "strategy_params": args.strategy_params, # Log JSON string
828
+ "top_n": args.top_n,
829
+ "use_injection": args.use_injection,
830
+ "use_qe": qe_active, # Log QE status
831
+ "neighbors_included": neighbors_included, # Log neighbor flag
832
+ "similarity_threshold": args.similarity_threshold,
833
+ # Метрики Chunk-level
834
+ **chunk_level_metrics,
835
+ # Метрики Assembly-level (теперь с recall по пунктам)
836
+ **assembly_level_metrics,
837
+ # Тексты для отладки (эталонный ответ удален, сборка добавлена выше)
838
+ # "assembled_context": assembled_context[:500] + "..." if assembled_context else "",
839
+ }
840
+ results.append(question_result)
841
+
842
+ print("Оценка завершена.")
843
+ return pd.DataFrame(results)
844
+
845
+
846
+ def evaluate_chunk_relevance(
847
+ question_id: int,
848
+ question_idx: int,
849
+ puncts: list[str],
850
+ similarity_matrix: np.ndarray,
851
+ chunks_df: pd.DataFrame,
852
+ top_n: int,
853
+ similarity_threshold: float
854
+ ) -> dict:
855
+ """
856
+ Оценивает релевантность чанков для одного вопроса.
857
+ (Адаптировано из evaluate_for_top_n_with_mapping в evaluate_chunking.py)
858
+
859
+ Возвращает словарь с метриками для этого вопроса.
860
+ """
861
+ metrics = {
862
+ "chunk_text_precision": 0.0,
863
+ "chunk_text_recall": 0.0,
864
+ "chunk_text_f1": 0.0,
865
+ "found_puncts": 0,
866
+ "total_puncts": len(puncts),
867
+ "relevant_chunks": 0,
868
+ "total_chunks_in_top_n": 0,
869
+ "top_chunk_ids": [], # Индексы строк в chunks_df
870
+ "top_chunk_similarities": [],
871
+ }
872
+
873
+ if chunks_df.empty or similarity_matrix.shape[1] == 0:
874
+ print(f"Предупреждение (QID {question_id}): Нет чанков для оценки.")
875
+ return metrics
876
+
877
+ # Получаем схожести всех чанков с текущим вопросом
878
+ question_similarities = similarity_matrix[question_idx, :]
879
+
880
+ # Сортируем чанки по схожести и берем top_n
881
+ # argsort возвращает индексы элементов, которые бы отсортировали массив
882
+ # Берем последние N индексов (-top_n:) и разворачиваем ([::-1]) для убывания
883
+ # Добавляем проверку на случай если top_n > количества чанков
884
+ if top_n >= similarity_matrix.shape[1]:
885
+ sorted_chunk_indices = np.argsort(question_similarities)[::-1] # Берем все, сортируем по убыванию
886
+ else:
887
+ sorted_chunk_indices = np.argsort(question_similarities)[-top_n:][::-1]
888
+
889
+ # Ограничиваем top_n, если чанков меньше (это должно быть сделано выше, но дублируем для надежности)
890
+ actual_top_n = min(top_n, len(sorted_chunk_indices))
891
+ top_chunk_indices = sorted_chunk_indices[:actual_top_n]
892
+
893
+ # Сохраняем ID и схожести топ-чанков
894
+ metrics["top_chunk_ids"] = top_chunk_indices.tolist()
895
+ metrics["top_chunk_similarities"] = question_similarities[top_chunk_indices].tolist()
896
+
897
+
898
+ # Отбираем данные топ-чанков
899
+ top_chunks_df = chunks_df.iloc[top_chunk_indices]
900
+ metrics["total_chunks_in_top_n"] = len(top_chunks_df)
901
+
902
+ if metrics["total_chunks_in_top_n"] == 0:
903
+ return metrics # Если нет топ-чанков, метрики остаются нулевыми
904
+
905
+ # Оценка на основе текста (пунктов)
906
+ punct_found = [False] * metrics["total_puncts"]
907
+ question_relevant_chunks = 0
908
+ for i, (idx, chunk_row) in enumerate(top_chunks_df.iterrows()):
909
+ chunk_text = chunk_row['text']
910
+ is_relevant_to_punct = False
911
+ for j, punct_text in enumerate(puncts):
912
+ overlap = calculate_chunk_overlap(chunk_text, punct_text)
913
+ if overlap >= similarity_threshold:
914
+ is_relevant_to_punct = True
915
+ punct_found[j] = True
916
+ if is_relevant_to_punct:
917
+ question_relevant_chunks += 1
918
+
919
+ metrics["found_puncts"] = sum(punct_found)
920
+ metrics["relevant_chunks"] = question_relevant_chunks
921
+
922
+ if metrics["total_chunks_in_top_n"] > 0:
923
+ metrics["chunk_text_precision"] = metrics["relevant_chunks"] / metrics["total_chunks_in_top_n"]
924
+ if metrics["total_puncts"] > 0:
925
+ metrics["chunk_text_recall"] = metrics["found_puncts"] / metrics["total_puncts"]
926
+ if metrics["chunk_text_precision"] + metrics["chunk_text_recall"] > 0:
927
+ metrics["chunk_text_f1"] = (2 * metrics["chunk_text_precision"] * metrics["chunk_text_recall"] /
928
+ (metrics["chunk_text_precision"] + metrics["chunk_text_recall"]))
929
+
930
+ return metrics
931
+
932
+
933
+ # --- Основная функция ---
934
+
935
+ def main():
936
+ """Основная функция скрипта."""
937
+ args = parse_args()
938
+ print(f"Запуск оценки с ID: {args.run_id}")
939
+ print(f"Параметры: {vars(args)}")
940
+
941
+ # --- Кэширование Чанкинга ---
942
+ CACHE_DIR_PATH = Path(args.cache_dir)
943
+ try:
944
+ # Парсим параметры стратегии один раз
945
+ parsed_strategy_params = json.loads(args.strategy_params)
946
+ except json.JSONDecodeError:
947
+ print(f"Предупреждение: Невалидный JSON в strategy_params: '{args.strategy_params}'. Используются параметры по умолчанию для хэша кэша.")
948
+ parsed_strategy_params = {}
949
+
950
+ chunking_hash = _get_chunking_cache_hash(
951
+ args.data_folder,
952
+ args.chunking_strategy,
953
+ args.process_tables,
954
+ parsed_strategy_params
955
+ )
956
+ chunks_df_cache_path = _get_cache_path(CACHE_DIR_PATH, chunking_hash, "chunks_df.parquet")
957
+ entities_cache_path = _get_cache_path(CACHE_DIR_PATH, chunking_hash, "final_entities.pkl")
958
+
959
+ chunks_df = None
960
+ all_entities = None
961
+
962
+ if chunks_df_cache_path.exists() and entities_cache_path.exists():
963
+ print(f"Найден кэш чанкинга (hash: {chunking_hash}). Загрузка...")
964
+ try:
965
+ chunks_df = pd.read_parquet(chunks_df_cache_path)
966
+ with open(entities_cache_path, 'rb') as f:
967
+ all_entities = pickle.load(f)
968
+ print(f"Кэш чанкинга успешно загружен: {len(chunks_df)} чанков, {len(all_entities)} сущностей.")
969
+ except Exception as e:
970
+ print(f"Ошибка загрузки кэша чанкинга: {e}. Выполняем чанкинг заново.")
971
+ chunks_df = None
972
+ all_entities = None
973
+
974
+ if chunks_df is None or all_entities is None:
975
+ print("Кэш чанкинга не найден или поврежден. Выполнение чтения документов и чанкинга...")
976
+ # 1. Загрузка данных
977
+ documents_map = read_documents(args.data_folder)
978
+ if not documents_map:
979
+ print("Нет документов для обработки. Завершение.")
980
+ return
981
+
982
+ # 2. Чанкинг
983
+ chunks_df, all_entities = perform_chunking(
984
+ documents_map,
985
+ args.chunking_strategy, # Pass strategy
986
+ args.process_tables, # Pass table flag
987
+ args.strategy_params # Pass JSON string parameters
988
+ )
989
+ if chunks_df.empty:
990
+ print("После чанкинга не осталось чанков для обработки. Завершение.")
991
+ return
992
+
993
+ # Сохраняем результаты чанкинга в кэш
994
+ try:
995
+ print(f"Сохранение результатов чанкинга в кэш (hash: {chunking_hash})...")
996
+ # Убедимся, что директория кэша существует (на всякий случай)
997
+ chunks_df_cache_path.parent.mkdir(parents=True, exist_ok=True)
998
+ entities_cache_path.parent.mkdir(parents=True, exist_ok=True)
999
+
1000
+ chunks_df.to_parquet(chunks_df_cache_path)
1001
+ with open(entities_cache_path, 'wb') as f:
1002
+ pickle.dump(all_entities, f)
1003
+ print("Результаты чанкинга сохранены в кэш.")
1004
+ except Exception as e:
1005
+ print(f"Ошибка сохранения кэша чанкинга: {e}")
1006
+
1007
+ # --- Конец Кэширования Чанкинга ---
1008
+
1009
+ # Загружаем поисковый датасет (это нужно делать всегда, т.к. он не кэшируется здесь)
1010
+ search_df, questions_to_embed = load_datasets(args.search_dataset_path)
1011
+
1012
+ # 3. Выполнение оценки (передаем загруженные или свежесгенерированные chunks_df и all_entities)
1013
+ results_df = evaluate_run(
1014
+ search_df, questions_to_embed, chunks_df, all_entities,
1015
+ None, None, args # Передаем None для model и tokenizer
1016
+ )
1017
+
1018
+ # 5. Сохранение результатов
1019
+ if not results_df.empty:
1020
+ os.makedirs(args.output_dir, exist_ok=True)
1021
+ # output_filename = f"results_{args.run_id}.csv"
1022
+ # Добавляем batch_id в имя файла для лучшей группировки
1023
+ output_filename = f"results_{args.batch_id}_{args.run_id}.csv"
1024
+ output_path = os.path.join(args.output_dir, output_filename)
1025
+ try:
1026
+ results_df.to_csv(output_path, index=False, encoding='utf-8')
1027
+ print(f"Детальные результаты сохранены в: {output_path}")
1028
+ except Exception as e:
1029
+ print(f"Ошибка при сохранении результатов в {output_path}: {e}")
1030
+ else:
1031
+ print("Нет результатов для сохранения.")
1032
+
1033
+ if __name__ == "__main__":
1034
+ main()
scripts/testing/plot_results.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Скрипт для визуализации агрегированных результатов тестирования RAG.
5
+
6
+ Читает данные из Excel-файла, сгенерированного aggregate_results.py,
7
+ и строит различные графики для анализа влияния параметров на метрики.
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+ import os
13
+
14
+ import matplotlib.pyplot as plt
15
+ import pandas as pd
16
+ import seaborn as sns
17
+
18
+ # --- Настройки ---
19
+ DEFAULT_RESULTS_FILE = "data/output/aggregated_results.xlsx" # Файл с агрегированными данными
20
+ DEFAULT_PLOTS_DIR = "data/output/plots" # Куда сохранять графики
21
+
22
+ # Настройки графиков
23
+ plt.rcParams['font.family'] = 'DejaVu Sans' # Шрифт с поддержкой кириллицы
24
+ sns.set_style("whitegrid")
25
+ FIGSIZE = (16, 10) # Увеличенный размер для сложных графиков
26
+ DPI = 300
27
+ PALETTE = "viridis" # Цветовая палитра
28
+
29
+ # --- Маппинг названий столбцов (копия из aggregate_results.py) ---
30
+ COLUMN_NAME_MAPPING = {
31
+ # Параметры запуска из pipeline.py
32
+ 'run_id': 'ID Запуска',
33
+ 'model_name': 'Модель',
34
+ 'chunking_strategy': 'Стратегия Чанкинга',
35
+ 'strategy_params': 'Параметры Стратегии',
36
+ 'process_tables': 'Обраб. Таблиц',
37
+ 'top_n': 'Top N',
38
+ 'use_injection': 'Сборка Контекста',
39
+ 'use_qe': 'Query Expansion',
40
+ 'neighbors_included': 'Вкл. Соседей',
41
+ 'similarity_threshold': 'Порог Схожести',
42
+
43
+ # Идентификаторы из датасета (для детальных результатов)
44
+ 'question_id': 'ID Вопроса',
45
+ 'question_text': 'Текст Вопроса',
46
+
47
+ # Детальные метрики из pipeline.py
48
+ 'chunk_text_precision': 'Точность (Чанк-Текст)',
49
+ 'chunk_text_recall': 'Полнота (Чанк-Текст)',
50
+ 'chunk_text_f1': 'F1 (Чанк-Текст)',
51
+ 'found_puncts': 'Найдено Пунктов',
52
+ 'total_puncts': 'Всего Пунктов',
53
+ 'relevant_chunks': 'Релевантных Чанков',
54
+ 'total_chunks_in_top_n': 'Всего Чанков в Топ-N',
55
+ 'assembly_punct_recall': 'Полнота (Сборка-Пункт)',
56
+ 'assembled_context_preview': 'Предпросмотр Сборки',
57
+ # 'top_chunk_ids': 'Индексы Топ-Чанков', # Списки, могут плохо отображаться
58
+ # 'top_chunk_similarities': 'Схожести Топ-Чанков', # Списки
59
+
60
+ # Агрегированные метрики (добавляются в calculate_aggregated_metrics)
61
+ 'weighted_chunk_text_precision': 'Weighted Точность (Чанк-Текст)',
62
+ 'weighted_chunk_text_recall': 'Weighted Полнота (Чанк-Текст)',
63
+ 'weighted_chunk_text_f1': 'Weighted F1 (Чанк-Текст)',
64
+ 'weighted_assembly_punct_recall': 'Weighted Полнота (Сборка-Пункт)',
65
+
66
+ 'macro_chunk_text_precision': 'Macro Точность (Чанк-Текст)',
67
+ 'macro_chunk_text_recall': 'Macro Полнота (Чанк-Текст)',
68
+ 'macro_chunk_text_f1': 'Macro F1 (Чанк-Текст)',
69
+ 'macro_assembly_punct_recall': 'Macro Полнота (Сборка-Пункт)',
70
+
71
+ 'micro_text_precision': 'Micro Точность (Текст)',
72
+ 'micro_text_recall': 'Micro Полнота (Текст)',
73
+ 'micro_text_f1': 'Micro F1 (Текст)',
74
+ }
75
+ # --- Конец маппинга ---
76
+
77
+ def parse_args():
78
+ """Парсит аргументы командной строки."""
79
+ parser = argparse.ArgumentParser(description="Визуализация результатов тестирования RAG")
80
+
81
+ parser.add_argument("--results-file", type=str, default=DEFAULT_RESULTS_FILE,
82
+ help=f"Путь к Excel-файлу с агрегированными результатами (по умолчанию: {DEFAULT_RESULTS_FILE})")
83
+ parser.add_argument("--plots-dir", type=str, default=DEFAULT_PLOTS_DIR,
84
+ help=f"Директория для сохранения графиков (по умолчанию: {DEFAULT_PLOTS_DIR})")
85
+ parser.add_argument("--sheet-name", type=str, default="Агрегированные метрики",
86
+ help="Название листа в Excel-файле для чтения данных")
87
+
88
+ return parser.parse_args()
89
+
90
+ def setup_plots_directory(plots_dir: str) -> None:
91
+ """Создает директорию для графиков, если она не существует."""
92
+ if not os.path.exists(plots_dir):
93
+ os.makedirs(plots_dir)
94
+ print(f"Создана директория для графиков: {plots_dir}")
95
+ else:
96
+ print(f"Использование существующей директории для графиков: {plots_dir}")
97
+
98
+ def load_aggregated_data(file_path: str, sheet_name: str) -> pd.DataFrame:
99
+ """Загружает данные из указанного листа Excel-файла."""
100
+ print(f"Загрузка данных из файла: {file_path}, лист: {sheet_name}")
101
+ try:
102
+ df = pd.read_excel(file_path, sheet_name=sheet_name)
103
+ print(f"Загружено {len(df)} строк.")
104
+ print(f"Колонки: {df.columns.tolist()}")
105
+ # Добавим проверку на необходимые колонки (РУССКИЕ НАЗВАНИЯ)
106
+ required_cols_rus = [
107
+ COLUMN_NAME_MAPPING['model_name'], COLUMN_NAME_MAPPING['chunking_strategy'],
108
+ COLUMN_NAME_MAPPING['strategy_params'], COLUMN_NAME_MAPPING['process_tables'],
109
+ COLUMN_NAME_MAPPING['top_n'], COLUMN_NAME_MAPPING['use_injection'],
110
+ COLUMN_NAME_MAPPING['use_qe'], COLUMN_NAME_MAPPING['neighbors_included'],
111
+ COLUMN_NAME_MAPPING['similarity_threshold']
112
+ ]
113
+ # Проверяем только те, что есть в маппинге
114
+ missing_required = [col for col in required_cols_rus if col not in df.columns]
115
+ if missing_required:
116
+ print(f"Предупреждение: Не все ожидаемые колонки параметров найдены в данных: {missing_required}")
117
+
118
+ # --- Добавим парсинг strategy_params из JSON строки в словарь ---
119
+ params_col = COLUMN_NAME_MAPPING['strategy_params']
120
+ if params_col in df.columns:
121
+ def safe_json_loads(x):
122
+ try:
123
+ # Обработка NaN и пустых строк
124
+ if pd.isna(x) or not isinstance(x, str) or not x.strip():
125
+ return {}
126
+ return json.loads(x)
127
+ except (json.JSONDecodeError, TypeError):
128
+ return {} # Возвращаем пустой словарь при ошибке
129
+
130
+ df[params_col] = df[params_col].apply(safe_json_loads)
131
+ # Создаем строковое представление для группировки и лейблов
132
+ df[f"{params_col}_str"] = df[params_col].apply(
133
+ lambda d: json.dumps(d, sort_keys=True, ensure_ascii=False)
134
+ )
135
+ print(f"Колонка '{params_col}' преобразована из JSON строк.")
136
+ # --------------------------------------------------------------
137
+
138
+ return df
139
+ except FileNotFoundError:
140
+ print(f"Ошибка: Файл не найден: {file_path}")
141
+ return pd.DataFrame()
142
+ except ValueError as e:
143
+ print(f"Ошибка: Лист '{sheet_name}' не найден в файле {file_path}. Доступные листы: {pd.ExcelFile(file_path).sheet_names}")
144
+ return pd.DataFrame()
145
+ except Exception as e:
146
+ print(f"Ошибка при чтении Excel файла: {e}")
147
+ return pd.DataFrame()
148
+
149
+ # --- Функции построения графиков --- #
150
+
151
+ def plot_metric_vs_top_n(
152
+ df: pd.DataFrame,
153
+ metric_name_rus: str, # Ожидаем русское имя метрики
154
+ fixed_strategy: str | None,
155
+ fixed_strategy_params: str | None, # Ожидаем строку JSON или None
156
+ plots_dir: str
157
+ ) -> None:
158
+ """
159
+ Строит график зависимости метрики от top_n для разных моделей
160
+ (при фиксированных параметрах чанкинга).
161
+ Разделяет линии по значению use_injection.
162
+ Использует русские названия колонок.
163
+ """
164
+ # Используем русские названия колонок из маппинга
165
+ metric_col_rus = metric_name_rus # Передаем уже готовое русское имя
166
+ top_n_col_rus = COLUMN_NAME_MAPPING['top_n']
167
+ model_col_rus = COLUMN_NAME_MAPPING['model_name']
168
+ injection_col_rus = COLUMN_NAME_MAPPING['use_injection']
169
+ strategy_col_rus = COLUMN_NAME_MAPPING['chunking_strategy']
170
+ params_str_col_rus = f"{COLUMN_NAME_MAPPING['strategy_params']}_str" # Используем строковое представление
171
+
172
+ if metric_col_rus not in df.columns:
173
+ print(f"График пропущен: Колонка '{metric_col_rus}' не найдена.")
174
+ return
175
+
176
+ plot_df = df.copy()
177
+
178
+ # Фильтруем по параметрам чанкинга, если задано
179
+ chunk_suffix = "all_strategies_all_params"
180
+ if fixed_strategy and strategy_col_rus in plot_df.columns:
181
+ plot_df = plot_df[plot_df[strategy_col_rus] == fixed_strategy]
182
+ chunk_suffix = f"strategy_{fixed_strategy}"
183
+ # Фильтруем по строковому пред��тавлению параметров
184
+ if fixed_strategy_params and params_str_col_rus in plot_df.columns:
185
+ plot_df = plot_df[plot_df[params_str_col_rus] == fixed_strategy_params]
186
+ # Генерируем короткий хэш для параметров в названии файла
187
+ params_hash = hash(fixed_strategy_params) # Хэш от строки
188
+ chunk_suffix += f"_params-{params_hash:x}" # Hex hash
189
+
190
+ if plot_df.empty:
191
+ print(f"График Metric vs Top-N пропущен: Нет данных для strategy={fixed_strategy}, params={fixed_strategy_params}")
192
+ return
193
+
194
+ plt.figure(figsize=FIGSIZE)
195
+ sns.lineplot(
196
+ data=plot_df,
197
+ x=top_n_col_rus,
198
+ y=metric_col_rus,
199
+ hue=model_col_rus,
200
+ style=injection_col_rus, # Разные стили линий для True/False
201
+ markers=True,
202
+ markersize=8,
203
+ linewidth=2,
204
+ palette=PALETTE
205
+ )
206
+
207
+ plt.title(f"Зависимость {metric_col_rus} от top_n ({chunk_suffix})")
208
+ plt.xlabel("Top N")
209
+ plt.ylabel(metric_col_rus.replace("_", " ").title())
210
+ plt.legend(title="Модель / Сборка", bbox_to_anchor=(1.05, 1), loc='upper left')
211
+ plt.grid(True, linestyle='--', alpha=0.7)
212
+ plt.tight_layout(rect=[0, 0, 0.85, 1]) # Оставляем место для легенды
213
+
214
+ filename = f"plot_{metric_col_rus.replace(' ', '_').replace('(', '').replace(')', '')}_vs_top_n_{chunk_suffix}.png"
215
+ filepath = os.path.join(plots_dir, filename)
216
+ plt.savefig(filepath, dpi=DPI)
217
+ plt.close()
218
+ print(f"Создан график: {filepath}")
219
+
220
+ def plot_injection_comparison(
221
+ df: pd.DataFrame,
222
+ metric_name_rus: str, # Ожидаем русское имя метрики
223
+ plots_dir: str
224
+ ) -> None:
225
+ """
226
+ Сравнивает метрики с использованием и без использования сборки контекста
227
+ в виде парных столбчатых диаграмм для разных моделей и параметров чанкинга.
228
+ Использует русские названия колонок.
229
+ """
230
+ # Русские названия колонок
231
+ metric_col_rus = metric_name_rus
232
+ injection_col_rus = COLUMN_NAME_MAPPING['use_injection']
233
+ model_col_rus = COLUMN_NAME_MAPPING['model_name']
234
+ strategy_col_rus = COLUMN_NAME_MAPPING['chunking_strategy']
235
+ params_str_col_rus = f"{COLUMN_NAME_MAPPING['strategy_params']}_str"
236
+ tables_col_rus = COLUMN_NAME_MAPPING['process_tables']
237
+ qe_col_rus = COLUMN_NAME_MAPPING['use_qe']
238
+ neighbors_col_rus = COLUMN_NAME_MAPPING['neighbors_included']
239
+ top_n_col_rus = COLUMN_NAME_MAPPING['top_n']
240
+ threshold_col_rus = COLUMN_NAME_MAPPING['similarity_threshold']
241
+
242
+ if metric_col_rus not in df.columns or injection_col_rus not in df.columns:
243
+ print(f"График сравнения сборки пропущен: Колонки '{metric_col_rus}' или '{injection_col_rus}' не найдены.")
244
+ return
245
+
246
+ plot_df = df.copy()
247
+ # Используем русские названия при создании лейбла
248
+ plot_df['config_label'] = plot_df.apply(
249
+ lambda r: (
250
+ f"{r.get(model_col_rus, 'N/A')}\n"
251
+ f"Стратегия: {r.get(strategy_col_rus, 'N/A')}\n"
252
+ # Используем строковое представление параметров
253
+ f"Параметры: {r.get(params_str_col_rus, '{}')[:30]}...\n"
254
+ f"Табл: {r.get(tables_col_rus, 'N/A')}, QE: {r.get(qe_col_rus, 'N/A')}, Соседи: {r.get(neighbors_col_rus, 'N/A')}\n"
255
+ f"TopN: {int(r.get(top_n_col_rus, 0))}, Порог: {r.get(threshold_col_rus, 0):.2f}"
256
+ ),
257
+ axis=1
258
+ )
259
+
260
+ # Оставляем только строки, где есть и True, и False для данного флага
261
+ # Группируем по config_label, считаем уникальные значения флага use_injection
262
+ counts = plot_df.groupby('config_label')[injection_col_rus].nunique()
263
+ configs_with_both = counts[counts >= 2].index # Используем >= 2 на случай дубликатов
264
+ plot_df = plot_df[plot_df['config_label'].isin(configs_with_both)]
265
+
266
+ if plot_df.empty:
267
+ print(f"График сравнения сборки пропущен: Нет конфигураций с обоими вариантами {injection_col_rus}.")
268
+ return
269
+
270
+ # Ограничим количество конфигураций для читаемости (по средней метрике)
271
+ top_configs = plot_df.groupby('config_label')[metric_col_rus].mean().nlargest(10).index # Уменьшил до 10
272
+ plot_df = plot_df[plot_df['config_label'].isin(top_configs)]
273
+
274
+ if plot_df.empty:
275
+ print(f"График сравнения сборки пропущен: Не осталось да��ных после фильтрации топ-конфигураций.")
276
+ return
277
+
278
+ plt.figure(figsize=(FIGSIZE[0]*0.9, FIGSIZE[1]*0.7)) # Уменьшил размер
279
+ sns.barplot(
280
+ data=plot_df,
281
+ x='config_label',
282
+ y=metric_col_rus,
283
+ hue=injection_col_rus,
284
+ palette=PALETTE
285
+ )
286
+
287
+ plt.title(f"Сравнение {metric_col_rus} с/без {injection_col_rus}")
288
+ plt.xlabel("Конфигурация")
289
+ plt.ylabel(metric_col_rus)
290
+ plt.xticks(rotation=60, ha='right', fontsize=8) # Уменьшил шрифт, увеличил поворот
291
+ plt.legend(title=injection_col_rus)
292
+ plt.grid(True, axis='y', linestyle='--', alpha=0.7)
293
+ plt.tight_layout()
294
+
295
+ filename = f"plot_{metric_col_rus.replace(' ', '_').replace('(', '').replace(')', '')}_injection_comparison.png"
296
+ filepath = os.path.join(plots_dir, filename)
297
+ plt.savefig(filepath, dpi=DPI)
298
+ plt.close()
299
+ print(f"Создан график: {filepath}")
300
+
301
+ # --- Новая функция для сравнения булевых флагов ---
302
+ def plot_boolean_flag_comparison(
303
+ df: pd.DataFrame,
304
+ metric_name_rus: str, # Ожидаем русское имя метрики
305
+ flag_column_eng: str, # Ожидаем английское имя флага для поиска в маппинге
306
+ plots_dir: str
307
+ ) -> None:
308
+ """
309
+ Сравнивает метрики при True/False значениях указанного булева флага
310
+ в виде парных столбчатых диаграмм для разных конфигураций.
311
+ Использует русские названия колонок.
312
+ """
313
+ # Русские названия колонок
314
+ metric_col_rus = metric_name_rus
315
+ try:
316
+ flag_col_rus = COLUMN_NAME_MAPPING[flag_column_eng]
317
+ except KeyError:
318
+ print(f"Ошибка: Английское имя флага '{flag_column_eng}' не найдено в COLUMN_NAME_MAPPING.")
319
+ return
320
+
321
+ model_col_rus = COLUMN_NAME_MAPPING['model_name']
322
+ strategy_col_rus = COLUMN_NAME_MAPPING['chunking_strategy']
323
+ params_str_col_rus = f"{COLUMN_NAME_MAPPING['strategy_params']}_str"
324
+ injection_col_rus = COLUMN_NAME_MAPPING['use_injection']
325
+ top_n_col_rus = COLUMN_NAME_MAPPING['top_n']
326
+ # Другие флаги
327
+ tables_col_rus = COLUMN_NAME_MAPPING['process_tables']
328
+ qe_col_rus = COLUMN_NAME_MAPPING['use_qe']
329
+ neighbors_col_rus = COLUMN_NAME_MAPPING['neighbors_included']
330
+
331
+
332
+ if metric_col_rus not in df.columns or flag_col_rus not in df.columns:
333
+ print(f"График сравнения флага '{flag_col_rus}' пропущен: Колонки '{metric_col_rus}' или '{flag_col_rus}' не найдены.")
334
+ return
335
+
336
+ plot_df = df.copy()
337
+ # Создаем обобщенный лейбл конфигурации, исключая сам флаг
338
+ plot_df['config_label'] = plot_df.apply(
339
+ lambda r: (
340
+ f"{r.get(model_col_rus, 'N/A')}\n"
341
+ f"Стратегия: {r.get(strategy_col_rus, 'N/A')} Параметры: {r.get(params_str_col_rus, '{}')[:20]}...\n"
342
+ f"Сборка: {r.get(injection_col_rus, 'N/A')}, TopN: {int(r.get(top_n_col_rus, 0))}"
343
+ # Динамически добавляем другие флаги, кроме сравниваемого
344
+ + (f", Табл: {r.get(tables_col_rus, 'N/A')}" if flag_col_rus != tables_col_rus else "")
345
+ + (f", QE: {r.get(qe_col_rus, 'N/A')}" if flag_col_rus != qe_col_rus else "")
346
+ + (f", Соседи: {r.get(neighbors_col_rus, 'N/A')}" if flag_col_rus != neighbors_col_rus else "")
347
+ ),
348
+ axis=1
349
+ )
350
+
351
+ # Оставляем только строки, где есть и True, и False для данного флага
352
+ counts = plot_df.groupby('config_label')[flag_col_rus].nunique()
353
+ configs_with_both = counts[counts >= 2].index # Используем >= 2
354
+ plot_df = plot_df[plot_df['config_label'].isin(configs_with_both)]
355
+
356
+ if plot_df.empty:
357
+ print(f"График сравнения флага '{flag_col_rus}' пропущен: Нет конфигураций с обоими вариантами {flag_col_rus}.")
358
+ return
359
+
360
+ # Ограничим количество конфигураций для читаемости (по средней метрике)
361
+ top_configs = plot_df.groupby('config_label')[metric_col_rus].mean().nlargest(10).index # Уменьшил до 10
362
+ plot_df = plot_df[plot_df['config_label'].isin(top_configs)]
363
+
364
+ if plot_df.empty:
365
+ print(f"График сравнения флага '{flag_col_rus}' пропущен: Не осталось данных после фильтрации топ-конфигураций.")
366
+ return
367
+
368
+ plt.figure(figsize=(FIGSIZE[0]*0.9, FIGSIZE[1]*0.7)) # Уменьшил размер
369
+ sns.barplot(
370
+ data=plot_df,
371
+ x='config_label',
372
+ y=metric_col_rus,
373
+ hue=flag_col_rus,
374
+ palette=PALETTE
375
+ )
376
+
377
+ plt.title(f"Сравнение {metric_col_rus} в зависимости от '{flag_col_rus}'")
378
+ plt.xlabel("Конфигурация")
379
+ plt.ylabel(metric_col_rus)
380
+ plt.xticks(rotation=60, ha='right', fontsize=8) # Уменьшил шрифт, увеличил поворот
381
+ plt.legend(title=f"{flag_col_rus}")
382
+ plt.grid(True, axis='y', linestyle='--', alpha=0.7)
383
+ plt.tight_layout()
384
+
385
+ filename = f"plot_{metric_col_rus.replace(' ', '_').replace('(', '').replace(')', '')}_{flag_column_eng}_comparison.png"
386
+ filepath = os.path.join(plots_dir, filename)
387
+ plt.savefig(filepath, dpi=DPI)
388
+ plt.close()
389
+ print(f"Создан график: {filepath}")
390
+
391
+ # --- Основная функция ---
392
+ def main():
393
+ """Основная функция скрипта."""
394
+ args = parse_args()
395
+
396
+ setup_plots_directory(args.plots_dir)
397
+ df = load_aggregated_data(args.results_file, args.sheet_name)
398
+
399
+ if df.empty:
400
+ print("Нет данных для построения графиков. Завершение.")
401
+ return
402
+
403
+ # Определяем метрики для построения графиков (используем английские ключи для поиска русских имен)
404
+ metric_keys = [
405
+ 'weighted_chunk_text_recall', 'weighted_chunk_text_f1', 'weighted_assembly_punct_recall',
406
+ 'macro_chunk_text_recall', 'macro_chunk_text_f1', 'macro_assembly_punct_recall',
407
+ 'micro_text_recall', 'micro_text_f1'
408
+ ]
409
+
410
+ # Получаем существующие русские имена метрик в DataFrame
411
+ existing_metrics_rus = [COLUMN_NAME_MAPPING.get(key) for key in metric_keys if COLUMN_NAME_MAPPING.get(key) in df.columns]
412
+
413
+ # Определяем фиксированные параметры для некоторых графиков
414
+ strategy_col_rus = COLUMN_NAME_MAPPING.get('chunking_strategy')
415
+ params_str_col_rus = f"{COLUMN_NAME_MAPPING.get('strategy_params')}_str"
416
+ model_col_rus = COLUMN_NAME_MAPPING.get('model_name')
417
+
418
+ fixed_strategy_example = df[strategy_col_rus].unique()[0] if strategy_col_rus in df.columns and len(df[strategy_col_rus].unique()) > 0 else None
419
+ fixed_strategy_params_example = None
420
+ if fixed_strategy_example and params_str_col_rus in df.columns:
421
+ params_list = df[df[strategy_col_rus] == fixed_strategy_example][params_str_col_rus].unique()
422
+ if len(params_list) > 0:
423
+ fixed_strategy_params_example = params_list[0]
424
+
425
+ fixed_model_example = df[model_col_rus].unique()[0] if model_col_rus in df.columns and len(df[model_col_rus].unique()) > 0 else None
426
+ fixed_top_n_example = 20
427
+
428
+ print("--- Построение графиков ---")
429
+
430
+ # 1. Графики Metric vs Top-N
431
+ print("\n1. Зависимость метрик от Top-N:")
432
+ for metric_name_rus in existing_metrics_rus:
433
+ # Проверяем, что метрика не micro (у micro нет зависимости от top_n)
434
+ if 'Micro' in metric_name_rus:
435
+ continue
436
+ plot_metric_vs_top_n(
437
+ df, metric_name_rus,
438
+ fixed_strategy_example, fixed_strategy_params_example,
439
+ args.plots_dir
440
+ )
441
+
442
+ # 2. Графики Metric vs Chunking
443
+ print("\n2. Зависимость метрик от параметров чанкинга: [Пропущено - требует переосмысления]")
444
+ # plot_metric_vs_chunking(...) # Закомментировано
445
+
446
+ # 3. Графики сравнения Use Injection
447
+ print("\n3. Сравнение метрик с/без сборки контекста:")
448
+ for metric_name_rus in existing_metrics_rus:
449
+ plot_injection_comparison(df, metric_name_rus, args.plots_dir)
450
+
451
+ # 4. Графики сравнения других булевых флагов
452
+ boolean_flags_eng = ['process_tables', 'use_qe', 'neighbors_included']
453
+ print("\n4. Сравнение метрик в зависимости от булевых флагов:")
454
+ for flag_eng in boolean_flags_eng:
455
+ flag_rus = COLUMN_NAME_MAPPING.get(flag_eng)
456
+ if not flag_rus or flag_rus not in df.columns:
457
+ print(f" Пропуск сравнения для флага: '{flag_eng}' (колонка '{flag_rus}' не найдена)")
458
+ continue
459
+ print(f" Сравнение для флага: '{flag_rus}'")
460
+ for metric_name_rus in existing_metrics_rus:
461
+ plot_boolean_flag_comparison(df, metric_name_rus, flag_eng, args.plots_dir)
462
+
463
+ print("\n--- Построение графиков завершено ---")
464
+
465
+ if __name__ == "__main__":
466
+ main()
scripts/testing/run_pipelines.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Скрипт для запуска множества пайплайнов оценки (`pipeline.py`)
5
+ с различными комбинациями параметров.
6
+
7
+ Собирает команды для `pipeline.py` и запускает их последовательно,
8
+ логируя вывод каждого запуска.
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+ import pathlib
15
+ import subprocess
16
+ import sys
17
+ import time
18
+ from datetime import datetime
19
+ from itertools import product
20
+ from uuid import uuid4
21
+
22
+ # --- Конфигурация Экспериментов ---
23
+
24
+ # Модели для тестирования
25
+ MODELS_TO_TEST = [
26
+ # "intfloat/e5-base",
27
+ # "intfloat/e5-large",
28
+ "BAAI/bge-m3",
29
+ # "deepvk/USER-bge-m3"
30
+ # "ai-forever/FRIDA" # Требует --use-sentence-transformers
31
+ ]
32
+
33
+ # Параметры чанкинга (слова / перекрытие)
34
+ CHUNKING_PARAMS = [
35
+ # Пример для стратегии "fixed_size"
36
+ {"strategy": "fixed_size", "params": {"words_per_chunk": 50, "overlap_words": 25}},
37
+ # {"strategy": "fixed_size", "params": {"words_per_chunk": 100, "overlap_words": 25}},
38
+ # {"strategy": "fixed_size", "params": {"words_per_chunk": 50, "overlap_words": 0}},
39
+ # TODO: Добавить другие стратегии и их параметры, если нужно
40
+ # {"strategy": "some_other_strategy", "params": {"param1": "value1"}}
41
+ ]
42
+
43
+ # Значения Top-N для ретривера
44
+ TOP_N_VALUES = [20, 50, 100]
45
+
46
+ # Использовать ли сборку контекста (InjectionBuilder)
47
+ USE_INJECTION_OPTIONS = [False, True]
48
+
49
+ # Порог схожести для fuzzy сравнения (чанк/пункт)
50
+ SIMILARITY_THRESHOLDS = [0.7]
51
+
52
+ # Опции использования Query Expansion
53
+ USE_QE_OPTIONS = [False, True]
54
+
55
+ # Опции обработки таблиц
56
+ PROCESS_TABLES_OPTIONS = [True]
57
+
58
+ # Опции включения соседей
59
+ INCLUDE_NEIGHBORS_OPTIONS = [True]
60
+
61
+ # --- Настройки Скрипта ---
62
+ DEFAULT_LOG_DIR = "logs" # Директория для логов отдельных запусков pipeline.py
63
+ DEFAULT_INTERMEDIATE_DIR = "data/intermediate" # Куда pipeline.py сохраняет свои результаты
64
+ DEFAULT_PYTHON_EXECUTABLE = sys.executable # Использовать тот же python, что и для запуска этого скрипта
65
+
66
+ def parse_args():
67
+ """Парсит аргументы командной строки."""
68
+ parser = argparse.ArgumentParser(description="Запуск серии оценочных пайплайнов")
69
+
70
+ # Флаги для пропуска определенных измерений
71
+ parser.add_argument("--skip-models", action="store_true",
72
+ help="Пропустить итерацию по разным моделям (использовать первую в списке)")
73
+ parser.add_argument("--skip-chunking", action="store_true",
74
+ help="Пропустить итерацию по разным параметрам чанкинга (использовать первую в списке)")
75
+ parser.add_argument("--skip-top-n", action="store_true",
76
+ help="Пропустить итерацию по разным top_n (использовать первое значение)")
77
+ parser.add_argument("--skip-injection", action="store_true",
78
+ help="Пропустить итерацию по опциям сборки контекста (использовать False)")
79
+ parser.add_argument("--skip-thresholds", action="store_true",
80
+ help="Пропустить итерацию по порогам схожести (использовать первый)")
81
+ parser.add_argument("--skip-process-tables", action="store_true",
82
+ help="Пропустить итерацию по обработке таблиц (использовать True)")
83
+ parser.add_argument("--skip-include-neighbors", action="store_true",
84
+ help="Пропустить итерацию по включению соседей (использовать False)")
85
+ parser.add_argument("--skip-qe", action="store_true",
86
+ help="Пропустить итерацию по использованию Query Expansion (использовать False)")
87
+
88
+ # Настройки путей и выполнения
89
+ parser.add_argument("--log-dir", type=str, default=DEFAULT_LOG_DIR,
90
+ help=f"Директория для сохранения логов запусков (по умолчанию: {DEFAULT_LOG_DIR})")
91
+ parser.add_argument("--intermediate-dir", type=str, default=DEFAULT_INTERMEDIATE_DIR,
92
+ help=f"Директория для промежуточных результатов pipeline.py (по умолчанию: {DEFAULT_INTERMEDIATE_DIR})")
93
+ parser.add_argument("--device", type=str, default="cuda:0",
94
+ help="Устройство для вычислений в pipeline.py (напр., cpu, cuda:0)")
95
+ parser.add_argument("--python-executable", type=str, default=DEFAULT_PYTHON_EXECUTABLE,
96
+ help="Путь к интерпретатору Python для запуска pipeline.py")
97
+
98
+ # Параметры, передаваемые в pipeline.py (если не перебираются)
99
+ parser.add_argument("--data-folder", type=str, default="data/input/docs", help="Папка с документами для pipeline.py")
100
+ parser.add_argument("--search-dataset-path", type=str, default="data/input/search_dataset_text.xlsx", help="Поисковый датасет для pipeline.py")
101
+ parser.add_argument("--qa-dataset-path", type=str, default="data/input/question_answering.xlsx", help="QA датасет для pipeline.py")
102
+
103
+ return parser.parse_args()
104
+
105
+ def run_single_pipeline(cmd: list[str], log_path: str):
106
+ """
107
+ Запускает один экземпляр pipeline.py и логирует его вывод.
108
+
109
+ Args:
110
+ cmd: Список аргументов команды для subprocess.
111
+ log_path: Путь к файлу для сохранения лога.
112
+
113
+ Returns:
114
+ Код возврата процесса.
115
+ """
116
+ print(f"\n--- Запуск: {' '.join(cmd)} ---")
117
+ print(f"--- Лог: {log_path} --- ")
118
+
119
+ start_time = time.time()
120
+ return_code = -1
121
+
122
+ try:
123
+ with open(log_path, "w", encoding="utf-8") as log_file:
124
+ log_file.write(f"Команда: {' '.join(cmd)}\n")
125
+ log_file.write(f"Время запуска: {datetime.now()}\n\n")
126
+ log_file.flush()
127
+
128
+ # Запускаем процесс
129
+ process = subprocess.Popen(
130
+ cmd,
131
+ stdout=subprocess.PIPE,
132
+ stderr=subprocess.STDOUT, # Перенаправляем stderr в stdout
133
+ text=True,
134
+ encoding='utf-8', # Указываем кодировку
135
+ errors='replace', # Заменяем ошибки кодирования
136
+ bufsize=1 # Построчная буферизация
137
+ )
138
+
139
+ # Читаем и пишем вывод построчно
140
+ for line in process.stdout:
141
+ print(line, end="") # Выводим в консоль
142
+ log_file.write(line) # Пишем в лог
143
+ log_file.flush()
144
+
145
+ # Ждем завершения и получаем код возврата
146
+ process.wait()
147
+ return_code = process.returncode
148
+
149
+ except Exception as e:
150
+ print(f"\nОшибка при запуске процесса: {e}")
151
+ with open(log_path, "a", encoding="utf-8") as log_file:
152
+ log_file.write(f"\nОшибка при запуске: {e}\n")
153
+ return_code = 1 # Считаем ошибкой
154
+
155
+ end_time = time.time()
156
+ duration = end_time - start_time
157
+
158
+ result_message = f"Успешно завершено за {duration:.2f} сек."
159
+ if return_code != 0:
160
+ result_message = f"Завершено с ошибкой (код {return_code}) за {duration:.2f} сек."
161
+
162
+ print(f"--- {result_message} ---")
163
+ with open(log_path, "a", encoding="utf-8") as log_file:
164
+ log_file.write(f"\nВремя завершения: {datetime.now()}")
165
+ log_file.write(f"\nДлительность: {duration:.2f} сек.")
166
+ log_file.write(f"\nКод возврата: {return_code}\n")
167
+
168
+ return return_code
169
+
170
+ def main():
171
+ """Основная функция скрипта."""
172
+ args = parse_args()
173
+
174
+ # --- Генерируем ID для всей серии запусков ---
175
+ batch_run_id = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
176
+ print(f"Запуск серии экспериментов. Batch ID: {batch_run_id}")
177
+
178
+ # Создаем директории для логов и промежуточных результатов
179
+ os.makedirs(args.log_dir, exist_ok=True)
180
+ os.makedirs(args.intermediate_dir, exist_ok=True)
181
+
182
+ # Определяем абсолютный путь к pipeline.py
183
+ RUN_PIPELINES_SCRIPT_PATH = pathlib.Path(__file__).resolve()
184
+ SCRIPTS_TESTING_DIR = RUN_PIPELINES_SCRIPT_PATH.parent
185
+ PIPELINE_SCRIPT_PATH = SCRIPTS_TESTING_DIR / "pipeline.py"
186
+
187
+ # --- Определяем параметры для перебора ---
188
+ models = [MODELS_TO_TEST[0]] if args.skip_models else MODELS_TO_TEST
189
+ chunking_configs = [CHUNKING_PARAMS[0]] if args.skip_chunking else CHUNKING_PARAMS
190
+ top_n_list = [TOP_N_VALUES[0]] if args.skip_top_n else TOP_N_VALUES
191
+ use_injection_list = [False] if args.skip_injection else USE_INJECTION_OPTIONS
192
+ threshold_list = [SIMILARITY_THRESHOLDS[0]] if args.skip_thresholds else SIMILARITY_THRESHOLDS
193
+
194
+ # Определяем списки для новых измерений
195
+ process_tables_list = [PROCESS_TABLES_OPTIONS[0]] if args.skip_process_tables else PROCESS_TABLES_OPTIONS
196
+ include_neighbors_list = [INCLUDE_NEIGHBORS_OPTIONS[0]] if args.skip_include_neighbors else INCLUDE_NEIGHBORS_OPTIONS
197
+ use_qe_list = [USE_QE_OPTIONS[0]] if args.skip_qe else USE_QE_OPTIONS
198
+
199
+ # --- Создаем список всех комбинаций параметров ---
200
+ parameter_combinations = list(product(
201
+ models,
202
+ chunking_configs,
203
+ top_n_list,
204
+ use_injection_list,
205
+ threshold_list,
206
+ process_tables_list,
207
+ include_neighbors_list,
208
+ use_qe_list
209
+ ))
210
+
211
+ total_runs = len(parameter_combinations)
212
+ print(f"Всего запланировано запусков: {total_runs}")
213
+
214
+ # --- Запускаем пайплайны для каждой комбинации ---
215
+ completed_runs = 0
216
+ failed_runs = 0
217
+ start_time_all = time.time()
218
+
219
+ for i, (model, chunk_cfg, top_n, use_injection, threshold, process_tables, include_neighbors, use_qe) in enumerate(parameter_combinations):
220
+ print(f"\n{'='*80}")
221
+ print(f"Запуск {i+1}/{total_runs}")
222
+ print(f" Модель: {model}")
223
+ # Логируем параметры чанкинга
224
+ strategy = chunk_cfg['strategy']
225
+ params = chunk_cfg['params']
226
+ params_str = json.dumps(params, ensure_ascii=False)
227
+ print(f" Чанкинг: Стратегия='{strategy}', Параметры={params_str}")
228
+ print(f" Обработка таблиц: {process_tables}")
229
+ print(f" Top-N: {top_n}")
230
+ print(f" Сборка контекста: {use_injection}")
231
+ print(f" Query Expansion: {use_qe}")
232
+ print(f" Включение соседей: {include_neighbors}")
233
+ print(f" Порог схожести: {threshold}")
234
+ print(f"{'='*80}")
235
+
236
+ # Генерируем уникальный ID для этого запуска
237
+ run_id = f"run_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid4().hex[:8]}"
238
+
239
+ # Формируем команду для pipeline.py
240
+ cmd = [
241
+ args.python_executable,
242
+ str(PIPELINE_SCRIPT_PATH), # Используем абсолютный путь
243
+ "--run-id", run_id,
244
+ "--batch-id", batch_run_id,
245
+ "--data-folder", args.data_folder,
246
+ "--search-dataset-path", args.search_dataset_path,
247
+ "--output-dir", args.intermediate_dir,
248
+ "--model-name", model,
249
+ "--chunking-strategy", strategy,
250
+ "--strategy-params", params_str,
251
+ "--top-n", str(top_n),
252
+ "--similarity-threshold", str(threshold),
253
+ "--device", args.device,
254
+ ]
255
+
256
+ # Добавляем флаг --use-injection, если нужно
257
+ if use_injection:
258
+ cmd.append("--use-injection")
259
+
260
+ # Добавляем флаг --no-process-tables, если process_tables == False
261
+ if not process_tables:
262
+ cmd.append("--no-process-tables")
263
+
264
+ # Добавляем флаг --include-neighbors, если include_neighbors == True
265
+ if include_neighbors:
266
+ cmd.append("--include-neighbors")
267
+
268
+ # Добавляем флаг --use-qe, если use_qe == True
269
+ if use_qe:
270
+ cmd.append("--use-qe")
271
+
272
+ # Добавляем флаг --use-sentence-transformers для определенных моделей
273
+ if "FRIDA" in model or "sentence-transformer" in model.lower(): # Пример
274
+ cmd.append("--use-sentence-transformers")
275
+
276
+ # Формируем путь к лог-файлу
277
+ log_filename = f"{run_id}_log.txt"
278
+ log_path = os.path.join(args.log_dir, log_filename)
279
+
280
+ # Запускаем пайплайн
281
+ return_code = run_single_pipeline(cmd, log_path)
282
+
283
+ if return_code == 0:
284
+ completed_runs += 1
285
+ else:
286
+ failed_runs += 1
287
+ print(f"*** ВНИМАНИЕ: Запуск {i+1} завершился с ошибкой! Лог: {log_path} ***")
288
+
289
+ # --- Вывод итоговой статистики ---
290
+ end_time_all = time.time()
291
+ total_duration = end_time_all - start_time_all
292
+
293
+ print(f"\n{'='*80}")
294
+ print("Все запуски завершены.")
295
+ print(f"Общее время выполнения: {total_duration:.2f} сек ({total_duration/60:.2f} мин)")
296
+ print(f"Всего запусков: {total_runs}")
297
+ print(f"Успешно завершено: {completed_runs}")
298
+ print(f"Завершено с ошибками: {failed_runs}")
299
+ print(f"Промежуточные результаты сохранены в: {args.intermediate_dir}")
300
+ print(f"Логи запусков сохранены в: {args.log_dir}")
301
+ print(f"{'='*80}")
302
+
303
+ if __name__ == "__main__":
304
+ main()