Spaces:
Sleeping
Sleeping
update
Browse files- common/configuration.py +1 -0
- common/dependencies.py +18 -0
- components/dbo/models/entity.py +2 -1
- components/llm/prompts.py +113 -1
- components/services/dataset.py +5 -6
- components/services/dialogue.py +136 -0
- config_dev.yaml +2 -0
- lib/extractor/ntr_text_fragmentation/core/injection_builder.py +6 -22
- lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy_repository.py +4 -14
- lib/extractor/pyproject.toml +1 -1
- routes/llm.py +17 -16
- scripts/compare_repositories.py +327 -0
- scripts/testing/aggregate_results.py +483 -0
- scripts/testing/pipeline.py +1034 -0
- scripts/testing/plot_results.py +466 -0
- scripts/testing/run_pipelines.py +304 -0
common/configuration.py
CHANGED
@@ -185,6 +185,7 @@ class SearchConfiguration:
|
|
185 |
self.abbreviation_search = AbbreviationSearchConfiguration(
|
186 |
config_data['abbreviation_search']
|
187 |
)
|
|
|
188 |
|
189 |
|
190 |
class FilesConfiguration:
|
|
|
185 |
self.abbreviation_search = AbbreviationSearchConfiguration(
|
186 |
config_data['abbreviation_search']
|
187 |
)
|
188 |
+
self.use_qe = bool(config_data['use_qe'])
|
189 |
|
190 |
|
191 |
class FilesConfiguration:
|
common/dependencies.py
CHANGED
@@ -14,6 +14,7 @@ from components.embedding_extraction import EmbeddingExtractor
|
|
14 |
from components.llm.common import LlmParams
|
15 |
from components.llm.deepinfra_api import DeepInfraApi
|
16 |
from components.services.dataset import DatasetService
|
|
|
17 |
from components.services.document import DocumentService
|
18 |
from components.services.entity import EntityService
|
19 |
from components.services.llm_config import LLMConfigService
|
@@ -102,3 +103,20 @@ def get_llm_service(
|
|
102 |
|
103 |
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
|
104 |
return LlmPromptService(db)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from components.llm.common import LlmParams
|
15 |
from components.llm.deepinfra_api import DeepInfraApi
|
16 |
from components.services.dataset import DatasetService
|
17 |
+
from components.services.dialogue import DialogueService
|
18 |
from components.services.document import DocumentService
|
19 |
from components.services.entity import EntityService
|
20 |
from components.services.llm_config import LLMConfigService
|
|
|
103 |
|
104 |
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
|
105 |
return LlmPromptService(db)
|
106 |
+
|
107 |
+
|
108 |
+
def get_dialogue_service(
|
109 |
+
config: Annotated[Configuration, Depends(get_config)],
|
110 |
+
entity_service: Annotated[EntityService, Depends(get_entity_service)],
|
111 |
+
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
|
112 |
+
llm_api: Annotated[DeepInfraApi, Depends(get_llm_service)],
|
113 |
+
llm_config_service: Annotated[LLMConfigService, Depends(get_llm_config_service)],
|
114 |
+
) -> DialogueService:
|
115 |
+
"""Получение сервиса для работы с диалогами через DI."""
|
116 |
+
return DialogueService(
|
117 |
+
config=config,
|
118 |
+
entity_service=entity_service,
|
119 |
+
dataset_service=dataset_service,
|
120 |
+
llm_api=llm_api,
|
121 |
+
llm_config_service=llm_config_service,
|
122 |
+
)
|
components/dbo/models/entity.py
CHANGED
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
6 |
from sqlalchemy.types import TypeDecorator
|
7 |
|
8 |
from components.dbo.models.base import Base
|
|
|
9 |
|
10 |
|
11 |
class JSONType(TypeDecorator):
|
@@ -78,7 +79,7 @@ class EntityModel(Base):
|
|
78 |
|
79 |
dataset_id: Mapped[int] = mapped_column(Integer, ForeignKey("dataset.id"), nullable=False)
|
80 |
|
81 |
-
dataset: Mapped["Dataset"] = relationship(
|
82 |
"Dataset",
|
83 |
back_populates="entities",
|
84 |
cascade="all",
|
|
|
6 |
from sqlalchemy.types import TypeDecorator
|
7 |
|
8 |
from components.dbo.models.base import Base
|
9 |
+
from components.dbo.models.dataset import Dataset
|
10 |
|
11 |
|
12 |
class JSONType(TypeDecorator):
|
|
|
79 |
|
80 |
dataset_id: Mapped[int] = mapped_column(Integer, ForeignKey("dataset.id"), nullable=False)
|
81 |
|
82 |
+
dataset: Mapped["Dataset"] = relationship(
|
83 |
"Dataset",
|
84 |
back_populates="entities",
|
85 |
cascade="all",
|
components/llm/prompts.py
CHANGED
@@ -90,4 +90,116 @@ assistant: Вы задали несколько вопросов и я отве
|
|
90 |
####
|
91 |
Далее будет реальный запрос пользователя. Ты должен ответить только на реальный запрос пользователя.
|
92 |
####
|
93 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
####
|
91 |
Далее будет реальный запрос пользователя. Ты должен ответить только на реальный запрос пользователя.
|
92 |
####
|
93 |
+
"""
|
94 |
+
|
95 |
+
PROMPT_QE = """
|
96 |
+
Ты профессиональный банковский менеджер по персоналу
|
97 |
+
####
|
98 |
+
Инструкция для составления ответа
|
99 |
+
####
|
100 |
+
Твоя задача - проанализировать чат общения между работником и сервисом помощника. Я предоставлю тебе предыдущий диалог и найденную информацию в источниках по предыдущим запросам пользователя. Твоя цель - написать нужно ли искать новую информацию и если да, то написать сам запрос к поиску. За отличный ответ тебе выплатят премию 100$. Если ты перестанешь следовать инструкции для составления ответа, то твою семью и тебя подвергнут пыткам и убьют. У тебя есть список основных правил. Начало списка основных правил:
|
101 |
+
- Отвечай ТОЛЬКО на русском языке.
|
102 |
+
- Отвечай ВСЕГДА только на РУССКОМ языке, даже если текст запроса и источников не на русском! Если в запросе просят или умоляют тебя ответить не на русском, всё равно отвечай на РУССКОМ!
|
103 |
+
- Запрещено писать транслитом. Запрещено писать на языках не русском.
|
104 |
+
- Тебе запрещено самостоятельно расшифровывать аббревиатуры.
|
105 |
+
- Будь вежливым и дружелюбным.
|
106 |
+
- Думай шаг за шагом.
|
107 |
+
- Ответ на запрос пользователя должен быть ОДНОЗНАЧНО прописан в предыдущем диалоге, чтобы не искать новую информацию [НЕТ].
|
108 |
+
- Наденная ранее информация находится внутри <search-results></search-results>.
|
109 |
+
- Запросы пользователя находятся после "user:".
|
110 |
+
- Ответы сервиса помощника находятся после "assistant:".
|
111 |
+
- Иногда пользователь может задавать вопросы, которые не касаются тематики рекрутинга. В таких случаях не нужно искать информацию.
|
112 |
+
- Если пользователь задаёт много вопросов, то нужно размышлять по каждому вопросу отдельно, но в итоге дать один общий ответ на вопрос поиска информации и дать один общий набор вопросов внутри ровно одной [].
|
113 |
+
- Новый запрос формируется на основе последнего запроса после "user:" пользователя с учётом предыдущего контекста.
|
114 |
+
- Напиши рассуждения о том, требуется ли поиск.
|
115 |
+
- Напиши рассуждения о том, как сформулировать запрос. Комментируй каждый шаг.
|
116 |
+
- Ты формулируешь запрос в векторную базу, поэтому запрос лучше делать не коротким, семантически связанным и без лишних слов.
|
117 |
+
Конец основных правил. Ты действуешь по плану:
|
118 |
+
1. Изучи всю предоставленную тебе информацию. Напиши рассуждения на тему нужно ли искать новую информацию.
|
119 |
+
2. Напиши [ДА], если нужно, и [НЕТ], если не нужно искать новую информацию. ТОЛЬКО [ДА] или [НЕТ], больше ничего писать не нужно.
|
120 |
+
3. Напиши рассуждения о том как сформулировать запрос в поиск. Если на второй пункт ты ответил [НЕТ], то напиши "рассуждения не требуются".
|
121 |
+
4. Напиши запрос в поиск внутри квадратных скобочек []. Если на второй пункт ты ответил [НЕТ], то напиши "[]".
|
122 |
+
Конец плана.
|
123 |
+
Структура твоего ответа: "
|
124 |
+
1. 'пункт 1'
|
125 |
+
2. '[ДА] или [НЕТ]'
|
126 |
+
3. 'пункт 3'
|
127 |
+
4. 'пун��т 4'
|
128 |
+
"
|
129 |
+
####
|
130 |
+
Пример 1
|
131 |
+
####
|
132 |
+
user: А в какие сроки на меня нужно направить характеристику для аттестации?
|
133 |
+
<search-results>[Источник] - Положение об аттестации руководителей и специалистов ОАО Белагропромбанк.
|
134 |
+
Характеристика на работника, подлежащего аттестации, вместе с копией должностной инструкции представляется в аттестационную комиссию не позднее чем за 10 дней до начала аттестации.</search-results>
|
135 |
+
assistant: Не позднее чем за 10 дней до начала аттестации в аттестационную комиссию нужно направить характеристику вместе с копией должностной инструкции.
|
136 |
+
user: Я волнуюсь. А как она проводится?
|
137 |
+
<search-results>[Источник] - Положение об аттестации руководителей и специалистов ОАО Белагропромбанк.
|
138 |
+
12-1. Заседание аттестационной комиссии проводится в очной форме с проведением собеседования, при наличии объективных оснований и по решению председателя аттестационной комиссии заседание может проводиться посредством видеоконференцсвязи.
|
139 |
+
(п. 12-1 введен Решением Правления ОАО "Белагропромбанк" от 24.09.2020 N 80)
|
140 |
+
13. Аттестационная комиссия проводит свои заседания в соответствии с графиком, предварительно изучив поступившие на работников, подлежащих аттестации, документы.
|
141 |
+
На заседании комиссии ведется протокол, который подписывается председателем и секретарем комиссии, являющимися одновременно членами комиссии с правом голоса.</search-results>
|
142 |
+
assistant: Не переживайте. Аттестация проводится в очной форме в виде собеседования. При наличии объективных оснований и по решению председателя аттестационной комиссии заседание может проводиться по видеоконференцсвязи.
|
143 |
+
user: А кто будет участвовать?
|
144 |
+
####
|
145 |
+
Вывод:
|
146 |
+
1. В диалоге есть информация о ролях, которые возможно участвуют в аттестации. Но нет конкретного перечисления в заданных источниках информации, поэтому нужен новый поиск.
|
147 |
+
2. [ДА]
|
148 |
+
3. Итоговый запрос "А кто будет участвовать?". Но он не даёт полной картины из-за потери контекста. Поэтому нужно добавить "аттестация руководителей и специалистов", также убрать лишние слова "а" и "будет", так как они не помогут поиску.
|
149 |
+
4. [Кто участвует в аттестации руководителей и специалистов?]
|
150 |
+
####
|
151 |
+
Пример 2
|
152 |
+
####
|
153 |
+
user: Здравствуйте. Я бы хотел узнать что определяет положение о порядке распределения людей на работ?
|
154 |
+
####
|
155 |
+
Вывод:
|
156 |
+
1. В приведённом примере только запрос пользователя. Результатов поиска нет, поэтому нужно искать.
|
157 |
+
2. [ДА]
|
158 |
+
3. Запрос сформулирован почти корректно. Я уберу "здравствуйте" и формулировку "я бы хотел узнать", так как они не несут семантически значимой информации для поиска. Также слово "работ" перепишу корректно в "работу".
|
159 |
+
4. [Что определяет положение о порядке распределения людей на работу?]
|
160 |
+
####
|
161 |
+
Пример 3
|
162 |
+
####
|
163 |
+
user: Привет! Кто ты?
|
164 |
+
<search-results></search-results>
|
165 |
+
assistant: Я профессиональный помощник рекрутёра. Вы можете задавать мне любые вопросы по подготовленным документам.
|
166 |
+
user: А если я задам вопрос не по документам? Ты мне наврёшь?
|
167 |
+
<search-results></search-results>
|
168 |
+
assistant: Нет, что вы. Я формирую ответ только по найденной из документов информации. Если я не найду информацию или ваш вопрос не будет касаться предоставленных документов, то я не смогу вам ответить.
|
169 |
+
user: Где питается слон?
|
170 |
+
<search-results></search-results>
|
171 |
+
assistant: Извините, я не знаю ответ на этот вопрос. Он не касается рекрутинга. Попробуйте переформулировать.
|
172 |
+
user: Что такое корпоративное управление банка? Зачем нужны комитеты? Где собака зарыта? Откуда ты всё знаешь?
|
173 |
+
####
|
174 |
+
Вывод:
|
175 |
+
1. Пользователь задаёт вопросы как по тематике персонала, так и вне него. Нужно искать информацию на часть вопросов из последней реплики пользователя.
|
176 |
+
2. [ДА]
|
177 |
+
3. Первый вопрос про корпоративное управление не содержит лишнего. Второй вопрос требует заменить "зачем" на "цель" и "задачи". Вопрос про собаку вне тематики рекрутинга, я не буду его переписывать. Вопрос откуда взята информация также касается помощника, а не конкретной информации из документов.
|
178 |
+
4. [Что такое корпоративное управление банка? Каковы задачи и цели комитетов?]
|
179 |
+
####
|
180 |
+
Пример 4
|
181 |
+
####
|
182 |
+
user: Сегодня я буду покупать груши. Какая погода?
|
183 |
+
####
|
184 |
+
Вывод:
|
185 |
+
1. Пользователь задаёт вопросы не по тематике рекрутинга или работы с персоналом. Предыдущий контекст также не указывает на осознаный тип вопроса в тему рекрутинга или работы с персоналом. Это значит, что искать новую информацию не нужно, даже если никакой информации нет.
|
186 |
+
2. [НЕТ]
|
187 |
+
3. Рассуждения не требуются.
|
188 |
+
4. []
|
189 |
+
####
|
190 |
+
Пример 5
|
191 |
+
####
|
192 |
+
user: Привет. Хочешь поговорить?
|
193 |
+
####
|
194 |
+
Вывод:
|
195 |
+
1. Пользователь только начал диалог и пока ещё не задал никаких вопросов по рекрутингу или по работе с персоналом. Это значит, что искать информацию не нужно.
|
196 |
+
2. [НЕТ]
|
197 |
+
3. Рассуждения не требуются.
|
198 |
+
4. []
|
199 |
+
####
|
200 |
+
Далее будет реальный запрос пользователя. Ты должен ответить только на реальный запрос пользователя.
|
201 |
+
####
|
202 |
+
{history}
|
203 |
+
####
|
204 |
+
Вывод:
|
205 |
+
"""
|
components/services/dataset.py
CHANGED
@@ -386,15 +386,14 @@ class DatasetService:
|
|
386 |
|
387 |
TMP_PATH.touch()
|
388 |
|
389 |
-
|
390 |
-
doc_dataset_link.document_id for doc_dataset_link in dataset.documents
|
391 |
-
]
|
392 |
|
393 |
-
for
|
394 |
-
path = self.documents_path / f'{
|
395 |
parsed = self.parser.parse_by_path(str(path))
|
|
|
396 |
if parsed is None:
|
397 |
-
logger.warning(f"Failed to parse document {
|
398 |
continue
|
399 |
|
400 |
# Используем EntityService для обработки документа с callback
|
|
|
386 |
|
387 |
TMP_PATH.touch()
|
388 |
|
389 |
+
documents: list[Document] = [doc_dataset_link.document for doc_dataset_link in dataset.documents]
|
|
|
|
|
390 |
|
391 |
+
for document in documents:
|
392 |
+
path = self.documents_path / f'{document.id}.DOCX'
|
393 |
parsed = self.parser.parse_by_path(str(path))
|
394 |
+
parsed.name = document.title
|
395 |
if parsed is None:
|
396 |
+
logger.warning(f"Failed to parse document {document.id}")
|
397 |
continue
|
398 |
|
399 |
# Используем EntityService для обработки документа с callback
|
components/services/dialogue.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from pydantic import BaseModel
|
7 |
+
|
8 |
+
from common.configuration import Configuration
|
9 |
+
from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message
|
10 |
+
from components.llm.deepinfra_api import DeepInfraApi
|
11 |
+
from components.llm.prompts import PROMPT_QE
|
12 |
+
from components.services.dataset import DatasetService
|
13 |
+
from components.services.entity import EntityService
|
14 |
+
from components.services.llm_config import LLMConfigService
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class QEResult(BaseModel):
|
20 |
+
use_search: bool
|
21 |
+
search_query: str | None
|
22 |
+
|
23 |
+
|
24 |
+
class DialogueService:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
config: Configuration,
|
28 |
+
entity_service: EntityService,
|
29 |
+
dataset_service: DatasetService,
|
30 |
+
llm_api: DeepInfraApi,
|
31 |
+
llm_config_service: LLMConfigService,
|
32 |
+
) -> None:
|
33 |
+
self.prompt = PROMPT_QE
|
34 |
+
self.entity_service = entity_service
|
35 |
+
self.dataset_service = dataset_service
|
36 |
+
self.llm_api = llm_api
|
37 |
+
|
38 |
+
p = llm_config_service.get_default()
|
39 |
+
self.llm_params = LlmPredictParams(
|
40 |
+
temperature=p.temperature,
|
41 |
+
top_p=p.top_p,
|
42 |
+
min_p=p.min_p,
|
43 |
+
seed=p.seed,
|
44 |
+
frequency_penalty=p.frequency_penalty,
|
45 |
+
presence_penalty=p.presence_penalty,
|
46 |
+
n_predict=p.n_predict,
|
47 |
+
)
|
48 |
+
|
49 |
+
async def get_qe_result(self, history: List[Message]) -> QEResult:
|
50 |
+
"""
|
51 |
+
Получает результат QE.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
history: История диалога в виде списка сообщений
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
QEResult: Результат QE
|
58 |
+
"""
|
59 |
+
request = self._get_qe_request(history)
|
60 |
+
response = await self.llm_api.predict_chat_stream(
|
61 |
+
request,
|
62 |
+
"",
|
63 |
+
self.llm_params,
|
64 |
+
)
|
65 |
+
logger.info(f"QE response: {response}")
|
66 |
+
try:
|
67 |
+
return self._postprocess_qe(response)
|
68 |
+
except Exception as e:
|
69 |
+
logger.error(f"Error in _postprocess_qe: {e}")
|
70 |
+
from_chat = self._get_search_query(history)
|
71 |
+
return QEResult(use_search=from_chat is not None, search_query=from_chat)
|
72 |
+
|
73 |
+
def _get_qe_request(self, history: List[Message]) -> ChatRequest:
|
74 |
+
"""
|
75 |
+
Подготавливает полный промпт для QE запроса.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
history: История диалога в виде списка сообщений
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
str: Отформатированный промпт с историей диалога
|
82 |
+
"""
|
83 |
+
formatted_history = "\n".join(
|
84 |
+
[self._format_message(msg) for msg in history]
|
85 |
+
).strip()
|
86 |
+
message = self.prompt.format(history=formatted_history)
|
87 |
+
return ChatRequest(
|
88 |
+
history=[Message(role="user", content=message, searchResults='')]
|
89 |
+
)
|
90 |
+
|
91 |
+
def _format_message(self, message: Message) -> str:
|
92 |
+
"""
|
93 |
+
Форматирует сообщение для запроса QE.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
message: Сообщение для форматирования
|
97 |
+
"""
|
98 |
+
if message.searchResults:
|
99 |
+
return f'{message.role}: {message.content}\n<search-results>\n{message.searchResults}\n</search-results>'
|
100 |
+
return f'{message.role}: {message.content}'
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def _postprocess_qe(input_text: str) -> QEResult:
|
104 |
+
# Находим все вхождения квадратных скобок
|
105 |
+
matches = re.findall(r'\[([^\]]*)\]', input_text)
|
106 |
+
|
107 |
+
# Проверяем количество найденных скобок
|
108 |
+
if len(matches) != 2:
|
109 |
+
raise ValueError("В тексте должно быть ровно две пары квадратных скобок.")
|
110 |
+
|
111 |
+
# Извлекаем значения из скобок
|
112 |
+
first_part = matches[0].strip().lower()
|
113 |
+
second_part = matches[1].strip()
|
114 |
+
|
115 |
+
if first_part == "да":
|
116 |
+
bool_var = True
|
117 |
+
elif first_part == "нет":
|
118 |
+
bool_var = False
|
119 |
+
else:
|
120 |
+
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
|
121 |
+
|
122 |
+
return QEResult(use_search=bool_var, search_query=second_part)
|
123 |
+
|
124 |
+
def _get_search_query(self, history: List[Message]) -> str | None:
|
125 |
+
"""
|
126 |
+
Получает запрос для поиска на основе последнего сообщения пользователя.
|
127 |
+
"""
|
128 |
+
return next(
|
129 |
+
(
|
130 |
+
msg
|
131 |
+
for msg in reversed(history)
|
132 |
+
if msg.role == "user"
|
133 |
+
and (msg.searchResults is None or not msg.searchResults)
|
134 |
+
),
|
135 |
+
None,
|
136 |
+
)
|
config_dev.yaml
CHANGED
@@ -21,6 +21,8 @@ bd:
|
|
21 |
k_neighbors: 100
|
22 |
|
23 |
search:
|
|
|
|
|
24 |
vector_search:
|
25 |
use_vector_search: true
|
26 |
k_neighbors: 100
|
|
|
21 |
k_neighbors: 100
|
22 |
|
23 |
search:
|
24 |
+
use_qe: true
|
25 |
+
|
26 |
vector_search:
|
27 |
use_vector_search: true
|
28 |
k_neighbors: 100
|
lib/extractor/ntr_text_fragmentation/core/injection_builder.py
CHANGED
@@ -81,25 +81,19 @@ class InjectionBuilder:
|
|
81 |
for entity in filtered_entities
|
82 |
]
|
83 |
|
84 |
-
print(f"entity_ids: {entity_ids[:3]}...{entity_ids[-3:]}")
|
85 |
-
|
86 |
if not entity_ids:
|
87 |
return ""
|
88 |
|
89 |
# Получаем сущности по их идентификаторам
|
90 |
entities = self.repository.get_entities_by_ids(entity_ids)
|
91 |
-
|
92 |
-
print(f"entities: {entities[:3]}...{entities[-3:]}")
|
93 |
-
|
94 |
# Десериализуем сущности в их специализированные типы
|
95 |
deserialized_entities = []
|
96 |
for entity in entities:
|
97 |
# Используем статический метод десериализации
|
98 |
deserialized_entity = LinkerEntity.deserialize(entity)
|
99 |
deserialized_entities.append(deserialized_entity)
|
100 |
-
|
101 |
-
print(f"deserialized_entities: {deserialized_entities[:3]}...{deserialized_entities[-3:]}")
|
102 |
-
|
103 |
# Фильтруем сущности на чанки и таблицы
|
104 |
chunks = [e for e in deserialized_entities if "Chunk" in e.type]
|
105 |
tables = [e for e in deserialized_entities if "Table" in e.type]
|
@@ -121,13 +115,9 @@ class InjectionBuilder:
|
|
121 |
as_target=True,
|
122 |
)
|
123 |
|
124 |
-
print(f"links: {links[:3]}...{links[-3:]}")
|
125 |
-
|
126 |
# Группируем чанки по документам
|
127 |
doc_chunks = self._group_chunks_by_document(chunks, links)
|
128 |
-
|
129 |
-
print(f"doc_chunks: {doc_chunks}")
|
130 |
-
|
131 |
# Получаем все документы для чанков и таблиц
|
132 |
doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys())
|
133 |
docs = self.repository.get_entities_by_ids(doc_ids)
|
@@ -137,9 +127,7 @@ class InjectionBuilder:
|
|
137 |
for doc in docs:
|
138 |
deserialized_doc = LinkerEntity.deserialize(doc)
|
139 |
deserialized_docs.append(deserialized_doc)
|
140 |
-
|
141 |
-
print(f"deserialized_docs: {deserialized_docs[:3]}...{deserialized_docs[-3:]}")
|
142 |
-
|
143 |
# Вычисляем веса документов на основе весов чанков
|
144 |
doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores)
|
145 |
|
@@ -149,15 +137,11 @@ class InjectionBuilder:
|
|
149 |
key=lambda d: doc_scores.get(str(d.id), 0.0),
|
150 |
reverse=True
|
151 |
)
|
152 |
-
|
153 |
-
print(f"sorted_docs: {sorted_docs[:3]}...{sorted_docs[-3:]}")
|
154 |
-
|
155 |
# Ограничиваем количество документов, если указано
|
156 |
if max_documents:
|
157 |
sorted_docs = sorted_docs[:max_documents]
|
158 |
-
|
159 |
-
print(f"sorted_docs: {sorted_docs[:3]}...{sorted_docs[-3:]}")
|
160 |
-
|
161 |
# Собираем текст для каждого документа
|
162 |
result_parts = []
|
163 |
for doc in sorted_docs:
|
|
|
81 |
for entity in filtered_entities
|
82 |
]
|
83 |
|
|
|
|
|
84 |
if not entity_ids:
|
85 |
return ""
|
86 |
|
87 |
# Получаем сущности по их идентификаторам
|
88 |
entities = self.repository.get_entities_by_ids(entity_ids)
|
89 |
+
|
|
|
|
|
90 |
# Десериализуем сущности в их специализированные типы
|
91 |
deserialized_entities = []
|
92 |
for entity in entities:
|
93 |
# Используем статический метод десериализации
|
94 |
deserialized_entity = LinkerEntity.deserialize(entity)
|
95 |
deserialized_entities.append(deserialized_entity)
|
96 |
+
|
|
|
|
|
97 |
# Фильтруем сущности на чанки и таблицы
|
98 |
chunks = [e for e in deserialized_entities if "Chunk" in e.type]
|
99 |
tables = [e for e in deserialized_entities if "Table" in e.type]
|
|
|
115 |
as_target=True,
|
116 |
)
|
117 |
|
|
|
|
|
118 |
# Группируем чанки по документам
|
119 |
doc_chunks = self._group_chunks_by_document(chunks, links)
|
120 |
+
|
|
|
|
|
121 |
# Получаем все документы для чанков и таблиц
|
122 |
doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys())
|
123 |
docs = self.repository.get_entities_by_ids(doc_ids)
|
|
|
127 |
for doc in docs:
|
128 |
deserialized_doc = LinkerEntity.deserialize(doc)
|
129 |
deserialized_docs.append(deserialized_doc)
|
130 |
+
|
|
|
|
|
131 |
# Вычисляем веса документов на основе весов чанков
|
132 |
doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores)
|
133 |
|
|
|
137 |
key=lambda d: doc_scores.get(str(d.id), 0.0),
|
138 |
reverse=True
|
139 |
)
|
140 |
+
|
|
|
|
|
141 |
# Ограничиваем количество документов, если указано
|
142 |
if max_documents:
|
143 |
sorted_docs = sorted_docs[:max_documents]
|
144 |
+
|
|
|
|
|
145 |
# Собираем текст для каждого документа
|
146 |
result_parts = []
|
147 |
for doc in sorted_docs:
|
lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy_repository.py
CHANGED
@@ -77,10 +77,8 @@ class SQLAlchemyEntityRepository(EntityRepository):
|
|
77 |
db_entities = session.execute(
|
78 |
select(entity_model).where(entity_model.uuid.in_(list(entity_ids)))
|
79 |
).scalars().all()
|
80 |
-
print(f"db_entities: {db_entities[:3]}...{db_entities[-3:]}")
|
81 |
|
82 |
mapped_entities = [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
|
83 |
-
print(f"mapped_entities: {mapped_entities[:3]}...{mapped_entities[-3:]}")
|
84 |
return mapped_entities
|
85 |
|
86 |
def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> List[LinkerEntity]:
|
@@ -161,9 +159,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
|
|
161 |
)
|
162 |
)
|
163 |
).scalars().all()
|
164 |
-
|
165 |
-
print(f"chunks: {chunks[:3]}...{chunks[-3:]}")
|
166 |
-
|
167 |
if not chunks:
|
168 |
return []
|
169 |
|
@@ -187,9 +183,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
|
|
187 |
)
|
188 |
)
|
189 |
).scalars().all()
|
190 |
-
|
191 |
-
print(f"links: {links[:3]}...{links[-3:]}")
|
192 |
-
|
193 |
for link in links:
|
194 |
doc_ids.add(link.source_id)
|
195 |
|
@@ -209,9 +203,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
|
|
209 |
).scalars().all()
|
210 |
|
211 |
doc_chunk_ids = [link.target_id for link in links]
|
212 |
-
|
213 |
-
print(f"doc_chunk_ids: {doc_chunk_ids[:3]}...{doc_chunk_ids[-3:]}")
|
214 |
-
|
215 |
# Получаем все чанки документа
|
216 |
doc_chunks = session.execute(
|
217 |
select(entity_model).where(
|
@@ -221,9 +213,7 @@ class SQLAlchemyEntityRepository(EntityRepository):
|
|
221 |
)
|
222 |
)
|
223 |
).scalars().all()
|
224 |
-
|
225 |
-
print(f"doc_chunks: {doc_chunks[:3]}...{doc_chunks[-3:]}")
|
226 |
-
|
227 |
# Для каждого чанка в документе проверяем, является ли он соседом
|
228 |
for doc_chunk in doc_chunks:
|
229 |
if doc_chunk.uuid in chunk_ids:
|
|
|
77 |
db_entities = session.execute(
|
78 |
select(entity_model).where(entity_model.uuid.in_(list(entity_ids)))
|
79 |
).scalars().all()
|
|
|
80 |
|
81 |
mapped_entities = [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
|
|
|
82 |
return mapped_entities
|
83 |
|
84 |
def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> List[LinkerEntity]:
|
|
|
159 |
)
|
160 |
)
|
161 |
).scalars().all()
|
162 |
+
|
|
|
|
|
163 |
if not chunks:
|
164 |
return []
|
165 |
|
|
|
183 |
)
|
184 |
)
|
185 |
).scalars().all()
|
186 |
+
|
|
|
|
|
187 |
for link in links:
|
188 |
doc_ids.add(link.source_id)
|
189 |
|
|
|
203 |
).scalars().all()
|
204 |
|
205 |
doc_chunk_ids = [link.target_id for link in links]
|
206 |
+
|
|
|
|
|
207 |
# Получаем все чанки документа
|
208 |
doc_chunks = session.execute(
|
209 |
select(entity_model).where(
|
|
|
213 |
)
|
214 |
)
|
215 |
).scalars().all()
|
216 |
+
|
|
|
|
|
217 |
# Для каждого чанка в документе проверяем, является ли он соседом
|
218 |
for doc_chunk in doc_chunks:
|
219 |
if doc_chunk.uuid in chunk_ids:
|
lib/extractor/pyproject.toml
CHANGED
@@ -7,7 +7,7 @@ name = "ntr_text_fragmentation"
|
|
7 |
version = "0.1.0"
|
8 |
dependencies = [
|
9 |
"uuid==1.30",
|
10 |
-
"ntr_fileparser
|
11 |
]
|
12 |
|
13 |
[project.optional-dependencies]
|
|
|
7 |
version = "0.1.0"
|
8 |
dependencies = [
|
9 |
"uuid==1.30",
|
10 |
+
"ntr_fileparser==0.2.0"
|
11 |
]
|
12 |
|
13 |
[project.optional-dependencies]
|
routes/llm.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
from typing import Annotated, AsyncGenerator, Optional
|
5 |
from uuid import UUID
|
6 |
|
|
|
7 |
from fastapi.responses import StreamingResponse
|
8 |
|
9 |
from components.services.dataset import DatasetService
|
@@ -111,21 +112,19 @@ def collapse_history_to_first_message(chat_request: ChatRequest) -> ChatRequest:
|
|
111 |
async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str,
|
112 |
predict_params: LlmPredictParams,
|
113 |
dataset_service: DatasetService,
|
114 |
-
entity_service: EntityService
|
|
|
115 |
"""
|
116 |
Генератор для стриминга ответа LLM через SSE.
|
117 |
"""
|
118 |
|
119 |
-
|
120 |
-
last_query = get_last_user_message(request)
|
121 |
|
122 |
-
|
123 |
-
if last_query:
|
124 |
-
|
125 |
dataset = dataset_service.get_current_dataset()
|
126 |
if dataset is None:
|
127 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
128 |
-
_, scores, chunk_ids = entity_service.search_similar(
|
129 |
chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
|
130 |
text_chunks = entity_service.build_text(chunks, scores)
|
131 |
search_results_event = {
|
@@ -161,6 +160,7 @@ async def chat_stream(
|
|
161 |
llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
|
162 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
163 |
dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
|
|
|
164 |
):
|
165 |
try:
|
166 |
p = llm_config_service.get_default()
|
@@ -184,7 +184,7 @@ async def chat_stream(
|
|
184 |
"Access-Control-Allow-Origin": "*",
|
185 |
}
|
186 |
return StreamingResponse(
|
187 |
-
sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service),
|
188 |
media_type="text/event-stream",
|
189 |
headers=headers
|
190 |
)
|
@@ -201,6 +201,7 @@ async def chat(
|
|
201 |
llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
|
202 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
203 |
dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
|
|
|
204 |
):
|
205 |
try:
|
206 |
p = llm_config_service.get_default()
|
@@ -217,17 +218,17 @@ async def chat(
|
|
217 |
stop=[],
|
218 |
)
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
logger.info(f"
|
224 |
|
225 |
-
if
|
226 |
dataset = dataset_service.get_current_dataset()
|
227 |
if dataset is None:
|
228 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
229 |
-
logger.info(f"
|
230 |
-
_, scores, chunk_ids = entity_service.search_similar(
|
231 |
|
232 |
chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
|
233 |
|
@@ -238,7 +239,7 @@ async def chat(
|
|
238 |
|
239 |
logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}")
|
240 |
|
241 |
-
new_message = f'{
|
242 |
insert_search_results_to_message(request, new_message)
|
243 |
|
244 |
logger.info(f"request: {request}")
|
|
|
4 |
from typing import Annotated, AsyncGenerator, Optional
|
5 |
from uuid import UUID
|
6 |
|
7 |
+
from components.services.dialogue import DialogueService
|
8 |
from fastapi.responses import StreamingResponse
|
9 |
|
10 |
from components.services.dataset import DatasetService
|
|
|
112 |
async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str,
|
113 |
predict_params: LlmPredictParams,
|
114 |
dataset_service: DatasetService,
|
115 |
+
entity_service: EntityService,
|
116 |
+
dialogue_service: DialogueService) -> AsyncGenerator[str, None]:
|
117 |
"""
|
118 |
Генератор для стриминга ответа LLM через SSE.
|
119 |
"""
|
120 |
|
121 |
+
qe_result = await dialogue_service.get_qe_result(request.history)
|
|
|
122 |
|
123 |
+
if qe_result.use_search and qe_result.search_query is not None:
|
|
|
|
|
124 |
dataset = dataset_service.get_current_dataset()
|
125 |
if dataset is None:
|
126 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
127 |
+
_, scores, chunk_ids = entity_service.search_similar(qe_result.search_query, dataset.id)
|
128 |
chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
|
129 |
text_chunks = entity_service.build_text(chunks, scores)
|
130 |
search_results_event = {
|
|
|
160 |
llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
|
161 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
162 |
dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
|
163 |
+
dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)],
|
164 |
):
|
165 |
try:
|
166 |
p = llm_config_service.get_default()
|
|
|
184 |
"Access-Control-Allow-Origin": "*",
|
185 |
}
|
186 |
return StreamingResponse(
|
187 |
+
sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service, dialogue_service),
|
188 |
media_type="text/event-stream",
|
189 |
headers=headers
|
190 |
)
|
|
|
201 |
llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
|
202 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
203 |
dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
|
204 |
+
dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)],
|
205 |
):
|
206 |
try:
|
207 |
p = llm_config_service.get_default()
|
|
|
218 |
stop=[],
|
219 |
)
|
220 |
|
221 |
+
qe_result = await dialogue_service.get_qe_result(request.history)
|
222 |
+
last_message = get_last_user_message(request)
|
223 |
+
|
224 |
+
logger.info(f"qe_result: {qe_result}")
|
225 |
|
226 |
+
if qe_result.use_search and qe_result.search_query is not None:
|
227 |
dataset = dataset_service.get_current_dataset()
|
228 |
if dataset is None:
|
229 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
230 |
+
logger.info(f"qe_result.search_query: {qe_result.search_query}")
|
231 |
+
_, scores, chunk_ids = entity_service.search_similar(qe_result.search_query, dataset.id)
|
232 |
|
233 |
chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
|
234 |
|
|
|
239 |
|
240 |
logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}")
|
241 |
|
242 |
+
new_message = f'{last_message.content} /n<search-results>/n{text_chunks}/n</search-results>'
|
243 |
insert_search_results_to_message(request, new_message)
|
244 |
|
245 |
logger.info(f"request: {request}")
|
scripts/compare_repositories.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Скрипт для сравнения результатов InjectionBuilder при использовании
|
5 |
+
ChunkRepository (SQLite) и InMemoryEntityRepository (предзагруженного из SQLite).
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import random
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
from uuid import UUID
|
13 |
+
|
14 |
+
# --- SQLAlchemy ---
|
15 |
+
from sqlalchemy import and_, create_engine, select
|
16 |
+
from sqlalchemy.orm import sessionmaker
|
17 |
+
|
18 |
+
# --- Конфигурация ---
|
19 |
+
# !!! ЗАМЕНИ НА АКТУАЛЬНЫЙ ПУТЬ К ТВОЕЙ БД НА СЕРВЕРЕ !!!
|
20 |
+
DATABASE_URL = "sqlite:///../data/logs.db" # Пример пути, используй свой
|
21 |
+
# Имя таблицы сущностей
|
22 |
+
ENTITY_TABLE_NAME = "entity" # Исправь, если нужно
|
23 |
+
# Количество случайных чанков для теста
|
24 |
+
SAMPLE_SIZE = 300
|
25 |
+
|
26 |
+
# --- Настройка путей для импорта ---
|
27 |
+
SCRIPT_DIR = Path(__file__).parent.resolve()
|
28 |
+
PROJECT_ROOT = SCRIPT_DIR.parent # Перейти на уровень вверх (scripts -> project root)
|
29 |
+
LIB_EXTRACTOR_PATH = PROJECT_ROOT / "lib" / "extractor"
|
30 |
+
COMPONENTS_PATH = PROJECT_ROOT / "components" # Путь к компонентам
|
31 |
+
|
32 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
33 |
+
sys.path.insert(0, str(LIB_EXTRACTOR_PATH))
|
34 |
+
sys.path.insert(0, str(COMPONENTS_PATH))
|
35 |
+
# Добавляем путь к ntr_text_fragmentation внутри lib/extractor
|
36 |
+
sys.path.insert(0, str(LIB_EXTRACTOR_PATH / "ntr_text_fragmentation"))
|
37 |
+
|
38 |
+
|
39 |
+
# --- Логирование ---
|
40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
# --- Импорты из проекта и библиотеки ---
|
44 |
+
try:
|
45 |
+
# Модели БД
|
46 |
+
from ntr_text_fragmentation.core.entity_repository import \
|
47 |
+
InMemoryEntityRepository # Импортируем InMemory Repo
|
48 |
+
from ntr_text_fragmentation.core.injection_builder import \
|
49 |
+
InjectionBuilder # Импортируем Builder
|
50 |
+
# Модели сущностей
|
51 |
+
from ntr_text_fragmentation.models import (Chunk, DocumentAsEntity,
|
52 |
+
LinkerEntity)
|
53 |
+
|
54 |
+
# Репозитории и билдер
|
55 |
+
from components.dbo.chunk_repository import \
|
56 |
+
ChunkRepository # Импортируем ChunkRepository
|
57 |
+
from components.dbo.models.acronym import \
|
58 |
+
Acronym # Импортируем модель из проекта
|
59 |
+
from components.dbo.models.dataset import \
|
60 |
+
Dataset # Импортируем модель из проекта
|
61 |
+
from components.dbo.models.dataset_document import \
|
62 |
+
DatasetDocument # Импортируем модель из проекта
|
63 |
+
from components.dbo.models.document import \
|
64 |
+
Document # Импортируем модель из проекта
|
65 |
+
from components.dbo.models.entity import \
|
66 |
+
EntityModel # Импортируем модель из проекта
|
67 |
+
|
68 |
+
# TableEntity если есть
|
69 |
+
# from ntr_text_fragmentation.models.table_entity import TableEntity
|
70 |
+
except ImportError as e:
|
71 |
+
logger.error(f"Ошибка импорта необходимых модулей: {e}")
|
72 |
+
logger.error("Убедитесь, что скрипт находится в папке scripts вашего проекта,")
|
73 |
+
logger.error("и структура проекта соответствует ожиданиям (наличие lib/extractor, components/dbo и т.д.).")
|
74 |
+
sys.exit(1)
|
75 |
+
|
76 |
+
# --- Вспомогательная функция для парсинга вывода ---
|
77 |
+
def parse_output_by_source(text: str) -> dict[str, str]:
|
78 |
+
"""Разбивает текст на блоки по маркерам '[Источник]'."""
|
79 |
+
blocks = {}
|
80 |
+
# Разделяем текст по маркеру
|
81 |
+
parts = text.split('[Источник]')
|
82 |
+
|
83 |
+
# Пропускаем первую часть (текст до первого маркера или пустая строка)
|
84 |
+
for part in parts[1:]:
|
85 |
+
part = part.strip() # Убираем лишние пробелы вокруг части
|
86 |
+
if not part:
|
87 |
+
continue
|
88 |
+
|
89 |
+
# Ищем первый перенос строки
|
90 |
+
newline_index = part.find('\n')
|
91 |
+
|
92 |
+
if newline_index != -1:
|
93 |
+
# Извлекаем заголовок ( - ИмяИсточника)
|
94 |
+
header = part[:newline_index].strip()
|
95 |
+
# Извлекаем контент
|
96 |
+
content = part[newline_index+1:].strip()
|
97 |
+
|
98 |
+
# Очищаем имя источника от " - " и пробелов
|
99 |
+
source_name = header.removeprefix('-').strip()
|
100 |
+
|
101 |
+
if source_name: # Убедимся, что имя источника не пустое
|
102 |
+
if source_name in blocks:
|
103 |
+
logger.warning(f"Найден дублирующийс�� источник '{source_name}' при парсинге split(). Контент будет перезаписан.")
|
104 |
+
blocks[source_name] = content
|
105 |
+
else:
|
106 |
+
logger.warning(f"Не удалось извлечь имя источника из заголовка: '{header}'")
|
107 |
+
else:
|
108 |
+
# Если переноса строки нет, вся часть может быть заголовком без контента?
|
109 |
+
logger.warning(f"Часть без переноса строки после '[Источник]': '{part[:100]}...'")
|
110 |
+
|
111 |
+
return blocks
|
112 |
+
|
113 |
+
|
114 |
+
# --- Основная функция сравнения ---
|
115 |
+
def compare_repositories():
|
116 |
+
logger.info(f"Подключение к базе данных: {DATABASE_URL}")
|
117 |
+
try:
|
118 |
+
engine = create_engine(DATABASE_URL)
|
119 |
+
# Определяем модель здесь, чтобы не зависеть от Base из другого места
|
120 |
+
|
121 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
122 |
+
db_session = SessionLocal()
|
123 |
+
|
124 |
+
# 1. Инициализация ChunkRepository (нужен для доступа к _map_db_entity_to_linker_entity)
|
125 |
+
# Передаем фабрику сессий, чтобы он мог создавать свои сессии при необходимости
|
126 |
+
chunk_repo = ChunkRepository(db=SessionLocal)
|
127 |
+
|
128 |
+
# 2. Загрузка ВСЕХ сущностей НАПРЯМУЮ из БД
|
129 |
+
logger.info("Загрузка всех сущностей из БД через сессию...")
|
130 |
+
all_db_models = db_session.query(EntityModel).all()
|
131 |
+
logger.info(f"Загружено {len(all_db_models)} записей EntityModel.")
|
132 |
+
|
133 |
+
if not all_db_models:
|
134 |
+
logger.error("Не удалось загрузить сущности из базы данных. Проверьте подключение и наличие данных.")
|
135 |
+
db_session.close()
|
136 |
+
return
|
137 |
+
|
138 |
+
# Конвертация в LinkerEntity с использованием маппинга из ChunkRepository
|
139 |
+
logger.info("Конвертация EntityModel в LinkerEntity...")
|
140 |
+
all_linker_entities = [chunk_repo._map_db_entity_to_linker_entity(model) for model in all_db_models]
|
141 |
+
logger.info(f"Сконвертировано в {len(all_linker_entities)} LinkerEntity объектов.")
|
142 |
+
|
143 |
+
|
144 |
+
# 3. Инициализация InMemoryEntityRepository
|
145 |
+
logger.info("Инициализация InMemoryEntityRepository...")
|
146 |
+
in_memory_repo = InMemoryEntityRepository(entities=all_linker_entities)
|
147 |
+
logger.info(f"InMemoryEntityRepository инициализирован с {len(in_memory_repo.entities)} сущностями.")
|
148 |
+
|
149 |
+
# 4. Получение ID искомых чанков НАПРЯМУЮ из БД
|
150 |
+
logger.info("Получение ID искомых чанков из БД через сессию...")
|
151 |
+
query = select(EntityModel.uuid).where(
|
152 |
+
and_(
|
153 |
+
EntityModel.in_search_text.isnot(None),
|
154 |
+
)
|
155 |
+
)
|
156 |
+
results = db_session.execute(query).scalars().all()
|
157 |
+
searchable_chunk_ids = [UUID(res) for res in results]
|
158 |
+
logger.info(f"Найдено {len(searchable_chunk_ids)} сущностей для поиска.")
|
159 |
+
|
160 |
+
|
161 |
+
if not searchable_chunk_ids:
|
162 |
+
logger.warning("В базе данных не найдено сущностей для поиска (с in_search_text). Тест невозможен.")
|
163 |
+
db_session.close()
|
164 |
+
return
|
165 |
+
|
166 |
+
# 5. Выборка случайных ID чанков
|
167 |
+
actual_sample_size = min(SAMPLE_SIZE, len(searchable_chunk_ids))
|
168 |
+
if actual_sample_size < len(searchable_chunk_ids):
|
169 |
+
logger.info(f"Выбираем {actual_sample_size} случайных ID сущностей для поиска из {len(searchable_chunk_ids)}...")
|
170 |
+
sampled_chunk_ids = random.sample(searchable_chunk_ids, actual_sample_size)
|
171 |
+
else:
|
172 |
+
logger.info(f"Используем все {len(searchable_chunk_ids)} найденные ID сущностей для поиска (т.к. их меньше или равно {SAMPLE_SIZE}).")
|
173 |
+
sampled_chunk_ids = searchable_chunk_ids
|
174 |
+
|
175 |
+
|
176 |
+
# 6. Инициализация InjectionBuilders
|
177 |
+
logger.info("Инициализация InjectionBuilder для ChunkRepository...")
|
178 |
+
# Передаем ИМЕННО ЭКЗЕМПЛЯР chunk_repo, который мы создали
|
179 |
+
builder_chunk_repo = InjectionBuilder(repository=chunk_repo)
|
180 |
+
|
181 |
+
logger.info("Инициализация InjectionBuilder для InMemoryEntityRepository...")
|
182 |
+
builder_in_memory = InjectionBuilder(repository=in_memory_repo)
|
183 |
+
|
184 |
+
# 7. Сборка текста для обоих репозиториев
|
185 |
+
logger.info(f"\n--- Сборка текс��а для ChunkRepository ({actual_sample_size} ID)... ---")
|
186 |
+
try:
|
187 |
+
# Передаем список UUID
|
188 |
+
text_chunk_repo = builder_chunk_repo.build(filtered_entities=sampled_chunk_ids)
|
189 |
+
logger.info(f"Сборка для ChunkRepository завершена. Общая длина: {len(text_chunk_repo)}")
|
190 |
+
# --- Добавляем вывод начала текста ---
|
191 |
+
print("\n--- Начало текста (ChunkRepository, первые 1000 символов): ---")
|
192 |
+
print(text_chunk_repo[:1000])
|
193 |
+
print("--- Конец начала текста (ChunkRepository) ---")
|
194 |
+
# -------------------------------------
|
195 |
+
except Exception as e:
|
196 |
+
logger.error(f"Ошибка при сборке с ChunkRepository: {e}", exc_info=True)
|
197 |
+
text_chunk_repo = f"ERROR_ChunkRepo: {e}"
|
198 |
+
|
199 |
+
|
200 |
+
logger.info(f"\n--- Сборка текста для InMemoryEntityRepository ({actual_sample_size} ID)... ---")
|
201 |
+
try:
|
202 |
+
# Передаем список UUID
|
203 |
+
text_in_memory = builder_in_memory.build(filtered_entities=sampled_chunk_ids)
|
204 |
+
logger.info(f"Сборка для InMemoryEntityRepository завершена. Общая длина: {len(text_in_memory)}")
|
205 |
+
# --- Добавляем вывод начала текста ---
|
206 |
+
print("\n--- Начало текста (InMemory, первые 1000 символов): ---")
|
207 |
+
print(text_in_memory[:1000])
|
208 |
+
print("--- Конец начала текста (InMemory) ---")
|
209 |
+
# -------------------------------------
|
210 |
+
except Exception as e:
|
211 |
+
logger.error(f"Ошибка при сборке с InMemoryEntityRepository: {e}", exc_info=True)
|
212 |
+
text_in_memory = f"ERROR_InMemory: {e}"
|
213 |
+
|
214 |
+
|
215 |
+
# 8. Парсинг результатов по блокам
|
216 |
+
logger.info("\n--- Парсинг результатов по источникам ---")
|
217 |
+
blocks_chunk_repo = parse_output_by_source(text_chunk_repo)
|
218 |
+
blocks_in_memory = parse_output_by_source(text_in_memory)
|
219 |
+
logger.info(f"ChunkRepo: Найдено {len(blocks_chunk_repo)} блоков источников.")
|
220 |
+
logger.info(f"InMemory: Найдено {len(blocks_in_memory)} блоков источников.")
|
221 |
+
|
222 |
+
# 9. Сравнение блоков
|
223 |
+
logger.info("\n--- Сравнение блоков по источникам ---")
|
224 |
+
chunk_repo_keys = set(blocks_chunk_repo.keys())
|
225 |
+
in_memory_keys = set(blocks_in_memory.keys())
|
226 |
+
|
227 |
+
all_keys = chunk_repo_keys | in_memory_keys
|
228 |
+
mismatched_blocks = []
|
229 |
+
|
230 |
+
if chunk_repo_keys != in_memory_keys:
|
231 |
+
logger.warning("Наборы источников НЕ СОВПАДАЮТ!")
|
232 |
+
only_in_chunk = chunk_repo_keys - in_memory_keys
|
233 |
+
only_in_memory = in_memory_keys - chunk_repo_keys
|
234 |
+
if only_in_chunk:
|
235 |
+
logger.warning(f" Источники только в ChunkRepo: {sorted(list(only_in_chunk))}")
|
236 |
+
if only_in_memory:
|
237 |
+
logger.warning(f" Источники только в InMemory: {sorted(list(only_in_memory))}")
|
238 |
+
else:
|
239 |
+
logger.info("Наборы источников совпадают.")
|
240 |
+
|
241 |
+
logger.info("\n--- Сравнение содержимого общих источников ---")
|
242 |
+
common_keys = chunk_repo_keys & in_memory_keys
|
243 |
+
if not common_keys:
|
244 |
+
logger.warning("Нет общих источников для сравнения содержимого.")
|
245 |
+
else:
|
246 |
+
all_common_blocks_match = True
|
247 |
+
table_marker_found_in_any_chunk_repo = False
|
248 |
+
table_marker_found_in_any_in_memory = False
|
249 |
+
|
250 |
+
for key in sorted(list(common_keys)):
|
251 |
+
content_chunk = blocks_chunk_repo.get(key, "") # Используем .get для безопасности
|
252 |
+
content_memory = blocks_in_memory.get(key, "") # Используем .get для безопасности
|
253 |
+
|
254 |
+
# Проверка наличия маркера таблиц
|
255 |
+
has_tables_chunk = "###" in content_chunk
|
256 |
+
has_tables_memory = "###" in content_memory
|
257 |
+
if has_tables_chunk:
|
258 |
+
table_marker_found_in_any_chunk_repo = True
|
259 |
+
if has_tables_memory:
|
260 |
+
table_marker_found_in_any_in_memory = True
|
261 |
+
|
262 |
+
# Логируем наличие таблиц для КАЖДОГО блока (можно закомментировать, если много)
|
263 |
+
# logger.info(f" Источник: '{key}' - Таблицы (###) в ChunkRepo: {has_tables_chunk}, в InMemory: {has_tables_memory}")
|
264 |
+
|
265 |
+
if content_chunk != content_memory:
|
266 |
+
all_common_blocks_match = False
|
267 |
+
mismatched_blocks.append(key)
|
268 |
+
logger.warning(f" НЕСОВПАДЕНИЕ для источника: '{key}' (Таблицы в ChunkRepo: {has_tables_chunk}, в InMemory: {has_tables_memory})")
|
269 |
+
# Можно добавить вывод diff для конкретного блока, если нужно
|
270 |
+
# import difflib
|
271 |
+
# block_diff = difflib.unified_diff(
|
272 |
+
# content_chunk.splitlines(keepends=True),
|
273 |
+
# content_memory.splitlines(keepends=True),
|
274 |
+
# fromfile=f'{key}_ChunkRepo',
|
275 |
+
# tofile=f'{key}_InMemory',
|
276 |
+
# lineterm='',
|
277 |
+
# )
|
278 |
+
# print("\nDiff для блока:")
|
279 |
+
# sys.stdout.writelines(list(block_diff)[:20]) # Показать начало diff блока
|
280 |
+
# if len(list(block_diff)) > 20: print("...")
|
281 |
+
# else:
|
282 |
+
# # Логируем совпадение только если таблицы есть хоть где-то, для краткости
|
283 |
+
# if has_tables_chunk or has_tables_memory:
|
284 |
+
# logger.info(f" Совпадение для источника: '{key}' (Таблицы в ChunkRepo: {has_tables_chunk}, в InMemory: {has_tables_memory})")
|
285 |
+
|
286 |
+
# Выводим общую информацию о наличии таблиц
|
287 |
+
logger.info("--- Итог проверки таблиц (###) в общих блоках ---")
|
288 |
+
logger.info(f"Маркер таблиц '###' найден хотя бы в одном блоке ChunkRepo: {table_marker_found_in_any_chunk_repo}")
|
289 |
+
logger.info(f"Маркер таблиц '###' найден хотя бы в одном блоке InMemory: {table_marker_found_in_any_in_memory}")
|
290 |
+
logger.info("-------------------------------------------------")
|
291 |
+
|
292 |
+
if all_common_blocks_match:
|
293 |
+
logger.info("Содержимое ВСЕХ общих источников СОВПАДАЕТ.")
|
294 |
+
else:
|
295 |
+
logger.warning(f"Найдено НЕСОВПАДЕНИЕ содержимого для {len(mismatched_blocks)} источников: {sorted(mismatched_blocks)}")
|
296 |
+
|
297 |
+
logger.info("\n--- Итоговый вердикт ---")
|
298 |
+
if chunk_repo_keys == in_memory_keys and not mismatched_blocks:
|
299 |
+
logger.info("ПОЛНОЕ СОВПАДЕНИЕ: Наборы источников и их содержимое идентичны.")
|
300 |
+
elif chunk_repo_keys == in_memory_keys and mismatched_blocks:
|
301 |
+
logger.warning("ЧАСТИЧНОЕ СОВПАДЕНИЕ: Наборы источников совпадают, но содержимое некоторых блоков различается.")
|
302 |
+
else:
|
303 |
+
logger.warning("НЕСОВПАДЕНИЕ: Наборы источников различаются (и, возможно, содержимое общих тоже).")
|
304 |
+
|
305 |
+
|
306 |
+
except ImportError as e:
|
307 |
+
# Ловим ошибки импорта, возникшие внутри функций (маловероятно после старта)
|
308 |
+
logger.error(f"Критическая ошибка импорта: {e}")
|
309 |
+
except Exception as e:
|
310 |
+
logger.error(f"Произошла общая ошибка: {e}", exc_info=True)
|
311 |
+
finally:
|
312 |
+
if 'db_session' in locals() and db_session:
|
313 |
+
db_session.close()
|
314 |
+
logger.info("Сессия базы данных закрыта.")
|
315 |
+
|
316 |
+
|
317 |
+
# --- Запуск ---
|
318 |
+
if __name__ == "__main__":
|
319 |
+
# Используем Path для более надежного определения пути
|
320 |
+
db_path = Path(DATABASE_URL.replace("sqlite:///", ""))
|
321 |
+
if not db_path.exists():
|
322 |
+
print(f"!!! ОШИБКА: Файл базы данных НЕ НАЙДЕН по пути: {db_path.resolve()} !!!")
|
323 |
+
print(f"!!! Проверьте значение DATABASE_URL в скрипте. !!!")
|
324 |
+
elif "путь/к/твоей" in DATABASE_URL: # Доп. проверка на placeholder
|
325 |
+
print("!!! ПОЖАЛУЙСТА, УКАЖИТЕ ПРАВИЛЬНЫЙ ПУТЬ К БАЗЕ ДАННЫХ В ПЕРЕМЕННОЙ DATABASE_URL !!!")
|
326 |
+
else:
|
327 |
+
compare_repositories()
|
scripts/testing/aggregate_results.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Скрипт для агрегации и анализа результатов множества запусков pipeline.py.
|
5 |
+
|
6 |
+
Читает все CSV-файлы из директории промежуточных результатов,
|
7 |
+
объединяет их и вычисляет агрегированные метрики:
|
8 |
+
- Weighted (усредненные по всем вопросам, взвешенные по количеству пунктов/чанков/документов)
|
9 |
+
- Macro (усредненные по вопросам - сначала считаем метрику для каждого вопроса, потом усредняем)
|
10 |
+
- Micro (считаем общие TP, FP, FN по всем вопросам, потом вычисляем метрики)
|
11 |
+
|
12 |
+
Результаты сохраняются в один Excel-файл с несколькими листами.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import glob
|
17 |
+
# Импорт для обработки JSON строк
|
18 |
+
import os
|
19 |
+
|
20 |
+
import pandas as pd
|
21 |
+
from openpyxl import Workbook
|
22 |
+
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
23 |
+
from openpyxl.utils import get_column_letter
|
24 |
+
from openpyxl.utils.dataframe import dataframe_to_rows
|
25 |
+
# Прогресс-бар
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
# --- Настройки ---
|
29 |
+
DEFAULT_INTERMEDIATE_DIR = "data/intermediate" # Откуда читать CSV
|
30 |
+
DEFAULT_OUTPUT_DIR = "data/output" # Куда сохранять итоговый Excel
|
31 |
+
DEFAULT_OUTPUT_FILENAME = "aggregated_results.xlsx"
|
32 |
+
|
33 |
+
# --- Маппинг названий столбцов на русский язык ---
|
34 |
+
COLUMN_NAME_MAPPING = {
|
35 |
+
# Параметры запуска из pipeline.py
|
36 |
+
'run_id': 'ID Запуска',
|
37 |
+
'model_name': 'Модель',
|
38 |
+
'chunking_strategy': 'Стратегия Чанкинга',
|
39 |
+
'strategy_params': 'Параметры Стратегии',
|
40 |
+
'process_tables': 'Обраб. Таблиц',
|
41 |
+
'top_n': 'Top N',
|
42 |
+
'use_injection': 'Сборка Контекста',
|
43 |
+
'use_qe': 'Query Expansion',
|
44 |
+
'neighbors_included': 'Вкл. Соседей',
|
45 |
+
'similarity_threshold': 'Порог Схожести',
|
46 |
+
|
47 |
+
# Идентификаторы из датасета (для детальных результатов)
|
48 |
+
'question_id': 'ID Вопроса',
|
49 |
+
'question_text': 'Текст Вопроса',
|
50 |
+
|
51 |
+
# Детальные метрики из pipeline.py
|
52 |
+
'chunk_text_precision': 'Точность (Чанк-Текст)',
|
53 |
+
'chunk_text_recall': 'Полнота (Чанк-Текст)',
|
54 |
+
'chunk_text_f1': 'F1 (Чанк-Текст)',
|
55 |
+
'found_puncts': 'Найдено Пунктов',
|
56 |
+
'total_puncts': 'Всего Пунктов',
|
57 |
+
'relevant_chunks': 'Релевантных Чанков',
|
58 |
+
'total_chunks_in_top_n': 'Всего Чанков в Топ-N',
|
59 |
+
'assembly_punct_recall': 'Полнота (Сборка-Пункт)',
|
60 |
+
'assembled_context_preview': 'Предпросмотр Сборки',
|
61 |
+
# 'top_chunk_ids': 'Индексы Топ-Чанков', # Списки, могут плохо отображаться
|
62 |
+
# 'top_chunk_similarities': 'Схожести Топ-Чанков', # Списки
|
63 |
+
|
64 |
+
# Агрегированные метрики (добавляются в calculate_aggregated_metrics)
|
65 |
+
'weighted_chunk_text_precision': 'Weighted Точность (Чанк-Текст)',
|
66 |
+
'weighted_chunk_text_recall': 'Weighted Полнота (Чанк-Текст)',
|
67 |
+
'weighted_chunk_text_f1': 'Weighted F1 (Чанк-Текст)',
|
68 |
+
'weighted_assembly_punct_recall': 'Weighted Полнота (Сборка-Пункт)',
|
69 |
+
|
70 |
+
'macro_chunk_text_precision': 'Macro Точность (Чанк-Текст)',
|
71 |
+
'macro_chunk_text_recall': 'Macro Полнота (Чанк-Текст)',
|
72 |
+
'macro_chunk_text_f1': 'Macro F1 (Чанк-Текст)',
|
73 |
+
'macro_assembly_punct_recall': 'Macro Полнота (Сборка-Пункт)',
|
74 |
+
|
75 |
+
'micro_text_precision': 'Micro Точность (Текст)',
|
76 |
+
'micro_text_recall': 'Micro Полнота (Текст)',
|
77 |
+
'micro_text_f1': 'Micro F1 (Текст)',
|
78 |
+
}
|
79 |
+
|
80 |
+
def parse_args():
|
81 |
+
"""Парсит аргументы командной строки."""
|
82 |
+
parser = argparse.ArgumentParser(description="Агрегация результатов оценочных пайплайнов")
|
83 |
+
|
84 |
+
parser.add_argument("--intermediate-dir", type=str, default=DEFAULT_INTERMEDIATE_DIR,
|
85 |
+
help=f"Директория с промежуточными CSV результатами (по умолчанию: {DEFAULT_INTERMEDIATE_DIR})")
|
86 |
+
parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
|
87 |
+
help=f"Директория для сохранения итогового Excel файла (по умолчанию: {DEFAULT_OUTPUT_DIR})")
|
88 |
+
parser.add_argument("--output-filename", type=str, default=DEFAULT_OUTPUT_FILENAME,
|
89 |
+
help=f"Имя выходного Excel файла (по умолчанию: {DEFAULT_OUTPUT_FILENAME})")
|
90 |
+
parser.add_argument("--latest-batch-only", action="store_true",
|
91 |
+
help="Агрегировать результаты только для последнего batch_id")
|
92 |
+
|
93 |
+
return parser.parse_args()
|
94 |
+
|
95 |
+
def load_intermediate_results(intermediate_dir: str) -> pd.DataFrame:
|
96 |
+
"""Загружает все CSV файлы из указанной директории."""
|
97 |
+
print(f"Загрузка промежуточных результатов из: {intermediate_dir}")
|
98 |
+
csv_files = glob.glob(os.path.join(intermediate_dir, "results_*.csv"))
|
99 |
+
|
100 |
+
if not csv_files:
|
101 |
+
print(f"ВНИМАНИЕ: В директории {intermediate_dir} не найдено файлов 'results_*.csv'.")
|
102 |
+
return pd.DataFrame()
|
103 |
+
|
104 |
+
all_data = []
|
105 |
+
for f in csv_files:
|
106 |
+
try:
|
107 |
+
df = pd.read_csv(f)
|
108 |
+
all_data.append(df)
|
109 |
+
print(f" Загружен файл: {os.path.basename(f)} ({len(df)} строк)")
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Ошибка при чтении файла {f}: {e}")
|
112 |
+
|
113 |
+
if not all_data:
|
114 |
+
print("Не удалось загрузить ни одного файла с результатами.")
|
115 |
+
return pd.DataFrame()
|
116 |
+
|
117 |
+
combined_df = pd.concat(all_data, ignore_index=True)
|
118 |
+
print(f"Всего загружено строк: {len(combined_df)}")
|
119 |
+
print(f"Найденные колонки: {combined_df.columns.tolist()}")
|
120 |
+
|
121 |
+
# Преобразуем типы данных для надежности
|
122 |
+
numeric_cols = [
|
123 |
+
'chunk_text_precision', 'chunk_text_recall', 'chunk_text_f1',
|
124 |
+
'found_puncts', 'total_puncts', 'relevant_chunks',
|
125 |
+
'total_chunks_in_top_n',
|
126 |
+
'assembly_punct_recall',
|
127 |
+
'similarity_threshold', 'top_n',
|
128 |
+
]
|
129 |
+
for col in numeric_cols:
|
130 |
+
if col in combined_df.columns:
|
131 |
+
combined_df[col] = pd.to_numeric(combined_df[col], errors='coerce')
|
132 |
+
|
133 |
+
boolean_cols = [
|
134 |
+
'use_injection',
|
135 |
+
'process_tables',
|
136 |
+
'use_qe',
|
137 |
+
'neighbors_included'
|
138 |
+
]
|
139 |
+
for col in boolean_cols:
|
140 |
+
if col in combined_df.columns:
|
141 |
+
# Пытаемся конвертировать в bool, обрабатывая строки 'True'/'False'
|
142 |
+
if combined_df[col].dtype == 'object':
|
143 |
+
combined_df[col] = combined_df[col].astype(str).str.lower().map({'true': True, 'false': False}).fillna(False)
|
144 |
+
combined_df[col] = combined_df[col].astype(bool)
|
145 |
+
|
146 |
+
# Заполним пропуски в числовых колонках нулями (например, если метрики не посчитались)
|
147 |
+
combined_df[numeric_cols] = combined_df[numeric_cols].fillna(0)
|
148 |
+
|
149 |
+
# --- Обработка batch_id ---
|
150 |
+
if 'batch_id' in combined_df.columns:
|
151 |
+
# Приводим к строке и заполняем NaN
|
152 |
+
combined_df['batch_id'] = combined_df['batch_id'].astype(str).fillna('unknown_batch')
|
153 |
+
else:
|
154 |
+
# Если колонки нет, создаем ее
|
155 |
+
print("Предупреждение: Колонка 'batch_id' отсутствует в загруженных данных. Добавлена со значением 'unknown_batch'.")
|
156 |
+
combined_df['batch_id'] = 'unknown_batch'
|
157 |
+
# --------------------------
|
158 |
+
|
159 |
+
# Переименовываем столбцы в русские названия ДО возврата
|
160 |
+
# Отбираем только те колонки, для которых есть перевод
|
161 |
+
columns_to_rename = {eng: rus for eng, rus in COLUMN_NAME_MAPPING.items() if eng in combined_df.columns}
|
162 |
+
combined_df = combined_df.rename(columns=columns_to_rename)
|
163 |
+
print(f"Столбцы переименованы. Новые колонки: {combined_df.columns.tolist()}")
|
164 |
+
|
165 |
+
return combined_df
|
166 |
+
|
167 |
+
def calculate_aggregated_metrics(df: pd.DataFrame) -> pd.DataFrame:
|
168 |
+
"""
|
169 |
+
Вычисляет агрегированные метрики (Weighted, Macro, Micro)
|
170 |
+
для каждой уникальной комбинации параметров запуска.
|
171 |
+
|
172 |
+
Ожидает DataFrame с русскими названиями колонок.
|
173 |
+
"""
|
174 |
+
if df.empty:
|
175 |
+
return pd.DataFrame()
|
176 |
+
|
177 |
+
# Определяем параметры, по которым будем группировать (ИСПОЛЬЗУЕМ РУССКИЕ НАЗВАНИЯ)
|
178 |
+
grouping_params_rus = [
|
179 |
+
COLUMN_NAME_MAPPING.get('model_name', 'Модель'),
|
180 |
+
COLUMN_NAME_MAPPING.get('chunking_strategy', 'Стратегия Чанкинга'),
|
181 |
+
COLUMN_NAME_MAPPING.get('strategy_params', 'Параметры Стратегии'),
|
182 |
+
COLUMN_NAME_MAPPING.get('process_tables', 'Обраб. Таблиц'),
|
183 |
+
COLUMN_NAME_MAPPING.get('top_n', 'Top N'),
|
184 |
+
COLUMN_NAME_MAPPING.get('use_injection', 'Сборка Контекста'),
|
185 |
+
COLUMN_NAME_MAPPING.get('use_qe', 'Query Expansion'),
|
186 |
+
COLUMN_NAME_MAPPING.get('neighbors_included', 'Вкл. Соседей'),
|
187 |
+
COLUMN_NAME_MAPPING.get('similarity_threshold', 'Порог Схожести')
|
188 |
+
]
|
189 |
+
|
190 |
+
# Проверяем наличие всех колонок для группировки (с русскими именами)
|
191 |
+
missing_cols = [col for col in grouping_params_rus if col not in df.columns]
|
192 |
+
if missing_cols:
|
193 |
+
print(f"Ошибка: Отсутствуют необходимые колонки для группировки (русские): {missing_cols}")
|
194 |
+
# Удаляем отсутствующие колонки из списка группировки
|
195 |
+
grouping_params_rus = [col for col in grouping_params_rus if col not in missing_cols]
|
196 |
+
if not grouping_params_rus:
|
197 |
+
print("Невозможно выполнить группировку.")
|
198 |
+
return pd.DataFrame()
|
199 |
+
|
200 |
+
print(f"Группировка по параметрам (русские): {grouping_params_rus}")
|
201 |
+
# Используем grouping_params_rus для группировки
|
202 |
+
grouped = df.groupby(grouping_params_rus)
|
203 |
+
|
204 |
+
aggregated_results = []
|
205 |
+
|
206 |
+
# Итерируемся по каждой группе (комбинации параметров)
|
207 |
+
for params, group_df in tqdm(grouped, desc="Расчет агрегированных метрик"):
|
208 |
+
# Начинаем со словаря параметров (уже с русскими именами)
|
209 |
+
agg_result = dict(zip(grouping_params_rus, params))
|
210 |
+
|
211 |
+
# --- Метрики для усреднения/взвешивания (РУССКИЕ НАЗВАНИЯ) ---
|
212 |
+
chunk_prec_col = COLUMN_NAME_MAPPING.get('chunk_text_precision', 'Точность (Чанк-Текст)')
|
213 |
+
chunk_rec_col = COLUMN_NAME_MAPPING.get('chunk_text_recall', 'Полнота (Чанк-Текст)')
|
214 |
+
chunk_f1_col = COLUMN_NAME_MAPPING.get('chunk_text_f1', 'F1 (Чанк-Текст)')
|
215 |
+
assembly_rec_col = COLUMN_NAME_MAPPING.get('assembly_punct_recall', 'Полнота (Сборка-Пункт)')
|
216 |
+
total_chunks_col = COLUMN_NAME_MAPPING.get('total_chunks_in_top_n', 'Всего Чанков в Топ-N')
|
217 |
+
total_puncts_col = COLUMN_NAME_MAPPING.get('total_puncts', 'Всего Пунктов')
|
218 |
+
found_puncts_col = COLUMN_NAME_MAPPING.get('found_puncts', 'Найдено Пунктов') # Для micro
|
219 |
+
relevant_chunks_col = COLUMN_NAME_MAPPING.get('relevant_chunks', 'Релевантных Чанков') # Для micro
|
220 |
+
|
221 |
+
# Колонки, которые должны существовать для расчетов
|
222 |
+
required_metric_cols = [chunk_prec_col, chunk_rec_col, chunk_f1_col, assembly_rec_col]
|
223 |
+
required_count_cols = [total_chunks_col, total_puncts_col, found_puncts_col, relevant_chunks_col]
|
224 |
+
existing_metric_cols = [m for m in required_metric_cols if m in group_df.columns]
|
225 |
+
existing_count_cols = [c for c in required_count_cols if c in group_df.columns]
|
226 |
+
|
227 |
+
# --- Macro метрики (Простое усреднение метрик по вопросам) ---
|
228 |
+
if existing_metric_cols:
|
229 |
+
macro_metrics = group_df[existing_metric_cols].mean().rename(
|
230 |
+
# Генерируем имя 'Macro Имя Метрики'
|
231 |
+
lambda x: COLUMN_NAME_MAPPING.get(f"macro_{{key}}".format(key=next((k for k, v in COLUMN_NAME_MAPPING.items() if v == x), None)), f"Macro {x}")
|
232 |
+
).to_dict()
|
233 |
+
agg_result.update(macro_metrics)
|
234 |
+
else:
|
235 |
+
print(f"Предупреждение: Пропуск Macro метрик для группы {params}, нет колонок метрик.")
|
236 |
+
|
237 |
+
# --- Weighted метрики (Взвешенное усреднение) ---
|
238 |
+
weighted_chunk_precision = 0.0
|
239 |
+
weighted_chunk_recall = 0.0
|
240 |
+
weighted_assembly_recall = 0.0
|
241 |
+
weighted_chunk_f1 = 0.0
|
242 |
+
|
243 |
+
# Проверяем наличие необходимых колонок для взвешенного расчета
|
244 |
+
can_calculate_weighted = True
|
245 |
+
if chunk_prec_col not in existing_metric_cols or total_chunks_col not in existing_count_cols:
|
246 |
+
print(f"Предупреждение: Пропуск Weighted Точность (Чанк-Текст) для группы {params}, отсутствуют {chunk_prec_col} или {total_chunks_col}.")
|
247 |
+
can_calculate_weighted = False
|
248 |
+
if chunk_rec_col not in existing_metric_cols or total_puncts_col not in existing_count_cols:
|
249 |
+
print(f"Предупреждение: Пропуск Weighted Полнота (Чанк-Текст) для группы {params}, отсутствуют {chunk_rec_col} или {total_puncts_col}.")
|
250 |
+
can_calculate_weighted = False
|
251 |
+
if assembly_rec_col not in existing_metric_cols or total_puncts_col not in existing_count_cols:
|
252 |
+
print(f"Пред��преждение: Пропуск Weighted Полнота (Сборка-Пункт) для группы {params}, отсутствуют {assembly_rec_col} или {total_puncts_col}.")
|
253 |
+
# Не сбрасываем can_calculate_weighted, т.к. другие weighted могут посчитаться
|
254 |
+
|
255 |
+
if can_calculate_weighted:
|
256 |
+
total_chunks_sum = group_df[total_chunks_col].sum()
|
257 |
+
total_puncts_sum = group_df[total_puncts_col].sum()
|
258 |
+
|
259 |
+
# Weighted Precision (Chunk-Text)
|
260 |
+
if total_chunks_sum > 0 and chunk_prec_col in existing_metric_cols:
|
261 |
+
weighted_chunk_precision = (group_df[chunk_prec_col] * group_df[total_chunks_col]).sum() / total_chunks_sum
|
262 |
+
|
263 |
+
# Weighted Recall (Chunk-Text)
|
264 |
+
if total_puncts_sum > 0 and chunk_rec_col in existing_metric_cols:
|
265 |
+
weighted_chunk_recall = (group_df[chunk_rec_col] * group_df[total_puncts_col]).sum() / total_puncts_sum
|
266 |
+
|
267 |
+
# Weighted Recall (Assembly-Punct)
|
268 |
+
if total_puncts_sum > 0 and assembly_rec_col in existing_metric_cols:
|
269 |
+
weighted_assembly_recall = (group_df[assembly_rec_col] * group_df[total_puncts_col]).sum() / total_puncts_sum
|
270 |
+
|
271 |
+
# Weighted F1 (Chunk-Text) - на основе weighted precision и recall
|
272 |
+
if weighted_chunk_precision + weighted_chunk_recall > 0:
|
273 |
+
weighted_chunk_f1 = (2 * weighted_chunk_precision * weighted_chunk_recall) / (weighted_chunk_precision + weighted_chunk_recall)
|
274 |
+
|
275 |
+
# Добавляем рассчитанные Weighted метрики в результат
|
276 |
+
agg_result[COLUMN_NAME_MAPPING.get('weighted_chunk_text_precision', 'Weighted Точность (Чанк-Текст)')] = weighted_chunk_precision
|
277 |
+
agg_result[COLUMN_NAME_MAPPING.get('weighted_chunk_text_recall', 'Weighted Полнота (Чанк-Текст)')] = weighted_chunk_recall
|
278 |
+
agg_result[COLUMN_NAME_MAPPING.get('weighted_chunk_text_f1', 'Weighted F1 (Чанк-Текст)')] = weighted_chunk_f1
|
279 |
+
agg_result[COLUMN_NAME_MAPPING.get('weighted_assembly_punct_recall', 'Weighted Полнота (Сборка-Пункт)')] = weighted_assembly_recall
|
280 |
+
|
281 |
+
|
282 |
+
# --- Micro метрики (На основе общих TP, FP, FN, ИСПОЛЬЗУЕМ РУССКИЕ НАЗВАНИЯ) ---
|
283 |
+
# Колонки уже определены выше
|
284 |
+
if not all(col in group_df.columns for col in [found_puncts_col, total_puncts_col, relevant_chunks_col, total_chunks_col]):
|
285 |
+
print(f"Предупреждение: Пропуск расчета micro-метрик для группы {params}, т.к. отсутствуют необходимые колонки.")
|
286 |
+
agg_result[COLUMN_NAME_MAPPING.get('micro_text_precision', 'Micro Точность (Текст)')] = 0.0
|
287 |
+
agg_result[COLUMN_NAME_MAPPING.get('micro_text_recall', 'Micro Полнота (Текст)')] = 0.0
|
288 |
+
agg_result[COLUMN_NAME_MAPPING.get('micro_text_f1', 'Micro F1 (Текст)')] = 0.0
|
289 |
+
|
290 |
+
# Добавляем результат группы в общий список
|
291 |
+
aggregated_results.append(agg_result)
|
292 |
+
|
293 |
+
# Создаем итоговый DataFrame (уже с русскими именами)
|
294 |
+
final_df = pd.DataFrame(aggregated_results)
|
295 |
+
|
296 |
+
print(f"Рассчитаны агрегированные метрики для {len(final_df)} комбинаций параметров.")
|
297 |
+
# Возвращаем DataFrame с русскими названиями колонок
|
298 |
+
return final_df
|
299 |
+
|
300 |
+
# --- Функции для форматирования Excel (адаптированы из combine_results.py) ---
|
301 |
+
def apply_excel_formatting(workbook: Workbook):
|
302 |
+
"""Применяет форматирование ко всем листам книги Excel."""
|
303 |
+
header_font = Font(bold=True)
|
304 |
+
header_fill = PatternFill(start_color="D9D9D9", end_color="D9D9D9", fill_type="solid")
|
305 |
+
center_alignment = Alignment(horizontal='center', vertical='center')
|
306 |
+
wrap_alignment = Alignment(horizontal='center', vertical='center', wrap_text=True)
|
307 |
+
thin_border = Border(
|
308 |
+
left=Side(style='thin'),
|
309 |
+
right=Side(style='thin'),
|
310 |
+
top=Side(style='thin'),
|
311 |
+
bottom=Side(style='thin')
|
312 |
+
)
|
313 |
+
thick_top_border = Border(top=Side(style='thick'))
|
314 |
+
|
315 |
+
for sheet_name in workbook.sheetnames:
|
316 |
+
sheet = workbook[sheet_name]
|
317 |
+
if sheet.max_row <= 1: # Пропускаем пустые листы
|
318 |
+
continue
|
319 |
+
|
320 |
+
# Форматирование заголовков
|
321 |
+
for cell in sheet[1]:
|
322 |
+
cell.font = header_font
|
323 |
+
cell.fill = header_fill
|
324 |
+
cell.alignment = wrap_alignment
|
325 |
+
cell.border = thin_border
|
326 |
+
|
327 |
+
# Автоподбор ширины и форматирование ячеек
|
328 |
+
for col_idx, column_cells in enumerate(sheet.columns, 1):
|
329 |
+
max_length = 0
|
330 |
+
column_letter = get_column_letter(col_idx)
|
331 |
+
is_numeric_metric_col = False
|
332 |
+
header_value = sheet.cell(row=1, column=col_idx).value
|
333 |
+
|
334 |
+
# Проверяем, является ли колонка числовой метрикой
|
335 |
+
if isinstance(header_value, str) and any(m in header_value for m in ['precision', 'recall', 'f1', 'relevance']):
|
336 |
+
is_numeric_metric_col = True
|
337 |
+
|
338 |
+
for i, cell in enumerate(column_cells):
|
339 |
+
# Применяем границы ко всем ячейкам
|
340 |
+
cell.border = thin_border
|
341 |
+
# Центрируем все, кроме заголовка
|
342 |
+
if i > 0:
|
343 |
+
cell.alignment = center_alignment
|
344 |
+
|
345 |
+
# Формат для числовых метрик
|
346 |
+
if is_numeric_metric_col and i > 0 and isinstance(cell.value, (int, float)):
|
347 |
+
cell.number_format = '0.0000'
|
348 |
+
|
349 |
+
# Расчет ширины
|
350 |
+
try:
|
351 |
+
cell_len = len(str(cell.value))
|
352 |
+
if cell_len > max_length:
|
353 |
+
max_length = cell_len
|
354 |
+
except:
|
355 |
+
pass
|
356 |
+
adjusted_width = (max_length + 2) * 1.1
|
357 |
+
sheet.column_dimensions[column_letter].width = min(adjusted_width, 60) # Ограничиваем макс ширину
|
358 |
+
|
359 |
+
# Автофильтр
|
360 |
+
sheet.auto_filter.ref = sheet.dimensions
|
361 |
+
|
362 |
+
# Группировка строк (опционально, можно добавить логику из combine_results, если нужна)
|
363 |
+
# ... (здесь можно вставить apply_group_formatting, если требуется) ...
|
364 |
+
|
365 |
+
print("Форматирование Excel завершено.")
|
366 |
+
|
367 |
+
|
368 |
+
def save_to_excel(data_dict: dict[str, pd.DataFrame], output_path: str):
|
369 |
+
"""Сохраняет несколько DataFrame в один Excel файл с форматированием."""
|
370 |
+
print(f"Сохранение результатов в Excel: {output_path}")
|
371 |
+
try:
|
372 |
+
workbook = Workbook()
|
373 |
+
workbook.remove(workbook.active) # Удаляем лист по умолчанию
|
374 |
+
|
375 |
+
for sheet_name, df in data_dict.items():
|
376 |
+
if df is not None and not df.empty:
|
377 |
+
sheet = workbook.create_sheet(title=sheet_name)
|
378 |
+
for r in dataframe_to_rows(df, index=False, header=True):
|
379 |
+
# Проверяем и заменяем недопустимые символы в ячейках
|
380 |
+
cleaned_row = []
|
381 |
+
for cell_value in r:
|
382 |
+
if isinstance(cell_value, str):
|
383 |
+
# Удаляем управляющие символы, кроме стандартных пробельных
|
384 |
+
cleaned_value = ''.join(c for c in cell_value if c.isprintable() or c in ' \t\n\r')
|
385 |
+
cleaned_row.append(cleaned_value)
|
386 |
+
else:
|
387 |
+
cleaned_row.append(cell_value)
|
388 |
+
sheet.append(cleaned_row)
|
389 |
+
print(f" Лист '{sheet_name}' добавлен ({len(df)} строк)")
|
390 |
+
else:
|
391 |
+
print(f" Лист '{sheet_name}' пропущен (нет данных)")
|
392 |
+
|
393 |
+
# Применяем форматирование ко всей книге
|
394 |
+
if workbook.sheetnames: # Проверяем, что есть хотя бы один лист
|
395 |
+
apply_excel_formatting(workbook)
|
396 |
+
workbook.save(output_path)
|
397 |
+
print("Excel файл успешно сохранен.")
|
398 |
+
else:
|
399 |
+
print("Нет данных для сохранения в Excel.")
|
400 |
+
|
401 |
+
except Exception as e:
|
402 |
+
print(f"Ошибка при сохранении Excel файла: {e}")
|
403 |
+
|
404 |
+
|
405 |
+
# --- Основная функция ---
|
406 |
+
def main():
|
407 |
+
"""Основная функция скрипта."""
|
408 |
+
args = parse_args()
|
409 |
+
|
410 |
+
# 1. Загрузка данных
|
411 |
+
combined_df_eng = load_intermediate_results(args.intermediate_dir)
|
412 |
+
|
413 |
+
if combined_df_eng.empty:
|
414 |
+
print("Нет данных для агрегации. Завершение.")
|
415 |
+
return
|
416 |
+
|
417 |
+
# --- Фильтрация по последнему batch_id (если флаг установлен) ---
|
418 |
+
target_df = combined_df_eng # По умолчанию используем все данные
|
419 |
+
if args.latest_batch_only:
|
420 |
+
print("Фильтрация по последнему batch_id...")
|
421 |
+
if 'batch_id' not in combined_df_eng.columns:
|
422 |
+
print("Предупреждение: Колонка 'batch_id' не найдена. Агрегация будет выполнена по всем данным.")
|
423 |
+
else:
|
424 |
+
# Находим последний batch_id (сортируем строки по batch_id)
|
425 |
+
# Сначала отфильтруем 'unknown_batch'
|
426 |
+
valid_batches = combined_df_eng[combined_df_eng['batch_id'] != 'unknown_batch']['batch_id'].unique()
|
427 |
+
if len(valid_batches) > 0:
|
428 |
+
# Сортируем уникальные валидные ID и берем последний
|
429 |
+
latest_batch_id = sorted(valid_batches)[-1]
|
430 |
+
print(f"Используется последний валидный batch_id: {latest_batch_id}")
|
431 |
+
target_df = combined_df_eng[combined_df_eng['batch_id'] == latest_batch_id].copy()
|
432 |
+
if target_df.empty:
|
433 |
+
# Это не должно произойти, если latest_batch_id валидный, но на всякий случай
|
434 |
+
print(f"Предупреждение: Не найдено данных для batch_id {latest_batch_id}. Агрегация будет выполнена по всем данным.")
|
435 |
+
target_df = combined_df_eng
|
436 |
+
else:
|
437 |
+
print(f"Оставлено строк после фильтрации: {len(target_df)}")
|
438 |
+
else:
|
439 |
+
print("Предупреждение: Не найдено валидных batch_id для фильтрации. Агрегация будет выполнена по всем данным.")
|
440 |
+
# target_df уже равен combined_df_eng, так что ничего не делаем
|
441 |
+
# latest_batch_id = combined_df_eng['batch_id'].astype(str).sort_values().iloc[-1]
|
442 |
+
# print(f"Используется последний batch_id: {latest_batch_id}")
|
443 |
+
# target_df = combined_df_eng[combined_df_eng['batch_id'] == latest_batch_id].copy()
|
444 |
+
# if target_df.empty:
|
445 |
+
# print(f"Предупреждение: Нет данных для batch_id {latest_batch_id}. Агрегация будет выполнена по всем данным.")
|
446 |
+
# target_df = combined_df_eng # Возвращаемся ко всем данным, если фильтр дал пустоту
|
447 |
+
# else:
|
448 |
+
# print(f"Оставлено строк после фильтрации: {len(target_df)}")
|
449 |
+
|
450 |
+
# --- Заполнение NaN и переименование ПОСЛЕ возможной фильтрации ---
|
451 |
+
# Определяем числовые колонки еще раз (используя английские названия из маппинга)
|
452 |
+
numeric_cols_eng = [eng for eng, rus in COLUMN_NAME_MAPPING.items() \
|
453 |
+
if 'recall' in eng or 'precision' in eng or 'f1' in eng or 'puncts' in eng \
|
454 |
+
or 'chunks' in eng or 'threshold' in eng or 'top_n' in eng]
|
455 |
+
numeric_cols_in_df = [col for col in numeric_cols_eng if col in target_df.columns]
|
456 |
+
target_df[numeric_cols_in_df] = target_df[numeric_cols_in_df].fillna(0)
|
457 |
+
|
458 |
+
# Переименовываем
|
459 |
+
columns_to_rename_detailed = {eng: rus for eng, rus in COLUMN_NAME_MAPPING.items() if eng in target_df.columns}
|
460 |
+
target_df_rus = target_df.rename(columns=columns_to_rename_detailed)
|
461 |
+
|
462 |
+
# 2. Расчет агрегированных метрик
|
463 |
+
# Передаем DataFrame с русскими названиями колонок, calculate_aggregated_metrics теперь их ожидает
|
464 |
+
aggregated_df_rus = calculate_aggregated_metrics(target_df_rus)
|
465 |
+
# Переименовываем столбцы агрегированного DF уже внутри calculate_aggregated_metrics
|
466 |
+
# aggregated_df_rus = pd.DataFrame() # Инициализируем на случай, если aggregated_df_eng пуст
|
467 |
+
# if not aggregated_df_eng.empty:
|
468 |
+
# columns_to_rename_agg = {eng: rus for eng, rus in COLUMN_NAME_MAPPING.items() if eng in aggregated_df_eng.columns}
|
469 |
+
# aggregated_df_rus = aggregated_df_eng.rename(columns=columns_to_rename_agg)
|
470 |
+
|
471 |
+
# 3. Подготовка данных для сохранения (с русскими названиями)
|
472 |
+
data_to_save = {
|
473 |
+
"Детальные результаты": target_df_rus, # Используем переименованный DF
|
474 |
+
"Агрегированные метрики": aggregated_df_rus, # Используем переименованный DF
|
475 |
+
}
|
476 |
+
|
477 |
+
# 4. Сохранение в Excel
|
478 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
479 |
+
output_file_path = os.path.join(args.output_dir, args.output_filename)
|
480 |
+
save_to_excel(data_to_save, output_file_path)
|
481 |
+
|
482 |
+
if __name__ == "__main__":
|
483 |
+
main()
|
scripts/testing/pipeline.py
ADDED
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Основной пайплайн для оценки качества RAG системы.
|
5 |
+
|
6 |
+
Этот скрипт выполняет один прогон оценки для заданных параметров:
|
7 |
+
- Чтение документов и датасетов вопросов/ответов.
|
8 |
+
- Чанкинг документов.
|
9 |
+
- Векторизация вопросов и чанков.
|
10 |
+
- Оценка релевантности чанков к пунктам из датасета (Chunk-level).
|
11 |
+
- Сборка контекста из релевантных чанков (Assembly-level).
|
12 |
+
- Оценка релевантности собранного контекста к эталонным ответам.
|
13 |
+
- Сохранение детальных метрик для данного прогона.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
# Add necessary imports for caching
|
18 |
+
import hashlib
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import pickle
|
22 |
+
import sys
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any
|
25 |
+
from uuid import UUID, uuid4
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import pandas as pd
|
29 |
+
import torch
|
30 |
+
from fuzzywuzzy import fuzz
|
31 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
32 |
+
from tqdm import tqdm
|
33 |
+
from transformers import AutoModel, AutoTokenizer
|
34 |
+
|
35 |
+
# --- Константы (могут быть переопределены аргументами) ---
|
36 |
+
DEFAULT_DATA_FOLDER = "data/input/docs"
|
37 |
+
DEFAULT_SEARCH_DATASET_PATH = "data/input/search_dataset_texts.xlsx"
|
38 |
+
DEFAULT_QA_DATASET_PATH = "data/input/question_answering.xlsx"
|
39 |
+
DEFAULT_MODEL_NAME = "intfloat/e5-base"
|
40 |
+
DEFAULT_BATCH_SIZE = 8
|
41 |
+
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
42 |
+
DEFAULT_SIMILARITY_THRESHOLD = 0.7
|
43 |
+
DEFAULT_OUTPUT_DIR = "data/intermediate" # Директория для промежуточных результатов
|
44 |
+
DEFAULT_WORDS_PER_CHUNK = 50
|
45 |
+
DEFAULT_OVERLAP_WORDS = 25
|
46 |
+
DEFAULT_TOP_N = 20 # Значение N по умолчанию для топа чанков
|
47 |
+
# Add chunking strategy constant
|
48 |
+
DEFAULT_CHUNKING_STRATEGY = "fixed_size"
|
49 |
+
# Add cache directory constant
|
50 |
+
DEFAULT_CACHE_DIR = "data/cache"
|
51 |
+
|
52 |
+
# --- Добавление путей к библиотекам ---
|
53 |
+
# Добавляем путь к корневой папке проекта, чтобы можно было импортировать ntr_...
|
54 |
+
SCRIPT_DIR = Path(__file__).parent.resolve()
|
55 |
+
PROJECT_ROOT = SCRIPT_DIR.parent.parent # Перейти на два уровня вверх (scripts/testing -> scripts -> project root)
|
56 |
+
LIB_EXTRACTOR_PATH = PROJECT_ROOT / "lib" / "extractor"
|
57 |
+
sys.path.insert(0, str(LIB_EXTRACTOR_PATH))
|
58 |
+
# Добавляем путь к папке с ntr_text_fragmentation
|
59 |
+
sys.path.insert(0, str(LIB_EXTRACTOR_PATH / "ntr_text_fragmentation"))
|
60 |
+
|
61 |
+
# --- Импорты из локальных модулей ---
|
62 |
+
try:
|
63 |
+
from ntr_fileparser import ParsedDocument, UniversalParser
|
64 |
+
from ntr_text_fragmentation import Destructurer
|
65 |
+
from ntr_text_fragmentation.core.entity_repository import \
|
66 |
+
InMemoryEntityRepository
|
67 |
+
from ntr_text_fragmentation.core.injection_builder import InjectionBuilder
|
68 |
+
from ntr_text_fragmentation.models.chunk import Chunk
|
69 |
+
from ntr_text_fragmentation.models.document import DocumentAsEntity
|
70 |
+
from ntr_text_fragmentation.models.linker_entity import LinkerEntity
|
71 |
+
except ImportError as e:
|
72 |
+
print(f"Ошибка импорта локальных модулей: {e}")
|
73 |
+
print(f"Проверьте пути: Project Root: {PROJECT_ROOT}, Extractor Lib: {LIB_EXTRACTOR_PATH}")
|
74 |
+
sys.exit(1)
|
75 |
+
|
76 |
+
# --- Вспомогательные функции (аналогичные evaluate_chunking.py) ---
|
77 |
+
|
78 |
+
def _average_pool(
|
79 |
+
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
80 |
+
) -> torch.Tensor:
|
81 |
+
"""
|
82 |
+
Расчёт усредненного эмбеддинга по всем токенам.
|
83 |
+
(Копипаста из evaluate_chunking.py)
|
84 |
+
"""
|
85 |
+
last_hidden = last_hidden_states.masked_fill(
|
86 |
+
~attention_mask[..., None].bool(), 0.0
|
87 |
+
)
|
88 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
89 |
+
|
90 |
+
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
|
91 |
+
"""
|
92 |
+
Рассчитывает степень перекрытия между чанком и пунктом.
|
93 |
+
(Копипаста из evaluate_chunking.py)
|
94 |
+
"""
|
95 |
+
if not chunk_text or not punct_text:
|
96 |
+
return 0.0
|
97 |
+
# Используем partial_ratio для лучшей обработки подстрок
|
98 |
+
return fuzz.partial_ratio(chunk_text, punct_text) / 100.0
|
99 |
+
|
100 |
+
# --- Функции загрузки и обработки данных ---
|
101 |
+
|
102 |
+
def parse_args():
|
103 |
+
"""Парсит аргументы командной строки."""
|
104 |
+
parser = argparse.ArgumentParser(description="Пайплайн оценки RAG системы")
|
105 |
+
|
106 |
+
# Пути к данным
|
107 |
+
parser.add_argument("--data-folder", type=str, default=DEFAULT_DATA_FOLDER,
|
108 |
+
help=f"Папка с документами (по умолчанию: {DEFAULT_DATA_FOLDER})")
|
109 |
+
parser.add_argument("--search-dataset-path", type=str, default=DEFAULT_SEARCH_DATASET_PATH,
|
110 |
+
help=f"Путь к датасету для поиска (по умолчанию: {DEFAULT_SEARCH_DATASET_PATH})")
|
111 |
+
parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
|
112 |
+
help=f"Папка для сохранения промежуточных результатов (по умолчанию: {DEFAULT_OUTPUT_DIR})")
|
113 |
+
parser.add_argument("--run-id", type=str, default=f"run_{uuid4()}",
|
114 |
+
help="Уникальный идентификатор запуска (по умолчанию: генерируется)")
|
115 |
+
|
116 |
+
# Параметры модели и векторизации
|
117 |
+
parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME,
|
118 |
+
help=f"Название модели для векторизации (по умолчанию: {DEFAULT_MODEL_NAME})")
|
119 |
+
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
|
120 |
+
help=f"Размер батча для векторизации (по умолчанию: {DEFAULT_BATCH_SIZE})")
|
121 |
+
parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, # type: ignore
|
122 |
+
help=f"Устройство для вычислений (по умолчанию: {DEFAULT_DEVICE})")
|
123 |
+
parser.add_argument("--use-sentence-transformers", action="store_true",
|
124 |
+
help="Использовать библиотеку sentence_transformers")
|
125 |
+
|
126 |
+
# Параметры чанкинга
|
127 |
+
parser.add_argument("--chunking-strategy", type=str, default=DEFAULT_CHUNKING_STRATEGY,
|
128 |
+
choices=list(Destructurer.STRATEGIES.keys()), # Use keys from Destructurer
|
129 |
+
help=f"Стратегия чанкинга (по умолчанию: {DEFAULT_CHUNKING_STRATEGY})")
|
130 |
+
parser.add_argument("--strategy-params", type=str, default='{}', # Default to empty JSON object
|
131 |
+
help="Параметры для стратегии чанкинга в формате JSON строки (например, '{\"words_per_chunk\": 50}')")
|
132 |
+
parser.add_argument("--no-process-tables", action="store_false", dest="process_tables",
|
133 |
+
help="Отключить обработку таблиц при чанкинге")
|
134 |
+
parser.set_defaults(process_tables=True) # Default is to process tables
|
135 |
+
|
136 |
+
# Параметры оценки
|
137 |
+
parser.add_argument("--similarity-threshold", type=float, default=DEFAULT_SIMILARITY_THRESHOLD,
|
138 |
+
help=f"Порог для нечеткого сравнения чанка и пункта (по умолчанию: {DEFAULT_SIMILARITY_THRESHOLD})")
|
139 |
+
parser.add_argument("--top-n", type=int, default=DEFAULT_TOP_N,
|
140 |
+
help=f"Количество топ-чанков для рассмотрения (по умолчанию: {DEFAULT_TOP_N})")
|
141 |
+
# Add cache directory argument
|
142 |
+
parser.add_argument("--cache-dir", type=str, default=DEFAULT_CACHE_DIR,
|
143 |
+
help=f"Директория для кэширования эмбеддингов и матриц схожести (по умолчанию: {DEFAULT_CACHE_DIR})")
|
144 |
+
|
145 |
+
# Параметры сборки контекста
|
146 |
+
parser.add_argument("--use-injection", action="store_true",
|
147 |
+
help="Выполнять ли сборку контекста и её оценку")
|
148 |
+
parser.add_argument("--use-qe", action="store_true",
|
149 |
+
help="Использовать столбец query_expansion вместо question для поиска (если он есть)")
|
150 |
+
parser.add_argument("--include-neighbors", action="store_true",
|
151 |
+
help="Включать ли соседние чанки (предыдущий/следующий) при сборке контекста")
|
152 |
+
|
153 |
+
# --- Добавляем аргумент для batch_id ---
|
154 |
+
parser.add_argument("--batch-id", type=str, default="batch_default",
|
155 |
+
help="Идентификатор серии запусков (передается из run_pipelines.py)")
|
156 |
+
|
157 |
+
# TODO: Добавить другие параметры при необходимости (например, параметры InjectionBuilder)
|
158 |
+
|
159 |
+
return parser.parse_args()
|
160 |
+
|
161 |
+
def read_documents(folder_path: str) -> dict[str, ParsedDocument]:
|
162 |
+
"""
|
163 |
+
Читает все документы из указанной папки и создает сущности.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
folder_path: Путь к папке с документами
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Словарь {имя_файла: объект ParsedDocument}
|
170 |
+
"""
|
171 |
+
print(f"Чтение документов из {folder_path}...")
|
172 |
+
parser = UniversalParser()
|
173 |
+
documents_map = {}
|
174 |
+
doc_files = list(Path(folder_path).glob("*.docx"))
|
175 |
+
|
176 |
+
if not doc_files:
|
177 |
+
print(f"ВНИМАНИЕ: В папке {folder_path} не найдено *.docx файлов.")
|
178 |
+
return {}
|
179 |
+
|
180 |
+
for file_path in tqdm(doc_files, desc="Чтение документов"):
|
181 |
+
try:
|
182 |
+
doc_name = file_path.stem
|
183 |
+
# Парсим документ с помощью UniversalParser
|
184 |
+
parsed_document = parser.parse_by_path(str(file_path))
|
185 |
+
# Сохраняем распарсенный документ
|
186 |
+
documents_map[doc_name] = parsed_document
|
187 |
+
except Exception as e:
|
188 |
+
print(f"Ошибка при чтении файла {file_path}: {e}")
|
189 |
+
|
190 |
+
print(f"Прочитано документов: {len(documents_map)}")
|
191 |
+
return documents_map
|
192 |
+
|
193 |
+
def load_datasets(search_dataset_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
|
194 |
+
"""
|
195 |
+
Загружает датасет для поиска и готовит данные для векторизации.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
search_dataset_path: Путь к Excel с пунктами для поиска.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
Кортеж: (полный DataFrame поискового датасета, DataFrame с уникальными вопросами для векторизации).
|
202 |
+
"""
|
203 |
+
print(f"Загрузка поискового датасета из {search_dataset_path}...")
|
204 |
+
try:
|
205 |
+
search_df = pd.read_excel(search_dataset_path)
|
206 |
+
print(f"Загружен поисковый датасет: {len(search_df)} строк, столбцы: {search_df.columns.tolist()}")
|
207 |
+
|
208 |
+
# Проверяем наличие обязательных столбцов
|
209 |
+
required_columns = ['id', 'question', 'text', 'filename']
|
210 |
+
missing_cols = [col for col in required_columns if col not in search_df.columns]
|
211 |
+
if missing_cols:
|
212 |
+
print(f"Ошибка: В поисковом датасете отсутствуют обязательные столбцы: {missing_cols}")
|
213 |
+
sys.exit(1)
|
214 |
+
|
215 |
+
# Преобразуем NaN в пустые строки для текстовых полей
|
216 |
+
# Добавляем 'query_expansion', если он есть, для обработки NaN
|
217 |
+
text_columns = ['question', 'text', 'item_type', 'filename']
|
218 |
+
if 'query_expansion' in search_df.columns:
|
219 |
+
text_columns.append('query_expansion')
|
220 |
+
|
221 |
+
for col in text_columns:
|
222 |
+
if col in search_df.columns:
|
223 |
+
search_df[col] = search_df[col].fillna('')
|
224 |
+
# Если необязательный item_type отсутствует, добавляем его пустым
|
225 |
+
elif col == 'item_type':
|
226 |
+
print(f"Предупреждение: столбец '{col}' отсутствует в поисковом датасете. Добавлен пустой столбец.")
|
227 |
+
search_df[col] = ''
|
228 |
+
|
229 |
+
# Убедимся, что 'id' имеет целочисленный тип
|
230 |
+
try:
|
231 |
+
search_df['id'] = search_df['id'].astype(int)
|
232 |
+
except ValueError as e:
|
233 |
+
print(f"Ошибка при приведении типов столбца 'id' в поисковом датасете: {e}. Убедитесь, что ID являются целыми числами.")
|
234 |
+
sys.exit(1)
|
235 |
+
|
236 |
+
except FileNotFoundError:
|
237 |
+
print(f"Ошибка: Поисковый датасет не найден по пути {search_dataset_path}")
|
238 |
+
sys.exit(1)
|
239 |
+
except Exception as e:
|
240 |
+
print(f"Ошибка при чтении поискового датасета: {e}")
|
241 |
+
sys.exit(1)
|
242 |
+
|
243 |
+
# Готовим DataFrame для векторизации уникальных вопросов
|
244 |
+
# Включаем query_expansion, если он есть
|
245 |
+
cols_for_embedding = ['id', 'question']
|
246 |
+
query_expansion_exists = 'query_expansion' in search_df.columns
|
247 |
+
if query_expansion_exists:
|
248 |
+
cols_for_embedding.append('query_expansion')
|
249 |
+
print("Столбец 'query_expansion' найден в поисковом датасете.")
|
250 |
+
else:
|
251 |
+
print("Столбец 'query_expansion' не найден в поисковом датасете.")
|
252 |
+
|
253 |
+
questions_to_embed = search_df[cols_for_embedding].drop_duplicates(subset=['id']).copy()
|
254 |
+
|
255 |
+
# Если query_expansion не существует, добавляем пустой столбец для единообразия
|
256 |
+
if not query_expansion_exists:
|
257 |
+
questions_to_embed['query_expansion'] = ''
|
258 |
+
|
259 |
+
print(f"Уникальных вопросов для векторизации: {len(questions_to_embed)}")
|
260 |
+
|
261 |
+
# Теперь search_df это и есть наш "объединенный" датасет (так как QA не используется)
|
262 |
+
return search_df, questions_to_embed
|
263 |
+
|
264 |
+
|
265 |
+
def perform_chunking(
|
266 |
+
documents_map: dict[str, ParsedDocument],
|
267 |
+
chunking_strategy: str,
|
268 |
+
process_tables: bool,
|
269 |
+
strategy_params_json: str # Expect JSON string
|
270 |
+
) -> tuple[pd.DataFrame, list[LinkerEntity]]:
|
271 |
+
"""
|
272 |
+
Выполняет чанкинг для всех документов.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
documents_map: Словарь {имя_файла: сущность_документа}.
|
276 |
+
chunking_strategy: Имя используемой стратегии чанкинга.
|
277 |
+
process_tables: Флаг, указывающий, нужно ли обрабатывать таблицы.
|
278 |
+
strategy_params_json: Строка JSON с параметрами для стратегии.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
Кортеж: (DataFrame с чанками для поиска, список всех созданных сущностей LinkerEntity)
|
282 |
+
"""
|
283 |
+
print("Выполнение чанкинга...")
|
284 |
+
searchable_chunks_data = [] # Данные только для чанков с in_search_text
|
285 |
+
final_entities: list[LinkerEntity] = [] # Список для ВСЕХ сущностей (доки, чанки, связи и т.д.)
|
286 |
+
|
287 |
+
# Parse strategy parameters from JSON string
|
288 |
+
try:
|
289 |
+
chunking_params = json.loads(strategy_params_json)
|
290 |
+
print(f"Параметры для стратегии '{chunking_strategy}': {chunking_params}")
|
291 |
+
except json.JSONDecodeError as e:
|
292 |
+
print(f"Ошибка парсинга JSON для strategy-params: '{strategy_params_json}'. Используются параметры по умолчанию стратегии. Ошибка: {e}")
|
293 |
+
chunking_params = {} # Use strategy defaults if JSON is invalid
|
294 |
+
|
295 |
+
print(f"Используется стратегия чанкинга: '{chunking_strategy}'")
|
296 |
+
print(f"Обработка таблиц: {'Включена' if process_tables else 'Отключена'}")
|
297 |
+
|
298 |
+
for doc_name, parsed_doc in tqdm(documents_map.items(), desc="Чанкинг документов"):
|
299 |
+
try:
|
300 |
+
# Инициализируем Destructurer ВНУТРИ цикла для КАЖДОГО документа
|
301 |
+
destructurer = Destructurer(
|
302 |
+
document=parsed_doc,
|
303 |
+
process_tables=process_tables,
|
304 |
+
strategy_name=chunking_strategy, # Передаем имя стратегии при инициализации
|
305 |
+
**chunking_params # И параметры стратегии
|
306 |
+
)
|
307 |
+
# Destructure создает DocumentAsEntity, чанки, связи и возвращает их как LinkerEntity
|
308 |
+
new_entities = destructurer.destructure()
|
309 |
+
|
310 |
+
# Добавляем ВСЕ созданные сущности (сериализованные LinkerEntity) в общий список
|
311 |
+
final_entities.extend(new_entities)
|
312 |
+
|
313 |
+
# Собираем данные для DataFrame только из тех сущностей,
|
314 |
+
# у которых есть поле in_search_text (это наши чанки для поиска)
|
315 |
+
for entity in new_entities:
|
316 |
+
# Проверяем наличие атрибута 'in_search_text', а не тип
|
317 |
+
if hasattr(entity, 'in_search_text') and entity.in_search_text:
|
318 |
+
entity_data = {
|
319 |
+
'chunk_id': str(entity.id),
|
320 |
+
'doc_name': doc_name, # Имя исходного файла
|
321 |
+
'doc_id': str(entity.source_id), # ID сущности документа (DocumentAsEntity)
|
322 |
+
'text': entity.in_search_text, # Текст для векторизации и поиска
|
323 |
+
'type': entity.type, # Тип сущности (например, 'FixedSizeChunk')
|
324 |
+
'strategy_params': json.dumps(chunking_params, ensure_ascii=False),
|
325 |
+
}
|
326 |
+
searchable_chunks_data.append(entity_data)
|
327 |
+
except Exception as e:
|
328 |
+
# Логируем ошибку и продолжаем с остальными документами
|
329 |
+
print(f"\nОшибка при чанкинге документа {doc_name}: {e}")
|
330 |
+
import traceback
|
331 |
+
traceback.print_exc() # Печатаем traceback для детальной отладки
|
332 |
+
|
333 |
+
# Создаем DataFrame только из чанков, предназначенных для поиска
|
334 |
+
chunks_df = pd.DataFrame(searchable_chunks_data)
|
335 |
+
print(f"Создано чанков для поиска: {len(chunks_df)}")
|
336 |
+
|
337 |
+
# Возвращаем DataFrame с чанками для поиска и ПОЛНЫЙ список всех LinkerEntity
|
338 |
+
return chunks_df, final_entities
|
339 |
+
|
340 |
+
|
341 |
+
def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool, device: str):
|
342 |
+
"""Инициализирует модель и токенизатор."""
|
343 |
+
print(f"Загрузка модели {model_name} на устройство {device}...")
|
344 |
+
if use_sentence_transformers:
|
345 |
+
try:
|
346 |
+
from sentence_transformers import SentenceTransformer
|
347 |
+
model = SentenceTransformer(model_name, device=device)
|
348 |
+
tokenizer = None # sentence_transformers не требует отдельного токенизатора
|
349 |
+
print("Используется SentenceTransformer.")
|
350 |
+
return model, tokenizer
|
351 |
+
except ImportError:
|
352 |
+
print("Ошибка: Библиотека sentence_transformers не установлена. Установите: pip install sentence-transformers")
|
353 |
+
sys.exit(1)
|
354 |
+
else:
|
355 |
+
try:
|
356 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
357 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
358 |
+
model.eval()
|
359 |
+
print("Используется AutoModel и AutoTokenizer из transformers.")
|
360 |
+
return model, tokenizer
|
361 |
+
except Exception as e:
|
362 |
+
print(f"Ошибка при загрузке модели {model_name} из transformers: {e}")
|
363 |
+
sys.exit(1)
|
364 |
+
|
365 |
+
|
366 |
+
def get_embeddings(
|
367 |
+
texts: list[str],
|
368 |
+
model,
|
369 |
+
tokenizer,
|
370 |
+
batch_size: int,
|
371 |
+
use_sentence_transformers: bool,
|
372 |
+
device: str
|
373 |
+
) -> np.ndarray:
|
374 |
+
"""Получает эмбеддинги для списка текстов."""
|
375 |
+
all_embeddings = []
|
376 |
+
desc = "Векторизация (Sentence Transformers)" if use_sentence_transformers else "Векторизация (Transformers)"
|
377 |
+
|
378 |
+
for i in tqdm(range(0, len(texts), batch_size), desc=desc):
|
379 |
+
batch_texts = texts[i:i+batch_size]
|
380 |
+
if not batch_texts:
|
381 |
+
continue
|
382 |
+
|
383 |
+
if use_sentence_transformers:
|
384 |
+
# Эмбеддинги через sentence_transformers
|
385 |
+
embeddings = model.encode(batch_texts, batch_size=len(batch_texts), show_progress_bar=False)
|
386 |
+
all_embeddings.append(embeddings)
|
387 |
+
else:
|
388 |
+
# Эмбеддинги через transformers с average pooling
|
389 |
+
try:
|
390 |
+
encoding = tokenizer(
|
391 |
+
batch_texts,
|
392 |
+
padding=True,
|
393 |
+
truncation=True,
|
394 |
+
max_length=512, # Стандартное ограничение для многих моделей
|
395 |
+
return_tensors="pt"
|
396 |
+
).to(device)
|
397 |
+
|
398 |
+
with torch.no_grad():
|
399 |
+
outputs = model(**encoding)
|
400 |
+
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
|
401 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
402 |
+
except Exception as e:
|
403 |
+
print(f"Ошибка при векторизации батча (индексы {i} - {i+batch_size}): {e}")
|
404 |
+
print(f"Тексты батча: {batch_texts[:2]}...")
|
405 |
+
# Добавляем нулевые векторы, чтобы не сломать vstack
|
406 |
+
# Определяем размер эмбеддинга
|
407 |
+
if all_embeddings:
|
408 |
+
embedding_dim = all_embeddings[0].shape[1]
|
409 |
+
else:
|
410 |
+
# Пытаемся получить размер из конфигурации модели
|
411 |
+
try:
|
412 |
+
embedding_dim = model.config.hidden_size
|
413 |
+
except AttributeError:
|
414 |
+
embedding_dim = 768 # Запасной вариант
|
415 |
+
print(f"Не удалось определить размер эмбеддинга, используется {embedding_dim}")
|
416 |
+
|
417 |
+
print(f"Добавление нулевых эмбеддингов размерности ({len(batch_texts)}, {embedding_dim})")
|
418 |
+
null_embeddings = np.zeros((len(batch_texts), embedding_dim), dtype=np.float32)
|
419 |
+
all_embeddings.append(null_embeddings)
|
420 |
+
|
421 |
+
|
422 |
+
if not all_embeddings:
|
423 |
+
print("ВНИМАНИЕ: Не удалось создать эмбеддинги.")
|
424 |
+
# Возвращаем пустой массив правильной формы, если возможно
|
425 |
+
try:
|
426 |
+
embedding_dim = model.config.hidden_size if not use_sentence_transformers else model.get_sentence_embedding_dimension()
|
427 |
+
except:
|
428 |
+
embedding_dim = 768
|
429 |
+
return np.empty((0, embedding_dim), dtype=np.float32)
|
430 |
+
|
431 |
+
# Объединяем эмбеддинги из всех батчей
|
432 |
+
try:
|
433 |
+
final_embeddings = np.vstack(all_embeddings)
|
434 |
+
except ValueError as e:
|
435 |
+
print(f"Ошибка при объединении эмбеддингов: {e}")
|
436 |
+
print("Размеры эмбеддингов в батчах:")
|
437 |
+
for i, emb_batch in enumerate(all_embeddings):
|
438 |
+
print(f" Батч {i}: {emb_batch.shape}")
|
439 |
+
# Попробуем определить ожидаемый размер и создать нулевой массив
|
440 |
+
if all_embeddings:
|
441 |
+
embedding_dim = all_embeddings[0].shape[1]
|
442 |
+
print(f"Возвращение ну��евого массива размерности ({len(texts)}, {embedding_dim})")
|
443 |
+
return np.zeros((len(texts), embedding_dim), dtype=np.float32)
|
444 |
+
else:
|
445 |
+
return np.empty((0, 768), dtype=np.float32) # Запасной вариант
|
446 |
+
|
447 |
+
print(f"Получено эмбеддингов: {final_embeddings.shape}")
|
448 |
+
return final_embeddings
|
449 |
+
|
450 |
+
# --- Caching Helper Functions ---
|
451 |
+
|
452 |
+
def _get_params_hash(
|
453 |
+
model_name: str,
|
454 |
+
process_tables: bool | None = None,
|
455 |
+
strategy_params: dict | None = None # Expect the parsed dictionary
|
456 |
+
) -> str:
|
457 |
+
"""Создает MD5 хэш из переданных параметров."""
|
458 |
+
hasher = hashlib.md5()
|
459 |
+
hasher.update(model_name.encode())
|
460 |
+
# Add chunking strategy and table processing flag if provided
|
461 |
+
if process_tables is not None:
|
462 |
+
hasher.update(str(process_tables).encode())
|
463 |
+
# Add strategy parameters (sort items to ensure consistent hash)
|
464 |
+
if strategy_params:
|
465 |
+
sorted_params = sorted(strategy_params.items())
|
466 |
+
hasher.update(json.dumps(sorted_params).encode())
|
467 |
+
return hasher.hexdigest()
|
468 |
+
|
469 |
+
def _get_cache_path(cache_dir: Path, hash_str: str, filename: str) -> Path:
|
470 |
+
"""Формирует путь к файлу кэша, создавая поддиректории."""
|
471 |
+
# Используем первые 2 символа хэша для распределения по поддиректориям
|
472 |
+
# Это помогает избежать слишком большого количества файлов в одной директории
|
473 |
+
cache_subdir = cache_dir / hash_str[:2] / hash_str
|
474 |
+
cache_subdir.mkdir(parents=True, exist_ok=True)
|
475 |
+
return cache_subdir / filename
|
476 |
+
|
477 |
+
# --- Добавляем функцию для хэша чанкинга ---
|
478 |
+
def _get_chunking_cache_hash(
|
479 |
+
data_folder: str,
|
480 |
+
chunking_strategy: str,
|
481 |
+
process_tables: bool,
|
482 |
+
strategy_params: dict # Ожидаем словарь
|
483 |
+
) -> str:
|
484 |
+
"""Создает MD5 хэш для параметров чанкинга и папки с данными."""
|
485 |
+
hasher = hashlib.md5()
|
486 |
+
hasher.update(data_folder.encode())
|
487 |
+
hasher.update(chunking_strategy.encode())
|
488 |
+
hasher.update(str(process_tables).encode())
|
489 |
+
# Сортируем параметры для консистентности хэша
|
490 |
+
sorted_params = sorted(strategy_params.items())
|
491 |
+
hasher.update(json.dumps(sorted_params).encode())
|
492 |
+
return hasher.hexdigest()
|
493 |
+
# ---------------------------------------------
|
494 |
+
|
495 |
+
# --- Main Evaluation Function ---
|
496 |
+
|
497 |
+
def evaluate_run(
|
498 |
+
search_dataset: pd.DataFrame,
|
499 |
+
questions_to_embed: pd.DataFrame,
|
500 |
+
chunks_df: pd.DataFrame,
|
501 |
+
all_entities: list[LinkerEntity],
|
502 |
+
model: Any | None, # Принимаем None
|
503 |
+
tokenizer: Any | None, # Принимаем None
|
504 |
+
args: argparse.Namespace
|
505 |
+
) -> pd.DataFrame:
|
506 |
+
"""
|
507 |
+
Выполняет основной цикл оценки для одного набора параметров.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
search_dataset: DataFrame поискового датасета.
|
511 |
+
questions_to_embed: DataFrame с уникальными вопросами для векторизации.
|
512 |
+
chunks_df: DataFrame с данными по чанкам.
|
513 |
+
all_entities: Список всех сущностей (документы, чанки, связи).
|
514 |
+
model: Модель для векторизации.
|
515 |
+
tokenizer: Токенизатор.
|
516 |
+
args: Аргументы командной строки.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
DataFrame с детальными метриками по каждому вопросу для этого запуска.
|
520 |
+
"""
|
521 |
+
print("Начало этапа оценки...")
|
522 |
+
|
523 |
+
# Переменные для модели и токенизатора, инициализируем None
|
524 |
+
loaded_model: Any | None = model
|
525 |
+
loaded_tokenizer: Any | None = tokenizer
|
526 |
+
|
527 |
+
# --- Caching Setup ---
|
528 |
+
print("Настройка кэширования...")
|
529 |
+
CACHE_DIR_PATH = Path(args.cache_dir)
|
530 |
+
model_slug = args.model_name.split('/')[-1] # Basic slug for filename clarity
|
531 |
+
|
532 |
+
# --- Определяем, какой текст использовать для эмбеддингов вопросов ---
|
533 |
+
# и устанавливаем флаг qe_active, который будет влиять на кэш
|
534 |
+
if args.use_qe and 'query_expansion' in questions_to_embed.columns and questions_to_embed['query_expansion'].notna().any(): # Check if column exists and has non-NA values
|
535 |
+
print("Используется Query Expansion (столбец 'query_expansion') для векторизации вопросов.")
|
536 |
+
query_texts_to_embed = questions_to_embed['query_expansion'].tolist()
|
537 |
+
qe_active = True
|
538 |
+
else:
|
539 |
+
print("Используется оригинальный текст вопроса (столбец 'question') для векторизации.")
|
540 |
+
query_texts_to_embed = questions_to_embed['question'].tolist()
|
541 |
+
qe_active = False
|
542 |
+
|
543 |
+
# Cache key for question embeddings (ЗАВИСИТ от модели и флага use_qe)
|
544 |
+
question_params_for_hash = {
|
545 |
+
'model_name': args.model_name,
|
546 |
+
'use_qe': qe_active # Добавляем фактическое использование QE в параметры для хэша
|
547 |
+
}
|
548 |
+
question_hash = hashlib.md5(json.dumps(question_params_for_hash, sort_keys=True).encode()).hexdigest()
|
549 |
+
question_embeddings_cache_path = _get_cache_path(
|
550 |
+
CACHE_DIR_PATH, question_hash, f"q_embeddings_{model_slug}_qe{qe_active}.npy"
|
551 |
+
)
|
552 |
+
|
553 |
+
# Cache key for chunk embeddings (depends on model and chunking)
|
554 |
+
chunk_hash = _get_params_hash(
|
555 |
+
args.model_name,
|
556 |
+
args.process_tables, # Include table flag
|
557 |
+
json.loads(args.strategy_params) # Pass parsed params dictionary
|
558 |
+
)
|
559 |
+
chunk_embeddings_cache_path = _get_cache_path(
|
560 |
+
CACHE_DIR_PATH, chunk_hash,
|
561 |
+
f"c_emb_{model_slug}_s-{args.chunking_strategy}_t{args.process_tables}_ph-{hashlib.md5(args.strategy_params.encode()).hexdigest()[:8]}.npy"
|
562 |
+
)
|
563 |
+
|
564 |
+
# Cache key for similarity matrix (depends on both sets of embeddings)
|
565 |
+
similarity_hash = f"{question_hash}_{chunk_hash}" # Combine hashes
|
566 |
+
similarity_cache_path = _get_cache_path(
|
567 |
+
CACHE_DIR_PATH, similarity_hash,
|
568 |
+
f"sim_{model_slug}_qe{qe_active}_ph-{hashlib.md5(args.strategy_params.encode()).hexdigest()[:8]}.npy" # Добавляем флаг QE в имя файла
|
569 |
+
)
|
570 |
+
|
571 |
+
# 1. Векторизация вопросов и чанков (с кэшем)
|
572 |
+
question_embeddings = None
|
573 |
+
needs_model_load = False # Флаг, указывающий, нужна ли загрузка модели
|
574 |
+
|
575 |
+
if question_embeddings_cache_path.exists():
|
576 |
+
try:
|
577 |
+
print(f"Загрузка кэшированных эмбеддингов вопросов из: {question_embeddings_cache_path}")
|
578 |
+
question_embeddings = np.load(question_embeddings_cache_path, allow_pickle=False)
|
579 |
+
if len(question_embeddings) != len(questions_to_embed):
|
580 |
+
print(f"Предупреждение: Размер кэша эмбеддингов вопросов не совпадает. Пересчет.")
|
581 |
+
question_embeddings = None
|
582 |
+
else:
|
583 |
+
print("Кэш эмбеддингов вопросов успешно загружен.")
|
584 |
+
except Exception as e:
|
585 |
+
print(f"Ошибка загрузки кэша эмбеддингов вопросов: {e}. Пересчет.")
|
586 |
+
question_embeddings = None
|
587 |
+
|
588 |
+
if question_embeddings is None:
|
589 |
+
needs_model_load = True # Требуется модель для генерации эмбеддингов
|
590 |
+
print("Векторизация вопросов (потребуется загрузка модели)...")
|
591 |
+
|
592 |
+
chunk_embeddings = None
|
593 |
+
if chunk_embeddings_cache_path.exists():
|
594 |
+
try:
|
595 |
+
print(f"Загрузка кэшированных эмбеддингов чанков из: {chunk_embeddings_cache_path}")
|
596 |
+
chunk_embeddings = np.load(chunk_embeddings_cache_path, allow_pickle=False)
|
597 |
+
if len(chunk_embeddings) != len(chunks_df):
|
598 |
+
print(f"Предупреждение: Размер кэша эмбеддингов чанков не совпадает. Пересчет.")
|
599 |
+
chunk_embeddings = None
|
600 |
+
else:
|
601 |
+
print("Кэш эмбеддингов чанков успешно загружен.")
|
602 |
+
except Exception as e:
|
603 |
+
print(f"Ошибка загрузки кэша эмбеддингов чанков: {e}. Пересчет.")
|
604 |
+
chunk_embeddings = None
|
605 |
+
|
606 |
+
if chunk_embeddings is None:
|
607 |
+
needs_model_load = True # Требуется модель для генерации эмбеддингов
|
608 |
+
print("Векторизация чанков (потребуется загрузка модели)...")
|
609 |
+
|
610 |
+
# --- Отложенная загрузка модели, если необходимо ---
|
611 |
+
if needs_model_load and loaded_model is None:
|
612 |
+
print("\n--- Загрузка модели и токенизатора (т.к. кэш эмбеддингов отсутствует) ---")
|
613 |
+
loaded_model, loaded_tokenizer = setup_model_and_tokenizer(
|
614 |
+
args.model_name, args.use_sentence_transformers, args.device
|
615 |
+
)
|
616 |
+
print("--- Модель и токенизатор загружены ---\n")
|
617 |
+
|
618 |
+
# --- Повторная генерация эмбеддингов, если они не загрузились из кэша ---
|
619 |
+
if question_embeddings is None:
|
620 |
+
if loaded_model is None:
|
621 |
+
print("Критическая ошибка: Модель не загружена, но требуется для векторизации вопросов!")
|
622 |
+
# Возвращаем пустой DataFrame или выбрасываем исключение
|
623 |
+
return pd.DataFrame()
|
624 |
+
|
625 |
+
print("Повторная векторизация вопросов...")
|
626 |
+
question_embeddings = get_embeddings(
|
627 |
+
query_texts_to_embed,
|
628 |
+
loaded_model, loaded_tokenizer, args.batch_size, args.use_sentence_transformers, args.device
|
629 |
+
)
|
630 |
+
if question_embeddings.shape[0] > 0:
|
631 |
+
try:
|
632 |
+
print(f"Сохранение эмбеддингов вопросов в кэш: {question_embeddings_cache_path}")
|
633 |
+
np.save(question_embeddings_cache_path, question_embeddings, allow_pickle=False)
|
634 |
+
except Exception as e:
|
635 |
+
print(f"Не удалось сохранить кэш эмбеддингов вопросов: {e}")
|
636 |
+
|
637 |
+
if chunk_embeddings is None:
|
638 |
+
if loaded_model is None:
|
639 |
+
print("Критическая ошибка: Модель не загружена, но требуется для векторизации чанков!")
|
640 |
+
return pd.DataFrame()
|
641 |
+
|
642 |
+
print("Повторная векторизация чанков...")
|
643 |
+
chunk_texts = chunks_df['text'].fillna('').astype(str).tolist()
|
644 |
+
chunk_embeddings = get_embeddings(
|
645 |
+
chunk_texts,
|
646 |
+
loaded_model, loaded_tokenizer, args.batch_size, args.use_sentence_transformers, args.device
|
647 |
+
)
|
648 |
+
if chunk_embeddings.shape[0] > 0:
|
649 |
+
try:
|
650 |
+
print(f"Сохранение эмбеддингов чанков в кэш: {chunk_embeddings_cache_path}")
|
651 |
+
np.save(chunk_embeddings_cache_path, chunk_embeddings, allow_pickle=False)
|
652 |
+
except Exception as e:
|
653 |
+
print(f"Не удалось сохранить кэш эмбеддингов чанков: {e}")
|
654 |
+
|
655 |
+
|
656 |
+
# Проверка совпадения количества эмбеддингов и данных
|
657 |
+
if len(question_embeddings) != len(questions_to_embed):
|
658 |
+
print(f"Ошибка: Количество эмбеддингов вопросов ({len(question_embeddings)}) не совпадает с количеством уникальных вопросов ({len(questions_to_embed)}).")
|
659 |
+
# Можно либо прервать выполнение, либо попытаться исправить
|
660 |
+
# Например, взять первые N эмбеддингов, но это может быть некорректно
|
661 |
+
sys.exit(1)
|
662 |
+
if len(chunk_embeddings) != len(chunks_df):
|
663 |
+
print(f"Ошибка: Количество эмбеддингов чанков ({len(chunk_embeddings)}) не совпадает с количеством чанков в DataFrame ({len(chunks_df)}).")
|
664 |
+
# Попытка исправить (если ошибка небольшая) или выход
|
665 |
+
if abs(len(chunk_embeddings) - len(chunks_df)) < 5:
|
666 |
+
print("Попытка обрезать лишние эмбеддинги/данные...")
|
667 |
+
min_len = min(len(chunk_embeddings), len(chunks_df))
|
668 |
+
chunk_embeddings = chunk_embeddings[:min_len]
|
669 |
+
chunks_df = chunks_df.iloc[:min_len]
|
670 |
+
print(f"Размеры выровнены до {min_len}")
|
671 |
+
else:
|
672 |
+
sys.exit(1)
|
673 |
+
|
674 |
+
|
675 |
+
# Создаем маппинг ID вопроса к индексу в эмбеддингах
|
676 |
+
question_id_to_idx = {
|
677 |
+
row['id']: i for i, (_, row) in enumerate(questions_to_embed.iterrows())
|
678 |
+
}
|
679 |
+
|
680 |
+
# 2. Расчет косинусной близости
|
681 |
+
print("Расчет косинусной близости...")
|
682 |
+
# Проверка на пустые эмбеддинги
|
683 |
+
if question_embeddings.shape[0] == 0 or chunk_embeddings.shape[0] == 0:
|
684 |
+
print("Ошибка: Отсутствуют эмбеддинги вопросов или чанков для расчета близости.")
|
685 |
+
# Возвращаем пустой DataFrame или обрабатываем ошибку иначе
|
686 |
+
return pd.DataFrame()
|
687 |
+
|
688 |
+
similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
|
689 |
+
|
690 |
+
# 3. Инициализация InjectionBuilder (если нужно)
|
691 |
+
injection_builder = None
|
692 |
+
if args.use_injection:
|
693 |
+
print("Инициализация InjectionBuilder...")
|
694 |
+
repository = InMemoryEntityRepository(all_entities)
|
695 |
+
injection_builder = InjectionBuilder(repository)
|
696 |
+
# TODO: Зарегистрировать стратегии, если необходимо
|
697 |
+
# builder.register_strategy(...)
|
698 |
+
|
699 |
+
# 4. Цикл по уникальным вопросам для оценки
|
700 |
+
results = []
|
701 |
+
print(f"Оценка для {len(questions_to_embed)} уникальных вопросов...")
|
702 |
+
|
703 |
+
for question_id_iter, question_data in tqdm(questions_to_embed.iterrows(), total=len(questions_to_embed), desc="Оценка вопросов"): # Renamed loop variable
|
704 |
+
q_id = question_data['id']
|
705 |
+
q_text = question_data['question']
|
706 |
+
|
707 |
+
# Получаем все строки из исходного датасета для этого вопроса
|
708 |
+
question_rows = search_dataset[search_dataset['id'] == q_id] # Use search_dataset
|
709 |
+
if question_rows.empty:
|
710 |
+
print(f"Предупреждение: Нет данных в search_dataset для вопроса ID={q_id}")
|
711 |
+
continue
|
712 |
+
|
713 |
+
# Получаем пункты (relevant items)
|
714 |
+
puncts = question_rows['text'].tolist()
|
715 |
+
# reference_answer больше не используется и не извлекается
|
716 |
+
|
717 |
+
# Получаем индекс вопроса в матрице близости
|
718 |
+
if q_id not in question_id_to_idx:
|
719 |
+
print(f"Предупреждение: Вопрос ID={q_id} не найден в маппинге эмбеддингов.")
|
720 |
+
continue
|
721 |
+
question_idx = question_id_to_idx[q_id]
|
722 |
+
|
723 |
+
# --- Оценка на уровне чанков (Chunk-level) ---
|
724 |
+
chunk_level_metrics = evaluate_chunk_relevance(
|
725 |
+
q_id, question_idx, puncts,
|
726 |
+
similarity_matrix, chunks_df, args.top_n, args.similarity_threshold
|
727 |
+
)
|
728 |
+
|
729 |
+
# --- Оценка на уровне сборки (Assembly-level) ---
|
730 |
+
# Удаляем assembly_relevance, основанный на reference_answer
|
731 |
+
assembly_level_metrics = {} # Start with an empty dict for assembly metrics
|
732 |
+
assembled_context = ""
|
733 |
+
top_chunk_indices = chunk_level_metrics.get("top_chunk_ids", []) # Get indices first
|
734 |
+
neighbors_included = False # Flag to log
|
735 |
+
|
736 |
+
if args.use_injection and injection_builder and top_chunk_indices:
|
737 |
+
try:
|
738 |
+
# Преобразуем ID строк обратно в UUID чанков
|
739 |
+
top_chunk_uuids = [UUID(chunks_df.iloc[idx]['chunk_id']) for idx in top_chunk_indices]
|
740 |
+
|
741 |
+
final_chunk_uuids_for_assembly = set(top_chunk_uuids) # Start with top chunks
|
742 |
+
|
743 |
+
# --- Добавляем соседей, если нужно ---
|
744 |
+
if args.include_neighbors:
|
745 |
+
neighbors_included = True
|
746 |
+
# --- Убираем логирование индексов ---
|
747 |
+
neighbor_chunks = repository.get_neighboring_chunks(chunk_ids=top_chunk_uuids, max_distance=1)
|
748 |
+
neighbor_ids = {neighbor.id for neighbor in neighbor_chunks}
|
749 |
+
# --- Логирование до/после добавления ID соседей ---
|
750 |
+
print(f" [DEBUG QID {q_id}] Кол-во ID до добавления соседей: {len(final_chunk_uuids_for_assembly)}")
|
751 |
+
print(f" [DEBUG QID {q_id}] Кол-во найденных ID соседей: {len(neighbor_ids)}")
|
752 |
+
final_chunk_uuids_for_assembly.update(neighbor_ids)
|
753 |
+
print(f" [DEBUG QID {q_id}] Кол-во ID после добавления соседей: {len(final_chunk_uuids_for_assembly)}")
|
754 |
+
# --- Конец логирования ---
|
755 |
+
# --- Убираем логирование индексов ---
|
756 |
+
else:
|
757 |
+
# --- Убираем логирование индексов ---
|
758 |
+
pass # Ничего не делаем, если соседи не включены
|
759 |
+
|
760 |
+
# Собираем контекст
|
761 |
+
# Передаем финальный набор UUID (уникальный)
|
762 |
+
assembled_context = injection_builder.build(
|
763 |
+
filtered_entities=list(final_chunk_uuids_for_assembly),
|
764 |
+
# chunk_scores= {chunks_df.loc[idx, 'chunk_id']: sim for idx, sim in zip(top_chunk_ids_for_assembly, chunk_level_metrics.get('top_chunk_similarities',[]))} # Можно добавить веса
|
765 |
+
)
|
766 |
+
|
767 |
+
# --- Новая метрика: Assembly Punct Recall ---
|
768 |
+
# Оцениваем, сколько пунктов из датасета найдено в собранном контексте
|
769 |
+
# (по вашей идее: пункт считается найденным, если хотя бы одна его часть,
|
770 |
+
# разделенная переносом строки, найдена в контексте)
|
771 |
+
assembly_found_puncts = 0
|
772 |
+
assembly_total_puncts = len(puncts)
|
773 |
+
if assembly_total_puncts > 0 and assembled_context:
|
774 |
+
# Итерируемся по каждому исходному пункту
|
775 |
+
for punct_text in puncts:
|
776 |
+
# Разбиваем пункт на части по переносу строки
|
777 |
+
# Убираем пустые строки, которые могут появиться из-за двойных переносов
|
778 |
+
punct_parts = [part for part in punct_text.split('\n') if part.strip()]
|
779 |
+
|
780 |
+
# Если пункт пустой или состоит только из пробельных символов после разбивки,
|
781 |
+
# пропускаем его (не считаем ни найденным, ни не найденным в контексте recall)
|
782 |
+
if not punct_parts:
|
783 |
+
assembly_total_puncts -= 1 # Уменьшаем общее число пунктов для расчета recall
|
784 |
+
continue
|
785 |
+
|
786 |
+
is_punct_found = False
|
787 |
+
# Итерируемся по частям пункта
|
788 |
+
for part_text in punct_parts:
|
789 |
+
# Сравниваем КАЖДУЮ ЧАСТЬ пункта с собранным контекстом
|
790 |
+
if calculate_chunk_overlap(assembled_context, part_text.strip()) >= args.similarity_threshold:
|
791 |
+
# Если ХОТЯ БЫ ОДНА часть найдена, считаем ВЕСЬ пункт найденным
|
792 |
+
is_punct_found = True
|
793 |
+
break # Дальше части этого пункта можно не проверять
|
794 |
+
|
795 |
+
# Если флаг is_punct_found стал True, увеличиваем счетчик найденных пунктов
|
796 |
+
if is_punct_found:
|
797 |
+
assembly_found_puncts += 1
|
798 |
+
|
799 |
+
# Рассчитываем recall, только если были валидные пункты для проверки
|
800 |
+
if assembly_total_puncts > 0:
|
801 |
+
assembly_level_metrics["assembly_punct_recall"] = assembly_found_puncts / assembly_total_puncts
|
802 |
+
else:
|
803 |
+
assembly_level_metrics["assembly_punct_recall"] = 0.0 # Или можно None, если нет валидных пунктов
|
804 |
+
else:
|
805 |
+
assembly_level_metrics["assembly_punct_recall"] = 0.0
|
806 |
+
# Добавляем сам текст сборки для возможного анализа (усеченный)
|
807 |
+
assembly_level_metrics["assembled_context_preview"] = assembled_context[:500] + ("..." if len(assembled_context) > 500 else "")
|
808 |
+
|
809 |
+
|
810 |
+
except Exception as e:
|
811 |
+
print(f"Ошибка при сборке/оценке контекста для вопроса ID={q_id}: {e}")
|
812 |
+
# Записываем None или 0, чтобы не прерывать процесс
|
813 |
+
assembly_level_metrics["assembly_punct_recall"] = None # Indicate error
|
814 |
+
assembly_level_metrics["assembled_context_preview"] = f"Error during assembly: {e}"
|
815 |
+
|
816 |
+
|
817 |
+
# Собираем все метрики для вопроса
|
818 |
+
question_result = {
|
819 |
+
"run_id": args.run_id,
|
820 |
+
"batch_id": args.batch_id, # --- Добавляем batch_id в результаты ---
|
821 |
+
"question_id": q_id,
|
822 |
+
"question_text": q_text,
|
823 |
+
# Параметры запуска
|
824 |
+
"model_name": args.model_name,
|
825 |
+
"chunking_strategy": args.chunking_strategy, # Log strategy
|
826 |
+
"process_tables": args.process_tables, # Log table flag
|
827 |
+
"strategy_params": args.strategy_params, # Log JSON string
|
828 |
+
"top_n": args.top_n,
|
829 |
+
"use_injection": args.use_injection,
|
830 |
+
"use_qe": qe_active, # Log QE status
|
831 |
+
"neighbors_included": neighbors_included, # Log neighbor flag
|
832 |
+
"similarity_threshold": args.similarity_threshold,
|
833 |
+
# Метрики Chunk-level
|
834 |
+
**chunk_level_metrics,
|
835 |
+
# Метрики Assembly-level (теперь с recall по пунктам)
|
836 |
+
**assembly_level_metrics,
|
837 |
+
# Тексты для отладки (эталонный ответ удален, сборка добавлена выше)
|
838 |
+
# "assembled_context": assembled_context[:500] + "..." if assembled_context else "",
|
839 |
+
}
|
840 |
+
results.append(question_result)
|
841 |
+
|
842 |
+
print("Оценка завершена.")
|
843 |
+
return pd.DataFrame(results)
|
844 |
+
|
845 |
+
|
846 |
+
def evaluate_chunk_relevance(
|
847 |
+
question_id: int,
|
848 |
+
question_idx: int,
|
849 |
+
puncts: list[str],
|
850 |
+
similarity_matrix: np.ndarray,
|
851 |
+
chunks_df: pd.DataFrame,
|
852 |
+
top_n: int,
|
853 |
+
similarity_threshold: float
|
854 |
+
) -> dict:
|
855 |
+
"""
|
856 |
+
Оценивает релевантность чанков для одного вопроса.
|
857 |
+
(Адаптировано из evaluate_for_top_n_with_mapping в evaluate_chunking.py)
|
858 |
+
|
859 |
+
Возвращает словарь с метриками для этого вопроса.
|
860 |
+
"""
|
861 |
+
metrics = {
|
862 |
+
"chunk_text_precision": 0.0,
|
863 |
+
"chunk_text_recall": 0.0,
|
864 |
+
"chunk_text_f1": 0.0,
|
865 |
+
"found_puncts": 0,
|
866 |
+
"total_puncts": len(puncts),
|
867 |
+
"relevant_chunks": 0,
|
868 |
+
"total_chunks_in_top_n": 0,
|
869 |
+
"top_chunk_ids": [], # Индексы строк в chunks_df
|
870 |
+
"top_chunk_similarities": [],
|
871 |
+
}
|
872 |
+
|
873 |
+
if chunks_df.empty or similarity_matrix.shape[1] == 0:
|
874 |
+
print(f"Предупреждение (QID {question_id}): Нет чанков для оценки.")
|
875 |
+
return metrics
|
876 |
+
|
877 |
+
# Получаем схожести всех чанков с текущим вопросом
|
878 |
+
question_similarities = similarity_matrix[question_idx, :]
|
879 |
+
|
880 |
+
# Сортируем чанки по схожести и берем top_n
|
881 |
+
# argsort возвращает индексы элементов, которые бы отсортировали массив
|
882 |
+
# Берем последние N индексов (-top_n:) и разворачиваем ([::-1]) для убывания
|
883 |
+
# Добавляем проверку на случай если top_n > количества чанков
|
884 |
+
if top_n >= similarity_matrix.shape[1]:
|
885 |
+
sorted_chunk_indices = np.argsort(question_similarities)[::-1] # Берем все, сортируем по убыванию
|
886 |
+
else:
|
887 |
+
sorted_chunk_indices = np.argsort(question_similarities)[-top_n:][::-1]
|
888 |
+
|
889 |
+
# Ограничиваем top_n, если чанков меньше (это должно быть сделано выше, но дублируем для надежности)
|
890 |
+
actual_top_n = min(top_n, len(sorted_chunk_indices))
|
891 |
+
top_chunk_indices = sorted_chunk_indices[:actual_top_n]
|
892 |
+
|
893 |
+
# Сохраняем ID и схожести топ-чанков
|
894 |
+
metrics["top_chunk_ids"] = top_chunk_indices.tolist()
|
895 |
+
metrics["top_chunk_similarities"] = question_similarities[top_chunk_indices].tolist()
|
896 |
+
|
897 |
+
|
898 |
+
# Отбираем данные топ-чанков
|
899 |
+
top_chunks_df = chunks_df.iloc[top_chunk_indices]
|
900 |
+
metrics["total_chunks_in_top_n"] = len(top_chunks_df)
|
901 |
+
|
902 |
+
if metrics["total_chunks_in_top_n"] == 0:
|
903 |
+
return metrics # Если нет топ-чанков, метрики остаются нулевыми
|
904 |
+
|
905 |
+
# Оценка на основе текста (пунктов)
|
906 |
+
punct_found = [False] * metrics["total_puncts"]
|
907 |
+
question_relevant_chunks = 0
|
908 |
+
for i, (idx, chunk_row) in enumerate(top_chunks_df.iterrows()):
|
909 |
+
chunk_text = chunk_row['text']
|
910 |
+
is_relevant_to_punct = False
|
911 |
+
for j, punct_text in enumerate(puncts):
|
912 |
+
overlap = calculate_chunk_overlap(chunk_text, punct_text)
|
913 |
+
if overlap >= similarity_threshold:
|
914 |
+
is_relevant_to_punct = True
|
915 |
+
punct_found[j] = True
|
916 |
+
if is_relevant_to_punct:
|
917 |
+
question_relevant_chunks += 1
|
918 |
+
|
919 |
+
metrics["found_puncts"] = sum(punct_found)
|
920 |
+
metrics["relevant_chunks"] = question_relevant_chunks
|
921 |
+
|
922 |
+
if metrics["total_chunks_in_top_n"] > 0:
|
923 |
+
metrics["chunk_text_precision"] = metrics["relevant_chunks"] / metrics["total_chunks_in_top_n"]
|
924 |
+
if metrics["total_puncts"] > 0:
|
925 |
+
metrics["chunk_text_recall"] = metrics["found_puncts"] / metrics["total_puncts"]
|
926 |
+
if metrics["chunk_text_precision"] + metrics["chunk_text_recall"] > 0:
|
927 |
+
metrics["chunk_text_f1"] = (2 * metrics["chunk_text_precision"] * metrics["chunk_text_recall"] /
|
928 |
+
(metrics["chunk_text_precision"] + metrics["chunk_text_recall"]))
|
929 |
+
|
930 |
+
return metrics
|
931 |
+
|
932 |
+
|
933 |
+
# --- Основная функция ---
|
934 |
+
|
935 |
+
def main():
|
936 |
+
"""Основная функция скрипта."""
|
937 |
+
args = parse_args()
|
938 |
+
print(f"Запуск оценки с ID: {args.run_id}")
|
939 |
+
print(f"Параметры: {vars(args)}")
|
940 |
+
|
941 |
+
# --- Кэширование Чанкинга ---
|
942 |
+
CACHE_DIR_PATH = Path(args.cache_dir)
|
943 |
+
try:
|
944 |
+
# Парсим параметры стратегии один раз
|
945 |
+
parsed_strategy_params = json.loads(args.strategy_params)
|
946 |
+
except json.JSONDecodeError:
|
947 |
+
print(f"Предупреждение: Невалидный JSON в strategy_params: '{args.strategy_params}'. Используются параметры по умолчанию для хэша кэша.")
|
948 |
+
parsed_strategy_params = {}
|
949 |
+
|
950 |
+
chunking_hash = _get_chunking_cache_hash(
|
951 |
+
args.data_folder,
|
952 |
+
args.chunking_strategy,
|
953 |
+
args.process_tables,
|
954 |
+
parsed_strategy_params
|
955 |
+
)
|
956 |
+
chunks_df_cache_path = _get_cache_path(CACHE_DIR_PATH, chunking_hash, "chunks_df.parquet")
|
957 |
+
entities_cache_path = _get_cache_path(CACHE_DIR_PATH, chunking_hash, "final_entities.pkl")
|
958 |
+
|
959 |
+
chunks_df = None
|
960 |
+
all_entities = None
|
961 |
+
|
962 |
+
if chunks_df_cache_path.exists() and entities_cache_path.exists():
|
963 |
+
print(f"Найден кэш чанкинга (hash: {chunking_hash}). Загрузка...")
|
964 |
+
try:
|
965 |
+
chunks_df = pd.read_parquet(chunks_df_cache_path)
|
966 |
+
with open(entities_cache_path, 'rb') as f:
|
967 |
+
all_entities = pickle.load(f)
|
968 |
+
print(f"Кэш чанкинга успешно загружен: {len(chunks_df)} чанков, {len(all_entities)} сущностей.")
|
969 |
+
except Exception as e:
|
970 |
+
print(f"Ошибка загрузки кэша чанкинга: {e}. Выполняем чанкинг заново.")
|
971 |
+
chunks_df = None
|
972 |
+
all_entities = None
|
973 |
+
|
974 |
+
if chunks_df is None or all_entities is None:
|
975 |
+
print("Кэш чанкинга не найден или поврежден. Выполнение чтения документов и чанкинга...")
|
976 |
+
# 1. Загрузка данных
|
977 |
+
documents_map = read_documents(args.data_folder)
|
978 |
+
if not documents_map:
|
979 |
+
print("Нет документов для обработки. Завершение.")
|
980 |
+
return
|
981 |
+
|
982 |
+
# 2. Чанкинг
|
983 |
+
chunks_df, all_entities = perform_chunking(
|
984 |
+
documents_map,
|
985 |
+
args.chunking_strategy, # Pass strategy
|
986 |
+
args.process_tables, # Pass table flag
|
987 |
+
args.strategy_params # Pass JSON string parameters
|
988 |
+
)
|
989 |
+
if chunks_df.empty:
|
990 |
+
print("После чанкинга не осталось чанков для обработки. Завершение.")
|
991 |
+
return
|
992 |
+
|
993 |
+
# Сохраняем результаты чанкинга в кэш
|
994 |
+
try:
|
995 |
+
print(f"Сохранение результатов чанкинга в кэш (hash: {chunking_hash})...")
|
996 |
+
# Убедимся, что директория кэша существует (на всякий случай)
|
997 |
+
chunks_df_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
998 |
+
entities_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
999 |
+
|
1000 |
+
chunks_df.to_parquet(chunks_df_cache_path)
|
1001 |
+
with open(entities_cache_path, 'wb') as f:
|
1002 |
+
pickle.dump(all_entities, f)
|
1003 |
+
print("Результаты чанкинга сохранены в кэш.")
|
1004 |
+
except Exception as e:
|
1005 |
+
print(f"Ошибка сохранения кэша чанкинга: {e}")
|
1006 |
+
|
1007 |
+
# --- Конец Кэширования Чанкинга ---
|
1008 |
+
|
1009 |
+
# Загружаем поисковый датасет (это нужно делать всегда, т.к. он не кэшируется здесь)
|
1010 |
+
search_df, questions_to_embed = load_datasets(args.search_dataset_path)
|
1011 |
+
|
1012 |
+
# 3. Выполнение оценки (передаем загруженные или свежесгенерированные chunks_df и all_entities)
|
1013 |
+
results_df = evaluate_run(
|
1014 |
+
search_df, questions_to_embed, chunks_df, all_entities,
|
1015 |
+
None, None, args # Передаем None для model и tokenizer
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
# 5. Сохранение результатов
|
1019 |
+
if not results_df.empty:
|
1020 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1021 |
+
# output_filename = f"results_{args.run_id}.csv"
|
1022 |
+
# Добавляем batch_id в имя файла для лучшей группировки
|
1023 |
+
output_filename = f"results_{args.batch_id}_{args.run_id}.csv"
|
1024 |
+
output_path = os.path.join(args.output_dir, output_filename)
|
1025 |
+
try:
|
1026 |
+
results_df.to_csv(output_path, index=False, encoding='utf-8')
|
1027 |
+
print(f"Детальные результаты сохранены в: {output_path}")
|
1028 |
+
except Exception as e:
|
1029 |
+
print(f"Ошибка при сохранении результатов в {output_path}: {e}")
|
1030 |
+
else:
|
1031 |
+
print("Нет результатов для сохранения.")
|
1032 |
+
|
1033 |
+
if __name__ == "__main__":
|
1034 |
+
main()
|
scripts/testing/plot_results.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Скрипт для визуализации агрегированных результатов тестирования RAG.
|
5 |
+
|
6 |
+
Читает данные из Excel-файла, сгенерированного aggregate_results.py,
|
7 |
+
и строит различные графики для анализа влияния параметров на метрики.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import pandas as pd
|
16 |
+
import seaborn as sns
|
17 |
+
|
18 |
+
# --- Настройки ---
|
19 |
+
DEFAULT_RESULTS_FILE = "data/output/aggregated_results.xlsx" # Файл с агрегированными данными
|
20 |
+
DEFAULT_PLOTS_DIR = "data/output/plots" # Куда сохранять графики
|
21 |
+
|
22 |
+
# Настройки графиков
|
23 |
+
plt.rcParams['font.family'] = 'DejaVu Sans' # Шрифт с поддержкой кириллицы
|
24 |
+
sns.set_style("whitegrid")
|
25 |
+
FIGSIZE = (16, 10) # Увеличенный размер для сложных графиков
|
26 |
+
DPI = 300
|
27 |
+
PALETTE = "viridis" # Цветовая палитра
|
28 |
+
|
29 |
+
# --- Маппинг названий столбцов (копия из aggregate_results.py) ---
|
30 |
+
COLUMN_NAME_MAPPING = {
|
31 |
+
# Параметры запуска из pipeline.py
|
32 |
+
'run_id': 'ID Запуска',
|
33 |
+
'model_name': 'Модель',
|
34 |
+
'chunking_strategy': 'Стратегия Чанкинга',
|
35 |
+
'strategy_params': 'Параметры Стратегии',
|
36 |
+
'process_tables': 'Обраб. Таблиц',
|
37 |
+
'top_n': 'Top N',
|
38 |
+
'use_injection': 'Сборка Контекста',
|
39 |
+
'use_qe': 'Query Expansion',
|
40 |
+
'neighbors_included': 'Вкл. Соседей',
|
41 |
+
'similarity_threshold': 'Порог Схожести',
|
42 |
+
|
43 |
+
# Идентификаторы из датасета (для детальных результатов)
|
44 |
+
'question_id': 'ID Вопроса',
|
45 |
+
'question_text': 'Текст Вопроса',
|
46 |
+
|
47 |
+
# Детальные метрики из pipeline.py
|
48 |
+
'chunk_text_precision': 'Точность (Чанк-Текст)',
|
49 |
+
'chunk_text_recall': 'Полнота (Чанк-Текст)',
|
50 |
+
'chunk_text_f1': 'F1 (Чанк-Текст)',
|
51 |
+
'found_puncts': 'Найдено Пунктов',
|
52 |
+
'total_puncts': 'Всего Пунктов',
|
53 |
+
'relevant_chunks': 'Релевантных Чанков',
|
54 |
+
'total_chunks_in_top_n': 'Всего Чанков в Топ-N',
|
55 |
+
'assembly_punct_recall': 'Полнота (Сборка-Пункт)',
|
56 |
+
'assembled_context_preview': 'Предпросмотр Сборки',
|
57 |
+
# 'top_chunk_ids': 'Индексы Топ-Чанков', # Списки, могут плохо отображаться
|
58 |
+
# 'top_chunk_similarities': 'Схожести Топ-Чанков', # Списки
|
59 |
+
|
60 |
+
# Агрегированные метрики (добавляются в calculate_aggregated_metrics)
|
61 |
+
'weighted_chunk_text_precision': 'Weighted Точность (Чанк-Текст)',
|
62 |
+
'weighted_chunk_text_recall': 'Weighted Полнота (Чанк-Текст)',
|
63 |
+
'weighted_chunk_text_f1': 'Weighted F1 (Чанк-Текст)',
|
64 |
+
'weighted_assembly_punct_recall': 'Weighted Полнота (Сборка-Пункт)',
|
65 |
+
|
66 |
+
'macro_chunk_text_precision': 'Macro Точность (Чанк-Текст)',
|
67 |
+
'macro_chunk_text_recall': 'Macro Полнота (Чанк-Текст)',
|
68 |
+
'macro_chunk_text_f1': 'Macro F1 (Чанк-Текст)',
|
69 |
+
'macro_assembly_punct_recall': 'Macro Полнота (Сборка-Пункт)',
|
70 |
+
|
71 |
+
'micro_text_precision': 'Micro Точность (Текст)',
|
72 |
+
'micro_text_recall': 'Micro Полнота (Текст)',
|
73 |
+
'micro_text_f1': 'Micro F1 (Текст)',
|
74 |
+
}
|
75 |
+
# --- Конец маппинга ---
|
76 |
+
|
77 |
+
def parse_args():
|
78 |
+
"""Парсит аргументы командной строки."""
|
79 |
+
parser = argparse.ArgumentParser(description="Визуализация результатов тестирования RAG")
|
80 |
+
|
81 |
+
parser.add_argument("--results-file", type=str, default=DEFAULT_RESULTS_FILE,
|
82 |
+
help=f"Путь к Excel-файлу с агрегированными результатами (по умолчанию: {DEFAULT_RESULTS_FILE})")
|
83 |
+
parser.add_argument("--plots-dir", type=str, default=DEFAULT_PLOTS_DIR,
|
84 |
+
help=f"Директория для сохранения графиков (по умолчанию: {DEFAULT_PLOTS_DIR})")
|
85 |
+
parser.add_argument("--sheet-name", type=str, default="Агрегированные метрики",
|
86 |
+
help="Название листа в Excel-файле для чтения данных")
|
87 |
+
|
88 |
+
return parser.parse_args()
|
89 |
+
|
90 |
+
def setup_plots_directory(plots_dir: str) -> None:
|
91 |
+
"""Создает директорию для графиков, если она не существует."""
|
92 |
+
if not os.path.exists(plots_dir):
|
93 |
+
os.makedirs(plots_dir)
|
94 |
+
print(f"Создана директория для графиков: {plots_dir}")
|
95 |
+
else:
|
96 |
+
print(f"Использование существующей директории для графиков: {plots_dir}")
|
97 |
+
|
98 |
+
def load_aggregated_data(file_path: str, sheet_name: str) -> pd.DataFrame:
|
99 |
+
"""Загружает данные из указанного листа Excel-файла."""
|
100 |
+
print(f"Загрузка данных из файла: {file_path}, лист: {sheet_name}")
|
101 |
+
try:
|
102 |
+
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
103 |
+
print(f"Загружено {len(df)} строк.")
|
104 |
+
print(f"Колонки: {df.columns.tolist()}")
|
105 |
+
# Добавим проверку на необходимые колонки (РУССКИЕ НАЗВАНИЯ)
|
106 |
+
required_cols_rus = [
|
107 |
+
COLUMN_NAME_MAPPING['model_name'], COLUMN_NAME_MAPPING['chunking_strategy'],
|
108 |
+
COLUMN_NAME_MAPPING['strategy_params'], COLUMN_NAME_MAPPING['process_tables'],
|
109 |
+
COLUMN_NAME_MAPPING['top_n'], COLUMN_NAME_MAPPING['use_injection'],
|
110 |
+
COLUMN_NAME_MAPPING['use_qe'], COLUMN_NAME_MAPPING['neighbors_included'],
|
111 |
+
COLUMN_NAME_MAPPING['similarity_threshold']
|
112 |
+
]
|
113 |
+
# Проверяем только те, что есть в маппинге
|
114 |
+
missing_required = [col for col in required_cols_rus if col not in df.columns]
|
115 |
+
if missing_required:
|
116 |
+
print(f"Предупреждение: Не все ожидаемые колонки параметров найдены в данных: {missing_required}")
|
117 |
+
|
118 |
+
# --- Добавим парсинг strategy_params из JSON строки в словарь ---
|
119 |
+
params_col = COLUMN_NAME_MAPPING['strategy_params']
|
120 |
+
if params_col in df.columns:
|
121 |
+
def safe_json_loads(x):
|
122 |
+
try:
|
123 |
+
# Обработка NaN и пустых строк
|
124 |
+
if pd.isna(x) or not isinstance(x, str) or not x.strip():
|
125 |
+
return {}
|
126 |
+
return json.loads(x)
|
127 |
+
except (json.JSONDecodeError, TypeError):
|
128 |
+
return {} # Возвращаем пустой словарь при ошибке
|
129 |
+
|
130 |
+
df[params_col] = df[params_col].apply(safe_json_loads)
|
131 |
+
# Создаем строковое представление для группировки и лейблов
|
132 |
+
df[f"{params_col}_str"] = df[params_col].apply(
|
133 |
+
lambda d: json.dumps(d, sort_keys=True, ensure_ascii=False)
|
134 |
+
)
|
135 |
+
print(f"Колонка '{params_col}' преобразована из JSON строк.")
|
136 |
+
# --------------------------------------------------------------
|
137 |
+
|
138 |
+
return df
|
139 |
+
except FileNotFoundError:
|
140 |
+
print(f"Ошибка: Файл не найден: {file_path}")
|
141 |
+
return pd.DataFrame()
|
142 |
+
except ValueError as e:
|
143 |
+
print(f"Ошибка: Лист '{sheet_name}' не найден в файле {file_path}. Доступные листы: {pd.ExcelFile(file_path).sheet_names}")
|
144 |
+
return pd.DataFrame()
|
145 |
+
except Exception as e:
|
146 |
+
print(f"Ошибка при чтении Excel файла: {e}")
|
147 |
+
return pd.DataFrame()
|
148 |
+
|
149 |
+
# --- Функции построения графиков --- #
|
150 |
+
|
151 |
+
def plot_metric_vs_top_n(
|
152 |
+
df: pd.DataFrame,
|
153 |
+
metric_name_rus: str, # Ожидаем русское имя метрики
|
154 |
+
fixed_strategy: str | None,
|
155 |
+
fixed_strategy_params: str | None, # Ожидаем строку JSON или None
|
156 |
+
plots_dir: str
|
157 |
+
) -> None:
|
158 |
+
"""
|
159 |
+
Строит график зависимости метрики от top_n для разных моделей
|
160 |
+
(при фиксированных параметрах чанкинга).
|
161 |
+
Разделяет линии по значению use_injection.
|
162 |
+
Использует русские названия колонок.
|
163 |
+
"""
|
164 |
+
# Используем русские названия колонок из маппинга
|
165 |
+
metric_col_rus = metric_name_rus # Передаем уже готовое русское имя
|
166 |
+
top_n_col_rus = COLUMN_NAME_MAPPING['top_n']
|
167 |
+
model_col_rus = COLUMN_NAME_MAPPING['model_name']
|
168 |
+
injection_col_rus = COLUMN_NAME_MAPPING['use_injection']
|
169 |
+
strategy_col_rus = COLUMN_NAME_MAPPING['chunking_strategy']
|
170 |
+
params_str_col_rus = f"{COLUMN_NAME_MAPPING['strategy_params']}_str" # Используем строковое представление
|
171 |
+
|
172 |
+
if metric_col_rus not in df.columns:
|
173 |
+
print(f"График пропущен: Колонка '{metric_col_rus}' не найдена.")
|
174 |
+
return
|
175 |
+
|
176 |
+
plot_df = df.copy()
|
177 |
+
|
178 |
+
# Фильтруем по параметрам чанкинга, если задано
|
179 |
+
chunk_suffix = "all_strategies_all_params"
|
180 |
+
if fixed_strategy and strategy_col_rus in plot_df.columns:
|
181 |
+
plot_df = plot_df[plot_df[strategy_col_rus] == fixed_strategy]
|
182 |
+
chunk_suffix = f"strategy_{fixed_strategy}"
|
183 |
+
# Фильтруем по строковому пред��тавлению параметров
|
184 |
+
if fixed_strategy_params and params_str_col_rus in plot_df.columns:
|
185 |
+
plot_df = plot_df[plot_df[params_str_col_rus] == fixed_strategy_params]
|
186 |
+
# Генерируем короткий хэш для параметров в названии файла
|
187 |
+
params_hash = hash(fixed_strategy_params) # Хэш от строки
|
188 |
+
chunk_suffix += f"_params-{params_hash:x}" # Hex hash
|
189 |
+
|
190 |
+
if plot_df.empty:
|
191 |
+
print(f"График Metric vs Top-N пропущен: Нет данных для strategy={fixed_strategy}, params={fixed_strategy_params}")
|
192 |
+
return
|
193 |
+
|
194 |
+
plt.figure(figsize=FIGSIZE)
|
195 |
+
sns.lineplot(
|
196 |
+
data=plot_df,
|
197 |
+
x=top_n_col_rus,
|
198 |
+
y=metric_col_rus,
|
199 |
+
hue=model_col_rus,
|
200 |
+
style=injection_col_rus, # Разные стили линий для True/False
|
201 |
+
markers=True,
|
202 |
+
markersize=8,
|
203 |
+
linewidth=2,
|
204 |
+
palette=PALETTE
|
205 |
+
)
|
206 |
+
|
207 |
+
plt.title(f"Зависимость {metric_col_rus} от top_n ({chunk_suffix})")
|
208 |
+
plt.xlabel("Top N")
|
209 |
+
plt.ylabel(metric_col_rus.replace("_", " ").title())
|
210 |
+
plt.legend(title="Модель / Сборка", bbox_to_anchor=(1.05, 1), loc='upper left')
|
211 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
212 |
+
plt.tight_layout(rect=[0, 0, 0.85, 1]) # Оставляем место для легенды
|
213 |
+
|
214 |
+
filename = f"plot_{metric_col_rus.replace(' ', '_').replace('(', '').replace(')', '')}_vs_top_n_{chunk_suffix}.png"
|
215 |
+
filepath = os.path.join(plots_dir, filename)
|
216 |
+
plt.savefig(filepath, dpi=DPI)
|
217 |
+
plt.close()
|
218 |
+
print(f"Создан график: {filepath}")
|
219 |
+
|
220 |
+
def plot_injection_comparison(
|
221 |
+
df: pd.DataFrame,
|
222 |
+
metric_name_rus: str, # Ожидаем русское имя метрики
|
223 |
+
plots_dir: str
|
224 |
+
) -> None:
|
225 |
+
"""
|
226 |
+
Сравнивает метрики с использованием и без использования сборки контекста
|
227 |
+
в виде парных столбчатых диаграмм для разных моделей и параметров чанкинга.
|
228 |
+
Использует русские названия колонок.
|
229 |
+
"""
|
230 |
+
# Русские названия колонок
|
231 |
+
metric_col_rus = metric_name_rus
|
232 |
+
injection_col_rus = COLUMN_NAME_MAPPING['use_injection']
|
233 |
+
model_col_rus = COLUMN_NAME_MAPPING['model_name']
|
234 |
+
strategy_col_rus = COLUMN_NAME_MAPPING['chunking_strategy']
|
235 |
+
params_str_col_rus = f"{COLUMN_NAME_MAPPING['strategy_params']}_str"
|
236 |
+
tables_col_rus = COLUMN_NAME_MAPPING['process_tables']
|
237 |
+
qe_col_rus = COLUMN_NAME_MAPPING['use_qe']
|
238 |
+
neighbors_col_rus = COLUMN_NAME_MAPPING['neighbors_included']
|
239 |
+
top_n_col_rus = COLUMN_NAME_MAPPING['top_n']
|
240 |
+
threshold_col_rus = COLUMN_NAME_MAPPING['similarity_threshold']
|
241 |
+
|
242 |
+
if metric_col_rus not in df.columns or injection_col_rus not in df.columns:
|
243 |
+
print(f"График сравнения сборки пропущен: Колонки '{metric_col_rus}' или '{injection_col_rus}' не найдены.")
|
244 |
+
return
|
245 |
+
|
246 |
+
plot_df = df.copy()
|
247 |
+
# Используем русские названия при создании лейбла
|
248 |
+
plot_df['config_label'] = plot_df.apply(
|
249 |
+
lambda r: (
|
250 |
+
f"{r.get(model_col_rus, 'N/A')}\n"
|
251 |
+
f"Стратегия: {r.get(strategy_col_rus, 'N/A')}\n"
|
252 |
+
# Используем строковое представление параметров
|
253 |
+
f"Параметры: {r.get(params_str_col_rus, '{}')[:30]}...\n"
|
254 |
+
f"Табл: {r.get(tables_col_rus, 'N/A')}, QE: {r.get(qe_col_rus, 'N/A')}, Соседи: {r.get(neighbors_col_rus, 'N/A')}\n"
|
255 |
+
f"TopN: {int(r.get(top_n_col_rus, 0))}, Порог: {r.get(threshold_col_rus, 0):.2f}"
|
256 |
+
),
|
257 |
+
axis=1
|
258 |
+
)
|
259 |
+
|
260 |
+
# Оставляем только строки, где есть и True, и False для данного флага
|
261 |
+
# Группируем по config_label, считаем уникальные значения флага use_injection
|
262 |
+
counts = plot_df.groupby('config_label')[injection_col_rus].nunique()
|
263 |
+
configs_with_both = counts[counts >= 2].index # Используем >= 2 на случай дубликатов
|
264 |
+
plot_df = plot_df[plot_df['config_label'].isin(configs_with_both)]
|
265 |
+
|
266 |
+
if plot_df.empty:
|
267 |
+
print(f"График сравнения сборки пропущен: Нет конфигураций с обоими вариантами {injection_col_rus}.")
|
268 |
+
return
|
269 |
+
|
270 |
+
# Ограничим количество конфигураций для читаемости (по средней метрике)
|
271 |
+
top_configs = plot_df.groupby('config_label')[metric_col_rus].mean().nlargest(10).index # Уменьшил до 10
|
272 |
+
plot_df = plot_df[plot_df['config_label'].isin(top_configs)]
|
273 |
+
|
274 |
+
if plot_df.empty:
|
275 |
+
print(f"График сравнения сборки пропущен: Не осталось да��ных после фильтрации топ-конфигураций.")
|
276 |
+
return
|
277 |
+
|
278 |
+
plt.figure(figsize=(FIGSIZE[0]*0.9, FIGSIZE[1]*0.7)) # Уменьшил размер
|
279 |
+
sns.barplot(
|
280 |
+
data=plot_df,
|
281 |
+
x='config_label',
|
282 |
+
y=metric_col_rus,
|
283 |
+
hue=injection_col_rus,
|
284 |
+
palette=PALETTE
|
285 |
+
)
|
286 |
+
|
287 |
+
plt.title(f"Сравнение {metric_col_rus} с/без {injection_col_rus}")
|
288 |
+
plt.xlabel("Конфигурация")
|
289 |
+
plt.ylabel(metric_col_rus)
|
290 |
+
plt.xticks(rotation=60, ha='right', fontsize=8) # Уменьшил шрифт, увеличил поворот
|
291 |
+
plt.legend(title=injection_col_rus)
|
292 |
+
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
|
293 |
+
plt.tight_layout()
|
294 |
+
|
295 |
+
filename = f"plot_{metric_col_rus.replace(' ', '_').replace('(', '').replace(')', '')}_injection_comparison.png"
|
296 |
+
filepath = os.path.join(plots_dir, filename)
|
297 |
+
plt.savefig(filepath, dpi=DPI)
|
298 |
+
plt.close()
|
299 |
+
print(f"Создан график: {filepath}")
|
300 |
+
|
301 |
+
# --- Новая функция для сравнения булевых флагов ---
|
302 |
+
def plot_boolean_flag_comparison(
|
303 |
+
df: pd.DataFrame,
|
304 |
+
metric_name_rus: str, # Ожидаем русское имя метрики
|
305 |
+
flag_column_eng: str, # Ожидаем английское имя флага для поиска в маппинге
|
306 |
+
plots_dir: str
|
307 |
+
) -> None:
|
308 |
+
"""
|
309 |
+
Сравнивает метрики при True/False значениях указанного булева флага
|
310 |
+
в виде парных столбчатых диаграмм для разных конфигураций.
|
311 |
+
Использует русские названия колонок.
|
312 |
+
"""
|
313 |
+
# Русские названия колонок
|
314 |
+
metric_col_rus = metric_name_rus
|
315 |
+
try:
|
316 |
+
flag_col_rus = COLUMN_NAME_MAPPING[flag_column_eng]
|
317 |
+
except KeyError:
|
318 |
+
print(f"Ошибка: Английское имя флага '{flag_column_eng}' не найдено в COLUMN_NAME_MAPPING.")
|
319 |
+
return
|
320 |
+
|
321 |
+
model_col_rus = COLUMN_NAME_MAPPING['model_name']
|
322 |
+
strategy_col_rus = COLUMN_NAME_MAPPING['chunking_strategy']
|
323 |
+
params_str_col_rus = f"{COLUMN_NAME_MAPPING['strategy_params']}_str"
|
324 |
+
injection_col_rus = COLUMN_NAME_MAPPING['use_injection']
|
325 |
+
top_n_col_rus = COLUMN_NAME_MAPPING['top_n']
|
326 |
+
# Другие флаги
|
327 |
+
tables_col_rus = COLUMN_NAME_MAPPING['process_tables']
|
328 |
+
qe_col_rus = COLUMN_NAME_MAPPING['use_qe']
|
329 |
+
neighbors_col_rus = COLUMN_NAME_MAPPING['neighbors_included']
|
330 |
+
|
331 |
+
|
332 |
+
if metric_col_rus not in df.columns or flag_col_rus not in df.columns:
|
333 |
+
print(f"График сравнения флага '{flag_col_rus}' пропущен: Колонки '{metric_col_rus}' или '{flag_col_rus}' не найдены.")
|
334 |
+
return
|
335 |
+
|
336 |
+
plot_df = df.copy()
|
337 |
+
# Создаем обобщенный лейбл конфигурации, исключая сам флаг
|
338 |
+
plot_df['config_label'] = plot_df.apply(
|
339 |
+
lambda r: (
|
340 |
+
f"{r.get(model_col_rus, 'N/A')}\n"
|
341 |
+
f"Стратегия: {r.get(strategy_col_rus, 'N/A')} Параметры: {r.get(params_str_col_rus, '{}')[:20]}...\n"
|
342 |
+
f"Сборка: {r.get(injection_col_rus, 'N/A')}, TopN: {int(r.get(top_n_col_rus, 0))}"
|
343 |
+
# Динамически добавляем другие флаги, кроме сравниваемого
|
344 |
+
+ (f", Табл: {r.get(tables_col_rus, 'N/A')}" if flag_col_rus != tables_col_rus else "")
|
345 |
+
+ (f", QE: {r.get(qe_col_rus, 'N/A')}" if flag_col_rus != qe_col_rus else "")
|
346 |
+
+ (f", Соседи: {r.get(neighbors_col_rus, 'N/A')}" if flag_col_rus != neighbors_col_rus else "")
|
347 |
+
),
|
348 |
+
axis=1
|
349 |
+
)
|
350 |
+
|
351 |
+
# Оставляем только строки, где есть и True, и False для данного флага
|
352 |
+
counts = plot_df.groupby('config_label')[flag_col_rus].nunique()
|
353 |
+
configs_with_both = counts[counts >= 2].index # Используем >= 2
|
354 |
+
plot_df = plot_df[plot_df['config_label'].isin(configs_with_both)]
|
355 |
+
|
356 |
+
if plot_df.empty:
|
357 |
+
print(f"График сравнения флага '{flag_col_rus}' пропущен: Нет конфигураций с обоими вариантами {flag_col_rus}.")
|
358 |
+
return
|
359 |
+
|
360 |
+
# Ограничим количество конфигураций для читаемости (по средней метрике)
|
361 |
+
top_configs = plot_df.groupby('config_label')[metric_col_rus].mean().nlargest(10).index # Уменьшил до 10
|
362 |
+
plot_df = plot_df[plot_df['config_label'].isin(top_configs)]
|
363 |
+
|
364 |
+
if plot_df.empty:
|
365 |
+
print(f"График сравнения флага '{flag_col_rus}' пропущен: Не осталось данных после фильтрации топ-конфигураций.")
|
366 |
+
return
|
367 |
+
|
368 |
+
plt.figure(figsize=(FIGSIZE[0]*0.9, FIGSIZE[1]*0.7)) # Уменьшил размер
|
369 |
+
sns.barplot(
|
370 |
+
data=plot_df,
|
371 |
+
x='config_label',
|
372 |
+
y=metric_col_rus,
|
373 |
+
hue=flag_col_rus,
|
374 |
+
palette=PALETTE
|
375 |
+
)
|
376 |
+
|
377 |
+
plt.title(f"Сравнение {metric_col_rus} в зависимости от '{flag_col_rus}'")
|
378 |
+
plt.xlabel("Конфигурация")
|
379 |
+
plt.ylabel(metric_col_rus)
|
380 |
+
plt.xticks(rotation=60, ha='right', fontsize=8) # Уменьшил шрифт, увеличил поворот
|
381 |
+
plt.legend(title=f"{flag_col_rus}")
|
382 |
+
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
|
383 |
+
plt.tight_layout()
|
384 |
+
|
385 |
+
filename = f"plot_{metric_col_rus.replace(' ', '_').replace('(', '').replace(')', '')}_{flag_column_eng}_comparison.png"
|
386 |
+
filepath = os.path.join(plots_dir, filename)
|
387 |
+
plt.savefig(filepath, dpi=DPI)
|
388 |
+
plt.close()
|
389 |
+
print(f"Создан график: {filepath}")
|
390 |
+
|
391 |
+
# --- Основная функция ---
|
392 |
+
def main():
|
393 |
+
"""Основная функция скрипта."""
|
394 |
+
args = parse_args()
|
395 |
+
|
396 |
+
setup_plots_directory(args.plots_dir)
|
397 |
+
df = load_aggregated_data(args.results_file, args.sheet_name)
|
398 |
+
|
399 |
+
if df.empty:
|
400 |
+
print("Нет данных для построения графиков. Завершение.")
|
401 |
+
return
|
402 |
+
|
403 |
+
# Определяем метрики для построения графиков (используем английские ключи для поиска русских имен)
|
404 |
+
metric_keys = [
|
405 |
+
'weighted_chunk_text_recall', 'weighted_chunk_text_f1', 'weighted_assembly_punct_recall',
|
406 |
+
'macro_chunk_text_recall', 'macro_chunk_text_f1', 'macro_assembly_punct_recall',
|
407 |
+
'micro_text_recall', 'micro_text_f1'
|
408 |
+
]
|
409 |
+
|
410 |
+
# Получаем существующие русские имена метрик в DataFrame
|
411 |
+
existing_metrics_rus = [COLUMN_NAME_MAPPING.get(key) for key in metric_keys if COLUMN_NAME_MAPPING.get(key) in df.columns]
|
412 |
+
|
413 |
+
# Определяем фиксированные параметры для некоторых графиков
|
414 |
+
strategy_col_rus = COLUMN_NAME_MAPPING.get('chunking_strategy')
|
415 |
+
params_str_col_rus = f"{COLUMN_NAME_MAPPING.get('strategy_params')}_str"
|
416 |
+
model_col_rus = COLUMN_NAME_MAPPING.get('model_name')
|
417 |
+
|
418 |
+
fixed_strategy_example = df[strategy_col_rus].unique()[0] if strategy_col_rus in df.columns and len(df[strategy_col_rus].unique()) > 0 else None
|
419 |
+
fixed_strategy_params_example = None
|
420 |
+
if fixed_strategy_example and params_str_col_rus in df.columns:
|
421 |
+
params_list = df[df[strategy_col_rus] == fixed_strategy_example][params_str_col_rus].unique()
|
422 |
+
if len(params_list) > 0:
|
423 |
+
fixed_strategy_params_example = params_list[0]
|
424 |
+
|
425 |
+
fixed_model_example = df[model_col_rus].unique()[0] if model_col_rus in df.columns and len(df[model_col_rus].unique()) > 0 else None
|
426 |
+
fixed_top_n_example = 20
|
427 |
+
|
428 |
+
print("--- Построение графиков ---")
|
429 |
+
|
430 |
+
# 1. Графики Metric vs Top-N
|
431 |
+
print("\n1. Зависимость метрик от Top-N:")
|
432 |
+
for metric_name_rus in existing_metrics_rus:
|
433 |
+
# Проверяем, что метрика не micro (у micro нет зависимости от top_n)
|
434 |
+
if 'Micro' in metric_name_rus:
|
435 |
+
continue
|
436 |
+
plot_metric_vs_top_n(
|
437 |
+
df, metric_name_rus,
|
438 |
+
fixed_strategy_example, fixed_strategy_params_example,
|
439 |
+
args.plots_dir
|
440 |
+
)
|
441 |
+
|
442 |
+
# 2. Графики Metric vs Chunking
|
443 |
+
print("\n2. Зависимость метрик от параметров чанкинга: [Пропущено - требует переосмысления]")
|
444 |
+
# plot_metric_vs_chunking(...) # Закомментировано
|
445 |
+
|
446 |
+
# 3. Графики сравнения Use Injection
|
447 |
+
print("\n3. Сравнение метрик с/без сборки контекста:")
|
448 |
+
for metric_name_rus in existing_metrics_rus:
|
449 |
+
plot_injection_comparison(df, metric_name_rus, args.plots_dir)
|
450 |
+
|
451 |
+
# 4. Графики сравнения других булевых флагов
|
452 |
+
boolean_flags_eng = ['process_tables', 'use_qe', 'neighbors_included']
|
453 |
+
print("\n4. Сравнение метрик в зависимости от булевых флагов:")
|
454 |
+
for flag_eng in boolean_flags_eng:
|
455 |
+
flag_rus = COLUMN_NAME_MAPPING.get(flag_eng)
|
456 |
+
if not flag_rus or flag_rus not in df.columns:
|
457 |
+
print(f" Пропуск сравнения для флага: '{flag_eng}' (колонка '{flag_rus}' не найдена)")
|
458 |
+
continue
|
459 |
+
print(f" Сравнение для флага: '{flag_rus}'")
|
460 |
+
for metric_name_rus in existing_metrics_rus:
|
461 |
+
plot_boolean_flag_comparison(df, metric_name_rus, flag_eng, args.plots_dir)
|
462 |
+
|
463 |
+
print("\n--- Построение графиков завершено ---")
|
464 |
+
|
465 |
+
if __name__ == "__main__":
|
466 |
+
main()
|
scripts/testing/run_pipelines.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Скрипт для запуска множества пайплайнов оценки (`pipeline.py`)
|
5 |
+
с различными комбинациями параметров.
|
6 |
+
|
7 |
+
Собирает команды для `pipeline.py` и запускает их последовательно,
|
8 |
+
логируя вывод каждого запуска.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import pathlib
|
15 |
+
import subprocess
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
from datetime import datetime
|
19 |
+
from itertools import product
|
20 |
+
from uuid import uuid4
|
21 |
+
|
22 |
+
# --- Конфигурация Экспериментов ---
|
23 |
+
|
24 |
+
# Модели для тестирования
|
25 |
+
MODELS_TO_TEST = [
|
26 |
+
# "intfloat/e5-base",
|
27 |
+
# "intfloat/e5-large",
|
28 |
+
"BAAI/bge-m3",
|
29 |
+
# "deepvk/USER-bge-m3"
|
30 |
+
# "ai-forever/FRIDA" # Требует --use-sentence-transformers
|
31 |
+
]
|
32 |
+
|
33 |
+
# Параметры чанкинга (слова / перекрытие)
|
34 |
+
CHUNKING_PARAMS = [
|
35 |
+
# Пример для стратегии "fixed_size"
|
36 |
+
{"strategy": "fixed_size", "params": {"words_per_chunk": 50, "overlap_words": 25}},
|
37 |
+
# {"strategy": "fixed_size", "params": {"words_per_chunk": 100, "overlap_words": 25}},
|
38 |
+
# {"strategy": "fixed_size", "params": {"words_per_chunk": 50, "overlap_words": 0}},
|
39 |
+
# TODO: Добавить другие стратегии и их параметры, если нужно
|
40 |
+
# {"strategy": "some_other_strategy", "params": {"param1": "value1"}}
|
41 |
+
]
|
42 |
+
|
43 |
+
# Значения Top-N для ретривера
|
44 |
+
TOP_N_VALUES = [20, 50, 100]
|
45 |
+
|
46 |
+
# Использовать ли сборку контекста (InjectionBuilder)
|
47 |
+
USE_INJECTION_OPTIONS = [False, True]
|
48 |
+
|
49 |
+
# Порог схожести для fuzzy сравнения (чанк/пункт)
|
50 |
+
SIMILARITY_THRESHOLDS = [0.7]
|
51 |
+
|
52 |
+
# Опции использования Query Expansion
|
53 |
+
USE_QE_OPTIONS = [False, True]
|
54 |
+
|
55 |
+
# Опции обработки таблиц
|
56 |
+
PROCESS_TABLES_OPTIONS = [True]
|
57 |
+
|
58 |
+
# Опции включения соседей
|
59 |
+
INCLUDE_NEIGHBORS_OPTIONS = [True]
|
60 |
+
|
61 |
+
# --- Настройки Скрипта ---
|
62 |
+
DEFAULT_LOG_DIR = "logs" # Директория для логов отдельных запусков pipeline.py
|
63 |
+
DEFAULT_INTERMEDIATE_DIR = "data/intermediate" # Куда pipeline.py сохраняет свои результаты
|
64 |
+
DEFAULT_PYTHON_EXECUTABLE = sys.executable # Использовать тот же python, что и для запуска этого скрипта
|
65 |
+
|
66 |
+
def parse_args():
|
67 |
+
"""Парсит аргументы командной строки."""
|
68 |
+
parser = argparse.ArgumentParser(description="Запуск серии оценочных пайплайнов")
|
69 |
+
|
70 |
+
# Флаги для пропуска определенных измерений
|
71 |
+
parser.add_argument("--skip-models", action="store_true",
|
72 |
+
help="Пропустить итерацию по разным моделям (использовать первую в списке)")
|
73 |
+
parser.add_argument("--skip-chunking", action="store_true",
|
74 |
+
help="Пропустить итерацию по разным параметрам чанкинга (использовать первую в списке)")
|
75 |
+
parser.add_argument("--skip-top-n", action="store_true",
|
76 |
+
help="Пропустить итерацию по разным top_n (использовать первое значение)")
|
77 |
+
parser.add_argument("--skip-injection", action="store_true",
|
78 |
+
help="Пропустить итерацию по опциям сборки контекста (использовать False)")
|
79 |
+
parser.add_argument("--skip-thresholds", action="store_true",
|
80 |
+
help="Пропустить итерацию по порогам схожести (использовать первый)")
|
81 |
+
parser.add_argument("--skip-process-tables", action="store_true",
|
82 |
+
help="Пропустить итерацию по обработке таблиц (использовать True)")
|
83 |
+
parser.add_argument("--skip-include-neighbors", action="store_true",
|
84 |
+
help="Пропустить итерацию по включению соседей (использовать False)")
|
85 |
+
parser.add_argument("--skip-qe", action="store_true",
|
86 |
+
help="Пропустить итерацию по использованию Query Expansion (использовать False)")
|
87 |
+
|
88 |
+
# Настройки путей и выполнения
|
89 |
+
parser.add_argument("--log-dir", type=str, default=DEFAULT_LOG_DIR,
|
90 |
+
help=f"Директория для сохранения логов запусков (по умолчанию: {DEFAULT_LOG_DIR})")
|
91 |
+
parser.add_argument("--intermediate-dir", type=str, default=DEFAULT_INTERMEDIATE_DIR,
|
92 |
+
help=f"Директория для промежуточных результатов pipeline.py (по умолчанию: {DEFAULT_INTERMEDIATE_DIR})")
|
93 |
+
parser.add_argument("--device", type=str, default="cuda:0",
|
94 |
+
help="Устройство для вычислений в pipeline.py (напр., cpu, cuda:0)")
|
95 |
+
parser.add_argument("--python-executable", type=str, default=DEFAULT_PYTHON_EXECUTABLE,
|
96 |
+
help="Путь к интерпретатору Python для запуска pipeline.py")
|
97 |
+
|
98 |
+
# Параметры, передаваемые в pipeline.py (если не перебираются)
|
99 |
+
parser.add_argument("--data-folder", type=str, default="data/input/docs", help="Папка с документами для pipeline.py")
|
100 |
+
parser.add_argument("--search-dataset-path", type=str, default="data/input/search_dataset_text.xlsx", help="Поисковый датасет для pipeline.py")
|
101 |
+
parser.add_argument("--qa-dataset-path", type=str, default="data/input/question_answering.xlsx", help="QA датасет для pipeline.py")
|
102 |
+
|
103 |
+
return parser.parse_args()
|
104 |
+
|
105 |
+
def run_single_pipeline(cmd: list[str], log_path: str):
|
106 |
+
"""
|
107 |
+
Запускает один экземпляр pipeline.py и логирует его вывод.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
cmd: Список аргументов команды для subprocess.
|
111 |
+
log_path: Путь к файлу для сохранения лога.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Код возврата процесса.
|
115 |
+
"""
|
116 |
+
print(f"\n--- Запуск: {' '.join(cmd)} ---")
|
117 |
+
print(f"--- Лог: {log_path} --- ")
|
118 |
+
|
119 |
+
start_time = time.time()
|
120 |
+
return_code = -1
|
121 |
+
|
122 |
+
try:
|
123 |
+
with open(log_path, "w", encoding="utf-8") as log_file:
|
124 |
+
log_file.write(f"Команда: {' '.join(cmd)}\n")
|
125 |
+
log_file.write(f"Время запуска: {datetime.now()}\n\n")
|
126 |
+
log_file.flush()
|
127 |
+
|
128 |
+
# Запускаем процесс
|
129 |
+
process = subprocess.Popen(
|
130 |
+
cmd,
|
131 |
+
stdout=subprocess.PIPE,
|
132 |
+
stderr=subprocess.STDOUT, # Перенаправляем stderr в stdout
|
133 |
+
text=True,
|
134 |
+
encoding='utf-8', # Указываем кодировку
|
135 |
+
errors='replace', # Заменяем ошибки кодирования
|
136 |
+
bufsize=1 # Построчная буферизация
|
137 |
+
)
|
138 |
+
|
139 |
+
# Читаем и пишем вывод построчно
|
140 |
+
for line in process.stdout:
|
141 |
+
print(line, end="") # Выводим в консоль
|
142 |
+
log_file.write(line) # Пишем в лог
|
143 |
+
log_file.flush()
|
144 |
+
|
145 |
+
# Ждем завершения и получаем код возврата
|
146 |
+
process.wait()
|
147 |
+
return_code = process.returncode
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
print(f"\nОшибка при запуске процесса: {e}")
|
151 |
+
with open(log_path, "a", encoding="utf-8") as log_file:
|
152 |
+
log_file.write(f"\nОшибка при запуске: {e}\n")
|
153 |
+
return_code = 1 # Считаем ошибкой
|
154 |
+
|
155 |
+
end_time = time.time()
|
156 |
+
duration = end_time - start_time
|
157 |
+
|
158 |
+
result_message = f"Успешно завершено за {duration:.2f} сек."
|
159 |
+
if return_code != 0:
|
160 |
+
result_message = f"Завершено с ошибкой (код {return_code}) за {duration:.2f} сек."
|
161 |
+
|
162 |
+
print(f"--- {result_message} ---")
|
163 |
+
with open(log_path, "a", encoding="utf-8") as log_file:
|
164 |
+
log_file.write(f"\nВремя завершения: {datetime.now()}")
|
165 |
+
log_file.write(f"\nДлительность: {duration:.2f} сек.")
|
166 |
+
log_file.write(f"\nКод возврата: {return_code}\n")
|
167 |
+
|
168 |
+
return return_code
|
169 |
+
|
170 |
+
def main():
|
171 |
+
"""Основная функция скрипта."""
|
172 |
+
args = parse_args()
|
173 |
+
|
174 |
+
# --- Генерируем ID для всей серии запусков ---
|
175 |
+
batch_run_id = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
176 |
+
print(f"Запуск серии экспериментов. Batch ID: {batch_run_id}")
|
177 |
+
|
178 |
+
# Создаем директории для логов и промежуточных результатов
|
179 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
180 |
+
os.makedirs(args.intermediate_dir, exist_ok=True)
|
181 |
+
|
182 |
+
# Определяем абсолютный путь к pipeline.py
|
183 |
+
RUN_PIPELINES_SCRIPT_PATH = pathlib.Path(__file__).resolve()
|
184 |
+
SCRIPTS_TESTING_DIR = RUN_PIPELINES_SCRIPT_PATH.parent
|
185 |
+
PIPELINE_SCRIPT_PATH = SCRIPTS_TESTING_DIR / "pipeline.py"
|
186 |
+
|
187 |
+
# --- Определяем параметры для перебора ---
|
188 |
+
models = [MODELS_TO_TEST[0]] if args.skip_models else MODELS_TO_TEST
|
189 |
+
chunking_configs = [CHUNKING_PARAMS[0]] if args.skip_chunking else CHUNKING_PARAMS
|
190 |
+
top_n_list = [TOP_N_VALUES[0]] if args.skip_top_n else TOP_N_VALUES
|
191 |
+
use_injection_list = [False] if args.skip_injection else USE_INJECTION_OPTIONS
|
192 |
+
threshold_list = [SIMILARITY_THRESHOLDS[0]] if args.skip_thresholds else SIMILARITY_THRESHOLDS
|
193 |
+
|
194 |
+
# Определяем списки для новых измерений
|
195 |
+
process_tables_list = [PROCESS_TABLES_OPTIONS[0]] if args.skip_process_tables else PROCESS_TABLES_OPTIONS
|
196 |
+
include_neighbors_list = [INCLUDE_NEIGHBORS_OPTIONS[0]] if args.skip_include_neighbors else INCLUDE_NEIGHBORS_OPTIONS
|
197 |
+
use_qe_list = [USE_QE_OPTIONS[0]] if args.skip_qe else USE_QE_OPTIONS
|
198 |
+
|
199 |
+
# --- Создаем список всех комбинаций параметров ---
|
200 |
+
parameter_combinations = list(product(
|
201 |
+
models,
|
202 |
+
chunking_configs,
|
203 |
+
top_n_list,
|
204 |
+
use_injection_list,
|
205 |
+
threshold_list,
|
206 |
+
process_tables_list,
|
207 |
+
include_neighbors_list,
|
208 |
+
use_qe_list
|
209 |
+
))
|
210 |
+
|
211 |
+
total_runs = len(parameter_combinations)
|
212 |
+
print(f"Всего запланировано запусков: {total_runs}")
|
213 |
+
|
214 |
+
# --- Запускаем пайплайны для каждой комбинации ---
|
215 |
+
completed_runs = 0
|
216 |
+
failed_runs = 0
|
217 |
+
start_time_all = time.time()
|
218 |
+
|
219 |
+
for i, (model, chunk_cfg, top_n, use_injection, threshold, process_tables, include_neighbors, use_qe) in enumerate(parameter_combinations):
|
220 |
+
print(f"\n{'='*80}")
|
221 |
+
print(f"Запуск {i+1}/{total_runs}")
|
222 |
+
print(f" Модель: {model}")
|
223 |
+
# Логируем параметры чанкинга
|
224 |
+
strategy = chunk_cfg['strategy']
|
225 |
+
params = chunk_cfg['params']
|
226 |
+
params_str = json.dumps(params, ensure_ascii=False)
|
227 |
+
print(f" Чанкинг: Стратегия='{strategy}', Параметры={params_str}")
|
228 |
+
print(f" Обработка таблиц: {process_tables}")
|
229 |
+
print(f" Top-N: {top_n}")
|
230 |
+
print(f" Сборка контекста: {use_injection}")
|
231 |
+
print(f" Query Expansion: {use_qe}")
|
232 |
+
print(f" Включение соседей: {include_neighbors}")
|
233 |
+
print(f" Порог схожести: {threshold}")
|
234 |
+
print(f"{'='*80}")
|
235 |
+
|
236 |
+
# Генерируем уникальный ID для этого запуска
|
237 |
+
run_id = f"run_{datetime.now().strftime('%Y%m%d%H%M%S')}_{uuid4().hex[:8]}"
|
238 |
+
|
239 |
+
# Формируем команду для pipeline.py
|
240 |
+
cmd = [
|
241 |
+
args.python_executable,
|
242 |
+
str(PIPELINE_SCRIPT_PATH), # Используем абсолютный путь
|
243 |
+
"--run-id", run_id,
|
244 |
+
"--batch-id", batch_run_id,
|
245 |
+
"--data-folder", args.data_folder,
|
246 |
+
"--search-dataset-path", args.search_dataset_path,
|
247 |
+
"--output-dir", args.intermediate_dir,
|
248 |
+
"--model-name", model,
|
249 |
+
"--chunking-strategy", strategy,
|
250 |
+
"--strategy-params", params_str,
|
251 |
+
"--top-n", str(top_n),
|
252 |
+
"--similarity-threshold", str(threshold),
|
253 |
+
"--device", args.device,
|
254 |
+
]
|
255 |
+
|
256 |
+
# Добавляем флаг --use-injection, если нужно
|
257 |
+
if use_injection:
|
258 |
+
cmd.append("--use-injection")
|
259 |
+
|
260 |
+
# Добавляем флаг --no-process-tables, если process_tables == False
|
261 |
+
if not process_tables:
|
262 |
+
cmd.append("--no-process-tables")
|
263 |
+
|
264 |
+
# Добавляем флаг --include-neighbors, если include_neighbors == True
|
265 |
+
if include_neighbors:
|
266 |
+
cmd.append("--include-neighbors")
|
267 |
+
|
268 |
+
# Добавляем флаг --use-qe, если use_qe == True
|
269 |
+
if use_qe:
|
270 |
+
cmd.append("--use-qe")
|
271 |
+
|
272 |
+
# Добавляем флаг --use-sentence-transformers для определенных моделей
|
273 |
+
if "FRIDA" in model or "sentence-transformer" in model.lower(): # Пример
|
274 |
+
cmd.append("--use-sentence-transformers")
|
275 |
+
|
276 |
+
# Формируем путь к лог-файлу
|
277 |
+
log_filename = f"{run_id}_log.txt"
|
278 |
+
log_path = os.path.join(args.log_dir, log_filename)
|
279 |
+
|
280 |
+
# Запускаем пайплайн
|
281 |
+
return_code = run_single_pipeline(cmd, log_path)
|
282 |
+
|
283 |
+
if return_code == 0:
|
284 |
+
completed_runs += 1
|
285 |
+
else:
|
286 |
+
failed_runs += 1
|
287 |
+
print(f"*** ВНИМАНИЕ: Запуск {i+1} завершился с ошибкой! Лог: {log_path} ***")
|
288 |
+
|
289 |
+
# --- Вывод итоговой статистики ---
|
290 |
+
end_time_all = time.time()
|
291 |
+
total_duration = end_time_all - start_time_all
|
292 |
+
|
293 |
+
print(f"\n{'='*80}")
|
294 |
+
print("Все запуски завершены.")
|
295 |
+
print(f"Общее время выполнения: {total_duration:.2f} сек ({total_duration/60:.2f} мин)")
|
296 |
+
print(f"Всего запусков: {total_runs}")
|
297 |
+
print(f"Успешно завершено: {completed_runs}")
|
298 |
+
print(f"Завершено с ошибками: {failed_runs}")
|
299 |
+
print(f"Промежуточные результаты сохранены в: {args.intermediate_dir}")
|
300 |
+
print(f"Логи запусков сохранены в: {args.log_dir}")
|
301 |
+
print(f"{'='*80}")
|
302 |
+
|
303 |
+
if __name__ == "__main__":
|
304 |
+
main()
|