Spaces:
Running
on
T4
Running
on
T4
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +5 -2
- common/db.py +1 -2
- common/dependencies.py +59 -38
- components/dbo/chunk_repository.py +249 -0
- components/dbo/models/dataset.py +6 -1
- components/dbo/models/entity.py +85 -0
- components/embedding_extraction.py +11 -8
- components/nmd/faiss_vector_search.py +25 -14
- components/services/dataset.py +85 -137
- components/services/document.py +12 -12
- components/services/entity.py +210 -0
- lib/extractor/.cursor/rules/project-description.mdc +86 -0
- lib/extractor/.gitignore +11 -0
- lib/extractor/README.md +60 -0
- lib/extractor/docs/architecture.puml +149 -0
- lib/extractor/ntr_text_fragmentation/__init__.py +19 -0
- lib/extractor/ntr_text_fragmentation/additors/__init__.py +10 -0
- lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py +5 -0
- lib/extractor/ntr_text_fragmentation/additors/tables/table_entity.py +74 -0
- lib/extractor/ntr_text_fragmentation/additors/tables_processor.py +117 -0
- lib/extractor/ntr_text_fragmentation/chunking/__init__.py +11 -0
- lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py +86 -0
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py +11 -0
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/__init__.py +9 -0
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py +143 -0
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py +568 -0
- lib/extractor/ntr_text_fragmentation/core/__init__.py +9 -0
- lib/extractor/ntr_text_fragmentation/core/destructurer.py +143 -0
- lib/extractor/ntr_text_fragmentation/core/entity_repository.py +258 -0
- lib/extractor/ntr_text_fragmentation/core/injection_builder.py +429 -0
- lib/extractor/ntr_text_fragmentation/integrations/__init__.py +9 -0
- lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy_repository.py +339 -0
- lib/extractor/ntr_text_fragmentation/models/__init__.py +13 -0
- lib/extractor/ntr_text_fragmentation/models/chunk.py +48 -0
- lib/extractor/ntr_text_fragmentation/models/document.py +49 -0
- lib/extractor/ntr_text_fragmentation/models/linker_entity.py +217 -0
- lib/extractor/pyproject.toml +26 -0
- lib/extractor/scripts/README_test_chunking.md +107 -0
- lib/extractor/scripts/analyze_missing_puncts.py +547 -0
- lib/extractor/scripts/combine_results.py +1352 -0
- lib/extractor/scripts/debug_question_chunks.py +392 -0
- lib/extractor/scripts/evaluate_chunking.py +800 -0
- lib/extractor/scripts/plot_macro_metrics.py +348 -0
- lib/extractor/scripts/prepare_dataset.py +578 -0
- lib/extractor/scripts/run_chunking_experiments.sh +156 -0
- lib/extractor/scripts/run_experiments.py +206 -0
- lib/extractor/scripts/search_api.py +748 -0
- lib/extractor/scripts/test_chunking_visualization.py +235 -0
- lib/extractor/tests/__init__.py +3 -0
- lib/extractor/tests/chunking/__init__.py +3 -0
Dockerfile
CHANGED
@@ -30,13 +30,16 @@ RUN python -m pip install \
|
|
30 |
torch==2.6.0+cu126 \
|
31 |
--index-url https://download.pytorch.org/whl/cu126
|
32 |
|
|
|
33 |
COPY requirements.txt /app/
|
34 |
RUN python -m pip install -r requirements.txt
|
35 |
-
# RUN python -m pip install --ignore-installed elasticsearch==7.11.0 || true
|
36 |
|
37 |
COPY . .
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
|
41 |
|
42 |
EXPOSE ${PORT}
|
|
|
30 |
torch==2.6.0+cu126 \
|
31 |
--index-url https://download.pytorch.org/whl/cu126
|
32 |
|
33 |
+
|
34 |
COPY requirements.txt /app/
|
35 |
RUN python -m pip install -r requirements.txt
|
|
|
36 |
|
37 |
COPY . .
|
38 |
+
RUN python -m pip install -e ./lib/parser
|
39 |
+
RUN python -m pip install --no-deps -e ./lib/extractor
|
40 |
+
# RUN python -m pip install --ignore-installed elasticsearch==7.11.0 || true
|
41 |
|
42 |
+
RUN mkdir -p /data/regulation_datasets /data/documents /logs
|
43 |
|
44 |
|
45 |
EXPOSE ${PORT}
|
common/db.py
CHANGED
@@ -16,13 +16,12 @@ import components.dbo.models.document
|
|
16 |
import components.dbo.models.log
|
17 |
import components.dbo.models.llm_prompt
|
18 |
import components.dbo.models.llm_config
|
19 |
-
|
20 |
|
21 |
CONFIG_PATH = os.environ.get('CONFIG_PATH', './config_dev.yaml')
|
22 |
config = Configuration(CONFIG_PATH)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
-
print("sql url:", config.common_config.log_sql_path)
|
26 |
engine = create_engine(config.common_config.log_sql_path, connect_args={'check_same_thread': False})
|
27 |
|
28 |
session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
16 |
import components.dbo.models.log
|
17 |
import components.dbo.models.llm_prompt
|
18 |
import components.dbo.models.llm_config
|
19 |
+
import components.dbo.models.entity
|
20 |
|
21 |
CONFIG_PATH = os.environ.get('CONFIG_PATH', './config_dev.yaml')
|
22 |
config = Configuration(CONFIG_PATH)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
|
|
25 |
engine = create_engine(config.common_config.log_sql_path, connect_args={'check_same_thread': False})
|
26 |
|
27 |
session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
common/dependencies.py
CHANGED
@@ -1,21 +1,22 @@
|
|
1 |
import logging
|
2 |
-
from logging import Logger
|
3 |
import os
|
|
|
|
|
|
|
4 |
from fastapi import Depends
|
|
|
|
|
5 |
|
6 |
from common.configuration import Configuration
|
|
|
|
|
|
|
7 |
from components.llm.common import LlmParams
|
8 |
from components.llm.deepinfra_api import DeepInfraApi
|
9 |
from components.services.dataset import DatasetService
|
10 |
-
from components.embedding_extraction import EmbeddingExtractor
|
11 |
-
from components.datasets.dispatcher import Dispatcher
|
12 |
from components.services.document import DocumentService
|
13 |
-
from components.services.
|
14 |
from components.services.llm_config import LLMConfigService
|
15 |
-
|
16 |
-
from typing import Annotated
|
17 |
-
from sqlalchemy.orm import sessionmaker, Session
|
18 |
-
from common.db import session_factory
|
19 |
from components.services.llm_prompt import LlmPromptService
|
20 |
|
21 |
|
@@ -28,56 +29,76 @@ def get_db() -> sessionmaker:
|
|
28 |
|
29 |
|
30 |
def get_logger() -> Logger:
|
31 |
-
return logging.getLogger(__name__)
|
32 |
|
33 |
|
34 |
-
def get_embedding_extractor(
|
|
|
|
|
35 |
return EmbeddingExtractor(
|
36 |
config.db_config.faiss.model_embedding_path,
|
37 |
config.db_config.faiss.device,
|
38 |
)
|
39 |
|
40 |
|
41 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
|
|
|
43 |
config: Annotated[Configuration, Depends(get_config)],
|
44 |
-
|
45 |
-
|
46 |
-
return
|
47 |
|
48 |
-
def get_dispatcher(vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
|
49 |
-
config: Annotated[Configuration, Depends(get_config)],
|
50 |
-
logger: Annotated[Logger, Depends(get_logger)],
|
51 |
-
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)]) -> Dispatcher:
|
52 |
-
return Dispatcher(vectorizer, config, logger, dataset_service)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
-
def get_document_service(
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
return DocumentService(dataset_service, config, db)
|
63 |
|
64 |
|
65 |
def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
|
66 |
return LLMConfigService(db)
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
80 |
return DeepInfraApi(params=llm_params)
|
81 |
|
|
|
82 |
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
|
83 |
-
return LlmPromptService(db)
|
|
|
1 |
import logging
|
|
|
2 |
import os
|
3 |
+
from logging import Logger
|
4 |
+
from typing import Annotated
|
5 |
+
|
6 |
from fastapi import Depends
|
7 |
+
from ntr_text_fragmentation import InjectionBuilder
|
8 |
+
from sqlalchemy.orm import Session, sessionmaker
|
9 |
|
10 |
from common.configuration import Configuration
|
11 |
+
from common.db import session_factory
|
12 |
+
from components.dbo.chunk_repository import ChunkRepository
|
13 |
+
from components.embedding_extraction import EmbeddingExtractor
|
14 |
from components.llm.common import LlmParams
|
15 |
from components.llm.deepinfra_api import DeepInfraApi
|
16 |
from components.services.dataset import DatasetService
|
|
|
|
|
17 |
from components.services.document import DocumentService
|
18 |
+
from components.services.entity import EntityService
|
19 |
from components.services.llm_config import LLMConfigService
|
|
|
|
|
|
|
|
|
20 |
from components.services.llm_prompt import LlmPromptService
|
21 |
|
22 |
|
|
|
29 |
|
30 |
|
31 |
def get_logger() -> Logger:
|
32 |
+
return logging.getLogger(__name__)
|
33 |
|
34 |
|
35 |
+
def get_embedding_extractor(
|
36 |
+
config: Annotated[Configuration, Depends(get_config)],
|
37 |
+
) -> EmbeddingExtractor:
|
38 |
return EmbeddingExtractor(
|
39 |
config.db_config.faiss.model_embedding_path,
|
40 |
config.db_config.faiss.device,
|
41 |
)
|
42 |
|
43 |
|
44 |
+
def get_chunk_repository(db: Annotated[Session, Depends(get_db)]) -> ChunkRepository:
|
45 |
+
return ChunkRepository(db)
|
46 |
+
|
47 |
+
|
48 |
+
def get_injection_builder(
|
49 |
+
chunk_repository: Annotated[ChunkRepository, Depends(get_chunk_repository)],
|
50 |
+
) -> InjectionBuilder:
|
51 |
+
return InjectionBuilder(chunk_repository)
|
52 |
+
|
53 |
+
|
54 |
+
def get_entity_service(
|
55 |
vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
|
56 |
+
chunk_repository: Annotated[ChunkRepository, Depends(get_chunk_repository)],
|
57 |
config: Annotated[Configuration, Depends(get_config)],
|
58 |
+
) -> EntityService:
|
59 |
+
"""Получение сервиса для работы с сущностями через DI."""
|
60 |
+
return EntityService(vectorizer, chunk_repository, config)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
def get_dataset_service(
|
64 |
+
entity_service: Annotated[EntityService, Depends(get_entity_service)],
|
65 |
+
config: Annotated[Configuration, Depends(get_config)],
|
66 |
+
db: Annotated[sessionmaker, Depends(get_db)],
|
67 |
+
) -> DatasetService:
|
68 |
+
"""Получение сервиса для работы с датасетами через DI."""
|
69 |
+
return DatasetService(entity_service, config, db)
|
70 |
|
71 |
|
72 |
+
def get_document_service(
|
73 |
+
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
|
74 |
+
config: Annotated[Configuration, Depends(get_config)],
|
75 |
+
db: Annotated[sessionmaker, Depends(get_db)],
|
76 |
+
) -> DocumentService:
|
77 |
return DocumentService(dataset_service, config, db)
|
78 |
|
79 |
|
80 |
def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
|
81 |
return LLMConfigService(db)
|
82 |
|
83 |
+
|
84 |
+
def get_llm_service(
|
85 |
+
config: Annotated[Configuration, Depends(get_config)],
|
86 |
+
) -> DeepInfraApi:
|
87 |
+
|
88 |
+
llm_params = LlmParams(
|
89 |
+
**{
|
90 |
+
"url": config.llm_config.base_url,
|
91 |
+
"model": config.llm_config.model,
|
92 |
+
"tokenizer": config.llm_config.tokenizer,
|
93 |
+
"type": "deepinfra",
|
94 |
+
"default": True,
|
95 |
+
"predict_params": None, # должны задаваться при каждом запросе
|
96 |
+
"api_key": os.environ.get(config.llm_config.api_key_env),
|
97 |
+
"context_length": 128000,
|
98 |
+
}
|
99 |
+
)
|
100 |
return DeepInfraApi(params=llm_params)
|
101 |
|
102 |
+
|
103 |
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
|
104 |
+
return LlmPromptService(db)
|
components/dbo/chunk_repository.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from uuid import UUID
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from ntr_text_fragmentation import LinkerEntity
|
5 |
+
from ntr_text_fragmentation.integrations import SQLAlchemyEntityRepository
|
6 |
+
from sqlalchemy import and_, select
|
7 |
+
from sqlalchemy.orm import Session
|
8 |
+
|
9 |
+
from components.dbo.models.entity import EntityModel
|
10 |
+
|
11 |
+
|
12 |
+
class ChunkRepository(SQLAlchemyEntityRepository):
|
13 |
+
def __init__(self, db: Session):
|
14 |
+
super().__init__(db)
|
15 |
+
|
16 |
+
def _entity_model_class(self):
|
17 |
+
return EntityModel
|
18 |
+
|
19 |
+
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel):
|
20 |
+
"""
|
21 |
+
Преобразует сущность из базы данных в LinkerEntity.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
db_entity: Сущность из базы данных
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
LinkerEntity
|
28 |
+
"""
|
29 |
+
# Преобразуем строковые ID в UUID
|
30 |
+
entity = LinkerEntity(
|
31 |
+
id=UUID(db_entity.uuid), # Преобразуем строку в UUID
|
32 |
+
name=db_entity.name,
|
33 |
+
text=db_entity.text,
|
34 |
+
type=db_entity.entity_type,
|
35 |
+
in_search_text=db_entity.in_search_text,
|
36 |
+
metadata=db_entity.metadata_json,
|
37 |
+
source_id=UUID(db_entity.source_id) if db_entity.source_id else None, # Преобразуем строку в UUID
|
38 |
+
target_id=UUID(db_entity.target_id) if db_entity.target_id else None, # Преобразуем строку в UUID
|
39 |
+
number_in_relation=db_entity.number_in_relation,
|
40 |
+
)
|
41 |
+
return LinkerEntity.deserialize(entity)
|
42 |
+
|
43 |
+
def add_entities(
|
44 |
+
self,
|
45 |
+
entities: list[LinkerEntity],
|
46 |
+
dataset_id: int,
|
47 |
+
embeddings: dict[str, np.ndarray],
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Добавляет сущности в базу данных.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
entities: Список сущностей для добавления
|
54 |
+
dataset_id: ID датасета
|
55 |
+
embeddings: Словарь эмбеддингов {entity_id: embedding}
|
56 |
+
"""
|
57 |
+
with self.db() as session:
|
58 |
+
for entity in entities:
|
59 |
+
# Преобразуем UUID в строку для хранения в базе
|
60 |
+
entity_id = str(entity.id)
|
61 |
+
|
62 |
+
if entity_id in embeddings:
|
63 |
+
embedding = embeddings[entity_id]
|
64 |
+
else:
|
65 |
+
embedding = None
|
66 |
+
|
67 |
+
session.add(
|
68 |
+
EntityModel(
|
69 |
+
uuid=str(entity.id), # UUID в строку
|
70 |
+
name=entity.name,
|
71 |
+
text=entity.text,
|
72 |
+
entity_type=entity.type,
|
73 |
+
in_search_text=entity.in_search_text,
|
74 |
+
metadata_json=entity.metadata,
|
75 |
+
source_id=str(entity.source_id) if entity.source_id else None, # UUID в строку
|
76 |
+
target_id=str(entity.target_id) if entity.target_id else None, # UUID в строку
|
77 |
+
number_in_relation=entity.number_in_relation,
|
78 |
+
chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index
|
79 |
+
dataset_id=dataset_id,
|
80 |
+
embedding=embedding,
|
81 |
+
)
|
82 |
+
)
|
83 |
+
|
84 |
+
session.commit()
|
85 |
+
|
86 |
+
def get_searching_entities(
|
87 |
+
self,
|
88 |
+
dataset_id: int,
|
89 |
+
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
|
90 |
+
with self.db() as session:
|
91 |
+
models = (
|
92 |
+
session.query(EntityModel)
|
93 |
+
.filter(EntityModel.in_search_text is not None)
|
94 |
+
.filter(EntityModel.dataset_id == dataset_id)
|
95 |
+
.all()
|
96 |
+
)
|
97 |
+
return (
|
98 |
+
[self._map_db_entity_to_linker_entity(model) for model in models],
|
99 |
+
[model.embedding for model in models],
|
100 |
+
)
|
101 |
+
|
102 |
+
def get_chunks_by_ids(
|
103 |
+
self,
|
104 |
+
chunk_ids: list[str],
|
105 |
+
) -> list[LinkerEntity]:
|
106 |
+
"""
|
107 |
+
Получение чанков по их ID.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
chunk_ids: Список ID чанков
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Список чанков
|
114 |
+
"""
|
115 |
+
# Преобразуем все ID в строки для единообразия
|
116 |
+
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
|
117 |
+
|
118 |
+
with self.db() as session:
|
119 |
+
models = (
|
120 |
+
session.query(EntityModel)
|
121 |
+
.filter(EntityModel.uuid.in_(str_chunk_ids))
|
122 |
+
.all()
|
123 |
+
)
|
124 |
+
return [self._map_db_entity_to_linker_entity(model) for model in models]
|
125 |
+
|
126 |
+
def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[LinkerEntity]:
|
127 |
+
"""
|
128 |
+
Получить сущности по списку идентификаторов.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
entity_ids: Список идентифи��аторов сущностей
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Список сущностей, соответствующих указанным идентификаторам
|
135 |
+
"""
|
136 |
+
if not entity_ids:
|
137 |
+
return []
|
138 |
+
|
139 |
+
# Преобразуем UUID в строки
|
140 |
+
str_entity_ids = [str(entity_id) for entity_id in entity_ids]
|
141 |
+
|
142 |
+
with self.db() as session:
|
143 |
+
entity_model = self._entity_model_class()
|
144 |
+
db_entities = session.execute(
|
145 |
+
select(entity_model).where(entity_model.uuid.in_(str_entity_ids))
|
146 |
+
).scalars().all()
|
147 |
+
|
148 |
+
return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
|
149 |
+
|
150 |
+
def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]:
|
151 |
+
"""
|
152 |
+
Получить соседние чанки для указанных чанков.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
chunk_ids: Список идентификаторов чанков
|
156 |
+
max_distance: Максимальное расстояние до соседа
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
Список соседних чанков
|
160 |
+
"""
|
161 |
+
if not chunk_ids:
|
162 |
+
return []
|
163 |
+
|
164 |
+
# Преобразуем UUID в строки
|
165 |
+
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
|
166 |
+
|
167 |
+
with self.db() as session:
|
168 |
+
entity_model = self._entity_model_class()
|
169 |
+
result = []
|
170 |
+
|
171 |
+
# Сначала получаем указанные чанки, чтобы узнать их индексы и документы
|
172 |
+
chunks = session.execute(
|
173 |
+
select(entity_model).where(
|
174 |
+
and_(
|
175 |
+
entity_model.uuid.in_(str_chunk_ids),
|
176 |
+
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
|
177 |
+
)
|
178 |
+
)
|
179 |
+
).scalars().all()
|
180 |
+
|
181 |
+
if not chunks:
|
182 |
+
return []
|
183 |
+
|
184 |
+
# Находим документы для чанков через связи
|
185 |
+
doc_ids = set()
|
186 |
+
chunk_indices = {}
|
187 |
+
|
188 |
+
for chunk in chunks:
|
189 |
+
chunk_indices[chunk.uuid] = chunk.chunk_index
|
190 |
+
|
191 |
+
# Находим связь от документа к чанку
|
192 |
+
links = session.execute(
|
193 |
+
select(entity_model).where(
|
194 |
+
and_(
|
195 |
+
entity_model.target_id == chunk.uuid,
|
196 |
+
entity_model.name == "document_to_chunk"
|
197 |
+
)
|
198 |
+
)
|
199 |
+
).scalars().all()
|
200 |
+
|
201 |
+
for link in links:
|
202 |
+
doc_ids.add(link.source_id)
|
203 |
+
|
204 |
+
if not doc_ids or not any(idx is not None for idx in chunk_indices.values()):
|
205 |
+
return []
|
206 |
+
|
207 |
+
# Для каждого документа находим все его чанки
|
208 |
+
for doc_id in doc_ids:
|
209 |
+
# Находим все связи от документа к чанкам
|
210 |
+
links = session.execute(
|
211 |
+
select(entity_model).where(
|
212 |
+
and_(
|
213 |
+
entity_model.source_id == doc_id,
|
214 |
+
entity_model.name == "document_to_chunk"
|
215 |
+
)
|
216 |
+
)
|
217 |
+
).scalars().all()
|
218 |
+
|
219 |
+
doc_chunk_ids = [link.target_id for link in links]
|
220 |
+
|
221 |
+
# Получаем все чанки документа
|
222 |
+
doc_chunks = session.execute(
|
223 |
+
select(entity_model).where(
|
224 |
+
and_(
|
225 |
+
entity_model.uuid.in_(doc_chunk_ids),
|
226 |
+
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
|
227 |
+
)
|
228 |
+
)
|
229 |
+
).scalars().all()
|
230 |
+
|
231 |
+
# Для каждого чанка в документе проверяем, является ли он соседом
|
232 |
+
for doc_chunk in doc_chunks:
|
233 |
+
if doc_chunk.uuid in str_chunk_ids:
|
234 |
+
continue
|
235 |
+
|
236 |
+
if doc_chunk.chunk_index is None:
|
237 |
+
continue
|
238 |
+
|
239 |
+
# Проверяем, является ли чанк соседом какого-либо из исходных чанков
|
240 |
+
is_neighbor = False
|
241 |
+
for orig_chunk_id, orig_index in chunk_indices.items():
|
242 |
+
if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance:
|
243 |
+
is_neighbor = True
|
244 |
+
break
|
245 |
+
|
246 |
+
if is_neighbor:
|
247 |
+
result.append(self._map_db_entity_to_linker_entity(doc_chunk))
|
248 |
+
|
249 |
+
return result
|
components/dbo/models/dataset.py
CHANGED
@@ -23,4 +23,9 @@ class Dataset(Base):
|
|
23 |
documents: Mapped[list["DatasetDocument"]] = relationship(
|
24 |
"DatasetDocument", back_populates="dataset",
|
25 |
cascade="all, delete-orphan"
|
26 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
23 |
documents: Mapped[list["DatasetDocument"]] = relationship(
|
24 |
"DatasetDocument", back_populates="dataset",
|
25 |
cascade="all, delete-orphan"
|
26 |
+
)
|
27 |
+
|
28 |
+
entities: Mapped[list["EntityModel"]] = relationship(
|
29 |
+
"EntityModel", back_populates="dataset",
|
30 |
+
cascade="all, delete-orphan"
|
31 |
+
)
|
components/dbo/models/entity.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from sqlalchemy import ForeignKey, Integer, LargeBinary, String
|
5 |
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
6 |
+
from sqlalchemy.types import TypeDecorator
|
7 |
+
|
8 |
+
from components.dbo.models.base import Base
|
9 |
+
|
10 |
+
|
11 |
+
class JSONType(TypeDecorator):
|
12 |
+
"""Тип для хранения JSON в SQLite."""
|
13 |
+
|
14 |
+
impl = String
|
15 |
+
cache_ok = True
|
16 |
+
|
17 |
+
def process_bind_param(self, value, dialect):
|
18 |
+
"""Сохранение dict в JSON строку."""
|
19 |
+
if value is None:
|
20 |
+
return None
|
21 |
+
return json.dumps(value)
|
22 |
+
|
23 |
+
def process_result_value(self, value, dialect):
|
24 |
+
"""Загрузка JSON строки в dict."""
|
25 |
+
if value is None:
|
26 |
+
return None
|
27 |
+
return json.loads(value)
|
28 |
+
|
29 |
+
|
30 |
+
class EmbeddingType(TypeDecorator):
|
31 |
+
"""Тип для хранения эмбеддингов в SQLite."""
|
32 |
+
|
33 |
+
impl = LargeBinary
|
34 |
+
cache_ok = True
|
35 |
+
|
36 |
+
def process_bind_param(self, value, dialect):
|
37 |
+
"""Сохранение numpy array в базу."""
|
38 |
+
if value is None:
|
39 |
+
return None
|
40 |
+
# Убеждаемся, что массив двумерный перед сохранением
|
41 |
+
value = np.asarray(value, dtype=np.float32)
|
42 |
+
if value.ndim == 1:
|
43 |
+
value = value.reshape(1, -1)
|
44 |
+
return value.tobytes()
|
45 |
+
|
46 |
+
def process_result_value(self, value, dialect):
|
47 |
+
"""Загрузка из базы в numpy array."""
|
48 |
+
if value is None:
|
49 |
+
return None
|
50 |
+
return np.frombuffer(value, dtype=np.float32)
|
51 |
+
|
52 |
+
|
53 |
+
class EntityModel(Base):
|
54 |
+
"""
|
55 |
+
SQLAlchemy модель для хранения сущностей.
|
56 |
+
"""
|
57 |
+
|
58 |
+
__tablename__ = "entity"
|
59 |
+
|
60 |
+
uuid: Mapped[str] = mapped_column(String, unique=True)
|
61 |
+
name: Mapped[str] = mapped_column(String, nullable=False)
|
62 |
+
text: Mapped[str] = mapped_column(String, nullable=False)
|
63 |
+
in_search_text: Mapped[str] = mapped_column(String, nullable=True)
|
64 |
+
entity_type: Mapped[str] = mapped_column(String, nullable=False)
|
65 |
+
|
66 |
+
# Поля для связей (триплетный подход)
|
67 |
+
source_id: Mapped[str] = mapped_column(String, nullable=True)
|
68 |
+
target_id: Mapped[str] = mapped_column(String, nullable=True)
|
69 |
+
number_in_relation: Mapped[int] = mapped_column(Integer, nullable=True)
|
70 |
+
|
71 |
+
# Поле для индекса чанка в документе
|
72 |
+
chunk_index: Mapped[int] = mapped_column(Integer, nullable=True)
|
73 |
+
|
74 |
+
# JSON-поле для хранения метаданных
|
75 |
+
metadata_json: Mapped[dict] = mapped_column(JSONType, nullable=True)
|
76 |
+
|
77 |
+
embedding: Mapped[np.ndarray] = mapped_column(EmbeddingType, nullable=True)
|
78 |
+
|
79 |
+
dataset_id: Mapped[int] = mapped_column(Integer, ForeignKey("dataset.id"), nullable=False)
|
80 |
+
|
81 |
+
dataset: Mapped["Dataset"] = relationship( # type: ignore
|
82 |
+
"Dataset",
|
83 |
+
back_populates="entities",
|
84 |
+
cascade="all",
|
85 |
+
)
|
components/embedding_extraction.py
CHANGED
@@ -5,10 +5,10 @@ import numpy as np
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import DataLoader
|
8 |
-
from transformers import AutoModel, AutoTokenizer, BatchEncoding,
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
@@ -41,8 +41,8 @@ class EmbeddingExtractor:
|
|
41 |
|
42 |
self.device = device
|
43 |
# Инициализация модели
|
44 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
45 |
-
self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id).to(
|
46 |
self.device
|
47 |
)
|
48 |
self.model.eval()
|
@@ -122,7 +122,6 @@ class EmbeddingExtractor:
|
|
122 |
|
123 |
return embedding.cpu().numpy()
|
124 |
|
125 |
-
# TODO: В будущем стоит объединить vectorize и query_embed_extraction
|
126 |
def vectorize(
|
127 |
self,
|
128 |
texts: list[str] | str,
|
@@ -162,7 +161,11 @@ class EmbeddingExtractor:
|
|
162 |
|
163 |
logger.info('Vectorized all %d batches', len(embeddings))
|
164 |
|
165 |
-
|
|
|
|
|
|
|
|
|
166 |
|
167 |
@torch.no_grad()
|
168 |
def _vectorize_batch(
|
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import DataLoader
|
8 |
+
from transformers import (AutoModel, AutoTokenizer, BatchEncoding,
|
9 |
+
XLMRobertaModel)
|
10 |
+
from transformers.modeling_outputs import \
|
11 |
+
BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
41 |
|
42 |
self.device = device
|
43 |
# Инициализация модели
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
45 |
+
self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id, local_files_only=True).to(
|
46 |
self.device
|
47 |
)
|
48 |
self.model.eval()
|
|
|
122 |
|
123 |
return embedding.cpu().numpy()
|
124 |
|
|
|
125 |
def vectorize(
|
126 |
self,
|
127 |
texts: list[str] | str,
|
|
|
161 |
|
162 |
logger.info('Vectorized all %d batches', len(embeddings))
|
163 |
|
164 |
+
result = torch.cat(embeddings).numpy()
|
165 |
+
# Всегда возвращаем двумерный массив
|
166 |
+
if result.ndim == 1:
|
167 |
+
result = result.reshape(1, -1)
|
168 |
+
return result
|
169 |
|
170 |
@torch.no_grad()
|
171 |
def _vectorize_batch(
|
components/nmd/faiss_vector_search.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
import logging
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import pandas as pd
|
5 |
import faiss
|
|
|
6 |
|
7 |
-
from common.constants import COLUMN_EMBEDDING
|
8 |
-
from common.constants import DO_NORMALIZATION
|
9 |
from common.configuration import DataBaseConfiguration
|
|
|
10 |
from components.embedding_extraction import EmbeddingExtractor
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
@@ -14,7 +12,10 @@ logger = logging.getLogger(__name__)
|
|
14 |
|
15 |
class FaissVectorSearch:
|
16 |
def __init__(
|
17 |
-
self,
|
|
|
|
|
|
|
18 |
):
|
19 |
self.model = model
|
20 |
self.config = config
|
@@ -23,26 +24,36 @@ class FaissVectorSearch:
|
|
23 |
self.k_neighbors = config.ranker.k_neighbors
|
24 |
else:
|
25 |
self.k_neighbors = config.search.vector_search.k_neighbors
|
26 |
-
self.
|
|
|
27 |
|
28 |
-
def __create_index(self,
|
29 |
"""Load the metadata file."""
|
30 |
-
if len(
|
31 |
self.index = None
|
32 |
return
|
33 |
-
|
34 |
-
embeddings = np.array(df[COLUMN_EMBEDDING].tolist())
|
35 |
dim = embeddings.shape[1]
|
36 |
-
self.index = faiss.
|
37 |
self.index.add(embeddings)
|
38 |
|
39 |
def search_vectors(self, query: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
40 |
"""
|
41 |
Поиск векторов в индексе.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"""
|
43 |
logger.info(f"Searching vectors in index for query: {query}")
|
44 |
if self.index is None:
|
45 |
return (np.array([]), np.array([]), np.array([]))
|
46 |
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
|
47 |
-
|
48 |
-
|
|
|
|
1 |
import logging
|
2 |
+
|
|
|
|
|
3 |
import faiss
|
4 |
+
import numpy as np
|
5 |
|
|
|
|
|
6 |
from common.configuration import DataBaseConfiguration
|
7 |
+
from common.constants import DO_NORMALIZATION
|
8 |
from components.embedding_extraction import EmbeddingExtractor
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
|
|
12 |
|
13 |
class FaissVectorSearch:
|
14 |
def __init__(
|
15 |
+
self,
|
16 |
+
model: EmbeddingExtractor,
|
17 |
+
ids_to_embeddings: dict[str, np.ndarray],
|
18 |
+
config: DataBaseConfiguration,
|
19 |
):
|
20 |
self.model = model
|
21 |
self.config = config
|
|
|
24 |
self.k_neighbors = config.ranker.k_neighbors
|
25 |
else:
|
26 |
self.k_neighbors = config.search.vector_search.k_neighbors
|
27 |
+
self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
|
28 |
+
self.__create_index(ids_to_embeddings)
|
29 |
|
30 |
+
def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
|
31 |
"""Load the metadata file."""
|
32 |
+
if len(ids_to_embeddings) == 0:
|
33 |
self.index = None
|
34 |
return
|
35 |
+
embeddings = np.array(list(ids_to_embeddings.values()))
|
|
|
36 |
dim = embeddings.shape[1]
|
37 |
+
self.index = faiss.IndexFlatIP(dim)
|
38 |
self.index.add(embeddings)
|
39 |
|
40 |
def search_vectors(self, query: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
41 |
"""
|
42 |
Поиск векторов в индексе.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
query: Строка, запрос для поиска.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
|
49 |
+
- np.ndarray: Вектор запроса (1, embedding_size)
|
50 |
+
- np.ndarray: Оценки косинусного сходства (чем больше, тем лучше)
|
51 |
+
- np.ndarray: Идентификаторы найденных векторов
|
52 |
"""
|
53 |
logger.info(f"Searching vectors in index for query: {query}")
|
54 |
if self.index is None:
|
55 |
return (np.array([]), np.array([]), np.array([]))
|
56 |
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
|
57 |
+
similarities, indexes = self.index.search(query_embeds, self.k_neighbors)
|
58 |
+
ids = [self.index_to_id[index] for index in indexes[0]]
|
59 |
+
return query_embeds, similarities[0], np.array(ids)
|
components/services/dataset.py
CHANGED
@@ -4,33 +4,27 @@ import os
|
|
4 |
import shutil
|
5 |
import zipfile
|
6 |
from datetime import datetime
|
7 |
-
from multiprocessing import Process
|
8 |
from pathlib import Path
|
9 |
-
from typing import Optional
|
10 |
-
from threading import Lock
|
11 |
|
12 |
import pandas as pd
|
13 |
import torch
|
14 |
from fastapi import BackgroundTasks, HTTPException, UploadFile
|
|
|
|
|
15 |
|
16 |
from common.common import get_source_format
|
17 |
from common.configuration import Configuration
|
18 |
-
from components.embedding_extraction import EmbeddingExtractor
|
19 |
-
from components.parser.features.documents_dataset import DocumentsDataset
|
20 |
-
from components.parser.pipeline import DatasetCreationPipeline
|
21 |
-
from components.parser.xml.structures import ParsedXML
|
22 |
-
from components.parser.xml.xml_parser import XMLParser
|
23 |
-
from sqlalchemy.orm import Session
|
24 |
-
from components.dbo.models.acronym import Acronym
|
25 |
from components.dbo.models.dataset import Dataset
|
26 |
from components.dbo.models.dataset_document import DatasetDocument
|
27 |
from components.dbo.models.document import Document
|
|
|
28 |
from schemas.dataset import Dataset as DatasetSchema
|
29 |
from schemas.dataset import DatasetExpanded as DatasetExpandedSchema
|
30 |
from schemas.dataset import DatasetProcessing
|
31 |
from schemas.dataset import DocumentsPage as DocumentsPageSchema
|
32 |
from schemas.dataset import SortQueryList
|
33 |
from schemas.document import Document as DocumentSchema
|
|
|
34 |
logger = logging.getLogger(__name__)
|
35 |
|
36 |
|
@@ -38,24 +32,31 @@ class DatasetService:
|
|
38 |
"""
|
39 |
Сервис для работы с датасетами.
|
40 |
"""
|
41 |
-
|
42 |
def __init__(
|
43 |
-
self,
|
44 |
-
|
45 |
config: Configuration,
|
46 |
-
db: Session
|
47 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
logger.info("DatasetService initializing")
|
49 |
self.db = db
|
50 |
self.config = config
|
51 |
-
self.parser =
|
52 |
-
self.
|
53 |
self.regulations_path = Path(config.db_config.files.regulations_path)
|
54 |
self.documents_path = Path(config.db_config.files.documents_path)
|
55 |
-
self.tmp_path= Path(os.environ.get("APP_TMP_PATH", '.'))
|
56 |
logger.info("DatasetService initialized")
|
57 |
|
58 |
-
|
59 |
def get_dataset(
|
60 |
self,
|
61 |
dataset_id: int,
|
@@ -83,9 +84,6 @@ class DatasetService:
|
|
83 |
session.query(Document)
|
84 |
.join(DatasetDocument, DatasetDocument.document_id == Document.id)
|
85 |
.filter(DatasetDocument.dataset_id == dataset_id)
|
86 |
-
.filter(
|
87 |
-
Document.status.in_(['Актуальный', 'Требует актуализации', 'Упразднён'])
|
88 |
-
)
|
89 |
.filter(Document.title.like(f'%{search}%'))
|
90 |
)
|
91 |
|
@@ -98,7 +96,9 @@ class DatasetService:
|
|
98 |
.join(DatasetDocument, DatasetDocument.document_id == Document.id)
|
99 |
.filter(DatasetDocument.dataset_id == dataset_id)
|
100 |
.filter(
|
101 |
-
Document.status.in_(
|
|
|
|
|
102 |
)
|
103 |
.filter(Document.title.like(f'%{search}%'))
|
104 |
.count()
|
@@ -142,7 +142,7 @@ class DatasetService:
|
|
142 |
name=dataset.name,
|
143 |
isDraft=dataset.is_draft,
|
144 |
isActive=dataset.is_active,
|
145 |
-
dateCreated=dataset.date_created
|
146 |
)
|
147 |
for dataset in datasets
|
148 |
]
|
@@ -198,8 +198,10 @@ class DatasetService:
|
|
198 |
self.raise_if_processing()
|
199 |
|
200 |
with self.db() as session:
|
201 |
-
dataset: Dataset =
|
202 |
-
|
|
|
|
|
203 |
if not dataset:
|
204 |
raise HTTPException(status_code=404, detail='Dataset not found')
|
205 |
|
@@ -222,36 +224,42 @@ class DatasetService:
|
|
222 |
"""
|
223 |
try:
|
224 |
with self.db() as session:
|
225 |
-
dataset =
|
|
|
|
|
226 |
if not dataset:
|
227 |
-
raise HTTPException(
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
232 |
dataset.is_draft = False
|
233 |
dataset.is_active = True
|
234 |
if active_dataset:
|
235 |
active_dataset.is_active = False
|
236 |
-
|
237 |
session.commit()
|
238 |
except Exception as e:
|
239 |
logger.error(f"Error applying draft: {e}")
|
240 |
raise
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
244 |
"""
|
245 |
Активировать датасет в фоновой задаче.
|
246 |
"""
|
247 |
-
|
248 |
logger.info(f"Activating dataset {dataset_id}")
|
249 |
self.raise_if_processing()
|
250 |
|
251 |
with self.db() as session:
|
252 |
-
dataset = (
|
253 |
-
session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
254 |
-
)
|
255 |
active_dataset = session.query(Dataset).filter(Dataset.is_active).first()
|
256 |
if not dataset:
|
257 |
raise HTTPException(status_code=404, detail='Dataset not found')
|
@@ -329,7 +337,7 @@ class DatasetService:
|
|
329 |
|
330 |
dataset = self.create_dataset_from_directory(
|
331 |
is_default=False,
|
332 |
-
|
333 |
directory_with_ready_dataset=None,
|
334 |
)
|
335 |
|
@@ -341,10 +349,12 @@ class DatasetService:
|
|
341 |
def apply_draft(
|
342 |
self,
|
343 |
dataset: Dataset,
|
344 |
-
session,
|
345 |
) -> None:
|
346 |
"""
|
347 |
Сохранить черновик как полноценный датасет.
|
|
|
|
|
|
|
348 |
"""
|
349 |
torch.set_num_threads(1)
|
350 |
logger.info(f"Applying draft dataset {dataset.id}")
|
@@ -363,9 +373,7 @@ class DatasetService:
|
|
363 |
if current % log_step != 0:
|
364 |
return
|
365 |
if (total > 10) and (current % (total // 10) == 0):
|
366 |
-
logger.info(
|
367 |
-
f"Processing dataset {dataset.id}: {current}/{total}"
|
368 |
-
)
|
369 |
with open(TMP_PATH, 'w', encoding='utf-8') as f:
|
370 |
json.dump(
|
371 |
{
|
@@ -381,34 +389,25 @@ class DatasetService:
|
|
381 |
document_ids = [
|
382 |
doc_dataset_link.document_id for doc_dataset_link in dataset.documents
|
383 |
]
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
)
|
403 |
-
progress_callback(0, 1000)
|
404 |
-
|
405 |
-
try:
|
406 |
-
pipeline.run(progress_callback)
|
407 |
-
except Exception as e:
|
408 |
-
logger.error(f"Error running pipeline: {e}")
|
409 |
-
raise HTTPException(status_code=500, detail=str(e))
|
410 |
-
finally:
|
411 |
-
TMP_PATH.unlink()
|
412 |
|
413 |
def raise_if_processing(self) -> None:
|
414 |
"""
|
@@ -423,7 +422,7 @@ class DatasetService:
|
|
423 |
def create_dataset_from_directory(
|
424 |
self,
|
425 |
is_default: bool,
|
426 |
-
|
427 |
directory_with_ready_dataset: Path | None = None,
|
428 |
) -> Dataset:
|
429 |
"""
|
@@ -438,7 +437,7 @@ class DatasetService:
|
|
438 |
Dataset: Созданный датасет.
|
439 |
"""
|
440 |
logger.info(
|
441 |
-
f"Creating {'default' if is_default else 'new'} dataset from directory {
|
442 |
)
|
443 |
with self.db() as session:
|
444 |
documents = []
|
@@ -453,9 +452,9 @@ class DatasetService:
|
|
453 |
)
|
454 |
session.add(dataset)
|
455 |
|
456 |
-
for subpath in self._get_recursive_dirlist(
|
457 |
document, relation = self._create_document(
|
458 |
-
|
459 |
)
|
460 |
if document is None:
|
461 |
continue
|
@@ -484,7 +483,8 @@ class DatasetService:
|
|
484 |
old_filename = document.filename
|
485 |
new_filename = '{}.{}'.format(document.id, document.source_format)
|
486 |
shutil.copy(
|
487 |
-
|
|
|
488 |
)
|
489 |
document.filename = new_filename
|
490 |
|
@@ -495,16 +495,8 @@ class DatasetService:
|
|
495 |
|
496 |
dataset_id = dataset.id
|
497 |
|
498 |
-
|
499 |
logger.info(f"Dataset {dataset_id} created")
|
500 |
|
501 |
-
df = self.dataset_to_pandas(dataset_id)
|
502 |
-
|
503 |
-
(self.regulations_path / str(dataset_id)).mkdir(parents=True, exist_ok=True)
|
504 |
-
df.to_csv(
|
505 |
-
self.regulations_path / str(dataset_id) / 'documents.csv', index=False
|
506 |
-
)
|
507 |
-
|
508 |
return dataset
|
509 |
|
510 |
def create_empty_dataset(self, is_default: bool) -> Dataset:
|
@@ -526,20 +518,6 @@ class DatasetService:
|
|
526 |
session.commit()
|
527 |
session.refresh(dataset)
|
528 |
|
529 |
-
self.documents_path.mkdir(exist_ok=True)
|
530 |
-
|
531 |
-
dataset_id = dataset.id
|
532 |
-
|
533 |
-
|
534 |
-
folder = self.regulations_path / str(dataset_id)
|
535 |
-
folder.mkdir(parents=True, exist_ok=True)
|
536 |
-
|
537 |
-
pickle_creator = DocumentsDataset([])
|
538 |
-
pickle_creator.to_pickle(folder / 'dataset.pkl')
|
539 |
-
|
540 |
-
df = self.dataset_to_pandas(dataset_id)
|
541 |
-
df.to_csv(folder / 'documents.csv', index=False)
|
542 |
-
|
543 |
return dataset
|
544 |
|
545 |
@staticmethod
|
@@ -553,10 +531,10 @@ class DatasetService:
|
|
553 |
Returns:
|
554 |
list[Path]: Список путей к xml-файлам относительно path.
|
555 |
"""
|
556 |
-
xml_files = set()
|
557 |
for ext in ('*.xml', '*.XML', '*.docx', '*.DOCX'):
|
558 |
xml_files.update(path.glob(f'**/{ext}'))
|
559 |
-
|
560 |
return [p.relative_to(path) for p in xml_files]
|
561 |
|
562 |
def _create_document(
|
@@ -580,19 +558,19 @@ class DatasetService:
|
|
580 |
|
581 |
try:
|
582 |
source_format = get_source_format(str(subpath))
|
583 |
-
|
584 |
-
documents_path / subpath
|
585 |
)
|
586 |
|
587 |
-
if not
|
588 |
logger.warning(f"Failed to parse file: {subpath}")
|
589 |
return None, None
|
590 |
|
591 |
document = Document(
|
592 |
filename=str(subpath),
|
593 |
-
title=
|
594 |
-
status=
|
595 |
-
owner=
|
596 |
source_format=source_format,
|
597 |
)
|
598 |
relation = DatasetDocument(
|
@@ -606,36 +584,6 @@ class DatasetService:
|
|
606 |
logger.error(f"Error creating document from {subpath}: {e}")
|
607 |
return None, None
|
608 |
|
609 |
-
def dataset_to_pandas(self, dataset_id: int) -> pd.DataFrame:
|
610 |
-
"""
|
611 |
-
Преобразовать датасет в pandas DataFrame.
|
612 |
-
"""
|
613 |
-
with self.db() as session:
|
614 |
-
links = (
|
615 |
-
session.query(DatasetDocument)
|
616 |
-
.filter(DatasetDocument.dataset_id == dataset_id)
|
617 |
-
.all()
|
618 |
-
)
|
619 |
-
documents = (
|
620 |
-
session.query(Document)
|
621 |
-
.filter(Document.id.in_([link.document_id for link in links]))
|
622 |
-
.all()
|
623 |
-
)
|
624 |
-
|
625 |
-
return pd.DataFrame(
|
626 |
-
[
|
627 |
-
{
|
628 |
-
'id': document.id,
|
629 |
-
'filename': document.filename,
|
630 |
-
'title': document.title,
|
631 |
-
'status': document.status,
|
632 |
-
'owner': document.owner,
|
633 |
-
}
|
634 |
-
for document in documents
|
635 |
-
],
|
636 |
-
columns=['id', 'filename', 'title', 'status', 'owner'],
|
637 |
-
)
|
638 |
-
|
639 |
def get_current_dataset(self) -> Dataset | None:
|
640 |
with self.db() as session:
|
641 |
print(session)
|
|
|
4 |
import shutil
|
5 |
import zipfile
|
6 |
from datetime import datetime
|
|
|
7 |
from pathlib import Path
|
|
|
|
|
8 |
|
9 |
import pandas as pd
|
10 |
import torch
|
11 |
from fastapi import BackgroundTasks, HTTPException, UploadFile
|
12 |
+
from ntr_fileparser import ParsedDocument, UniversalParser
|
13 |
+
from sqlalchemy.orm import Session
|
14 |
|
15 |
from common.common import get_source_format
|
16 |
from common.configuration import Configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from components.dbo.models.dataset import Dataset
|
18 |
from components.dbo.models.dataset_document import DatasetDocument
|
19 |
from components.dbo.models.document import Document
|
20 |
+
from components.services.entity import EntityService
|
21 |
from schemas.dataset import Dataset as DatasetSchema
|
22 |
from schemas.dataset import DatasetExpanded as DatasetExpandedSchema
|
23 |
from schemas.dataset import DatasetProcessing
|
24 |
from schemas.dataset import DocumentsPage as DocumentsPageSchema
|
25 |
from schemas.dataset import SortQueryList
|
26 |
from schemas.document import Document as DocumentSchema
|
27 |
+
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
30 |
|
|
|
32 |
"""
|
33 |
Сервис для работы с датасетами.
|
34 |
"""
|
35 |
+
|
36 |
def __init__(
|
37 |
+
self,
|
38 |
+
entity_service: EntityService,
|
39 |
config: Configuration,
|
40 |
+
db: Session,
|
41 |
) -> None:
|
42 |
+
"""
|
43 |
+
Инициализация сервиса.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
entity_service: Сервис для работы с сущностями
|
47 |
+
config: Конфигурация приложения
|
48 |
+
db: SQLAlchemy сессия
|
49 |
+
"""
|
50 |
logger.info("DatasetService initializing")
|
51 |
self.db = db
|
52 |
self.config = config
|
53 |
+
self.parser = UniversalParser()
|
54 |
+
self.entity_service = entity_service
|
55 |
self.regulations_path = Path(config.db_config.files.regulations_path)
|
56 |
self.documents_path = Path(config.db_config.files.documents_path)
|
57 |
+
self.tmp_path = Path(os.environ.get("APP_TMP_PATH", '.'))
|
58 |
logger.info("DatasetService initialized")
|
59 |
|
|
|
60 |
def get_dataset(
|
61 |
self,
|
62 |
dataset_id: int,
|
|
|
84 |
session.query(Document)
|
85 |
.join(DatasetDocument, DatasetDocument.document_id == Document.id)
|
86 |
.filter(DatasetDocument.dataset_id == dataset_id)
|
|
|
|
|
|
|
87 |
.filter(Document.title.like(f'%{search}%'))
|
88 |
)
|
89 |
|
|
|
96 |
.join(DatasetDocument, DatasetDocument.document_id == Document.id)
|
97 |
.filter(DatasetDocument.dataset_id == dataset_id)
|
98 |
.filter(
|
99 |
+
Document.status.in_(
|
100 |
+
['Актуальный', 'Требует актуализации', 'Упразднён']
|
101 |
+
)
|
102 |
)
|
103 |
.filter(Document.title.like(f'%{search}%'))
|
104 |
.count()
|
|
|
142 |
name=dataset.name,
|
143 |
isDraft=dataset.is_draft,
|
144 |
isActive=dataset.is_active,
|
145 |
+
dateCreated=dataset.date_created,
|
146 |
)
|
147 |
for dataset in datasets
|
148 |
]
|
|
|
198 |
self.raise_if_processing()
|
199 |
|
200 |
with self.db() as session:
|
201 |
+
dataset: Dataset = (
|
202 |
+
session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
203 |
+
)
|
204 |
+
|
205 |
if not dataset:
|
206 |
raise HTTPException(status_code=404, detail='Dataset not found')
|
207 |
|
|
|
224 |
"""
|
225 |
try:
|
226 |
with self.db() as session:
|
227 |
+
dataset = (
|
228 |
+
session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
229 |
+
)
|
230 |
if not dataset:
|
231 |
+
raise HTTPException(
|
232 |
+
status_code=404,
|
233 |
+
detail=f"Dataset with id {dataset_id} not found",
|
234 |
+
)
|
235 |
+
|
236 |
+
active_dataset = (
|
237 |
+
session.query(Dataset).filter(Dataset.is_active == True).first()
|
238 |
+
)
|
239 |
+
|
240 |
+
self.apply_draft(dataset)
|
241 |
dataset.is_draft = False
|
242 |
dataset.is_active = True
|
243 |
if active_dataset:
|
244 |
active_dataset.is_active = False
|
245 |
+
|
246 |
session.commit()
|
247 |
except Exception as e:
|
248 |
logger.error(f"Error applying draft: {e}")
|
249 |
raise
|
250 |
+
|
251 |
+
def activate_dataset(
|
252 |
+
self, dataset_id: int, background_tasks: BackgroundTasks
|
253 |
+
) -> DatasetExpandedSchema:
|
254 |
"""
|
255 |
Активировать датасет в фоновой задаче.
|
256 |
"""
|
257 |
+
|
258 |
logger.info(f"Activating dataset {dataset_id}")
|
259 |
self.raise_if_processing()
|
260 |
|
261 |
with self.db() as session:
|
262 |
+
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
|
|
|
|
263 |
active_dataset = session.query(Dataset).filter(Dataset.is_active).first()
|
264 |
if not dataset:
|
265 |
raise HTTPException(status_code=404, detail='Dataset not found')
|
|
|
337 |
|
338 |
dataset = self.create_dataset_from_directory(
|
339 |
is_default=False,
|
340 |
+
directory_with_documents=file_location.parent,
|
341 |
directory_with_ready_dataset=None,
|
342 |
)
|
343 |
|
|
|
349 |
def apply_draft(
|
350 |
self,
|
351 |
dataset: Dataset,
|
|
|
352 |
) -> None:
|
353 |
"""
|
354 |
Сохранить черновик как полноценный датасет.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
dataset: Датасет для применения
|
358 |
"""
|
359 |
torch.set_num_threads(1)
|
360 |
logger.info(f"Applying draft dataset {dataset.id}")
|
|
|
373 |
if current % log_step != 0:
|
374 |
return
|
375 |
if (total > 10) and (current % (total // 10) == 0):
|
376 |
+
logger.info(f"Processing dataset {dataset.id}: {current}/{total}")
|
|
|
|
|
377 |
with open(TMP_PATH, 'w', encoding='utf-8') as f:
|
378 |
json.dump(
|
379 |
{
|
|
|
389 |
document_ids = [
|
390 |
doc_dataset_link.document_id for doc_dataset_link in dataset.documents
|
391 |
]
|
392 |
+
|
393 |
+
for document_id in document_ids:
|
394 |
+
path = self.documents_path / f'{document_id}.DOCX'
|
395 |
+
parsed = self.parser.parse_by_path(str(path))
|
396 |
+
if parsed is None:
|
397 |
+
logger.warning(f"Failed to parse document {document_id}")
|
398 |
+
continue
|
399 |
+
|
400 |
+
# Используем EntityService для обработки документа с callback
|
401 |
+
self.entity_service.process_document(
|
402 |
+
parsed,
|
403 |
+
dataset.id,
|
404 |
+
progress_callback=progress_callback,
|
405 |
+
words_per_chunk=50,
|
406 |
+
overlap_words=25,
|
407 |
+
respect_sentence_boundaries=True,
|
408 |
+
)
|
409 |
+
|
410 |
+
TMP_PATH.unlink()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
def raise_if_processing(self) -> None:
|
413 |
"""
|
|
|
422 |
def create_dataset_from_directory(
|
423 |
self,
|
424 |
is_default: bool,
|
425 |
+
directory_with_documents: Path,
|
426 |
directory_with_ready_dataset: Path | None = None,
|
427 |
) -> Dataset:
|
428 |
"""
|
|
|
437 |
Dataset: Созданный датасет.
|
438 |
"""
|
439 |
logger.info(
|
440 |
+
f"Creating {'default' if is_default else 'new'} dataset from directory {directory_with_documents}"
|
441 |
)
|
442 |
with self.db() as session:
|
443 |
documents = []
|
|
|
452 |
)
|
453 |
session.add(dataset)
|
454 |
|
455 |
+
for subpath in self._get_recursive_dirlist(directory_with_documents):
|
456 |
document, relation = self._create_document(
|
457 |
+
directory_with_documents, subpath, dataset
|
458 |
)
|
459 |
if document is None:
|
460 |
continue
|
|
|
483 |
old_filename = document.filename
|
484 |
new_filename = '{}.{}'.format(document.id, document.source_format)
|
485 |
shutil.copy(
|
486 |
+
directory_with_documents / old_filename,
|
487 |
+
self.documents_path / new_filename,
|
488 |
)
|
489 |
document.filename = new_filename
|
490 |
|
|
|
495 |
|
496 |
dataset_id = dataset.id
|
497 |
|
|
|
498 |
logger.info(f"Dataset {dataset_id} created")
|
499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
return dataset
|
501 |
|
502 |
def create_empty_dataset(self, is_default: bool) -> Dataset:
|
|
|
518 |
session.commit()
|
519 |
session.refresh(dataset)
|
520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
return dataset
|
522 |
|
523 |
@staticmethod
|
|
|
531 |
Returns:
|
532 |
list[Path]: Список путей к xml-файлам относительно path.
|
533 |
"""
|
534 |
+
xml_files = set() # set для отбрасывания неуникальных путей
|
535 |
for ext in ('*.xml', '*.XML', '*.docx', '*.DOCX'):
|
536 |
xml_files.update(path.glob(f'**/{ext}'))
|
537 |
+
|
538 |
return [p.relative_to(path) for p in xml_files]
|
539 |
|
540 |
def _create_document(
|
|
|
558 |
|
559 |
try:
|
560 |
source_format = get_source_format(str(subpath))
|
561 |
+
parsed: ParsedDocument | None = self.parser.parse_by_path(
|
562 |
+
str(documents_path / subpath)
|
563 |
)
|
564 |
|
565 |
+
if not parsed:
|
566 |
logger.warning(f"Failed to parse file: {subpath}")
|
567 |
return None, None
|
568 |
|
569 |
document = Document(
|
570 |
filename=str(subpath),
|
571 |
+
title=parsed.name,
|
572 |
+
status=parsed.meta.status,
|
573 |
+
owner=parsed.meta.owner,
|
574 |
source_format=source_format,
|
575 |
)
|
576 |
relation = DatasetDocument(
|
|
|
584 |
logger.error(f"Error creating document from {subpath}: {e}")
|
585 |
return None, None
|
586 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
def get_current_dataset(self) -> Dataset | None:
|
588 |
with self.db() as session:
|
589 |
print(session)
|
components/services/document.py
CHANGED
@@ -4,19 +4,18 @@ import shutil
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
from fastapi import HTTPException, UploadFile
|
|
|
7 |
|
8 |
from sqlalchemy.orm import Session
|
9 |
from common.common import get_source_format
|
10 |
from common.configuration import Configuration
|
11 |
from common.constants import PROCESSING_FORMATS
|
12 |
-
from components.parser.xml.xml_parser import XMLParser
|
13 |
from components.dbo.models.dataset import Dataset
|
14 |
from components.dbo.models.dataset_document import DatasetDocument
|
15 |
from components.dbo.models.document import Document
|
16 |
from schemas.document import Document as DocumentSchema
|
17 |
from schemas.document import DocumentDownload
|
18 |
from components.services.dataset import DatasetService
|
19 |
-
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
@@ -34,7 +33,7 @@ class DocumentService:
|
|
34 |
logger.info("Initializing DocumentService")
|
35 |
self.db = db
|
36 |
self.dataset_service = dataset_service
|
37 |
-
self.
|
38 |
self.documents_path = Path(config.db_config.files.documents_path)
|
39 |
|
40 |
def get_document(
|
@@ -101,10 +100,10 @@ class DocumentService:
|
|
101 |
logger.info(f"Source format: {source_format}")
|
102 |
|
103 |
try:
|
104 |
-
parsed = self.
|
105 |
except Exception:
|
106 |
raise HTTPException(
|
107 |
-
status_code=400, detail="Invalid
|
108 |
)
|
109 |
|
110 |
with self.db() as session:
|
@@ -118,9 +117,10 @@ class DocumentService:
|
|
118 |
raise HTTPException(status_code=403, detail='Dataset is not draft')
|
119 |
|
120 |
document = Document(
|
|
|
121 |
title=parsed.name,
|
122 |
-
owner=parsed.owner,
|
123 |
-
status=parsed.status,
|
124 |
source_format=source_format,
|
125 |
)
|
126 |
|
@@ -129,21 +129,21 @@ class DocumentService:
|
|
129 |
session.add(document)
|
130 |
session.flush()
|
131 |
|
132 |
-
logger.info(f"Document ID: {document.
|
133 |
|
134 |
link = DatasetDocument(
|
135 |
dataset_id=dataset_id,
|
136 |
-
document_id=document.
|
137 |
)
|
138 |
session.add(link)
|
139 |
|
140 |
if source_format in PROCESSING_FORMATS:
|
141 |
logger.info(
|
142 |
-
f"Moving file to: {self.documents_path / f'{document.
|
143 |
)
|
144 |
shutil.move(
|
145 |
file_location,
|
146 |
-
self.documents_path / f'{document.
|
147 |
)
|
148 |
else:
|
149 |
logger.error(f"Unknown source format: {source_format}")
|
@@ -156,7 +156,7 @@ class DocumentService:
|
|
156 |
session.refresh(document)
|
157 |
|
158 |
result = DocumentSchema(
|
159 |
-
id=document.
|
160 |
name=document.title,
|
161 |
owner=document.owner,
|
162 |
status=document.status,
|
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
from fastapi import HTTPException, UploadFile
|
7 |
+
from ntr_fileparser import UniversalParser
|
8 |
|
9 |
from sqlalchemy.orm import Session
|
10 |
from common.common import get_source_format
|
11 |
from common.configuration import Configuration
|
12 |
from common.constants import PROCESSING_FORMATS
|
|
|
13 |
from components.dbo.models.dataset import Dataset
|
14 |
from components.dbo.models.dataset_document import DatasetDocument
|
15 |
from components.dbo.models.document import Document
|
16 |
from schemas.document import Document as DocumentSchema
|
17 |
from schemas.document import DocumentDownload
|
18 |
from components.services.dataset import DatasetService
|
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
|
|
|
33 |
logger.info("Initializing DocumentService")
|
34 |
self.db = db
|
35 |
self.dataset_service = dataset_service
|
36 |
+
self.parser = UniversalParser()
|
37 |
self.documents_path = Path(config.db_config.files.documents_path)
|
38 |
|
39 |
def get_document(
|
|
|
100 |
logger.info(f"Source format: {source_format}")
|
101 |
|
102 |
try:
|
103 |
+
parsed = self.parser.parse_by_path(str(file_location))
|
104 |
except Exception:
|
105 |
raise HTTPException(
|
106 |
+
status_code=400, detail="Invalid file, service can't parse it"
|
107 |
)
|
108 |
|
109 |
with self.db() as session:
|
|
|
117 |
raise HTTPException(status_code=403, detail='Dataset is not draft')
|
118 |
|
119 |
document = Document(
|
120 |
+
filename=file.filename,
|
121 |
title=parsed.name,
|
122 |
+
owner=parsed.meta.owner,
|
123 |
+
status=parsed.meta.status,
|
124 |
source_format=source_format,
|
125 |
)
|
126 |
|
|
|
129 |
session.add(document)
|
130 |
session.flush()
|
131 |
|
132 |
+
logger.info(f"Document ID: {document.id}")
|
133 |
|
134 |
link = DatasetDocument(
|
135 |
dataset_id=dataset_id,
|
136 |
+
document_id=document.id,
|
137 |
)
|
138 |
session.add(link)
|
139 |
|
140 |
if source_format in PROCESSING_FORMATS:
|
141 |
logger.info(
|
142 |
+
f"Moving file to: {self.documents_path / f'{document.id}.{source_format}'}"
|
143 |
)
|
144 |
shutil.move(
|
145 |
file_location,
|
146 |
+
self.documents_path / f'{document.id}.{source_format}',
|
147 |
)
|
148 |
else:
|
149 |
logger.error(f"Unknown source format: {source_format}")
|
|
|
156 |
session.refresh(document)
|
157 |
|
158 |
result = DocumentSchema(
|
159 |
+
id=document.id,
|
160 |
name=document.title,
|
161 |
owner=document.owner,
|
162 |
status=document.status,
|
components/services/entity.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Callable, Optional
|
3 |
+
from uuid import UUID
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from ntr_fileparser import ParsedDocument
|
7 |
+
from ntr_text_fragmentation import Destructurer, InjectionBuilder, LinkerEntity
|
8 |
+
|
9 |
+
from common.configuration import Configuration
|
10 |
+
from components.dbo.chunk_repository import ChunkRepository
|
11 |
+
from components.embedding_extraction import EmbeddingExtractor
|
12 |
+
from components.nmd.faiss_vector_search import FaissVectorSearch
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class EntityService:
|
18 |
+
"""
|
19 |
+
Сервис для работы с сущностями.
|
20 |
+
Объединяет функциональность chunk_repository, destructurer, injection_builder и faiss_vector_search.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
vectorizer: EmbeddingExtractor,
|
26 |
+
chunk_repository: ChunkRepository,
|
27 |
+
config: Configuration,
|
28 |
+
) -> None:
|
29 |
+
"""
|
30 |
+
Инициализация сервиса.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
vectorizer: Модель для извлечения эмбеддингов
|
34 |
+
chunk_repository: Репозиторий для работы с чанками
|
35 |
+
config: Конфигурация приложения
|
36 |
+
"""
|
37 |
+
self.vectorizer = vectorizer
|
38 |
+
self.config = config
|
39 |
+
self.chunk_repository = chunk_repository
|
40 |
+
self.faiss_search = None # Инициализируется при необходимости
|
41 |
+
self.current_dataset_id = None # Текущий dataset_id
|
42 |
+
|
43 |
+
def _ensure_faiss_initialized(self, dataset_id: int) -> None:
|
44 |
+
"""
|
45 |
+
Проверяет и при необходимости инициализирует или обновляет FAISS индекс.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
dataset_id: ID датасета для инициализации
|
49 |
+
"""
|
50 |
+
# Если индекс не инициализирован или датасет изменился
|
51 |
+
if self.faiss_search is None or self.current_dataset_id != dataset_id:
|
52 |
+
logger.info(f'Initializing FAISS for dataset {dataset_id}')
|
53 |
+
entities, embeddings = self.chunk_repository.get_searching_entities(dataset_id)
|
54 |
+
if entities:
|
55 |
+
# Создаем словарь только из не-None эмбеддингов
|
56 |
+
embeddings_dict = {
|
57 |
+
str(entity.id): embedding # Преобразуем UUID в строку для ключа
|
58 |
+
for entity, embedding in zip(entities, embeddings)
|
59 |
+
if embedding is not None
|
60 |
+
}
|
61 |
+
if embeddings_dict: # Проверяем, что есть хотя бы один эмбеддинг
|
62 |
+
self.faiss_search = FaissVectorSearch(
|
63 |
+
self.vectorizer,
|
64 |
+
embeddings_dict,
|
65 |
+
self.config.db_config,
|
66 |
+
)
|
67 |
+
self.current_dataset_id = dataset_id
|
68 |
+
logger.info(f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings')
|
69 |
+
else:
|
70 |
+
logger.warning(f'No valid embeddings found for dataset {dataset_id}')
|
71 |
+
self.faiss_search = None
|
72 |
+
self.current_dataset_id = None
|
73 |
+
else:
|
74 |
+
logger.warning(f'No entities found for dataset {dataset_id}')
|
75 |
+
self.faiss_search = None
|
76 |
+
self.current_dataset_id = None
|
77 |
+
|
78 |
+
def process_document(
|
79 |
+
self,
|
80 |
+
document: ParsedDocument,
|
81 |
+
dataset_id: int,
|
82 |
+
progress_callback: Optional[Callable] = None,
|
83 |
+
**destructurer_kwargs,
|
84 |
+
) -> None:
|
85 |
+
"""
|
86 |
+
Обработка документа: разбиение на чанки и сохранение в базу.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
document: Документ для обработки
|
90 |
+
dataset_id: ID датасета
|
91 |
+
progress_callback: Функция для отслеживания прогресса
|
92 |
+
**destructurer_kwargs: Дополнительные параметры для Destructurer
|
93 |
+
"""
|
94 |
+
logger.info(f"Processing document {document.name} for dataset {dataset_id}")
|
95 |
+
|
96 |
+
# Создаем деструктуризатор с параметрами по умолчанию
|
97 |
+
destructurer = Destructurer(
|
98 |
+
document,
|
99 |
+
strategy_name="fixed_size",
|
100 |
+
process_tables=True,
|
101 |
+
**{
|
102 |
+
"words_per_chunk": 50,
|
103 |
+
"overlap_words": 25,
|
104 |
+
"respect_sentence_boundaries": True,
|
105 |
+
**destructurer_kwargs,
|
106 |
+
}
|
107 |
+
)
|
108 |
+
|
109 |
+
# Получаем сущности
|
110 |
+
entities = destructurer.destructure()
|
111 |
+
|
112 |
+
# Фильтруем сущности для поиска
|
113 |
+
filtering_entities = [entity for entity in entities if entity.in_search_text is not None]
|
114 |
+
filtering_texts = [entity.in_search_text for entity in filtering_entities]
|
115 |
+
|
116 |
+
# Получаем эмбеддинги с поддержкой callback
|
117 |
+
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback)
|
118 |
+
embeddings_dict = {
|
119 |
+
str(entity.id): embedding # Преобразуем UUID в строку для ключа
|
120 |
+
for entity, embedding in zip(filtering_entities, embeddings)
|
121 |
+
}
|
122 |
+
|
123 |
+
# Сохраняем в базу
|
124 |
+
self.chunk_repository.add_entities(entities, dataset_id, embeddings_dict)
|
125 |
+
|
126 |
+
# Переинициализируем FAISS индекс, если это текущий датасет
|
127 |
+
if self.current_dataset_id == dataset_id:
|
128 |
+
self._ensure_faiss_initialized(dataset_id)
|
129 |
+
|
130 |
+
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}")
|
131 |
+
|
132 |
+
def build_text(
|
133 |
+
self,
|
134 |
+
entities: list[LinkerEntity],
|
135 |
+
chunk_scores: Optional[list[float]] = None,
|
136 |
+
include_tables: bool = True,
|
137 |
+
max_documents: Optional[int] = None,
|
138 |
+
) -> str:
|
139 |
+
"""
|
140 |
+
Сборка текста из сущностей.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
entities: Список сущностей
|
144 |
+
chunk_scores: Список весов чанков
|
145 |
+
include_tables: Флаг включения таблиц
|
146 |
+
max_documents: Максимальное количество документов
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
Собранный текст
|
150 |
+
"""
|
151 |
+
logger.info(f"Building text for {len(entities)} entities")
|
152 |
+
if chunk_scores is not None:
|
153 |
+
chunk_scores = {entity.id: score for entity, score in zip(entities, chunk_scores)}
|
154 |
+
builder = InjectionBuilder(self.chunk_repository)
|
155 |
+
return builder.build(
|
156 |
+
[entity.id for entity in entities], # Передаем UUID напрямую
|
157 |
+
chunk_scores=chunk_scores,
|
158 |
+
include_tables=include_tables,
|
159 |
+
max_documents=max_documents,
|
160 |
+
)
|
161 |
+
|
162 |
+
def search_similar(
|
163 |
+
self,
|
164 |
+
query: str,
|
165 |
+
dataset_id: int,
|
166 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
167 |
+
"""
|
168 |
+
Поиск похожих сущностей.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
query: Текст запроса
|
172 |
+
dataset_id: ID датасета
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
tuple[np.ndarray, np.ndarray, np.ndarray]:
|
176 |
+
- Вектор запроса
|
177 |
+
- Оценки сходства
|
178 |
+
- Идентификаторы найденных сущностей
|
179 |
+
"""
|
180 |
+
# Убеждаемся, что FAISS инициализирован для текущего датасета
|
181 |
+
self._ensure_faiss_initialized(dataset_id)
|
182 |
+
|
183 |
+
if self.faiss_search is None:
|
184 |
+
return np.array([]), np.array([]), np.array([])
|
185 |
+
|
186 |
+
# Выполняем поиск
|
187 |
+
return self.faiss_search.search_vectors(query)
|
188 |
+
|
189 |
+
def add_neighboring_chunks(
|
190 |
+
self,
|
191 |
+
entities: list[LinkerEntity],
|
192 |
+
max_distance: int = 1,
|
193 |
+
) -> list[LinkerEntity]:
|
194 |
+
"""
|
195 |
+
Добавление соседних чанков.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
entities: Список сущностей
|
199 |
+
max_distance: Максимальное расстояние для поиска соседей
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
Расширенный список сущностей
|
203 |
+
"""
|
204 |
+
# Убедимся, что все ID представлены в UUID формате
|
205 |
+
for entity in entities:
|
206 |
+
if not isinstance(entity.id, UUID):
|
207 |
+
entity.id = UUID(str(entity.id))
|
208 |
+
|
209 |
+
builder = InjectionBuilder(self.chunk_repository)
|
210 |
+
return builder.add_neighboring_chunks(entities, max_distance)
|
lib/extractor/.cursor/rules/project-description.mdc
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
description:
|
3 |
+
globs:
|
4 |
+
alwaysApply: true
|
5 |
+
---
|
6 |
+
|
7 |
+
# Project description
|
8 |
+
|
9 |
+
Данный проект представляет собой библиотеку, предоставляющую возможности для чанкинга и сборки
|
10 |
+
инъекций в промпт LLM для дальнейшего использования в RAG-системах. Основная логика описана в README.md и в architectures, если они не устарели. Ядро системы представляют классы LinkerEntity, Destructurer, EntityRepository, InjectionBuilder, ChunkingStrategy.
|
11 |
+
|
12 |
+
- LinkerEntity – основная сущность, от которой затем наследуются Chunk и DocumentAsEntity. Реализует триплетный подход, при котором один и тот же класс задаёт и сущности, и связи, и при этом сущности-ассоциации реализуются одним экземпляром, а не множеством.
|
13 |
+
- Destructurer – реализует логику разбиения документа на множество LinkerEntity, во многом делегируя работу различным ChunkingStrategy (но не всю).
|
14 |
+
- EntityRepository – интерфейс. Предполагается, что после извлечения всех сущностей посредством Destructurer пользователь библиотеки сохранит все свои сущности некоторым произвольным образом, например, в csv-файл или PostgreSQL. Библиотека не знает, как работать с пользовательскими хранилищами данных, поэтому пользователь должен сам написать реализацию EntityRepository для своего решения, и предоставить её в InjectionBuilder
|
15 |
+
- InjectionBuilder – сборщик промпт-инъекции. Принимает на вход отфильтрованный и (в отдельных случаях) оценённый некоторым скором набор сущностей, сортирует их, распределяет по документам и собирает всё в единый текст, пользуясь EntityRepository, чтобы достать связанные полезные сущности
|
16 |
+
|
17 |
+
Данная библиотека ориентируется на ParsedDocument из библиотеки ntr_fileparser, структура которого примерно соответствует следующему:
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ParsedDocument(ParsedStructure):
|
21 |
+
"""
|
22 |
+
Документ, полученный в результате парсинга.
|
23 |
+
"""
|
24 |
+
name: str = ""
|
25 |
+
type: str = ""
|
26 |
+
meta: ParsedMeta = field(default_factory=ParsedMeta)
|
27 |
+
paragraphs: list[ParsedTextBlock] = field(default_factory=list)
|
28 |
+
tables: list[ParsedTable] = field(default_factory=list)
|
29 |
+
images: list[ParsedImage] = field(default_factory=list)
|
30 |
+
formulas: list[ParsedFormula] = field(default_factory=list)
|
31 |
+
|
32 |
+
def to_string() -> str:
|
33 |
+
...
|
34 |
+
|
35 |
+
def to_dict() -> dict:
|
36 |
+
...
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class ParsedTextBlock(DocumentElement):
|
41 |
+
"""
|
42 |
+
Текстовый блок документа.
|
43 |
+
"""
|
44 |
+
|
45 |
+
text: str = ""
|
46 |
+
style: TextStyle = field(default_factory=TextStyle)
|
47 |
+
anchors: list[str] = field(default_factory=list) # Список идентификаторов якорей (закладок)
|
48 |
+
links: list[str] = field(default_factory=list) # Список идентификаторов ссылок
|
49 |
+
|
50 |
+
# Технические метаданные о блоке
|
51 |
+
metadata: list[dict[str, Any]] = field(default_factory=list) # Для хранения технической информации
|
52 |
+
|
53 |
+
# Примечания и сноски к тексту
|
54 |
+
footnotes: list[dict[str, Any]] = field(default_factory=list) # Для хранения сносок
|
55 |
+
|
56 |
+
title_of_table: int | None = None
|
57 |
+
|
58 |
+
def to_string() -> str:
|
59 |
+
...
|
60 |
+
|
61 |
+
def to_dict() -> dict:
|
62 |
+
...
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class ParsedTable(DocumentElement):
|
67 |
+
"""
|
68 |
+
Таблица из документа.
|
69 |
+
"""
|
70 |
+
|
71 |
+
title: str | None = None
|
72 |
+
note: str | None = None
|
73 |
+
classified_tags: list[TableTag] = field(default_factory=list)
|
74 |
+
index: list[str] = field(default_factory=list)
|
75 |
+
headers: list[ParsedRow] = field(default_factory=list)
|
76 |
+
subtables: list[ParsedSubtable] = field(default_factory=list)
|
77 |
+
table_style: dict[str, Any] = field(default_factory=dict)
|
78 |
+
title_index_in_paragraphs: int | None = None
|
79 |
+
|
80 |
+
def to_string() -> str:
|
81 |
+
...
|
82 |
+
|
83 |
+
def to_dict() -> dict:
|
84 |
+
...
|
85 |
+
|
86 |
+
(Дальнейшую информацию о вложенных классах ты можешь уточнить у по��ьзователя, если это будет нужно)
|
lib/extractor/.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use_it/*
|
2 |
+
test_output/
|
3 |
+
test_input/
|
4 |
+
__pycache__/
|
5 |
+
*.pyc
|
6 |
+
*.pyo
|
7 |
+
*.pyd
|
8 |
+
*.pyw
|
9 |
+
*.pyz
|
10 |
+
|
11 |
+
*.egg-info/
|
lib/extractor/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Библиотека извлечения и сборки документов
|
2 |
+
|
3 |
+
Библиотека для извлечения структурированной информации из документов и их последующей сборки.
|
4 |
+
|
5 |
+
## Основные компоненты
|
6 |
+
|
7 |
+
- **Destructurer**: Разбивает документ на чанки и связи между ними, а также извлекает дополнительные сущности
|
8 |
+
- **Builder**: Собирает документ из чанков и связей
|
9 |
+
- **Entity**: Базовый класс для всех сущностей (Document, Chunk, Acronym и т.д.)
|
10 |
+
- **Link**: Класс для представления связей между сущностями
|
11 |
+
- **ChunkingStrategy**: Интерфейс для различных стратегий чанкинга
|
12 |
+
- **TablesProcessor**: Процессор для извлечения таблиц из документа
|
13 |
+
|
14 |
+
## Установка
|
15 |
+
|
16 |
+
```bash
|
17 |
+
pip install -e .
|
18 |
+
```
|
19 |
+
|
20 |
+
## Использование
|
21 |
+
|
22 |
+
```python
|
23 |
+
from ntr_text_fragmentation.core import Destructurer, Builder
|
24 |
+
from ntr_fileparser import ParsedDocument
|
25 |
+
|
26 |
+
# Пример использования Destructurer с обработкой таблиц
|
27 |
+
document = ParsedDocument(...)
|
28 |
+
destructurer = Destructurer(
|
29 |
+
document=document,
|
30 |
+
strategy_name="fixed_size",
|
31 |
+
process_tables=True
|
32 |
+
)
|
33 |
+
entities = destructurer.destructure()
|
34 |
+
|
35 |
+
# Пример использования Builder
|
36 |
+
builder = Builder(document)
|
37 |
+
builder.configure({"chunking_strategy": "fixed_size"})
|
38 |
+
reconstructed_document = builder.build()
|
39 |
+
```
|
40 |
+
|
41 |
+
## Модули
|
42 |
+
|
43 |
+
### Core
|
44 |
+
Основные классы для работы с документами:
|
45 |
+
- **Destructurer**: Разбивает документ на чанки и другие сущности
|
46 |
+
- **Builder**: Собирает документ из чанков и связей
|
47 |
+
|
48 |
+
### Chunking
|
49 |
+
Различные стратегии разбиения документа на чанки:
|
50 |
+
- **FixedSizeChunkingStrategy**: Разбиение на чанки фиксированного размера
|
51 |
+
|
52 |
+
### Additors
|
53 |
+
Дополнительные обработчики для извлечения сущностей:
|
54 |
+
- **TablesProcessor**: Извлекает таблицы из документа и создает для них сущности
|
55 |
+
|
56 |
+
### Models
|
57 |
+
Модели данных для представления сущностей и связей:
|
58 |
+
- **LinkerEntity**: Базовый класс для всех сущностей и связей
|
59 |
+
- **DocumentAsEntity**: Представление документа как сущности
|
60 |
+
- **TableEntity**: Представление таблицы как сущности
|
lib/extractor/docs/architecture.puml
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@startuml "NTR Text Fragmentation Architecture"
|
2 |
+
|
3 |
+
' Использование CSS-стилей вместо skinparams
|
4 |
+
<style>
|
5 |
+
.concrete {
|
6 |
+
BackgroundColor #FFFFFF
|
7 |
+
BorderColor #795548
|
8 |
+
}
|
9 |
+
|
10 |
+
.models {
|
11 |
+
BackgroundColor #E8F5E9
|
12 |
+
BorderColor #4CAF50
|
13 |
+
}
|
14 |
+
|
15 |
+
.strategies {
|
16 |
+
BackgroundColor #E1F5FE
|
17 |
+
BorderColor #03A9F4
|
18 |
+
}
|
19 |
+
|
20 |
+
.core {
|
21 |
+
BackgroundColor #FFEBEE
|
22 |
+
BorderColor #F44336
|
23 |
+
}
|
24 |
+
|
25 |
+
note {
|
26 |
+
BackgroundColor #FFF9C4
|
27 |
+
BorderColor #FFD54F
|
28 |
+
FontSize 10
|
29 |
+
}
|
30 |
+
</style>
|
31 |
+
|
32 |
+
' Легенда
|
33 |
+
legend
|
34 |
+
<b>Легенда</b>
|
35 |
+
|
36 |
+
| Цвет | Описание |
|
37 |
+
| <back:#E8F5E9>Зеленый</back> | Модели данных |
|
38 |
+
| <back:#E1F5FE>Голубой</back> | Стратегии чанкинга |
|
39 |
+
| <back:#FFEBEE>Красный</back> | Основные компоненты |
|
40 |
+
endlegend
|
41 |
+
|
42 |
+
' Разделение на пакеты
|
43 |
+
|
44 |
+
package "models" {
|
45 |
+
class LinkerEntity <<models>> {
|
46 |
+
+ id: UUID
|
47 |
+
+ name: str
|
48 |
+
+ text: str
|
49 |
+
+ in_search_text: str | None
|
50 |
+
+ metadata: dict
|
51 |
+
+ source_id: UUID | None
|
52 |
+
+ target_id: UUID | None
|
53 |
+
+ number_in_relation: int | None
|
54 |
+
+ type: str
|
55 |
+
+ serialize(): LinkerEntity
|
56 |
+
+ {abstract} deserialize(data: LinkerEntity): Self
|
57 |
+
}
|
58 |
+
|
59 |
+
class Chunk <<models>> extends LinkerEntity {
|
60 |
+
+ chunk_index: int | None
|
61 |
+
}
|
62 |
+
|
63 |
+
class DocumentAsEntity <<models>> extends LinkerEntity {
|
64 |
+
}
|
65 |
+
|
66 |
+
note right of LinkerEntity
|
67 |
+
Базовая сущность для всех элементов системы.
|
68 |
+
in_search_text определяет текст, используемый
|
69 |
+
при поиске, если None - данная сущность не должна попасть
|
70 |
+
в поиск и используется только для вспомогательных целей.
|
71 |
+
end note
|
72 |
+
}
|
73 |
+
|
74 |
+
package "chunking_strategies" as chunking_strategies {
|
75 |
+
abstract class ChunkingStrategy <<abstract>> {
|
76 |
+
+ {abstract} chunk(document: ParsedDocument, doc_entity: DocumentAsEntity): list[LinkerEntity]
|
77 |
+
+ dechunk(entities: list[LinkerEntity], links: list[LinkerEntity]): str
|
78 |
+
}
|
79 |
+
|
80 |
+
package "specific_strategies" {
|
81 |
+
class FixedSizeChunkingStrategy <<strategies>> extends chunking_strategies.ChunkingStrategy {
|
82 |
+
+ chunk(document: ParsedDocument, doc_entity: DocumentAsEntity): list[LinkerEntity]
|
83 |
+
+ dechunk(entities: list[LinkerEntity], links: list[LinkerEntity]): str
|
84 |
+
}
|
85 |
+
|
86 |
+
class SentenceChunkingStrategy <<strategies>> extends chunking_strategies.ChunkingStrategy {
|
87 |
+
+ chunk(document: ParsedDocument, doc_entity: DocumentAsEntity): list[LinkerEntity]
|
88 |
+
+ dechunk(entities: list[LinkerEntity], links: list[LinkerEntity]): str
|
89 |
+
}
|
90 |
+
|
91 |
+
class NumberedItemsChunkingStrategy <<strategies>> extends chunking_strategies.ChunkingStrategy {
|
92 |
+
+ chunk(document: ParsedDocument, doc_entity: DocumentAsEntity): list[LinkerEntity]
|
93 |
+
+ dechunk(entities: list[LinkerEntity], links: list[LinkerEntity]): str
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
note right of ChunkingStrategy
|
98 |
+
Базовая реализация dechunk сортирует чанки по chunk_index.
|
99 |
+
Стратегии могут переопределить, если им нужна
|
100 |
+
специфическая логика сборки
|
101 |
+
end note
|
102 |
+
}
|
103 |
+
|
104 |
+
package "core" {
|
105 |
+
class Destructurer <<core>> {
|
106 |
+
+ __init__(document: ParsedDocument, strategy_name: str)
|
107 |
+
+ configure(strategy_name: str, **kwargs)
|
108 |
+
+ destructure(): list[LinkerEntity]
|
109 |
+
}
|
110 |
+
|
111 |
+
class InjectionBuilder <<core>> {
|
112 |
+
+ __init__(entities: list[LinkerEntity], config: dict)
|
113 |
+
+ register_strategy(doc_type: str, strategy: ChunkingStrategy)
|
114 |
+
+ build(filtered_entities: list[LinkerEntity]): str
|
115 |
+
- _group_chunks_by_document(chunks, links): dict
|
116 |
+
}
|
117 |
+
|
118 |
+
note right of Destructurer
|
119 |
+
Основной класс библиотеки, используется для разбиения
|
120 |
+
документа на чанки и вспомогательные сущности. В
|
121 |
+
полученной конфигурации содержатся in_search сущности
|
122 |
+
и множество вспомогательных сущностей. Предполагается,
|
123 |
+
что первые будут отфильтрованы векторным или иным поиском,
|
124 |
+
а вторые можно будет использовать для обогащения и сборки
|
125 |
+
итоговой инъекции в промпт.
|
126 |
+
end note
|
127 |
+
|
128 |
+
note right of InjectionBuilder
|
129 |
+
Класс-единая точка входа для сборки итоговой инъекции
|
130 |
+
в промпт. Принимает в себя все сущности и конфигурацию
|
131 |
+
в конструкторе, а в методе build принимает отфильтрованные
|
132 |
+
сущности. Может частично делегировать сборку стратегиям для
|
133 |
+
специфических ти��ов чанкинга.
|
134 |
+
end note
|
135 |
+
|
136 |
+
}
|
137 |
+
|
138 |
+
' Композиционные отношения
|
139 |
+
core.Destructurer --> chunking_strategies.ChunkingStrategy
|
140 |
+
core.InjectionBuilder --> chunking_strategies.ChunkingStrategy
|
141 |
+
|
142 |
+
' Отношения между компонентами
|
143 |
+
chunking_strategies.ChunkingStrategy ..> models
|
144 |
+
|
145 |
+
' Дополнительные отношения
|
146 |
+
core.InjectionBuilder ..> models.LinkerEntity
|
147 |
+
core.Destructurer ..> models.LinkerEntity
|
148 |
+
|
149 |
+
@enduml
|
lib/extractor/ntr_text_fragmentation/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль извлечения и сборки документов.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .core.destructurer import Destructurer
|
6 |
+
from .core.entity_repository import EntityRepository, InMemoryEntityRepository
|
7 |
+
from .core.injection_builder import InjectionBuilder
|
8 |
+
from .models import Chunk, DocumentAsEntity, LinkerEntity
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"Destructurer",
|
12 |
+
"InjectionBuilder",
|
13 |
+
"EntityRepository",
|
14 |
+
"InMemoryEntityRepository",
|
15 |
+
"LinkerEntity",
|
16 |
+
"Chunk",
|
17 |
+
"DocumentAsEntity",
|
18 |
+
"integrations",
|
19 |
+
]
|
lib/extractor/ntr_text_fragmentation/additors/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль для дополнительных обработчиков документа.
|
3 |
+
|
4 |
+
Содержит обработчики, которые извлекают дополнительные сущности из документа,
|
5 |
+
например, таблицы, изображения и т.д.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from .tables_processor import TablesProcessor
|
9 |
+
|
10 |
+
__all__ = ["TablesProcessor"]
|
lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .table_entity import TableEntity
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'TableEntity',
|
5 |
+
]
|
lib/extractor/ntr_text_fragmentation/additors/tables/table_entity.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
from uuid import UUID
|
4 |
+
|
5 |
+
from ...models import LinkerEntity
|
6 |
+
from ...models.linker_entity import register_entity
|
7 |
+
|
8 |
+
|
9 |
+
@register_entity
|
10 |
+
@dataclass
|
11 |
+
class TableEntity(LinkerEntity):
|
12 |
+
"""
|
13 |
+
Сущность таблицы из документа.
|
14 |
+
|
15 |
+
Расширяет основную сущность LinkerEntity, добавляя информацию о таблице.
|
16 |
+
"""
|
17 |
+
|
18 |
+
table_index: Optional[int] = None
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def deserialize(cls, entity: LinkerEntity) -> "TableEntity":
|
22 |
+
"""
|
23 |
+
Десериализует сущность из базового LinkerEntity.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
entity: Базовая сущность LinkerEntity
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Десериализованная сущность TableEntity
|
30 |
+
"""
|
31 |
+
if entity.type != cls.__name__:
|
32 |
+
raise ValueError(f"Неверный тип сущности: {entity.type}, ожидался {cls.__name__}")
|
33 |
+
|
34 |
+
# Извлекаем дополнительные поля из метаданных
|
35 |
+
metadata = entity.metadata or {}
|
36 |
+
table_index = metadata.get("table_index")
|
37 |
+
|
38 |
+
return cls(
|
39 |
+
id=entity.id if isinstance(entity.id, UUID) else UUID(entity.id),
|
40 |
+
name=entity.name,
|
41 |
+
text=entity.text,
|
42 |
+
in_search_text=entity.in_search_text,
|
43 |
+
metadata=entity.metadata,
|
44 |
+
source_id=entity.source_id,
|
45 |
+
target_id=entity.target_id,
|
46 |
+
number_in_relation=entity.number_in_relation,
|
47 |
+
type=entity.type,
|
48 |
+
table_index=table_index,
|
49 |
+
)
|
50 |
+
|
51 |
+
def serialize(self) -> LinkerEntity:
|
52 |
+
"""
|
53 |
+
Сериализует сущность в базовый LinkerEntity.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Сериализованная сущность LinkerEntity
|
57 |
+
"""
|
58 |
+
metadata = self.metadata or {}
|
59 |
+
|
60 |
+
# Добавляем дополнительные поля в метаданные
|
61 |
+
if self.table_index is not None:
|
62 |
+
metadata["table_index"] = self.table_index
|
63 |
+
|
64 |
+
return LinkerEntity(
|
65 |
+
id=self.id,
|
66 |
+
name=self.name,
|
67 |
+
text=self.text,
|
68 |
+
in_search_text=self.in_search_text,
|
69 |
+
metadata=metadata,
|
70 |
+
source_id=self.source_id,
|
71 |
+
target_id=self.target_id,
|
72 |
+
number_in_relation=self.number_in_relation,
|
73 |
+
type=self.__class__.__name__,
|
74 |
+
)
|
lib/extractor/ntr_text_fragmentation/additors/tables_processor.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Процессор таблиц из документа.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from uuid import uuid4
|
6 |
+
|
7 |
+
from ntr_fileparser import ParsedDocument
|
8 |
+
|
9 |
+
from ..models import LinkerEntity
|
10 |
+
from .tables import TableEntity
|
11 |
+
|
12 |
+
|
13 |
+
class TablesProcessor:
|
14 |
+
"""
|
15 |
+
Процессор для извлечения таблиц из документа и создания связанных сущностей.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
"""Инициализация процессора таблиц."""
|
20 |
+
pass
|
21 |
+
|
22 |
+
def process(
|
23 |
+
self,
|
24 |
+
document: ParsedDocument,
|
25 |
+
doc_entity: LinkerEntity,
|
26 |
+
) -> list[LinkerEntity]:
|
27 |
+
"""
|
28 |
+
Извлекает таблицы из документа и создает для них сущности.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
document: Документ для обработки
|
32 |
+
doc_entity: Сущность документа для связи с таблицами
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Список сущностей TableEntity и связей
|
36 |
+
"""
|
37 |
+
if not document.tables:
|
38 |
+
return []
|
39 |
+
|
40 |
+
table_entities = []
|
41 |
+
links = []
|
42 |
+
|
43 |
+
rows = '\n\n'.join([table.to_string() for table in document.tables]).split(
|
44 |
+
'\n\n'
|
45 |
+
)
|
46 |
+
|
47 |
+
# Обрабатываем каждую таблицу
|
48 |
+
for idx, row in enumerate(rows):
|
49 |
+
# Создаем сущность таблицы
|
50 |
+
table_entity = self._create_table_entity(
|
51 |
+
table_text=row,
|
52 |
+
table_index=idx,
|
53 |
+
doc_name=doc_entity.name,
|
54 |
+
)
|
55 |
+
|
56 |
+
# Создаем связь между документом и таблицей
|
57 |
+
link = self._create_link(doc_entity, table_entity, idx)
|
58 |
+
|
59 |
+
table_entities.append(table_entity)
|
60 |
+
links.append(link)
|
61 |
+
|
62 |
+
# Возвращаем список таблиц и связей
|
63 |
+
return table_entities + links
|
64 |
+
|
65 |
+
def _create_table_entity(
|
66 |
+
self,
|
67 |
+
table_text: str,
|
68 |
+
table_index: int,
|
69 |
+
doc_name: str,
|
70 |
+
) -> TableEntity:
|
71 |
+
"""
|
72 |
+
Создает сущность таблицы.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
table_text: Текст таблицы
|
76 |
+
table_index: Индекс таблицы в документе
|
77 |
+
doc_name: Имя документа
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Сущность TableEntity
|
81 |
+
"""
|
82 |
+
entity_name = f"{doc_name}_table_{table_index}"
|
83 |
+
|
84 |
+
return TableEntity(
|
85 |
+
id=uuid4(),
|
86 |
+
name=entity_name,
|
87 |
+
text=table_text,
|
88 |
+
in_search_text=table_text,
|
89 |
+
metadata={},
|
90 |
+
type=TableEntity.__name__,
|
91 |
+
table_index=table_index,
|
92 |
+
)
|
93 |
+
|
94 |
+
def _create_link(
|
95 |
+
self, doc_entity: LinkerEntity, table_entity: TableEntity, index: int
|
96 |
+
) -> LinkerEntity:
|
97 |
+
"""
|
98 |
+
Создает связь между документом и таблицей.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
doc_entity: Сущность документа
|
102 |
+
table_entity: Сущность таблицы
|
103 |
+
index: Индекс таблицы в документе
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
Объект связи LinkerEntity
|
107 |
+
"""
|
108 |
+
return LinkerEntity(
|
109 |
+
id=uuid4(),
|
110 |
+
name="document_to_table",
|
111 |
+
text="",
|
112 |
+
metadata={},
|
113 |
+
source_id=doc_entity.id,
|
114 |
+
target_id=table_entity.id,
|
115 |
+
number_in_relation=index,
|
116 |
+
type="Link",
|
117 |
+
)
|
lib/extractor/ntr_text_fragmentation/chunking/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль для определения стратегий чанкинга.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .chunking_strategy import ChunkingStrategy
|
6 |
+
from .specific_strategies import FixedSizeChunkingStrategy
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"ChunkingStrategy",
|
10 |
+
"FixedSizeChunkingStrategy",
|
11 |
+
]
|
lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Базовый класс для всех стратегий чанкинга.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
|
7 |
+
from ntr_fileparser import ParsedDocument
|
8 |
+
|
9 |
+
from ..models import Chunk, DocumentAsEntity, LinkerEntity
|
10 |
+
|
11 |
+
|
12 |
+
class ChunkingStrategy(ABC):
|
13 |
+
"""
|
14 |
+
Базовый абстрактный класс для всех стратегий чанкинга.
|
15 |
+
"""
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def chunk(self, document: ParsedDocument, doc_entity: DocumentAsEntity | None = None) -> list[LinkerEntity]:
|
19 |
+
"""
|
20 |
+
Разбивает документ на чанки в соответствии со стратегией.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
document: ParsedDocument для извлечения текста
|
24 |
+
doc_entity: Опциональная сущность документа для привязки чанков.
|
25 |
+
Если не указана, будет создана новая.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
list[LinkerEntity]: Список сущностей (документ, чанки, связи)
|
29 |
+
"""
|
30 |
+
raise NotImplementedError("Стратегия чанкинга должна реализовать метод chunk")
|
31 |
+
|
32 |
+
def dechunk(self, chunks: list[LinkerEntity], repository: 'EntityRepository' = None) -> str:
|
33 |
+
"""
|
34 |
+
Собирает документ из чанков и связей.
|
35 |
+
|
36 |
+
Базовая реализация сортирует чанки по chunk_index и объединяет их тексты,
|
37 |
+
сохраняя структуру параграфов и избегая дублирования текста.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
chunks: Список отфильтрованных чанков в случайном порядке
|
41 |
+
repository: Репозиторий сущностей для получения дополнительной информации (может быть None)
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Восстановленный текст документа
|
45 |
+
"""
|
46 |
+
import re
|
47 |
+
|
48 |
+
# Проверяем, есть ли чанки для сборки
|
49 |
+
if not chunks:
|
50 |
+
return ""
|
51 |
+
|
52 |
+
# Отбираем только чанки
|
53 |
+
valid_chunks = [c for c in chunks if isinstance(c, Chunk)]
|
54 |
+
|
55 |
+
# Сортируем чанки по chunk_index
|
56 |
+
sorted_chunks = sorted(valid_chunks, key=lambda c: c.chunk_index or 0)
|
57 |
+
|
58 |
+
# Собираем текст документа с учетом структуры параграфов
|
59 |
+
result_text = ""
|
60 |
+
|
61 |
+
for chunk in sorted_chunks:
|
62 |
+
# Получаем текст чанка (предпочитаем text, а не in_search_text для избежания дублирования)
|
63 |
+
chunk_text = chunk.text if hasattr(chunk, 'text') and chunk.text else ""
|
64 |
+
|
65 |
+
# Добавляем текст чанка с сохранением структуры параграфов
|
66 |
+
if result_text and result_text[-1] != "\n" and chunk_text and chunk_text[0] != "\n":
|
67 |
+
result_text += " "
|
68 |
+
result_text += chunk_text
|
69 |
+
|
70 |
+
# Пост-обработка результата
|
71 |
+
# Заменяем множественные переносы строк на одиночные
|
72 |
+
result_text = re.sub(r'\n+', '\n', result_text)
|
73 |
+
|
74 |
+
# Заменяем множественные пробелы на одиночные
|
75 |
+
result_text = re.sub(r' +', ' ', result_text)
|
76 |
+
|
77 |
+
# Убираем пробелы перед переносами строк
|
78 |
+
result_text = re.sub(r' +\n', '\n', result_text)
|
79 |
+
|
80 |
+
# Убираем пробелы после переносов строк
|
81 |
+
result_text = re.sub(r'\n +', '\n', result_text)
|
82 |
+
|
83 |
+
# Убираем лишние переносы строк в начале и конце текста
|
84 |
+
result_text = result_text.strip()
|
85 |
+
|
86 |
+
return result_text
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль содержащий конкретные стратегии для чанкинга текста.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .fixed_size import FixedSizeChunk
|
6 |
+
from .fixed_size_chunking import FixedSizeChunkingStrategy
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"FixedSizeChunk",
|
10 |
+
"FixedSizeChunkingStrategy",
|
11 |
+
]
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль реализующий стратегию чанкинга с фиксированным размером.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .fixed_size_chunk import FixedSizeChunk
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"FixedSizeChunk",
|
9 |
+
]
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Класс для представления чанка фиксированного размера.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
from ....models.chunk import Chunk
|
9 |
+
from ....models.linker_entity import LinkerEntity, register_entity
|
10 |
+
|
11 |
+
|
12 |
+
@register_entity
|
13 |
+
@dataclass
|
14 |
+
class FixedSizeChunk(Chunk):
|
15 |
+
"""
|
16 |
+
Представляет чанк фиксированного размера.
|
17 |
+
|
18 |
+
Расширяет базовый класс Chunk дополнительными полями, связанными с токенами
|
19 |
+
и перекрытиями, а также добавляет методы для сборки документа с учетом
|
20 |
+
границ предложений.
|
21 |
+
"""
|
22 |
+
|
23 |
+
token_count: int = 0
|
24 |
+
|
25 |
+
# Информация о границах предложений и нахлестах
|
26 |
+
left_sentence_part: str = "" # Часть предложения слева от text
|
27 |
+
right_sentence_part: str = "" # Часть предложения справа от text
|
28 |
+
overlap_left: str = "" # Нахлест слева (без учета границ предложений)
|
29 |
+
overlap_right: str = "" # Нахлест справа (без учета границ предложений)
|
30 |
+
|
31 |
+
# Метаданные для дополнительной информации
|
32 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
33 |
+
|
34 |
+
def __str__(self) -> str:
|
35 |
+
"""
|
36 |
+
Строковое представление чанка.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Строка с информацией о чанке.
|
40 |
+
"""
|
41 |
+
return (
|
42 |
+
f"FixedSizeChunk(id={self.id}, chunk_index={self.chunk_index}, "
|
43 |
+
f"tokens={self.token_count}, "
|
44 |
+
f"text='{self.text[:30]}{'...' if len(self.text) > 30 else ''}'"
|
45 |
+
f")"
|
46 |
+
)
|
47 |
+
|
48 |
+
def get_adjacent_chunks_indices(self, max_distance: int = 1) -> list[int]:
|
49 |
+
"""
|
50 |
+
Возвращает индексы соседних чанков в пределах указанного расстояния.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
max_distance: Максимальное расстояние от текущего чанка
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Список индексов соседних чанков
|
57 |
+
"""
|
58 |
+
indices = []
|
59 |
+
for i in range(1, max_distance + 1):
|
60 |
+
# Добавляем предыдущие чанки
|
61 |
+
if self.chunk_index - i >= 0:
|
62 |
+
indices.append(self.chunk_index - i)
|
63 |
+
# Добавляем следующие чанки
|
64 |
+
indices.append(self.chunk_index + i)
|
65 |
+
|
66 |
+
return sorted(indices)
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def deserialize(cls, entity: LinkerEntity) -> 'FixedSizeChunk':
|
70 |
+
"""
|
71 |
+
Десериализует FixedSizeChunk из объекта LinkerEntity.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
entity: Объект LinkerEntity для преобразования в FixedSizeChunk
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Десериализованный объект FixedSizeChunk
|
78 |
+
"""
|
79 |
+
metadata = entity.metadata or {}
|
80 |
+
|
81 |
+
# Извлекаем параметры из метаданных
|
82 |
+
# Сначала проверяем в метаданных под ключом _chunk_index
|
83 |
+
chunk_index = metadata.get('_chunk_index')
|
84 |
+
if chunk_index is None:
|
85 |
+
# Затем пробуем получить как атрибут объекта
|
86 |
+
chunk_index = getattr(entity, 'chunk_index', None)
|
87 |
+
if chunk_index is None:
|
88 |
+
# Если и там нет, пробуем обычный поиск по метаданным
|
89 |
+
chunk_index = metadata.get('chunk_index')
|
90 |
+
|
91 |
+
# Преобразуем к int, если значение найдено
|
92 |
+
if chunk_index is not None:
|
93 |
+
try:
|
94 |
+
chunk_index = int(chunk_index)
|
95 |
+
except (ValueError, TypeError):
|
96 |
+
chunk_index = None
|
97 |
+
|
98 |
+
start_token = metadata.get('start_token', 0)
|
99 |
+
end_token = metadata.get('end_token', 0)
|
100 |
+
token_count = metadata.get(
|
101 |
+
'_token_count', metadata.get('token_count', end_token - start_token + 1)
|
102 |
+
)
|
103 |
+
|
104 |
+
# Извлекаем параметры для границ предложений и нахлестов
|
105 |
+
# Сначала ищем в метаданных с префиксом _
|
106 |
+
left_sentence_part = metadata.get('_left_sentence_part')
|
107 |
+
if left_sentence_part is None:
|
108 |
+
# Затем пробуем получить как атрибут объекта
|
109 |
+
left_sentence_part = getattr(entity, 'left_sentence_part', '')
|
110 |
+
|
111 |
+
right_sentence_part = metadata.get('_right_sentence_part')
|
112 |
+
if right_sentence_part is None:
|
113 |
+
right_sentence_part = getattr(entity, 'right_sentence_part', '')
|
114 |
+
|
115 |
+
overlap_left = metadata.get('_overlap_left')
|
116 |
+
if overlap_left is None:
|
117 |
+
overlap_left = getattr(entity, 'overlap_left', '')
|
118 |
+
|
119 |
+
overlap_right = metadata.get('_overlap_right')
|
120 |
+
if overlap_right is None:
|
121 |
+
overlap_right = getattr(entity, 'overlap_right', '')
|
122 |
+
|
123 |
+
# Создаем чистые метаданные без служебных полей
|
124 |
+
clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')}
|
125 |
+
|
126 |
+
# Создаем и возвращаем новый экземпляр FixedSizeChunk
|
127 |
+
return cls(
|
128 |
+
id=entity.id,
|
129 |
+
name=entity.name,
|
130 |
+
text=entity.text,
|
131 |
+
in_search_text=entity.in_search_text,
|
132 |
+
metadata=clean_metadata,
|
133 |
+
source_id=entity.source_id,
|
134 |
+
target_id=entity.target_id,
|
135 |
+
number_in_relation=entity.number_in_relation,
|
136 |
+
chunk_index=chunk_index,
|
137 |
+
token_count=token_count,
|
138 |
+
left_sentence_part=left_sentence_part,
|
139 |
+
right_sentence_part=right_sentence_part,
|
140 |
+
overlap_left=overlap_left,
|
141 |
+
overlap_right=overlap_right,
|
142 |
+
type="FixedSizeChunk",
|
143 |
+
)
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Стратегия чанкинга фиксированного размера.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import re
|
6 |
+
from typing import NamedTuple, TypeVar
|
7 |
+
from uuid import uuid4
|
8 |
+
|
9 |
+
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
10 |
+
|
11 |
+
from ...chunking.chunking_strategy import ChunkingStrategy
|
12 |
+
from ...models import DocumentAsEntity, LinkerEntity
|
13 |
+
from .fixed_size.fixed_size_chunk import FixedSizeChunk
|
14 |
+
|
15 |
+
T = TypeVar('T')
|
16 |
+
|
17 |
+
|
18 |
+
class _FixedSizeChunkingStrategyParams(NamedTuple):
|
19 |
+
words_per_chunk: int = 50
|
20 |
+
overlap_words: int = 25
|
21 |
+
respect_sentence_boundaries: bool = True
|
22 |
+
|
23 |
+
|
24 |
+
class FixedSizeChunkingStrategy(ChunkingStrategy):
|
25 |
+
"""
|
26 |
+
Стратегия чанкинга, разбивающая текст на чанки фиксированного размера.
|
27 |
+
|
28 |
+
Преимущества:
|
29 |
+
- Простое и предсказуемое разбиение
|
30 |
+
- Равные по размеру чанки
|
31 |
+
|
32 |
+
Недостатки:
|
33 |
+
- Может разрезать предложения и абзацы в середине (компенсируется сборкой - как для модели поиска, так и для LLM)
|
34 |
+
- Не учитывает смысловую структуру текста
|
35 |
+
|
36 |
+
Особенности реализации:
|
37 |
+
- В поле `text` чанков хранится текст без нахлеста (для удобства сборки)
|
38 |
+
- В поле `in_search_text` хранится текст с нахлестом (для улучшения векторизации)
|
39 |
+
"""
|
40 |
+
|
41 |
+
name = "fixed_size"
|
42 |
+
description = (
|
43 |
+
"Стратегия чанкинга, разбивающая текст на чанки фиксированного размера."
|
44 |
+
)
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
words_per_chunk: int = 50,
|
49 |
+
overlap_words: int = 25,
|
50 |
+
respect_sentence_boundaries: bool = True,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Инициализация стратегии чанкинга с фиксированным размером.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
words_per_chunk: Количество слов в чанке
|
57 |
+
overlap_words: Количество слов перекрытия между чанками
|
58 |
+
respect_sentence_boundaries: Флаг учета границ предложений
|
59 |
+
"""
|
60 |
+
|
61 |
+
self.params = _FixedSizeChunkingStrategyParams(
|
62 |
+
words_per_chunk=words_per_chunk,
|
63 |
+
overlap_words=overlap_words,
|
64 |
+
respect_sentence_boundaries=respect_sentence_boundaries,
|
65 |
+
)
|
66 |
+
|
67 |
+
def chunk(
|
68 |
+
self,
|
69 |
+
document: ParsedDocument | str,
|
70 |
+
doc_entity: DocumentAsEntity | None = None,
|
71 |
+
) -> list[LinkerEntity]:
|
72 |
+
"""
|
73 |
+
Разбивает документ на чанки фиксированного размера.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
document: Документ для разбиения (ParsedDocument или текст)
|
77 |
+
doc_entity: Сущность документа (опционально)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Список LinkerEntity - чанки, связи и прочие сущности
|
81 |
+
"""
|
82 |
+
doc = self._prepare_document(document)
|
83 |
+
words = self._extract_words(doc)
|
84 |
+
|
85 |
+
# Если документ пустой, возвращаем пустой список
|
86 |
+
if not words:
|
87 |
+
return []
|
88 |
+
|
89 |
+
doc_entity = self._ensure_document_entity(doc, doc_entity)
|
90 |
+
doc_name = doc_entity.name
|
91 |
+
|
92 |
+
chunks = []
|
93 |
+
links = []
|
94 |
+
|
95 |
+
step = self._calculate_step()
|
96 |
+
total_words = len(words)
|
97 |
+
|
98 |
+
# Начинаем с первого слова и идем шагами (не полным размером чанка)
|
99 |
+
for i in range(0, total_words, step):
|
100 |
+
# Создаем обычный чанк
|
101 |
+
chunk_text = self._prepare_chunk_text(words, i, step)
|
102 |
+
in_search_text = self._prepare_chunk_text(
|
103 |
+
words, i, self.params.words_per_chunk
|
104 |
+
)
|
105 |
+
|
106 |
+
chunk = self._create_chunk(
|
107 |
+
chunk_text,
|
108 |
+
in_search_text,
|
109 |
+
i,
|
110 |
+
i + self.params.words_per_chunk,
|
111 |
+
len(chunks),
|
112 |
+
words,
|
113 |
+
total_words,
|
114 |
+
doc_name,
|
115 |
+
)
|
116 |
+
|
117 |
+
chunks.append(chunk)
|
118 |
+
links.append(self._create_link(doc_entity, chunk))
|
119 |
+
|
120 |
+
# Возвращаем все сущности
|
121 |
+
return [doc_entity] + chunks + links
|
122 |
+
|
123 |
+
def _find_nearest_sentence_boundary(
|
124 |
+
self, text: str, position: int
|
125 |
+
) -> tuple[int, str, str]:
|
126 |
+
"""
|
127 |
+
Находит ближайшую границу предложения к указанной позиции.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
text: Полный текст для поиска границ
|
131 |
+
position: Позиция, для которой ищем ближайшую границу
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
tuple из (позиция границы, левая часть текста, правая часть текста)
|
135 |
+
"""
|
136 |
+
# Регулярное выражение для поиска конца предложения
|
137 |
+
sentence_end_pattern = r'[.!?](?:\s|$)'
|
138 |
+
|
139 |
+
# Ищем все совпадения в тексте
|
140 |
+
matches = list(re.finditer(sentence_end_pattern, text))
|
141 |
+
|
142 |
+
if not matches:
|
143 |
+
# Если совпадений нет, возвращаем исходную позицию
|
144 |
+
return position, text[:position], text[position:]
|
145 |
+
|
146 |
+
# Находим ближайшую границу предложения
|
147 |
+
nearest_pos = position
|
148 |
+
min_distance = float('inf')
|
149 |
+
|
150 |
+
for match in matches:
|
151 |
+
end_pos = match.end()
|
152 |
+
distance = abs(end_pos - position)
|
153 |
+
|
154 |
+
if distance < min_distance:
|
155 |
+
min_distance = distance
|
156 |
+
nearest_pos = end_pos
|
157 |
+
|
158 |
+
# Возвращаем позицию и соответствующие части текста
|
159 |
+
return nearest_pos, text[:nearest_pos], text[nearest_pos:]
|
160 |
+
|
161 |
+
def _find_sentence_boundary(self, text: str, is_left_boundary: bool) -> str:
|
162 |
+
"""
|
163 |
+
Находит часть текста на границе предложения.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
text: Текст для обработки
|
167 |
+
is_left_boundary: True для левой границы, False для правой
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Часть предложения на границе
|
171 |
+
"""
|
172 |
+
# Регулярное выражение для поиска конца предложения
|
173 |
+
sentence_end_pattern = r'[.!?](?:\s|$)'
|
174 |
+
matches = list(re.finditer(sentence_end_pattern, text))
|
175 |
+
|
176 |
+
if not matches:
|
177 |
+
return text
|
178 |
+
|
179 |
+
if is_left_boundary:
|
180 |
+
# Для левой границы берем часть после последней границы предложения
|
181 |
+
last_match = matches[-1]
|
182 |
+
return text[last_match.end() :].strip()
|
183 |
+
else:
|
184 |
+
# Для правой границы берем часть до первой границы предложения
|
185 |
+
first_match = matches[0]
|
186 |
+
return text[: first_match.end()].strip()
|
187 |
+
|
188 |
+
def dechunk(
|
189 |
+
self,
|
190 |
+
filtered_chunks: list[LinkerEntity],
|
191 |
+
repository: 'EntityRepository' = None, # type: ignore
|
192 |
+
) -> str:
|
193 |
+
"""
|
194 |
+
Собирает документ из чанков и связей.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
filtered_chunks: Список отфильтрованных чанков
|
198 |
+
repository: Репозиторий сущностей для получения дополнительной информации (может быть None)
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
Восстановленный текст документа
|
202 |
+
"""
|
203 |
+
if not filtered_chunks:
|
204 |
+
return ""
|
205 |
+
|
206 |
+
# Проверяем тип и десериализуем FixedSizeChunk
|
207 |
+
chunks = []
|
208 |
+
for chunk in filtered_chunks:
|
209 |
+
if chunk.type == "FixedSizeChunk":
|
210 |
+
chunks.append(FixedSizeChunk.deserialize(chunk))
|
211 |
+
else:
|
212 |
+
chunks.append(chunk)
|
213 |
+
|
214 |
+
# Сортируем чанки по индексу
|
215 |
+
sorted_chunks = sorted(chunks, key=lambda c: c.chunk_index or 0)
|
216 |
+
|
217 |
+
# Инициализируем результирующий текст
|
218 |
+
result_text = ""
|
219 |
+
|
220 |
+
# Группируем последовательные чанки
|
221 |
+
current_group = []
|
222 |
+
groups = []
|
223 |
+
|
224 |
+
for i, chunk in enumerate(sorted_chunks):
|
225 |
+
current_index = chunk.chunk_index or 0
|
226 |
+
|
227 |
+
# Если первый чанк или продолжение последовательности
|
228 |
+
if i == 0 or current_index == (sorted_chunks[i - 1].chunk_index or 0) + 1:
|
229 |
+
current_group.append(chunk)
|
230 |
+
else:
|
231 |
+
# Закрываем текущую группу и начинаем новую
|
232 |
+
if current_group:
|
233 |
+
groups.append(current_group)
|
234 |
+
current_group = [chunk]
|
235 |
+
|
236 |
+
# Добавляем последнюю группу
|
237 |
+
if current_group:
|
238 |
+
groups.append(current_group)
|
239 |
+
|
240 |
+
# Обрабатываем каждую группу
|
241 |
+
for group_index, group in enumerate(groups):
|
242 |
+
# Добавляем многоточие между непоследовательными группами
|
243 |
+
if group_index > 0:
|
244 |
+
result_text += "\n\n...\n\n"
|
245 |
+
|
246 |
+
# Обрабатываем группу соседних чанков
|
247 |
+
group_text = ""
|
248 |
+
|
249 |
+
# Добавляем левую недостающую часть к первому чанку группы
|
250 |
+
first_chunk = group[0]
|
251 |
+
|
252 |
+
# До��авляем левую часть предложения к первому чанку группы
|
253 |
+
if (
|
254 |
+
hasattr(first_chunk, 'left_sentence_part')
|
255 |
+
and first_chunk.left_sentence_part
|
256 |
+
):
|
257 |
+
group_text += first_chunk.left_sentence_part
|
258 |
+
|
259 |
+
# Добавляем текст всех чанков группы
|
260 |
+
for i, chunk in enumerate(group):
|
261 |
+
current_text = chunk.text.strip() if hasattr(chunk, 'text') else ""
|
262 |
+
if not current_text:
|
263 |
+
continue
|
264 |
+
|
265 |
+
# Проверяем, нужно ли добавить пробел между предыдущим текстом и текущим чанком
|
266 |
+
if group_text:
|
267 |
+
# Если текущий чанк начинается с новой строки, не добавляем пробел
|
268 |
+
if current_text.startswith("\n"):
|
269 |
+
pass
|
270 |
+
# Если предыдущий текст заканчивается переносом строки, также не добавляем пробел
|
271 |
+
elif group_text.endswith("\n"):
|
272 |
+
pass
|
273 |
+
# Если предыдущий текст заканчивается знаком препинания без пробела, добавляем пробел
|
274 |
+
elif group_text.rstrip()[-1] not in [
|
275 |
+
"\n",
|
276 |
+
" ",
|
277 |
+
".",
|
278 |
+
",",
|
279 |
+
"!",
|
280 |
+
"?",
|
281 |
+
":",
|
282 |
+
";",
|
283 |
+
"-",
|
284 |
+
"–",
|
285 |
+
"—",
|
286 |
+
]:
|
287 |
+
group_text += " "
|
288 |
+
|
289 |
+
# Добавляем текст чанка
|
290 |
+
group_text += current_text
|
291 |
+
|
292 |
+
# Добавляем правую недостающую часть к последнему чанку группы
|
293 |
+
last_chunk = group[-1]
|
294 |
+
|
295 |
+
# Добавляем правую часть предложения к последнему чанку группы
|
296 |
+
if (
|
297 |
+
hasattr(last_chunk, 'right_sentence_part')
|
298 |
+
and last_chunk.right_sentence_part
|
299 |
+
):
|
300 |
+
right_part = last_chunk.right_sentence_part.strip()
|
301 |
+
if right_part:
|
302 |
+
# Проверяем нужен ли пробел перед правой частью
|
303 |
+
if (
|
304 |
+
group_text
|
305 |
+
and group_text[-1] not in ["\n", " "]
|
306 |
+
and right_part[0]
|
307 |
+
not in ["\n", " ", ".", ",", "!", "?", ":", ";", "-", "–", "—"]
|
308 |
+
):
|
309 |
+
group_text += " "
|
310 |
+
group_text += right_part
|
311 |
+
|
312 |
+
# Добавляем текст группы к результату
|
313 |
+
if (
|
314 |
+
result_text
|
315 |
+
and result_text[-1] not in ["\n", " "]
|
316 |
+
and group_text
|
317 |
+
and group_text[0] not in ["\n", " "]
|
318 |
+
):
|
319 |
+
result_text += " "
|
320 |
+
result_text += group_text
|
321 |
+
|
322 |
+
# Постобработка текста: удаляем лишние пробелы и символы переноса строк
|
323 |
+
|
324 |
+
# Заменяем множественные переносы строк на двойные (для разделения абзацев)
|
325 |
+
result_text = re.sub(r'\n{3,}', '\n\n', result_text)
|
326 |
+
|
327 |
+
# Заменяем множественные пробелы на одиночные
|
328 |
+
result_text = re.sub(r' +', ' ', result_text)
|
329 |
+
|
330 |
+
# Убираем пробелы перед знаками препинания
|
331 |
+
result_text = re.sub(r' ([.,!?:;)])', r'\1', result_text)
|
332 |
+
|
333 |
+
# Убираем пробелы перед переносами строк и после переносов строк
|
334 |
+
result_text = re.sub(r' +\n', '\n', result_text)
|
335 |
+
result_text = re.sub(r'\n +', '\n', result_text)
|
336 |
+
|
337 |
+
# Убираем лишние переносы строк и пробелы в начале и конце текста
|
338 |
+
result_text = result_text.strip()
|
339 |
+
|
340 |
+
return result_text
|
341 |
+
|
342 |
+
def _get_sorted_chunks(
|
343 |
+
self, chunks: list[LinkerEntity], links: list[LinkerEntity]
|
344 |
+
) -> list[LinkerEntity]:
|
345 |
+
"""
|
346 |
+
Получает отсортированные чанки на основе связей или поля chunk_index.
|
347 |
+
|
348 |
+
Args:
|
349 |
+
chunks: Список чанков для сортировки
|
350 |
+
links: Список связей для определения порядка
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
Отсортированные чанки
|
354 |
+
"""
|
355 |
+
# Сортируем чанки по порядку в связях
|
356 |
+
if links:
|
357 |
+
# Получаем словарь для быстрого доступа к чанкам по ID
|
358 |
+
chunk_dict = {c.id: c for c in chunks}
|
359 |
+
|
360 |
+
# Сортируем по порядку в связях
|
361 |
+
sorted_chunks = []
|
362 |
+
for link in sorted(links, key=lambda l: l.number_in_relation or 0):
|
363 |
+
if link.target_id in chunk_dict:
|
364 |
+
sorted_chunks.append(chunk_dict[link.target_id])
|
365 |
+
|
366 |
+
return sorted_chunks
|
367 |
+
|
368 |
+
# Если нет связей, сортируем по chunk_index
|
369 |
+
return sorted(chunks, key=lambda c: c.chunk_index or 0)
|
370 |
+
|
371 |
+
def _prepare_document(self, document: ParsedDocument | str) -> ParsedDocument:
|
372 |
+
"""
|
373 |
+
Обрабатывает входные данные и возвращает ParsedDocument.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
document: Документ (ParsedDocument или текст)
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
Обработанный документ типа ParsedDocument
|
380 |
+
"""
|
381 |
+
if isinstance(document, ParsedDocument):
|
382 |
+
return document
|
383 |
+
elif isinstance(document, str):
|
384 |
+
# Простая обработка текстового документа
|
385 |
+
return ParsedDocument(
|
386 |
+
paragraphs=[
|
387 |
+
ParsedTextBlock(text=paragraph)
|
388 |
+
for paragraph in document.split('\n')
|
389 |
+
]
|
390 |
+
)
|
391 |
+
|
392 |
+
def _extract_words(self, doc: ParsedDocument) -> list[str]:
|
393 |
+
"""
|
394 |
+
Извлекает все слова из документа.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
doc: Документ для извлечения слов
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Список слов документа
|
401 |
+
"""
|
402 |
+
words = []
|
403 |
+
for paragraph in doc.paragraphs:
|
404 |
+
# Добавляем слова из параграфа
|
405 |
+
paragraph_words = paragraph.text.split()
|
406 |
+
words.extend(paragraph_words)
|
407 |
+
# Добавляем маркер конца параграфа как отдельный элемент
|
408 |
+
words.append("\n")
|
409 |
+
return words
|
410 |
+
|
411 |
+
def _ensure_document_entity(
|
412 |
+
self,
|
413 |
+
doc: ParsedDocument,
|
414 |
+
doc_entity: LinkerEntity | None,
|
415 |
+
) -> LinkerEntity:
|
416 |
+
"""
|
417 |
+
Создает сущность документа, если не предоставлена.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
doc: Документ
|
421 |
+
doc_entity: Сущность документа (может быть None)
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
Сущность документа
|
425 |
+
"""
|
426 |
+
if doc_entity is None:
|
427 |
+
return LinkerEntity(
|
428 |
+
id=uuid4(),
|
429 |
+
name=doc.name,
|
430 |
+
text=doc.name,
|
431 |
+
metadata={"type": doc.type},
|
432 |
+
type="Document",
|
433 |
+
)
|
434 |
+
return doc_entity
|
435 |
+
|
436 |
+
def _calculate_step(self) -> int:
|
437 |
+
"""
|
438 |
+
Вычисляет шаг для создания чанков.
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
Размер шага между началами чанков
|
442 |
+
"""
|
443 |
+
return self.params.words_per_chunk - self.params.overlap_words
|
444 |
+
|
445 |
+
def _prepare_chunk_text(
|
446 |
+
self,
|
447 |
+
words: list[str],
|
448 |
+
start_idx: int,
|
449 |
+
length: int,
|
450 |
+
) -> str:
|
451 |
+
"""
|
452 |
+
Подготавливает текст чанка и текст для поиска.
|
453 |
+
|
454 |
+
Args:
|
455 |
+
words: Список слов документа
|
456 |
+
start_idx: Индекс начала чанка
|
457 |
+
end_idx: Длина текста в словах
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
Итоговый текст
|
461 |
+
"""
|
462 |
+
# Извлекаем текст чанка без нахлеста с сохранением структуры параграфов
|
463 |
+
end_idx = min(start_idx + length, len(words))
|
464 |
+
chunk_words = words[start_idx:end_idx]
|
465 |
+
chunk_text = ""
|
466 |
+
|
467 |
+
for word in chunk_words:
|
468 |
+
if word == "\n":
|
469 |
+
# Если это маркер конца параграфа, добавляем перенос строки
|
470 |
+
chunk_text += "\n"
|
471 |
+
else:
|
472 |
+
# Иначе добавляем слово с пробелом
|
473 |
+
if chunk_text and chunk_text[-1] != "\n":
|
474 |
+
chunk_text += " "
|
475 |
+
chunk_text += word
|
476 |
+
|
477 |
+
return chunk_text
|
478 |
+
|
479 |
+
def _create_chunk(
|
480 |
+
self,
|
481 |
+
chunk_text: str,
|
482 |
+
in_search_text: str,
|
483 |
+
start_idx: int,
|
484 |
+
end_idx: int,
|
485 |
+
chunk_index: int,
|
486 |
+
words: list[str],
|
487 |
+
total_words: int,
|
488 |
+
doc_name: str,
|
489 |
+
) -> FixedSizeChunk:
|
490 |
+
"""
|
491 |
+
Создает чанк фиксированного размера.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
chunk_text: Текст чанка без нахлеста
|
495 |
+
in_search_text: Текст чанка с нахлестом
|
496 |
+
start_idx: Индекс первого слова в чанке
|
497 |
+
end_idx: Индекс последнего слова в чанке
|
498 |
+
chunk_index: Индекс чанка в документе
|
499 |
+
words: Список всех слов документа
|
500 |
+
total_words: Общее количество слов в документе
|
501 |
+
doc_name: Имя документа
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
FixedSizeChunk: Созданный чанк
|
505 |
+
"""
|
506 |
+
# Определяем нахлесты без учета границ предложений
|
507 |
+
overlap_left = " ".join(
|
508 |
+
words[max(0, start_idx - self.params.overlap_words) : start_idx]
|
509 |
+
)
|
510 |
+
overlap_right = " ".join(
|
511 |
+
words[end_idx : min(total_words, end_idx + self.params.overlap_words)]
|
512 |
+
)
|
513 |
+
|
514 |
+
# Определяем границы предложений
|
515 |
+
left_sentence_part = ""
|
516 |
+
right_sentence_part = ""
|
517 |
+
|
518 |
+
if self.params.respect_sentence_boundaries:
|
519 |
+
# Находим ближайшую границу предложения слева
|
520 |
+
left_text = " ".join(
|
521 |
+
words[max(0, start_idx - self.params.overlap_words) : start_idx]
|
522 |
+
)
|
523 |
+
left_sentence_part = self._find_sentence_boundary(left_text, True)
|
524 |
+
|
525 |
+
# Находим ближайшую границу предложения справа
|
526 |
+
right_text = " ".join(
|
527 |
+
words[end_idx : min(total_words, end_idx + self.params.overlap_words)]
|
528 |
+
)
|
529 |
+
right_sentence_part = self._find_sentence_boundary(right_text, False)
|
530 |
+
|
531 |
+
# Создаем чанк с учетом границ предложений
|
532 |
+
return FixedSizeChunk(
|
533 |
+
id=uuid4(),
|
534 |
+
name=f"{doc_name}_chunk_{chunk_index}",
|
535 |
+
text=chunk_text,
|
536 |
+
chunk_index=chunk_index,
|
537 |
+
in_search_text=in_search_text,
|
538 |
+
token_count=end_idx - start_idx + 1,
|
539 |
+
left_sentence_part=left_sentence_part,
|
540 |
+
right_sentence_part=right_sentence_part,
|
541 |
+
overlap_left=overlap_left,
|
542 |
+
overlap_right=overlap_right,
|
543 |
+
metadata={},
|
544 |
+
type=FixedSizeChunk.__name__,
|
545 |
+
)
|
546 |
+
|
547 |
+
def _create_link(
|
548 |
+
self, doc_entity: LinkerEntity, chunk: LinkerEntity
|
549 |
+
) -> LinkerEntity:
|
550 |
+
"""
|
551 |
+
Создает связь между документом и чанком.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
doc_entity: Сущность документа
|
555 |
+
chunk: Сущность чанка
|
556 |
+
|
557 |
+
Returns:
|
558 |
+
Объект связи
|
559 |
+
"""
|
560 |
+
return LinkerEntity(
|
561 |
+
id=uuid4(),
|
562 |
+
name="document_to_chunk",
|
563 |
+
text="",
|
564 |
+
metadata={},
|
565 |
+
source_id=doc_entity.id,
|
566 |
+
target_id=chunk.id,
|
567 |
+
type="Link",
|
568 |
+
)
|
lib/extractor/ntr_text_fragmentation/core/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Основные классы для разбиения и сборки документов.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .destructurer import Destructurer
|
6 |
+
from .entity_repository import EntityRepository, InMemoryEntityRepository
|
7 |
+
from .injection_builder import InjectionBuilder
|
8 |
+
|
9 |
+
__all__ = ["Destructurer", "InjectionBuilder", "EntityRepository", "InMemoryEntityRepository"]
|
lib/extractor/ntr_text_fragmentation/core/destructurer.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль для деструктуризации документа.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from uuid import uuid4
|
6 |
+
|
7 |
+
# Внешние импорты
|
8 |
+
from ntr_fileparser import ParsedDocument
|
9 |
+
|
10 |
+
# Импорты из этой же библиотеки
|
11 |
+
from ..additors.tables_processor import TablesProcessor
|
12 |
+
from ..chunking.chunking_strategy import ChunkingStrategy
|
13 |
+
from ..chunking.specific_strategies.fixed_size_chunking import \
|
14 |
+
FixedSizeChunkingStrategy
|
15 |
+
from ..models import DocumentAsEntity, LinkerEntity
|
16 |
+
|
17 |
+
|
18 |
+
class Destructurer:
|
19 |
+
"""
|
20 |
+
Класс для подготовки документа для загрузки в базу данных.
|
21 |
+
Разбивает документ на чанки, создает связи между ними и
|
22 |
+
извлекает вспомогательные сущности.
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Доступные стратегии чанкинга
|
26 |
+
STRATEGIES: dict[str, type[ChunkingStrategy]] = {
|
27 |
+
"fixed_size": FixedSizeChunkingStrategy,
|
28 |
+
}
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
document: ParsedDocument,
|
33 |
+
strategy_name: str = "fixed_size",
|
34 |
+
process_tables: bool = True,
|
35 |
+
**kwargs,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Инициализация деструктуризатора.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
document: Документ для обработки
|
42 |
+
strategy_name: Имя стратегии
|
43 |
+
process_tables: Флаг обработки таблиц
|
44 |
+
**kwargs: Параметры для стратегии
|
45 |
+
"""
|
46 |
+
self.document = document
|
47 |
+
self.strategy: ChunkingStrategy | None = None
|
48 |
+
self.process_tables = process_tables
|
49 |
+
# Инициализируем процессор таблиц, если нужно
|
50 |
+
self.tables_processor = TablesProcessor() if process_tables else None
|
51 |
+
# Кеш для хранения созданных стратегий
|
52 |
+
self._strategy_cache: dict[str, ChunkingStrategy] = {}
|
53 |
+
|
54 |
+
# Конфигурируем стратегию
|
55 |
+
self.configure(strategy_name, **kwargs)
|
56 |
+
|
57 |
+
def configure(self, strategy_name: str = "fixed_size", **kwargs) -> None:
|
58 |
+
"""
|
59 |
+
Установка стратегии чанкинга.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
strategy_name: Имя стратегии
|
63 |
+
**kwargs: Параметры для стратегии
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
ValueError: Если указана неизвестная стратегия
|
67 |
+
"""
|
68 |
+
# Получаем класс стратегии из словаря доступных стратегий
|
69 |
+
if strategy_name not in self.STRATEGIES:
|
70 |
+
raise ValueError(f"Неизвестная стратегия: {strategy_name}")
|
71 |
+
|
72 |
+
# Создаем ключ кеша на основе имени стратегии и параметров
|
73 |
+
cache_key = f"{strategy_name}_{hash(frozenset(kwargs.items()))}"
|
74 |
+
|
75 |
+
# Проверяем, есть ли стратегия в кеше
|
76 |
+
if cache_key in self._strategy_cache:
|
77 |
+
self.strategy = self._strategy_cache[cache_key]
|
78 |
+
return
|
79 |
+
|
80 |
+
# Создаем экземпляр стратегии с переданными параметрами
|
81 |
+
strategy_class = self.STRATEGIES[strategy_name]
|
82 |
+
self.strategy = strategy_class(**kwargs)
|
83 |
+
|
84 |
+
# Сохраняем стратегию в кеше
|
85 |
+
self._strategy_cache[cache_key] = self.strategy
|
86 |
+
|
87 |
+
def destructure(self) -> list[LinkerEntity]:
|
88 |
+
"""
|
89 |
+
Основной метод деструктуризации.
|
90 |
+
Разбивает документ на чанки и создает связи.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
list[LinkerEntity]: список сущностей, включая связи
|
94 |
+
|
95 |
+
Raises:
|
96 |
+
RuntimeError: Если стратегия не была сконфигурирована
|
97 |
+
"""
|
98 |
+
# Проверяем, что стратегия сконфигурирована
|
99 |
+
if self.strategy is None:
|
100 |
+
raise RuntimeError("Стратегия не была сконфигурирована")
|
101 |
+
|
102 |
+
# Создаем сущность документа с метаданными
|
103 |
+
doc_entity = self._create_document_entity()
|
104 |
+
|
105 |
+
# Применяем стратегию чанкинга
|
106 |
+
entities = self.strategy.chunk(self.document, doc_entity)
|
107 |
+
|
108 |
+
# Обрабатываем таблицы, если это включено
|
109 |
+
if self.process_tables and self.tables_processor and self.document.tables:
|
110 |
+
table_entities = self.tables_processor.process(self.document, doc_entity)
|
111 |
+
entities.extend(table_entities)
|
112 |
+
|
113 |
+
# Сериализуем все сущности в простейшую форму LinkerEntity
|
114 |
+
serialized_entities = [entity.serialize() for entity in entities]
|
115 |
+
|
116 |
+
return serialized_entities
|
117 |
+
|
118 |
+
def _create_document_entity(self) -> DocumentAsEntity:
|
119 |
+
"""
|
120 |
+
Создает сущность документа с метаданными.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
DocumentAsEntity: сущность документа
|
124 |
+
"""
|
125 |
+
# Получаем имя документа или используем значение по умолчанию
|
126 |
+
doc_name = self.document.name or "Document"
|
127 |
+
|
128 |
+
# Создаем метаданные, включая информацию о стратегии чанкинга
|
129 |
+
metadata = {
|
130 |
+
"type": self.document.type,
|
131 |
+
"chunking_strategy": (
|
132 |
+
self.strategy.__class__.__name__ if self.strategy else "unknown"
|
133 |
+
),
|
134 |
+
}
|
135 |
+
|
136 |
+
# Создаем сущность документа
|
137 |
+
return DocumentAsEntity(
|
138 |
+
id=uuid4(),
|
139 |
+
name=doc_name,
|
140 |
+
text="",
|
141 |
+
metadata=metadata,
|
142 |
+
type="Document",
|
143 |
+
)
|
lib/extractor/ntr_text_fragmentation/core/entity_repository.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Интерфейс репозитория сущностей.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from collections import defaultdict
|
7 |
+
from typing import Iterable
|
8 |
+
from uuid import UUID
|
9 |
+
|
10 |
+
from ..models import Chunk, LinkerEntity
|
11 |
+
from ..models.document import DocumentAsEntity
|
12 |
+
|
13 |
+
|
14 |
+
class EntityRepository(ABC):
|
15 |
+
"""
|
16 |
+
Абстрактный интерфейс для доступа к хранилищу сущностей.
|
17 |
+
Позволяет InjectionBuilder получать нужные сущности независимо от их хранилища.
|
18 |
+
|
19 |
+
Этот интерфейс определяет только методы для получения сущностей.
|
20 |
+
Логика сохранения и изменения сущностей остается за пределами этого интерфейса
|
21 |
+
и должна быть реализована в конкретных классах, расширяющих данный интерфейс.
|
22 |
+
"""
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def get_entities_by_ids(self, entity_ids: Iterable[UUID]) -> list[LinkerEntity]:
|
26 |
+
"""
|
27 |
+
Получить сущности по списку идентификаторов.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
entity_ids: Список идентификаторов сущностей
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Список сущностей, соответствующих указанным идентификаторам
|
34 |
+
"""
|
35 |
+
pass
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> list[LinkerEntity]:
|
39 |
+
"""
|
40 |
+
Получить документы, которым принадлежат указанные чанки.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
chunk_ids: Список идентификаторов чанков
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Список документов, которым принадлежат указанные чанки
|
47 |
+
"""
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def get_neighboring_chunks(self,
|
52 |
+
chunk_ids: Iterable[UUID],
|
53 |
+
max_distance: int = 1) -> list[LinkerEntity]:
|
54 |
+
"""
|
55 |
+
Получить соседние чанки для указанных чанков.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
chunk_ids: Список идентификаторов чанков
|
59 |
+
max_distance: Максимальное расстояние до соседа
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
Список соседних чанков
|
63 |
+
"""
|
64 |
+
pass
|
65 |
+
|
66 |
+
@abstractmethod
|
67 |
+
def get_related_entities(self,
|
68 |
+
entity_ids: Iterable[UUID],
|
69 |
+
relation_name: str | None = None,
|
70 |
+
as_source: bool = False,
|
71 |
+
as_target: bool = False) -> list[LinkerEntity]:
|
72 |
+
"""
|
73 |
+
Получить сущности, связанные с указанными сущностями.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
entity_ids: Список идентификаторов сущностей
|
77 |
+
relation_name: Опциональное имя отношения для фильтрации
|
78 |
+
as_source: Если True, ищем связи, где указанные entity_ids являются
|
79 |
+
источниками (source_id)
|
80 |
+
as_target: Если True, ищем связи, где указанные entity_ids являются
|
81 |
+
целевыми (target_id)
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
Список связанных сущностей и связей
|
85 |
+
"""
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
class InMemoryEntityRepository(EntityRepository):
|
90 |
+
"""
|
91 |
+
Реализация EntityRepository, хранящая все сущности в памяти.
|
92 |
+
Обеспечивает обратную совместимость и используется для тестирования.
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, entities: list[LinkerEntity] | None = None):
|
96 |
+
"""
|
97 |
+
Инициализация репозитория с начальным списком сущностей.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
entities: Начальный список сущностей
|
101 |
+
"""
|
102 |
+
self.entities = entities or []
|
103 |
+
self._build_indices()
|
104 |
+
|
105 |
+
def _build_indices(self) -> None:
|
106 |
+
"""
|
107 |
+
Строит индексы для быстрого доступа к сущностям.
|
108 |
+
"""
|
109 |
+
self.entities_by_id = {e.id: e for e in self.entities}
|
110 |
+
self.chunks = [e for e in self.entities if isinstance(e, Chunk)]
|
111 |
+
self.docs = [e for e in self.entities if isinstance(e, DocumentAsEntity)]
|
112 |
+
|
113 |
+
# Индексы для быстрого поиска связей
|
114 |
+
self.doc_to_chunks = defaultdict(list)
|
115 |
+
self.chunk_to_doc = {}
|
116 |
+
self.entity_relations = defaultdict(list)
|
117 |
+
self.entity_targets = defaultdict(list)
|
118 |
+
|
119 |
+
# Заполняем индексы
|
120 |
+
for e in self.entities:
|
121 |
+
if e.is_link():
|
122 |
+
self.entity_relations[e.source_id].append(e)
|
123 |
+
self.entity_targets[e.target_id].append(e)
|
124 |
+
if e.name == "document_to_chunk":
|
125 |
+
self.doc_to_chunks[e.source_id].append(e.target_id)
|
126 |
+
self.chunk_to_doc[e.target_id] = e.source_id
|
127 |
+
if e.name == "document_to_table":
|
128 |
+
self.entity_relations
|
129 |
+
self.entity_targets[e.source_id].append(e.target_id)
|
130 |
+
|
131 |
+
# Этот метод не является частью интерфейса EntityRepository,
|
132 |
+
# но он полезен для тестирования и реализации обратной совместимости
|
133 |
+
def add_entities(self, entities: list[LinkerEntity]) -> None:
|
134 |
+
"""
|
135 |
+
Добавляет сущности в репозиторий.
|
136 |
+
|
137 |
+
Примечание: Этот метод не является частью интерфейса EntityRepository.
|
138 |
+
Он добавлен для удобства тестирования и обратной совместимости.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
entities: Список сущностей для добавления
|
142 |
+
"""
|
143 |
+
self.entities.extend(entities)
|
144 |
+
self._build_indices()
|
145 |
+
|
146 |
+
def get_entities_by_ids(self, entity_ids: Iterable[UUID]) -> list[LinkerEntity]:
|
147 |
+
result = [self.entities_by_id.get(eid) for eid in entity_ids if eid in self.entities_by_id]
|
148 |
+
return result
|
149 |
+
|
150 |
+
def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> list[LinkerEntity]:
|
151 |
+
result = []
|
152 |
+
for chunk_id in chunk_ids:
|
153 |
+
doc_id = self.chunk_to_doc.get(chunk_id)
|
154 |
+
if doc_id and doc_id in self.entities_by_id:
|
155 |
+
doc = self.entities_by_id[doc_id]
|
156 |
+
if doc not in result:
|
157 |
+
result.append(doc)
|
158 |
+
return result
|
159 |
+
|
160 |
+
def get_neighboring_chunks(self,
|
161 |
+
chunk_ids: Iterable[UUID],
|
162 |
+
max_distance: int = 1) -> list[LinkerEntity]:
|
163 |
+
result = []
|
164 |
+
chunk_indices = {}
|
165 |
+
|
166 |
+
# Сначала собираем индексы всех указанных чанков
|
167 |
+
for chunk_id in chunk_ids:
|
168 |
+
if chunk_id in self.entities_by_id:
|
169 |
+
chunk = self.entities_by_id[chunk_id]
|
170 |
+
if hasattr(chunk, 'chunk_index') and chunk.chunk_index is not None:
|
171 |
+
chunk_indices[chunk_id] = chunk.chunk_index
|
172 |
+
|
173 |
+
# Если нет чанков с индексами, возвращаем пустой список
|
174 |
+
if not chunk_indices:
|
175 |
+
return []
|
176 |
+
|
177 |
+
# Затем для каждого документа находим соседние чанки
|
178 |
+
for doc_id, doc_chunk_ids in self.doc_to_chunks.items():
|
179 |
+
# Проверяем, принадлежит ли хоть один из чанков этому документу
|
180 |
+
has_chunks = any(chunk_id in doc_chunk_ids for chunk_id in chunk_ids)
|
181 |
+
if not has_chunks:
|
182 |
+
continue
|
183 |
+
|
184 |
+
# Для каждого чанка в документе проверяем, является ли он соседом
|
185 |
+
for doc_chunk_id in doc_chunk_ids:
|
186 |
+
if doc_chunk_id in self.entities_by_id:
|
187 |
+
chunk = self.entities_by_id[doc_chunk_id]
|
188 |
+
|
189 |
+
# Если у чанка нет индекса, пропускаем его
|
190 |
+
if not hasattr(chunk, 'chunk_index') or chunk.chunk_index is None:
|
191 |
+
continue
|
192 |
+
|
193 |
+
# Проверяем, является ли чанк соседом какого-либо из исходных чанков
|
194 |
+
for orig_chunk_id, orig_index in chunk_indices.items():
|
195 |
+
if abs(chunk.chunk_index - orig_index) <= max_distance and doc_chunk_id not in chunk_ids:
|
196 |
+
result.append(chunk)
|
197 |
+
break
|
198 |
+
|
199 |
+
return result
|
200 |
+
|
201 |
+
def get_related_entities(self,
|
202 |
+
entity_ids: Iterable[UUID],
|
203 |
+
relation_name: str | None = None,
|
204 |
+
as_source: bool = False,
|
205 |
+
as_target: bool = False) -> list[LinkerEntity]:
|
206 |
+
"""
|
207 |
+
Получить сущности, связанные с указанными сущностями.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
entity_ids: Список идентификаторов сущностей
|
211 |
+
relation_name: Опциональное имя отношения для фильтрации
|
212 |
+
as_source: Если True, ищем связи, где указанные entity_ids являются источниками
|
213 |
+
as_target: Если True, ищем связи, где указанные entity_ids являются целями
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Список связанных сущностей и связей
|
217 |
+
"""
|
218 |
+
result = []
|
219 |
+
|
220 |
+
# Если не указано ни as_source, ни as_target, по умолчанию ищем связи,
|
221 |
+
# где указанные entity_ids являются источниками
|
222 |
+
if not as_source and not as_target:
|
223 |
+
as_source = True
|
224 |
+
|
225 |
+
for entity_id in entity_ids:
|
226 |
+
if as_source:
|
227 |
+
# Ищем связи, где сущность является источником
|
228 |
+
relations = self.entity_relations.get(entity_id, [])
|
229 |
+
|
230 |
+
for link in relations:
|
231 |
+
if relation_name is None or link.name == relation_name:
|
232 |
+
# Добавляем саму связь
|
233 |
+
if link not in result:
|
234 |
+
result.append(link)
|
235 |
+
|
236 |
+
# Добавляем целевую сущность
|
237 |
+
if link.target_id in self.entities_by_id:
|
238 |
+
related_entity = self.entities_by_id[link.target_id]
|
239 |
+
if related_entity not in result:
|
240 |
+
result.append(related_entity)
|
241 |
+
|
242 |
+
if as_target:
|
243 |
+
# Ищем связи, где сущность является целью
|
244 |
+
relations = self.entity_targets.get(entity_id, [])
|
245 |
+
|
246 |
+
for link in relations:
|
247 |
+
if relation_name is None or link.name == relation_name:
|
248 |
+
# Добавляем саму связь
|
249 |
+
if link not in result:
|
250 |
+
result.append(link)
|
251 |
+
|
252 |
+
# Добавляем исходную сущность
|
253 |
+
if link.source_id in self.entities_by_id:
|
254 |
+
related_entity = self.entities_by_id[link.source_id]
|
255 |
+
if related_entity not in result:
|
256 |
+
result.append(related_entity)
|
257 |
+
|
258 |
+
return result
|
lib/extractor/ntr_text_fragmentation/core/injection_builder.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Класс для сборки документа из чанков.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
from typing import Optional, Type
|
7 |
+
from uuid import UUID
|
8 |
+
|
9 |
+
from ..chunking.chunking_strategy import ChunkingStrategy
|
10 |
+
from ..models.chunk import Chunk
|
11 |
+
from ..models.linker_entity import LinkerEntity
|
12 |
+
from .entity_repository import EntityRepository, InMemoryEntityRepository
|
13 |
+
|
14 |
+
|
15 |
+
class InjectionBuilder:
|
16 |
+
"""
|
17 |
+
Класс для сборки документов из чанков и связей.
|
18 |
+
|
19 |
+
Отвечает за:
|
20 |
+
- Сборку текста из чанков с учетом порядка
|
21 |
+
- Ранжирование документов на основе весов чанков
|
22 |
+
- Добавление соседних чанков для улучшения сборки
|
23 |
+
- Сборку данных из таблиц и других сущностей
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
repository: EntityRepository | None = None,
|
29 |
+
entities: list[LinkerEntity] | None = None,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Инициализация сборщика инъекций.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
repository: Репозиторий сущностей (если None, используется InMemoryEntityRepository)
|
36 |
+
entities: Список всех сущностей (опционально, для обратной совместимости)
|
37 |
+
"""
|
38 |
+
# Для обратной совместимости
|
39 |
+
if repository is None and entities is not None:
|
40 |
+
repository = InMemoryEntityRepository(entities)
|
41 |
+
|
42 |
+
self.repository = repository or InMemoryEntityRepository()
|
43 |
+
self.strategy_map: dict[str, Type[ChunkingStrategy]] = {}
|
44 |
+
|
45 |
+
def register_strategy(
|
46 |
+
self,
|
47 |
+
doc_type: str,
|
48 |
+
strategy: Type[ChunkingStrategy],
|
49 |
+
) -> None:
|
50 |
+
"""
|
51 |
+
Регистрирует стратегию для определенного типа документа.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
doc_type: Тип документа
|
55 |
+
strategy: Стратегия чанкинга
|
56 |
+
"""
|
57 |
+
self.strategy_map[doc_type] = strategy
|
58 |
+
|
59 |
+
def build(
|
60 |
+
self,
|
61 |
+
filtered_entities: list[LinkerEntity] | list[UUID],
|
62 |
+
chunk_scores: dict[str, float] | None = None,
|
63 |
+
include_tables: bool = True,
|
64 |
+
max_documents: Optional[int] = None,
|
65 |
+
) -> str:
|
66 |
+
"""
|
67 |
+
Собирает текст из всех документов, связанных с предоставленными чанками.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
filtered_entities: Список чанков или их идентификаторов
|
71 |
+
chunk_scores: Словарь весов чанков {chunk_id: score}
|
72 |
+
include_tables: Флаг для включения таблиц в результат
|
73 |
+
max_documents: Максимальное количество документов (None = все)
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Собранный текст со всеми документами
|
77 |
+
"""
|
78 |
+
# Преобразуем входные данные в список идентификаторов
|
79 |
+
entity_ids = [
|
80 |
+
entity.id if isinstance(entity, LinkerEntity) else entity
|
81 |
+
for entity in filtered_entities
|
82 |
+
]
|
83 |
+
|
84 |
+
print(f"entity_ids: {entity_ids[:3]}...{entity_ids[-3:]}")
|
85 |
+
|
86 |
+
if not entity_ids:
|
87 |
+
return ""
|
88 |
+
|
89 |
+
# Получаем сущности по их идентификаторам
|
90 |
+
entities = self.repository.get_entities_by_ids(entity_ids)
|
91 |
+
|
92 |
+
print(f"entities: {entities[:3]}...{entities[-3:]}")
|
93 |
+
|
94 |
+
# Десериализуем сущности в их специализированные типы
|
95 |
+
deserialized_entities = []
|
96 |
+
for entity in entities:
|
97 |
+
# Используем статический метод десериализации
|
98 |
+
deserialized_entity = LinkerEntity.deserialize(entity)
|
99 |
+
deserialized_entities.append(deserialized_entity)
|
100 |
+
|
101 |
+
print(f"deserialized_entities: {deserialized_entities[:3]}...{deserialized_entities[-3:]}")
|
102 |
+
|
103 |
+
# Фильтруем сущности на чанки и таблицы
|
104 |
+
chunks = [e for e in deserialized_entities if "Chunk" in e.type]
|
105 |
+
tables = [e for e in deserialized_entities if "Table" in e.type]
|
106 |
+
|
107 |
+
# Группируем таблицы по документам
|
108 |
+
table_ids = {table.id for table in tables}
|
109 |
+
doc_tables = self._group_tables_by_document(table_ids)
|
110 |
+
|
111 |
+
if not chunks and not tables:
|
112 |
+
return ""
|
113 |
+
|
114 |
+
# Получаем идентификаторы чанков
|
115 |
+
chunk_ids = [chunk.id for chunk in chunks]
|
116 |
+
|
117 |
+
# Получаем связи для чанков (чанки являются целями связей)
|
118 |
+
links = self.repository.get_related_entities(
|
119 |
+
chunk_ids,
|
120 |
+
relation_name="document_to_chunk",
|
121 |
+
as_target=True,
|
122 |
+
)
|
123 |
+
|
124 |
+
print(f"links: {links[:3]}...{links[-3:]}")
|
125 |
+
|
126 |
+
# Группируем чанки по документам
|
127 |
+
doc_chunks = self._group_chunks_by_document(chunks, links)
|
128 |
+
|
129 |
+
print(f"doc_chunks: {doc_chunks}")
|
130 |
+
|
131 |
+
# Получаем все документы для чанков и таблиц
|
132 |
+
doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys())
|
133 |
+
docs = self.repository.get_entities_by_ids(doc_ids)
|
134 |
+
|
135 |
+
# Десериализуем документы
|
136 |
+
deserialized_docs = []
|
137 |
+
for doc in docs:
|
138 |
+
deserialized_doc = LinkerEntity.deserialize(doc)
|
139 |
+
deserialized_docs.append(deserialized_doc)
|
140 |
+
|
141 |
+
print(f"deserialized_docs: {deserialized_docs[:3]}...{deserialized_docs[-3:]}")
|
142 |
+
|
143 |
+
# Вычисляем веса документов на основе весов чанков
|
144 |
+
doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores)
|
145 |
+
|
146 |
+
# Сортируем документы по весам (по убыванию)
|
147 |
+
sorted_docs = sorted(
|
148 |
+
deserialized_docs,
|
149 |
+
key=lambda d: doc_scores.get(str(d.id), 0.0),
|
150 |
+
reverse=True
|
151 |
+
)
|
152 |
+
|
153 |
+
print(f"sorted_docs: {sorted_docs[:3]}...{sorted_docs[-3:]}")
|
154 |
+
|
155 |
+
# Ограничиваем количество документов, если указано
|
156 |
+
if max_documents:
|
157 |
+
sorted_docs = sorted_docs[:max_documents]
|
158 |
+
|
159 |
+
print(f"sorted_docs: {sorted_docs[:3]}...{sorted_docs[-3:]}")
|
160 |
+
|
161 |
+
# Собираем текст для каждого документа
|
162 |
+
result_parts = []
|
163 |
+
for doc in sorted_docs:
|
164 |
+
doc_text = self._build_document_text(
|
165 |
+
doc,
|
166 |
+
doc_chunks.get(doc.id, []),
|
167 |
+
doc_tables.get(doc.id, []),
|
168 |
+
include_tables
|
169 |
+
)
|
170 |
+
if doc_text:
|
171 |
+
result_parts.append(doc_text)
|
172 |
+
|
173 |
+
# Объединяем результаты
|
174 |
+
return "\n\n".join(result_parts)
|
175 |
+
|
176 |
+
def _build_document_text(
|
177 |
+
self,
|
178 |
+
doc: LinkerEntity,
|
179 |
+
chunks: list[LinkerEntity],
|
180 |
+
tables: list[LinkerEntity],
|
181 |
+
include_tables: bool
|
182 |
+
) -> str:
|
183 |
+
"""
|
184 |
+
Собирает текст документа из чанков и таблиц.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
doc: Сущность документа
|
188 |
+
chunks: Список чанков документа
|
189 |
+
tables: Список таблиц документа
|
190 |
+
include_tables: Флаг для включения таблиц
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
Собранный текст документа
|
194 |
+
"""
|
195 |
+
# Получаем стратегию чанкинга
|
196 |
+
strategy_name = doc.metadata.get("chunking_strategy", "fixed_size")
|
197 |
+
strategy = self._get_strategy_instance(strategy_name)
|
198 |
+
|
199 |
+
# Собираем текст из чанков
|
200 |
+
chunks_text = strategy.dechunk(chunks, self.repository) if chunks else ""
|
201 |
+
|
202 |
+
# Собираем текст из таблиц, если нужно
|
203 |
+
tables_text = ""
|
204 |
+
if include_tables and tables:
|
205 |
+
# Сортируем таблицы по индексу, если он есть
|
206 |
+
sorted_tables = sorted(
|
207 |
+
tables,
|
208 |
+
key=lambda t: t.metadata.get("table_index", 0) if t.metadata else 0
|
209 |
+
)
|
210 |
+
|
211 |
+
# Собираем текст таблиц
|
212 |
+
tables_text = "\n\n".join(table.text for table in sorted_tables if hasattr(table, 'text'))
|
213 |
+
|
214 |
+
# Формируем результат
|
215 |
+
result = f"[Источник] - {doc.name}\n"
|
216 |
+
if chunks_text:
|
217 |
+
result += chunks_text
|
218 |
+
if tables_text:
|
219 |
+
if chunks_text:
|
220 |
+
result += "\n\n"
|
221 |
+
result += tables_text
|
222 |
+
|
223 |
+
return result
|
224 |
+
|
225 |
+
def _group_chunks_by_document(
|
226 |
+
self,
|
227 |
+
chunks: list[LinkerEntity],
|
228 |
+
links: list[LinkerEntity]
|
229 |
+
) -> dict[UUID, list[LinkerEntity]]:
|
230 |
+
"""
|
231 |
+
Группирует чанки по документам.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
chunks: Список чанков
|
235 |
+
links: Список связей между документами и чанками
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
Словарь {doc_id: [chunks]}
|
239 |
+
"""
|
240 |
+
result = defaultdict(list)
|
241 |
+
|
242 |
+
# Создаем словарь для быстрого доступа к чанкам по ID
|
243 |
+
chunk_dict = {chunk.id: chunk for chunk in chunks}
|
244 |
+
|
245 |
+
# Группируем чанки по документам на основе связей
|
246 |
+
for link in links:
|
247 |
+
if link.target_id in chunk_dict and link.source_id:
|
248 |
+
result[link.source_id].append(chunk_dict[link.target_id])
|
249 |
+
|
250 |
+
return result
|
251 |
+
|
252 |
+
def _group_tables_by_document(
|
253 |
+
self,
|
254 |
+
table_ids: set[UUID]
|
255 |
+
) -> dict[UUID, list[LinkerEntity]]:
|
256 |
+
"""
|
257 |
+
Группирует таблицы по документам.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
table_ids: Множество идентификаторов таблиц
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
Словарь {doc_id: [tables]}
|
264 |
+
"""
|
265 |
+
result = defaultdict(list)
|
266 |
+
|
267 |
+
table_ids = [str(table_id) for table_id in table_ids]
|
268 |
+
|
269 |
+
# Получаем связи для таблиц (таблицы являются целями связей)
|
270 |
+
if not table_ids:
|
271 |
+
return result
|
272 |
+
|
273 |
+
links = self.repository.get_related_entities(
|
274 |
+
table_ids,
|
275 |
+
relation_name="document_to_table",
|
276 |
+
as_target=True,
|
277 |
+
)
|
278 |
+
|
279 |
+
# Получаем сами таблицы
|
280 |
+
tables = self.repository.get_entities_by_ids(table_ids)
|
281 |
+
|
282 |
+
# Десериализуем таблицы
|
283 |
+
deserialized_tables = []
|
284 |
+
for table in tables:
|
285 |
+
deserialized_table = LinkerEntity.deserialize(table)
|
286 |
+
deserialized_tables.append(deserialized_table)
|
287 |
+
|
288 |
+
# Создаем словарь для быстрого доступа к таблицам по ID
|
289 |
+
table_dict = {str(table.id): table for table in deserialized_tables}
|
290 |
+
|
291 |
+
# Группируем таблицы по документам на основе связей
|
292 |
+
for link in links:
|
293 |
+
if link.target_id in table_dict and link.source_id:
|
294 |
+
result[link.source_id].append(table_dict[link.target_id])
|
295 |
+
|
296 |
+
return result
|
297 |
+
|
298 |
+
def _calculate_document_scores(
|
299 |
+
self,
|
300 |
+
doc_chunks: dict[UUID, list[LinkerEntity]],
|
301 |
+
chunk_scores: Optional[dict[str, float]]
|
302 |
+
) -> dict[str, float]:
|
303 |
+
"""
|
304 |
+
Вычисляет веса документов на основе весов чанков.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
doc_chunks: Словарь {doc_id: [chunks]}
|
308 |
+
chunk_scores: Словарь весов чанков {chunk_id: score}
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Словарь весов документов {doc_id: score}
|
312 |
+
"""
|
313 |
+
if not chunk_scores:
|
314 |
+
return {str(doc_id): 1.0 for doc_id in doc_chunks.keys()}
|
315 |
+
|
316 |
+
result = {}
|
317 |
+
for doc_id, chunks in doc_chunks.items():
|
318 |
+
# Берем максимальный вес среди чанков документа
|
319 |
+
chunk_weights = [chunk_scores.get(str(c.id), 0.0) for c in chunks]
|
320 |
+
result[str(doc_id)] = max(chunk_weights) if chunk_weights else 0.0
|
321 |
+
|
322 |
+
return result
|
323 |
+
|
324 |
+
def add_neighboring_chunks(
|
325 |
+
self, entities: list[LinkerEntity] | list[UUID], max_distance: int = 1
|
326 |
+
) -> list[LinkerEntity]:
|
327 |
+
"""
|
328 |
+
Добавляет соседние чанки к отфильтрованному списку чанков.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
entities: Список сущностей или их идентификаторов
|
332 |
+
max_distance: Максимальное расстояние для поиска соседей
|
333 |
+
|
334 |
+
Returns:
|
335 |
+
Расширенный список сущностей
|
336 |
+
"""
|
337 |
+
# Преобразуем входные данные в список идентификаторов
|
338 |
+
entity_ids = [
|
339 |
+
entity.id if isinstance(entity, LinkerEntity) else entity
|
340 |
+
for entity in entities
|
341 |
+
]
|
342 |
+
|
343 |
+
if not entity_ids:
|
344 |
+
return []
|
345 |
+
|
346 |
+
# Получаем исходные сущности
|
347 |
+
original_entities = self.repository.get_entities_by_ids(entity_ids)
|
348 |
+
|
349 |
+
# Фильтруем только чанки
|
350 |
+
chunk_entities = [e for e in original_entities if isinstance(e, Chunk)]
|
351 |
+
|
352 |
+
if not chunk_entities:
|
353 |
+
return original_entities
|
354 |
+
|
355 |
+
# Получаем идентификаторы чанков
|
356 |
+
chunk_ids = [chunk.id for chunk in chunk_entities]
|
357 |
+
|
358 |
+
# Получаем соседние чанки
|
359 |
+
neighboring_chunks = self.repository.get_neighboring_chunks(
|
360 |
+
chunk_ids, max_distance
|
361 |
+
)
|
362 |
+
|
363 |
+
# Объединяем исходные сущности с соседними чанками
|
364 |
+
result = list(original_entities)
|
365 |
+
for chunk in neighboring_chunks:
|
366 |
+
if chunk not in result:
|
367 |
+
result.append(chunk)
|
368 |
+
|
369 |
+
# Получаем документы и связи для всех чанков
|
370 |
+
all_chunk_ids = [chunk.id for chunk in result if isinstance(chunk, Chunk)]
|
371 |
+
|
372 |
+
docs = self.repository.get_document_for_chunks(all_chunk_ids)
|
373 |
+
links = self.repository.get_related_entities(
|
374 |
+
all_chunk_ids, relation_name="document_to_chunk", as_target=True
|
375 |
+
)
|
376 |
+
|
377 |
+
# Добавляем документы и связи в результат
|
378 |
+
for doc in docs:
|
379 |
+
if doc not in result:
|
380 |
+
result.append(doc)
|
381 |
+
|
382 |
+
for link in links:
|
383 |
+
if link not in result:
|
384 |
+
result.append(link)
|
385 |
+
|
386 |
+
return result
|
387 |
+
|
388 |
+
def _get_strategy_instance(self, strategy_name: str) -> ChunkingStrategy:
|
389 |
+
"""
|
390 |
+
Создает экземпляр стратегии чанкинга по имени.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
strategy_name: Имя стратегии
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
Экземпляр соответствующей стратегии
|
397 |
+
"""
|
398 |
+
# Используем словарь для маппинга имен стратегий на их классы
|
399 |
+
strategies = {
|
400 |
+
"fixed_size": "..chunking.specific_strategies.fixed_size_chunking.FixedSizeChunkingStrategy",
|
401 |
+
}
|
402 |
+
|
403 |
+
# Если стратегия зарегистрирована в self.strategy_map, используем её
|
404 |
+
if strategy_name in self.strategy_map:
|
405 |
+
return self.strategy_map[strategy_name]()
|
406 |
+
|
407 |
+
# Если стратегия известна, импортируем и инициализируем её
|
408 |
+
if strategy_name in strategies:
|
409 |
+
import importlib
|
410 |
+
|
411 |
+
module_path, class_name = strategies[strategy_name].rsplit(".", 1)
|
412 |
+
try:
|
413 |
+
# Конвертируем относительный путь в абсолютный
|
414 |
+
abs_module_path = f"ntr_text_fragmentation{module_path[2:]}"
|
415 |
+
module = importlib.import_module(abs_module_path)
|
416 |
+
strategy_class = getattr(module, class_name)
|
417 |
+
return strategy_class()
|
418 |
+
except (ImportError, AttributeError) as e:
|
419 |
+
# Если импорт не удался, используем стратегию по умолчанию
|
420 |
+
from ..chunking.specific_strategies.fixed_size_chunking import \
|
421 |
+
FixedSizeChunkingStrategy
|
422 |
+
|
423 |
+
return FixedSizeChunkingStrategy()
|
424 |
+
|
425 |
+
# По умолчанию используем стратегию с фиксированным размером
|
426 |
+
from ..chunking.specific_strategies.fixed_size_chunking import \
|
427 |
+
FixedSizeChunkingStrategy
|
428 |
+
|
429 |
+
return FixedSizeChunkingStrategy()
|
lib/extractor/ntr_text_fragmentation/integrations/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль интеграций с внешними хранилищами данных и ORM системами.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .sqlalchemy_repository import SQLAlchemyEntityRepository
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"SQLAlchemyEntityRepository",
|
9 |
+
]
|
lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy_repository.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Реализация EntityRepository для работы с SQLAlchemy.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from abc import abstractmethod
|
6 |
+
from typing import Any, Iterable, List, Optional, Type
|
7 |
+
from uuid import UUID
|
8 |
+
|
9 |
+
from sqlalchemy import and_, select
|
10 |
+
from sqlalchemy.ext.declarative import declarative_base
|
11 |
+
from sqlalchemy.orm import Session
|
12 |
+
|
13 |
+
from ..core.entity_repository import EntityRepository
|
14 |
+
from ..models import Chunk, LinkerEntity
|
15 |
+
|
16 |
+
Base = declarative_base()
|
17 |
+
|
18 |
+
|
19 |
+
class SQLAlchemyEntityRepository(EntityRepository):
|
20 |
+
"""
|
21 |
+
Реализация EntityRepository для работы с базой данных через SQLAlchemy.
|
22 |
+
|
23 |
+
Эта реализация предполагает, что таблицы для хранения сущностей уже созданы
|
24 |
+
в базе данных и соответствуют определенной структуре.
|
25 |
+
|
26 |
+
Вы можете наследоваться от этого класса и определить свою структуру моделей,
|
27 |
+
переопределив абстрактные методы.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, db: Session):
|
31 |
+
"""
|
32 |
+
Инициализирует репозиторий с указанной сессией SQLAlchemy.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
db: Сессия SQLAlchemy для работы с базой данных
|
36 |
+
"""
|
37 |
+
self.db = db
|
38 |
+
|
39 |
+
@abstractmethod
|
40 |
+
def _entity_model_class(self) -> Type['Base']:
|
41 |
+
"""
|
42 |
+
Возвращает класс модели SQLAlchemy для сущностей.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Класс модели SQLAlchemy для сущностей
|
46 |
+
"""
|
47 |
+
pass
|
48 |
+
|
49 |
+
@abstractmethod
|
50 |
+
def _map_db_entity_to_linker_entity(self, db_entity: Any) -> LinkerEntity:
|
51 |
+
"""
|
52 |
+
Преобразует сущность из базы данных в LinkerEntity.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
db_entity: Сущность из базы данных
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Сущность LinkerEntity
|
59 |
+
"""
|
60 |
+
pass
|
61 |
+
|
62 |
+
def get_entities_by_ids(self, entity_ids: Iterable[UUID]) -> List[LinkerEntity]:
|
63 |
+
"""
|
64 |
+
Получить сущности по списку идентификаторов.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
entity_ids: Список идентификаторов сущностей
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Список сущностей, соответствующих указанным идентификаторам
|
71 |
+
"""
|
72 |
+
if not entity_ids:
|
73 |
+
return []
|
74 |
+
|
75 |
+
with self.db() as session:
|
76 |
+
entity_model = self._entity_model_class()
|
77 |
+
db_entities = session.execute(
|
78 |
+
select(entity_model).where(entity_model.uuid.in_(list(entity_ids)))
|
79 |
+
).scalars().all()
|
80 |
+
print(f"db_entities: {db_entities[:3]}...{db_entities[-3:]}")
|
81 |
+
|
82 |
+
mapped_entities = [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
|
83 |
+
print(f"mapped_entities: {mapped_entities[:3]}...{mapped_entities[-3:]}")
|
84 |
+
return mapped_entities
|
85 |
+
|
86 |
+
def get_document_for_chunks(self, chunk_ids: Iterable[UUID]) -> List[LinkerEntity]:
|
87 |
+
"""
|
88 |
+
Получить документы, которым принадлежат указанные чанки.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
chunk_ids: Список идентификаторов чанков
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
Список документов, которым принадлежат указанные чанки
|
95 |
+
"""
|
96 |
+
if not chunk_ids:
|
97 |
+
return []
|
98 |
+
|
99 |
+
with self.db() as session:
|
100 |
+
entity_model = self._entity_model_class()
|
101 |
+
|
102 |
+
string_ids = [str(id) for id in chunk_ids]
|
103 |
+
|
104 |
+
# Получаем все сущности-связи между документами и чанками
|
105 |
+
links = session.execute(
|
106 |
+
select(entity_model).where(
|
107 |
+
and_(
|
108 |
+
entity_model.target_id.in_(string_ids),
|
109 |
+
entity_model.name == "document_to_chunk",
|
110 |
+
entity_model.target_id.isnot(None) # Проверяем, что это связь
|
111 |
+
)
|
112 |
+
)
|
113 |
+
).scalars().all()
|
114 |
+
|
115 |
+
if not links:
|
116 |
+
return []
|
117 |
+
|
118 |
+
# Извлекаем ID документов
|
119 |
+
doc_ids = [link.source_id for link in links]
|
120 |
+
|
121 |
+
# Получаем документы по их ID
|
122 |
+
documents = session.execute(
|
123 |
+
select(entity_model).where(
|
124 |
+
and_(
|
125 |
+
entity_model.uuid.in_(doc_ids),
|
126 |
+
entity_model.entity_type == "DocumentAsEntity"
|
127 |
+
)
|
128 |
+
)
|
129 |
+
).scalars().all()
|
130 |
+
|
131 |
+
return [self._map_db_entity_to_linker_entity(doc) for doc in documents]
|
132 |
+
|
133 |
+
def get_neighboring_chunks(self,
|
134 |
+
chunk_ids: Iterable[UUID],
|
135 |
+
max_distance: int = 1) -> List[LinkerEntity]:
|
136 |
+
"""
|
137 |
+
Получить соседние чанки для указанных чанков.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
chunk_ids: Список идентификаторов чанков
|
141 |
+
max_distance: Максимальное расстояние до соседа
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Список соседних чанков
|
145 |
+
"""
|
146 |
+
if not chunk_ids:
|
147 |
+
return []
|
148 |
+
|
149 |
+
string_ids = [str(id) for id in chunk_ids]
|
150 |
+
|
151 |
+
with self.db() as session:
|
152 |
+
entity_model = self._entity_model_class()
|
153 |
+
result = []
|
154 |
+
|
155 |
+
# Сначала получаем указанные чанки, чтобы узнать их индексы и документы
|
156 |
+
chunks = session.execute(
|
157 |
+
select(entity_model).where(
|
158 |
+
and_(
|
159 |
+
entity_model.uuid.in_(string_ids),
|
160 |
+
entity_model.entity_type.like("%Chunk") # Используем LIKE для поиска всех типов чанков
|
161 |
+
)
|
162 |
+
)
|
163 |
+
).scalars().all()
|
164 |
+
|
165 |
+
print(f"chunks: {chunks[:3]}...{chunks[-3:]}")
|
166 |
+
|
167 |
+
if not chunks:
|
168 |
+
return []
|
169 |
+
|
170 |
+
# Находим документы для чанков через связи
|
171 |
+
doc_ids = set()
|
172 |
+
chunk_indices = {}
|
173 |
+
|
174 |
+
for chunk in chunks:
|
175 |
+
mapped_chunk = self._map_db_entity_to_linker_entity(chunk)
|
176 |
+
if not isinstance(mapped_chunk, Chunk):
|
177 |
+
continue
|
178 |
+
|
179 |
+
chunk_indices[chunk.uuid] = mapped_chunk.chunk_index
|
180 |
+
|
181 |
+
# Находим связь от документа к чанку
|
182 |
+
links = session.execute(
|
183 |
+
select(entity_model).where(
|
184 |
+
and_(
|
185 |
+
entity_model.target_id == chunk.uuid,
|
186 |
+
entity_model.name == "document_to_chunk"
|
187 |
+
)
|
188 |
+
)
|
189 |
+
).scalars().all()
|
190 |
+
|
191 |
+
print(f"links: {links[:3]}...{links[-3:]}")
|
192 |
+
|
193 |
+
for link in links:
|
194 |
+
doc_ids.add(link.source_id)
|
195 |
+
|
196 |
+
if not doc_ids or not any(idx is not None for idx in chunk_indices.values()):
|
197 |
+
return []
|
198 |
+
|
199 |
+
# Для каждого документа находим все его чанки
|
200 |
+
for doc_id in doc_ids:
|
201 |
+
# Находим все связи от документа к чанкам
|
202 |
+
links = session.execute(
|
203 |
+
select(entity_model).where(
|
204 |
+
and_(
|
205 |
+
entity_model.source_id == doc_id,
|
206 |
+
entity_model.name == "document_to_chunk"
|
207 |
+
)
|
208 |
+
)
|
209 |
+
).scalars().all()
|
210 |
+
|
211 |
+
doc_chunk_ids = [link.target_id for link in links]
|
212 |
+
|
213 |
+
print(f"doc_chunk_ids: {doc_chunk_ids[:3]}...{doc_chunk_ids[-3:]}")
|
214 |
+
|
215 |
+
# Получаем все чанки документа
|
216 |
+
doc_chunks = session.execute(
|
217 |
+
select(entity_model).where(
|
218 |
+
and_(
|
219 |
+
entity_model.uuid.in_(doc_chunk_ids),
|
220 |
+
entity_model.entity_type.like("%Chunk") # Используем LIKE для поиска всех типов чанков
|
221 |
+
)
|
222 |
+
)
|
223 |
+
).scalars().all()
|
224 |
+
|
225 |
+
print(f"doc_chunks: {doc_chunks[:3]}...{doc_chunks[-3:]}")
|
226 |
+
|
227 |
+
# Для каждого чанка в документе проверяем, является ли он соседом
|
228 |
+
for doc_chunk in doc_chunks:
|
229 |
+
if doc_chunk.uuid in chunk_ids:
|
230 |
+
continue
|
231 |
+
|
232 |
+
mapped_chunk = self._map_db_entity_to_linker_entity(doc_chunk)
|
233 |
+
if not isinstance(mapped_chunk, Chunk):
|
234 |
+
continue
|
235 |
+
|
236 |
+
chunk_index = mapped_chunk.chunk_index
|
237 |
+
if chunk_index is None:
|
238 |
+
continue
|
239 |
+
|
240 |
+
# Проверяем, является ли чанк соседом какого-либо из исходных чанков
|
241 |
+
is_neighbor = False
|
242 |
+
for orig_chunk_id, orig_index in chunk_indices.items():
|
243 |
+
if orig_index is not None and abs(chunk_index - orig_index) <= max_distance:
|
244 |
+
is_neighbor = True
|
245 |
+
break
|
246 |
+
|
247 |
+
if is_neighbor:
|
248 |
+
result.append(mapped_chunk)
|
249 |
+
|
250 |
+
return result
|
251 |
+
|
252 |
+
def get_related_entities(self,
|
253 |
+
entity_ids: Iterable[UUID],
|
254 |
+
relation_name: Optional[str] = None,
|
255 |
+
as_source: bool = False,
|
256 |
+
as_target: bool = False) -> List[LinkerEntity]:
|
257 |
+
"""
|
258 |
+
Получить сущности, связанные с указанными сущностями.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
entity_ids: Список идентификаторов сущностей
|
262 |
+
relation_name: Опциональное имя отношения для фильтрации
|
263 |
+
as_source: Если True, ищем связи, где указанные entity_ids являются источниками
|
264 |
+
as_target: Если True, ищем связи, где указанные entity_ids являются целями
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
Список связанных сущностей и связей
|
268 |
+
"""
|
269 |
+
if not entity_ids:
|
270 |
+
return []
|
271 |
+
|
272 |
+
entity_model = self._entity_model_class()
|
273 |
+
result = []
|
274 |
+
|
275 |
+
# Если не указано ни as_source, ни as_target, по умолчанию ищем связи,
|
276 |
+
# где указанные entity_ids являются источниками
|
277 |
+
if not as_source and not as_target:
|
278 |
+
as_source = True
|
279 |
+
|
280 |
+
string_ids = [str(id) for id in entity_ids]
|
281 |
+
|
282 |
+
with self.db() as session:
|
283 |
+
# Поиск связей, где указанные entity_ids являются источниками
|
284 |
+
if as_source:
|
285 |
+
conditions = [
|
286 |
+
entity_model.source_id.in_(string_ids)
|
287 |
+
]
|
288 |
+
|
289 |
+
if relation_name:
|
290 |
+
conditions.append(entity_model.name == relation_name)
|
291 |
+
|
292 |
+
links = session.execute(
|
293 |
+
select(entity_model).where(and_(*conditions))
|
294 |
+
).scalars().all()
|
295 |
+
|
296 |
+
for link in links:
|
297 |
+
# Добавляем связь
|
298 |
+
link_entity = self._map_db_entity_to_linker_entity(link)
|
299 |
+
result.append(link_entity)
|
300 |
+
|
301 |
+
# Добавляем целевую сущность
|
302 |
+
target_entities = session.execute(
|
303 |
+
select(entity_model).where(entity_model.uuid == link.target_id)
|
304 |
+
).scalars().all()
|
305 |
+
|
306 |
+
for target in target_entities:
|
307 |
+
target_entity = self._map_db_entity_to_linker_entity(target)
|
308 |
+
if target_entity not in result:
|
309 |
+
result.append(target_entity)
|
310 |
+
|
311 |
+
# Поиск связей, где указанные entity_ids являются целями
|
312 |
+
if as_target:
|
313 |
+
conditions = [
|
314 |
+
entity_model.target_id.in_(string_ids)
|
315 |
+
]
|
316 |
+
|
317 |
+
if relation_name:
|
318 |
+
conditions.append(entity_model.name == relation_name)
|
319 |
+
|
320 |
+
links = session.execute(
|
321 |
+
select(entity_model).where(and_(*conditions))
|
322 |
+
).scalars().all()
|
323 |
+
|
324 |
+
for link in links:
|
325 |
+
# Добавляем связь
|
326 |
+
link_entity = self._map_db_entity_to_linker_entity(link)
|
327 |
+
result.append(link_entity)
|
328 |
+
|
329 |
+
# Добавляем исходную сущность
|
330 |
+
source_entities = session.execute(
|
331 |
+
select(entity_model).where(entity_model.uuid == link.source_id)
|
332 |
+
).scalars().all()
|
333 |
+
|
334 |
+
for source in source_entities:
|
335 |
+
source_entity = self._map_db_entity_to_linker_entity(source)
|
336 |
+
if source_entity not in result:
|
337 |
+
result.append(source_entity)
|
338 |
+
|
339 |
+
return result
|
lib/extractor/ntr_text_fragmentation/models/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль моделей данных.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .chunk import Chunk
|
6 |
+
from .document import DocumentAsEntity
|
7 |
+
from .linker_entity import LinkerEntity
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"LinkerEntity",
|
11 |
+
"DocumentAsEntity",
|
12 |
+
"Chunk",
|
13 |
+
]
|
lib/extractor/ntr_text_fragmentation/models/chunk.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Класс для представления чанка документа.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
from .linker_entity import LinkerEntity, register_entity
|
8 |
+
|
9 |
+
|
10 |
+
@register_entity
|
11 |
+
@dataclass
|
12 |
+
class Chunk(LinkerEntity):
|
13 |
+
"""
|
14 |
+
Класс для представления чанка документа в системе извлечения и сборки.
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
chunk_index: Порядковый номер чанка в документе (0-based).
|
18 |
+
Используется для восстановления порядка при сборке.
|
19 |
+
"""
|
20 |
+
|
21 |
+
chunk_index: int | None = None
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def deserialize(cls, data: LinkerEntity) -> 'Chunk':
|
25 |
+
"""
|
26 |
+
Десериализует Chunk из объекта LinkerEntity.
|
27 |
+
|
28 |
+
Базовый класс Chunk не должен использоваться напрямую,
|
29 |
+
все конкретные реализации должны переопределить этот метод.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
data: Объект LinkerEntity для преобразования в Chunk
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
NotImplementedError: Метод должен быть переопределен в дочерних классах
|
36 |
+
"""
|
37 |
+
if cls == Chunk:
|
38 |
+
# Если это прямой вызов на базовом классе Chunk, выбрасываем исключение
|
39 |
+
raise NotImplementedError(
|
40 |
+
"Базовый класс Chunk не поддерживает десериализацию. "
|
41 |
+
"Используйте конкретную реализацию Chunk (например, FixedSizeChunk)."
|
42 |
+
)
|
43 |
+
|
44 |
+
# Если вызывается из дочернего класса, который не переопределил метод,
|
45 |
+
# выбрасываем более конкретную ошибку
|
46 |
+
raise NotImplementedError(
|
47 |
+
f"Класс {cls.__name__} должен реализовать метод deserialize."
|
48 |
+
)
|
lib/extractor/ntr_text_fragmentation/models/document.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Класс для представления документа как сущности.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
from .linker_entity import LinkerEntity, register_entity
|
8 |
+
|
9 |
+
|
10 |
+
@register_entity
|
11 |
+
@dataclass
|
12 |
+
class DocumentAsEntity(LinkerEntity):
|
13 |
+
"""
|
14 |
+
Класс для представления документа как сущности в системе извлечения и сборки.
|
15 |
+
"""
|
16 |
+
|
17 |
+
doc_type: str = "unknown"
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def deserialize(cls, data: LinkerEntity) -> 'DocumentAsEntity':
|
21 |
+
"""
|
22 |
+
Десериализует DocumentAsEntity из объекта LinkerEntity.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
data: Объект LinkerEntity для преобразования в DocumentAsEntity
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Десериализованный объект DocumentAsEntity
|
29 |
+
"""
|
30 |
+
metadata = data.metadata or {}
|
31 |
+
|
32 |
+
# Получаем тип документа из метаданных или используем значение по умолчанию
|
33 |
+
doc_type = metadata.get('_doc_type', 'unknown')
|
34 |
+
|
35 |
+
# Создаем чистые метаданные без служебных полей
|
36 |
+
clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')}
|
37 |
+
|
38 |
+
return cls(
|
39 |
+
id=data.id,
|
40 |
+
name=data.name,
|
41 |
+
text=data.text,
|
42 |
+
in_search_text=data.in_search_text,
|
43 |
+
metadata=clean_metadata,
|
44 |
+
source_id=data.source_id,
|
45 |
+
target_id=data.target_id,
|
46 |
+
number_in_relation=data.number_in_relation,
|
47 |
+
type="DocumentAsEntity",
|
48 |
+
doc_type=doc_type
|
49 |
+
)
|
lib/extractor/ntr_text_fragmentation/models/linker_entity.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Базовый абстрактный класс для всех сущностей с поддержкой триплетного подхода.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import uuid
|
6 |
+
from abc import abstractmethod
|
7 |
+
from dataclasses import dataclass, field, fields
|
8 |
+
from uuid import UUID
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class LinkerEntity:
|
13 |
+
"""
|
14 |
+
Общий класс для всех сущностей в системе извлечения и сборки.
|
15 |
+
Поддерживает триплетный подход, где каждая сущность может опционально связывать две другие сущности.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
id (UUID): Уникальный идентификатор сущности.
|
19 |
+
name (str): Название сущности.
|
20 |
+
text (str): Текстое представление сущности.
|
21 |
+
in_search_text (str | None): Текст для поиска. Если задан, используется в __str__, иначе используется обычное представление.
|
22 |
+
metadata (dict): Метаданные сущности.
|
23 |
+
source_id (UUID | None): Опциональный идентификатор исходной сущности.
|
24 |
+
Если указан, эта сущность является связью.
|
25 |
+
target_id (UUID | None): Опциональный идентификатор целевой сущности.
|
26 |
+
Если указан, эта сущность является связью.
|
27 |
+
number_in_relation (int | None): Используется в случае связей один-ко-многим,
|
28 |
+
указывает номер целевой сущности в списке.
|
29 |
+
type (str): Тип сущности.
|
30 |
+
"""
|
31 |
+
|
32 |
+
id: UUID
|
33 |
+
name: str
|
34 |
+
text: str
|
35 |
+
metadata: dict # JSON с метаданными
|
36 |
+
in_search_text: str | None = None
|
37 |
+
source_id: UUID | None = None
|
38 |
+
target_id: UUID | None = None
|
39 |
+
number_in_relation: int | None = None
|
40 |
+
type: str = field(default_factory=lambda: "Entity")
|
41 |
+
|
42 |
+
def __post_init__(self):
|
43 |
+
if self.id is None:
|
44 |
+
self.id = uuid.uuid4()
|
45 |
+
|
46 |
+
# Проверяем корректность полей связи
|
47 |
+
if (self.source_id is not None and self.target_id is None) or \
|
48 |
+
(self.source_id is None and self.target_id is not None):
|
49 |
+
raise ValueError("source_id и target_id должны быть либо оба указаны, либо оба None")
|
50 |
+
|
51 |
+
def is_link(self) -> bool:
|
52 |
+
"""
|
53 |
+
Проверяет, является ли сущность связью (имеет и source_id, и target_id).
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
bool: True, если сущность является связью, иначе False
|
57 |
+
"""
|
58 |
+
return self.source_id is not None and self.target_id is not None
|
59 |
+
|
60 |
+
def __str__(self) -> str:
|
61 |
+
"""
|
62 |
+
Возвращает строковое представление сущности.
|
63 |
+
Если задан in_search_text, возвращает его, иначе возвращает стандартное представление.
|
64 |
+
"""
|
65 |
+
if self.in_search_text is not None:
|
66 |
+
return self.in_search_text
|
67 |
+
return f"{self.name}: {self.text}"
|
68 |
+
|
69 |
+
def __eq__(self, other: 'LinkerEntity') -> bool:
|
70 |
+
"""
|
71 |
+
Сравнивает текущую сущность с другой.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
other: Другая сущность для сравнения
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
bool: True если сущности совпадают, иначе False
|
78 |
+
"""
|
79 |
+
if not isinstance(other, self.__class__):
|
80 |
+
return False
|
81 |
+
|
82 |
+
basic_equality = (
|
83 |
+
self.id == other.id
|
84 |
+
and self.name == other.name
|
85 |
+
and self.text == other.text
|
86 |
+
and self.type == other.type
|
87 |
+
)
|
88 |
+
|
89 |
+
# Если мы имеем дело со связями, также проверяем поля связи
|
90 |
+
if self.is_link() or other.is_link():
|
91 |
+
return (
|
92 |
+
basic_equality
|
93 |
+
and self.source_id == other.source_id
|
94 |
+
and self.target_id == other.target_id
|
95 |
+
)
|
96 |
+
|
97 |
+
return basic_equality
|
98 |
+
|
99 |
+
def serialize(self) -> 'LinkerEntity':
|
100 |
+
"""
|
101 |
+
Сериализует сущность в простейшую форму сущности, передавая все дополнительные поля в метаданные.
|
102 |
+
"""
|
103 |
+
# Получаем список полей базового класса
|
104 |
+
known_fields = {field.name for field in fields(LinkerEntity)}
|
105 |
+
|
106 |
+
# Получаем все атрибуты текущего объекта
|
107 |
+
dict_entity = {}
|
108 |
+
for attr_name in dir(self):
|
109 |
+
# Пропускаем служебные атрибуты, методы и уже известные поля
|
110 |
+
if (
|
111 |
+
attr_name.startswith('_')
|
112 |
+
or attr_name in known_fields
|
113 |
+
or callable(getattr(self, attr_name))
|
114 |
+
):
|
115 |
+
continue
|
116 |
+
|
117 |
+
# Добавляем дополнительные поля в словарь
|
118 |
+
dict_entity[attr_name] = getattr(self, attr_name)
|
119 |
+
|
120 |
+
# Преобразуем имена дополнительных полей, добавляя префикс "_"
|
121 |
+
dict_entity = {f'_{name}': value for name, value in dict_entity.items()}
|
122 |
+
|
123 |
+
# Объединяем с существующими метаданными
|
124 |
+
dict_entity = {**dict_entity, **self.metadata}
|
125 |
+
|
126 |
+
result_type = self.type
|
127 |
+
if result_type == "Entity":
|
128 |
+
result_type = self.__class__.__name__
|
129 |
+
|
130 |
+
# Создаем базовый объект LinkerEntity с новыми метаданными
|
131 |
+
return LinkerEntity(
|
132 |
+
id=self.id,
|
133 |
+
name=self.name,
|
134 |
+
text=self.text,
|
135 |
+
in_search_text=self.in_search_text,
|
136 |
+
metadata=dict_entity,
|
137 |
+
source_id=self.source_id,
|
138 |
+
target_id=self.target_id,
|
139 |
+
number_in_relation=self.number_in_relation,
|
140 |
+
type=result_type,
|
141 |
+
)
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
@abstractmethod
|
145 |
+
def deserialize(cls, data: 'LinkerEntity') -> 'Self':
|
146 |
+
"""
|
147 |
+
Десериализует сущность из простейшей формы сущности, учитывая все дополнительные поля в метаданных.
|
148 |
+
"""
|
149 |
+
raise NotImplementedError(
|
150 |
+
f"Метод deserialize для класса {cls.__class__.__name__} не реализован"
|
151 |
+
)
|
152 |
+
|
153 |
+
# Реестр для хранения всех наследников LinkerEntity
|
154 |
+
_entity_classes = {}
|
155 |
+
|
156 |
+
@classmethod
|
157 |
+
def register_entity_class(cls, entity_class):
|
158 |
+
"""
|
159 |
+
Регистрирует класс-наследник в реестре.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
entity_class: Класс для регистрации
|
163 |
+
"""
|
164 |
+
entity_type = entity_class.__name__
|
165 |
+
cls._entity_classes[entity_type] = entity_class
|
166 |
+
# Также регистрируем по типу, если он отличается от имени класса
|
167 |
+
if hasattr(entity_class, 'type') and isinstance(entity_class.type, str):
|
168 |
+
cls._entity_classes[entity_class.type] = entity_class
|
169 |
+
|
170 |
+
@classmethod
|
171 |
+
def deserialize(cls, data: 'LinkerEntity') -> 'LinkerEntity':
|
172 |
+
"""
|
173 |
+
Десериализует сущность в нужный тип на основе поля type.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
data: Сериализованная сущность типа LinkerEntity
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
Десериализованная сущность правильного типа
|
180 |
+
"""
|
181 |
+
# Получаем тип сущности
|
182 |
+
entity_type = data.type
|
183 |
+
|
184 |
+
# Проверяем реестр классов
|
185 |
+
if entity_type in cls._entity_classes:
|
186 |
+
try:
|
187 |
+
return cls._entity_classes[entity_type].deserialize(data)
|
188 |
+
except (AttributeError, NotImplementedError) as e:
|
189 |
+
# Если метод не реализован, возвращаем исходную сущность
|
190 |
+
return data
|
191 |
+
|
192 |
+
# Если тип не найден в реестре, просто возвращаем исходную сущность
|
193 |
+
# Больше не используем опасное сканирование sys.modules
|
194 |
+
return data
|
195 |
+
|
196 |
+
|
197 |
+
# Декоратор для регистрации производных классов
|
198 |
+
def register_entity(cls):
|
199 |
+
"""
|
200 |
+
Декоратор для регистрации классов-наследников LinkerEntity.
|
201 |
+
|
202 |
+
Пример использования:
|
203 |
+
|
204 |
+
@register_entity
|
205 |
+
class MyEntity(LinkerEntity):
|
206 |
+
type = "my_entity"
|
207 |
+
|
208 |
+
Args:
|
209 |
+
cls: Класс, который нужно зарегистрировать
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
Исходный класс (без изменений)
|
213 |
+
"""
|
214 |
+
# Регистрируем класс в реестр, используя его имя или указанный тип
|
215 |
+
entity_type = getattr(cls, 'type', cls.__name__)
|
216 |
+
LinkerEntity._entity_classes[entity_type] = cls
|
217 |
+
return cls
|
lib/extractor/pyproject.toml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
build-backend = "setuptools.build_meta"
|
3 |
+
requires = ["setuptools>=61"]
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "ntr_text_fragmentation"
|
7 |
+
version = "0.1.0"
|
8 |
+
dependencies = [
|
9 |
+
"uuid==1.30",
|
10 |
+
"ntr_fileparser @ git+ssh://[email protected]/textai/parsers/parser.git@master"
|
11 |
+
]
|
12 |
+
|
13 |
+
[project.optional-dependencies]
|
14 |
+
test = [
|
15 |
+
"pytest>=7.0.0",
|
16 |
+
"pytest-cov>=4.0.0"
|
17 |
+
]
|
18 |
+
|
19 |
+
[tool.setuptools.packages.find]
|
20 |
+
where = ["."]
|
21 |
+
|
22 |
+
[tool.pytest]
|
23 |
+
testpaths = ["tests"]
|
24 |
+
python_files = "test_*.py"
|
25 |
+
python_classes = "Test*"
|
26 |
+
python_functions = "test_*"
|
lib/extractor/scripts/README_test_chunking.md
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Тестирование чанкинга и сборки документов
|
2 |
+
|
3 |
+
Скрипт `test_chunking.py` позволяет тестировать различные стратегии чанкинга документов и их последующую сборку.
|
4 |
+
|
5 |
+
## Возможности
|
6 |
+
|
7 |
+
1. **Разбивка документов** - применение различных стратегий чанкинга к документам
|
8 |
+
2. **Сохранение результатов** - сохранение чанков и метаданных в CSV
|
9 |
+
3. **Сборка документов** - загрузка чанков из CSV и сборка документа с помощью InjectionBuilder
|
10 |
+
4. **Фильтрация чанков** - возможность фильтровать чанки по индексу или ключевым словам
|
11 |
+
|
12 |
+
## Режимы работы
|
13 |
+
|
14 |
+
Скрипт поддерживает три режима работы:
|
15 |
+
|
16 |
+
1. **chunk** - только разбивка документа на чанки и сохранение в CSV
|
17 |
+
2. **build** - загрузка чанков из CSV и сборка документа
|
18 |
+
3. **full** - разбивка документа, сохранение в CSV и последующая сборка
|
19 |
+
|
20 |
+
## Примеры использования
|
21 |
+
|
22 |
+
### Разбивка документа на чанки (стратегия fixed_size)
|
23 |
+
|
24 |
+
```bash
|
25 |
+
python scripts/test_chunking.py --mode chunk --input test_input/test.docx --strategy fixed_size --words 50 --overlap 25
|
26 |
+
```
|
27 |
+
|
28 |
+
### Разбивка документа на чанки (стратегия sentence)
|
29 |
+
|
30 |
+
```bash
|
31 |
+
python scripts/test_chunking.py --mode chunk --input test_input/test.docx --strategy sentence
|
32 |
+
```
|
33 |
+
|
34 |
+
### Загрузка чанков из CSV и сборка документа (все чанки)
|
35 |
+
|
36 |
+
```bash
|
37 |
+
python scripts/test_chunking.py --mode build --csv test_output/test_fixed_size_w50_o25.csv
|
38 |
+
```
|
39 |
+
|
40 |
+
### Загрузка чанков из CSV и сборка документа (с фильтрацией по индексу)
|
41 |
+
|
42 |
+
```bash
|
43 |
+
python scripts/test_chunking.py --mode build --csv test_output/test_fixed_size_w50_o25.csv --filter index --filter-value "0,2,4"
|
44 |
+
```
|
45 |
+
|
46 |
+
### Загрузка чанков из CSV и сборка документа (с фильтрацией по ключевому слову)
|
47 |
+
|
48 |
+
```bash
|
49 |
+
python scripts/test_chunking.py --mode build --csv test_output/test_fixed_size_w50_o25.csv --filter keyword --filter-value "важно"
|
50 |
+
```
|
51 |
+
|
52 |
+
### Полный цикл: разбивка, сохранение и сборка
|
53 |
+
|
54 |
+
```bash
|
55 |
+
python scripts/test_chunking.py --mode full --input test_input/test.docx --strategy fixed_size --words 50 --overlap 25
|
56 |
+
```
|
57 |
+
|
58 |
+
## Параметры командной строки
|
59 |
+
|
60 |
+
### Основные параметры
|
61 |
+
|
62 |
+
| Параметр | Описание | Значения по умолчанию |
|
63 |
+
|----------|----------|------------------------|
|
64 |
+
| `--mode` | Режим работы | `chunk` |
|
65 |
+
| `--input` | Путь к входному файлу | `test_input/test.docx` |
|
66 |
+
| `--csv` | Путь к CSV файлу с сущностями | None |
|
67 |
+
| `--output-dir` | Директория для выходных файлов | `test_output` |
|
68 |
+
|
69 |
+
### Параметры стратегии чанкинга
|
70 |
+
|
71 |
+
| Параметр | Описание | Значения по умолчанию |
|
72 |
+
|----------|----------|------------------------|
|
73 |
+
| `--strategy` | Стратегия чанкинга | `fixed_size` |
|
74 |
+
| `--words` | Количество слов в чанке (для fixed_size) | 50 |
|
75 |
+
| `--overlap` | Перекрытие в словах (для fixed_size) | 25 |
|
76 |
+
| `--debug` | Режим отладки (для numbered_items) | False |
|
77 |
+
|
78 |
+
### Параметры фильтрации
|
79 |
+
|
80 |
+
| Параметр | Описание | Значения по умолчанию |
|
81 |
+
|----------|----------|------------------------|
|
82 |
+
| `--filter` | Тип фильтрации чанков | `none` |
|
83 |
+
| `--filter-value` | Значение для фильтрации | None |
|
84 |
+
|
85 |
+
## Подготовка тестовых данных
|
86 |
+
|
87 |
+
Для тестирования скрипта вам понадобится документ в формате docx, txt, pdf или другом поддерживаемом формате. Поместите тестовый документ в папку `test_input`.
|
88 |
+
|
89 |
+
## Результаты работы
|
90 |
+
|
91 |
+
После выполнения скрипта в папке `test_output` будут созданы следующие файлы:
|
92 |
+
|
93 |
+
1. **test_{strategy}_....csv** - CSV файл с сущностями (документ, чанки, связи)
|
94 |
+
2. **rebuilt_document_{filter}_{filter_value}.txt** - собранный текст документа (при использовании режимов build или full)
|
95 |
+
|
96 |
+
## Примечания
|
97 |
+
|
98 |
+
- Для различных стратегий чанкинга доступны разные пара��етры
|
99 |
+
- При сборке документа можно использовать фильтрацию чанков по индексу или ключевому слову
|
100 |
+
- Собранный документ будет отличаться от исходного, если использовалась фильтрация чанков
|
101 |
+
|
102 |
+
## Требования
|
103 |
+
|
104 |
+
- Python 3.8+
|
105 |
+
- pandas
|
106 |
+
- ntr_fileparser
|
107 |
+
- ntr_text_fragmentation
|
lib/extractor/scripts/analyze_missing_puncts.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для анализа ненайденных пунктов по лучшему подходу чанкинга (200 слов, 75 перекрытие, baai/bge-m3, top-100).
|
4 |
+
Формирует отчет в формате Markdown с топ-5 наиболее похожими чанками для каждого ненайденного пункта.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
from fuzzywuzzy import fuzz
|
16 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
# Константы
|
20 |
+
DATA_FOLDER = "data/docs" # Путь к папке с документами
|
21 |
+
MODEL_NAME = "BAAI/bge-m3" # Название лучшей модели
|
22 |
+
DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами
|
23 |
+
OUTPUT_DIR = "data" # Директория для сохранения результатов
|
24 |
+
MARKDOWN_FILE = "missing_puncts_analysis.md" # Имя выходного MD-файла
|
25 |
+
SIMILARITY_THRESHOLD = 0.7 # Порог для нечеткого сравнения
|
26 |
+
WORDS_PER_CHUNK = 200 # Размер чанка в словах
|
27 |
+
OVERLAP_WORDS = 75 # Перекрытие в словах
|
28 |
+
TOP_N = 100 # Количество чанков в топе
|
29 |
+
|
30 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
"""
|
35 |
+
Парсит аргументы командной строки.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Аргументы командной строки
|
39 |
+
"""
|
40 |
+
parser = argparse.ArgumentParser(description="Анализ ненайденных пунктов для лучшего подхода чанкинга")
|
41 |
+
|
42 |
+
parser.add_argument("--data-folder", type=str, default=DATA_FOLDER,
|
43 |
+
help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})")
|
44 |
+
parser.add_argument("--model-name", type=str, default=MODEL_NAME,
|
45 |
+
help=f"Название модели (по умолчанию: {MODEL_NAME})")
|
46 |
+
parser.add_argument("--dataset-path", type=str, default=DATASET_PATH,
|
47 |
+
help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})")
|
48 |
+
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
|
49 |
+
help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})")
|
50 |
+
parser.add_argument("--markdown-file", type=str, default=MARKDOWN_FILE,
|
51 |
+
help=f"Имя выходного MD-файла (по умолчанию: {MARKDOWN_FILE})")
|
52 |
+
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
|
53 |
+
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
|
54 |
+
parser.add_argument("--words-per-chunk", type=int, default=WORDS_PER_CHUNK,
|
55 |
+
help=f"Размер чанка в словах (по умолчанию: {WORDS_PER_CHUNK})")
|
56 |
+
parser.add_argument("--overlap-words", type=int, default=OVERLAP_WORDS,
|
57 |
+
help=f"Перекрытие в словах (по умолчанию: {OVERLAP_WORDS})")
|
58 |
+
parser.add_argument("--top-n", type=int, default=TOP_N,
|
59 |
+
help=f"Количество чанков в топе (по умолчанию: {TOP_N})")
|
60 |
+
|
61 |
+
return parser.parse_args()
|
62 |
+
|
63 |
+
|
64 |
+
def load_questions_dataset(file_path: str) -> pd.DataFrame:
|
65 |
+
"""
|
66 |
+
Загружает датасет с вопросами из Excel-файла.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
file_path: Путь к Excel-файлу
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
DataFrame с вопросами и пунктами
|
73 |
+
"""
|
74 |
+
print(f"Загрузка датасета из {file_path}...")
|
75 |
+
|
76 |
+
df = pd.read_excel(file_path)
|
77 |
+
print(f"Загружен датасет со столбцами: {df.columns.tolist()}")
|
78 |
+
|
79 |
+
# Преобразуем NaN в пустые строки для текстовых полей
|
80 |
+
text_columns = ['question', 'text', 'item_type']
|
81 |
+
for col in text_columns:
|
82 |
+
if col in df.columns:
|
83 |
+
df[col] = df[col].fillna('')
|
84 |
+
|
85 |
+
return df
|
86 |
+
|
87 |
+
|
88 |
+
def load_chunks_and_embeddings(output_dir: str, words_per_chunk: int, overlap_words: int, model_name: str) -> tuple:
|
89 |
+
"""
|
90 |
+
Загружает чанки и эмбеддинги из файлов.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
output_dir: Директория с файлами
|
94 |
+
words_per_chunk: Размер чанка в словах
|
95 |
+
overlap_words: Перекрытие в словах
|
96 |
+
model_name: Название модели
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Кортеж (чанки, эмбе��динги чанков, эмбеддинги вопросов, данные вопросов)
|
100 |
+
"""
|
101 |
+
# Формируем уникальное имя для файлов на основе параметров
|
102 |
+
model_name_safe = model_name.replace('/', '_')
|
103 |
+
strategy_config_str = f"fixed_size_w{words_per_chunk}_o{overlap_words}"
|
104 |
+
chunks_filename = f"chunks_{strategy_config_str}_{model_name_safe}"
|
105 |
+
questions_filename = f"questions_{model_name_safe}"
|
106 |
+
|
107 |
+
# Пути к файлам
|
108 |
+
chunks_embeddings_path = os.path.join(output_dir, f"{chunks_filename}_embeddings.npy")
|
109 |
+
chunks_data_path = os.path.join(output_dir, f"{chunks_filename}_data.csv")
|
110 |
+
questions_embeddings_path = os.path.join(output_dir, f"{questions_filename}_embeddings.npy")
|
111 |
+
questions_data_path = os.path.join(output_dir, f"{questions_filename}_data.csv")
|
112 |
+
|
113 |
+
# Проверяем наличие всех файлов
|
114 |
+
for path in [chunks_embeddings_path, chunks_data_path, questions_embeddings_path, questions_data_path]:
|
115 |
+
if not os.path.exists(path):
|
116 |
+
raise FileNotFoundError(f"Файл {path} не найден")
|
117 |
+
|
118 |
+
# Загружаем данные
|
119 |
+
print(f"Загрузка данных из {output_dir}...")
|
120 |
+
chunks_embeddings = np.load(chunks_embeddings_path)
|
121 |
+
chunks_df = pd.read_csv(chunks_data_path)
|
122 |
+
questions_embeddings = np.load(questions_embeddings_path)
|
123 |
+
questions_df = pd.read_csv(questions_data_path)
|
124 |
+
|
125 |
+
print(f"Загружено {len(chunks_df)} чанков и {len(questions_df)} вопросов")
|
126 |
+
|
127 |
+
return chunks_df, chunks_embeddings, questions_embeddings, questions_df
|
128 |
+
|
129 |
+
|
130 |
+
def load_top_chunks(top_chunks_dir: str) -> dict:
|
131 |
+
"""
|
132 |
+
Загружает JSON-файлы с топ-чанками для вопросов.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
top_chunks_dir: Директория с JSON-файлами
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
Словарь {question_id: данные из JSON}
|
139 |
+
"""
|
140 |
+
print(f"Загрузка топ-чанков из {top_chunks_dir}...")
|
141 |
+
|
142 |
+
top_chunks_data = {}
|
143 |
+
json_files = list(Path(top_chunks_dir).glob("question_*_top_chunks.json"))
|
144 |
+
|
145 |
+
for json_file in tqdm(json_files, desc="Загрузка JSON-файлов"):
|
146 |
+
try:
|
147 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
148 |
+
data = json.load(f)
|
149 |
+
question_id = data.get('question_id')
|
150 |
+
if question_id is not None:
|
151 |
+
top_chunks_data[question_id] = data
|
152 |
+
except Exception as e:
|
153 |
+
print(f"Ошибка при загрузке файла {json_file}: {e}")
|
154 |
+
|
155 |
+
print(f"Загружены данные для {len(top_chunks_data)} вопросов")
|
156 |
+
|
157 |
+
return top_chunks_data
|
158 |
+
|
159 |
+
|
160 |
+
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
|
161 |
+
"""
|
162 |
+
Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
chunk_text: Текст чанка
|
166 |
+
punct_text: Текст пункта
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Коэффициент перекрытия от 0 до 1
|
170 |
+
"""
|
171 |
+
# Если чанк входит в пункт, возвращаем 1.0 (полное вхождение)
|
172 |
+
if chunk_text in punct_text:
|
173 |
+
return 1.0
|
174 |
+
|
175 |
+
# Если пункт входит в чанк, возвращаем соотношение длин
|
176 |
+
if punct_text in chunk_text:
|
177 |
+
return len(punct_text) / len(chunk_text)
|
178 |
+
|
179 |
+
# Используем partial_ratio из fuzzywuzzy
|
180 |
+
partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0
|
181 |
+
|
182 |
+
return partial_ratio_score
|
183 |
+
|
184 |
+
|
185 |
+
def find_most_similar_chunks(punct_text: str, chunks_df: pd.DataFrame, chunks_embeddings: np.ndarray, punct_embedding: np.ndarray, top_n: int = 5) -> list:
|
186 |
+
"""
|
187 |
+
Находит топ-N наиболее похожих чанков для заданного пункта.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
punct_text: Текст пункта
|
191 |
+
chunks_df: DataFrame с чанками
|
192 |
+
chunks_embeddings: Эмбеддинги чанков
|
193 |
+
punct_embedding: Эмбеддинг пункта
|
194 |
+
top_n: Количество похожих чанков (по умолчанию 5)
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
Список словарей с информацией о похожих чанках
|
198 |
+
"""
|
199 |
+
# Вычисляем косинусную близость между пунктом и всеми чанками
|
200 |
+
similarities = cosine_similarity([punct_embedding], chunks_embeddings)[0]
|
201 |
+
|
202 |
+
# Получаем индексы топ-N чанков по косинусной близости
|
203 |
+
top_indices = np.argsort(similarities)[-top_n:][::-1]
|
204 |
+
|
205 |
+
similar_chunks = []
|
206 |
+
for idx in top_indices:
|
207 |
+
chunk = chunks_df.iloc[idx]
|
208 |
+
overlap = calculate_chunk_overlap(chunk['text'], punct_text)
|
209 |
+
|
210 |
+
similar_chunks.append({
|
211 |
+
'chunk_id': chunk['id'],
|
212 |
+
'doc_name': chunk['doc_name'],
|
213 |
+
'text': chunk['text'],
|
214 |
+
'similarity': float(similarities[idx]),
|
215 |
+
'overlap': overlap
|
216 |
+
})
|
217 |
+
|
218 |
+
return similar_chunks
|
219 |
+
|
220 |
+
|
221 |
+
def analyze_missing_puncts(questions_df: pd.DataFrame, chunks_df: pd.DataFrame,
|
222 |
+
questions_embeddings: np.ndarray, chunks_embeddings: np.ndarray,
|
223 |
+
similarity_threshold: float, top_n: int = 100) -> dict:
|
224 |
+
"""
|
225 |
+
Анализирует ненайденные пункты и находит для них наиболее похожие чанки.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
questions_df: DataFrame с вопросами и пунктами
|
229 |
+
chunks_df: DataFrame с чанками
|
230 |
+
questions_embeddings: Эмбеддинги вопросов
|
231 |
+
chunks_embeddings: Эмбеддинги чанков
|
232 |
+
similarity_threshold: Порог для определения найденных пунктов
|
233 |
+
top_n: Количество чанков для проверки (по умолчанию 100)
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
Словарь с результатами анализа
|
237 |
+
"""
|
238 |
+
print("Анализ ненайденных пунктов...")
|
239 |
+
|
240 |
+
# Проверяем соответствие количества вопросов и эмбеддингов
|
241 |
+
unique_question_ids = questions_df['id'].unique()
|
242 |
+
if len(unique_question_ids) != questions_embeddings.shape[0]:
|
243 |
+
print(f"ВНИМАНИЕ: Количество уникальных ID вопросов ({len(unique_question_ids)}) не соответствует размеру массива эмбеддингов ({questions_embeddings.shape[0]}).")
|
244 |
+
print("Будут анализироваться только вопросы, имеющие соответствующие эмбеддинги.")
|
245 |
+
|
246 |
+
# Создаем маппинг id вопроса -> индекс в DataFrame с метаданными
|
247 |
+
# Используем порядковый номер в списке уникальных ID, а не порядок строк в DataFrame
|
248 |
+
question_id_to_idx = {qid: idx for idx, qid in enumerate(unique_question_ids)}
|
249 |
+
|
250 |
+
# Вычисляем косинусную близость между вопросами и чанками
|
251 |
+
similarity_matrix = cosine_similarity(questions_embeddings, chunks_embeddings)
|
252 |
+
|
253 |
+
# Результаты анализа
|
254 |
+
analysis_results = {}
|
255 |
+
|
256 |
+
# Обрабатываем только те вопросы, для которых у нас есть эмбеддинги
|
257 |
+
valid_question_ids = [qid for qid in unique_question_ids if qid in question_id_to_idx and question_id_to_idx[qid] < len(questions_embeddings)]
|
258 |
+
|
259 |
+
# Группируем датасет по id вопроса
|
260 |
+
for question_id in tqdm(valid_question_ids, desc="Анализ вопросов"):
|
261 |
+
# Получаем строки для текущего вопроса
|
262 |
+
question_rows = questions_df[questions_df['id'] == question_id]
|
263 |
+
|
264 |
+
# Если нет строк с таким id, пропускаем
|
265 |
+
if len(question_rows) == 0:
|
266 |
+
continue
|
267 |
+
|
268 |
+
# Получаем индекс вопроса в массиве эмбеддингов
|
269 |
+
question_idx = question_id_to_idx[question_id]
|
270 |
+
|
271 |
+
# Если индекс выходит за границы массива эмбеддингов, пропускаем
|
272 |
+
if question_idx >= questions_embeddings.shape[0]:
|
273 |
+
print(f"ВНИМАНИЕ: Индекс {question_idx} для вопроса {question_id} выходит за границы массива эмбеддингов размера {questions_embeddings.shape[0]}. Пропускаем.")
|
274 |
+
continue
|
275 |
+
|
276 |
+
# Получаем текст вопроса и пункты
|
277 |
+
question_text = question_rows['question'].iloc[0]
|
278 |
+
|
279 |
+
# Собираем пункты с информацией о документе
|
280 |
+
puncts = []
|
281 |
+
for _, row in question_rows.iterrows():
|
282 |
+
punct_doc = row.get('filename', '') if 'filename' in row else ''
|
283 |
+
if pd.isna(punct_doc):
|
284 |
+
punct_doc = ''
|
285 |
+
puncts.append({
|
286 |
+
'text': row['text'],
|
287 |
+
'doc_name': punct_doc
|
288 |
+
})
|
289 |
+
|
290 |
+
# Получаем связанные документы
|
291 |
+
relevant_docs = []
|
292 |
+
if 'filename' in question_rows.columns:
|
293 |
+
relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)]
|
294 |
+
else:
|
295 |
+
relevant_docs = chunks_df['doc_name'].unique().tolist()
|
296 |
+
|
297 |
+
# Если для вопроса нет релевантных документов, пропускаем
|
298 |
+
if not relevant_docs:
|
299 |
+
continue
|
300 |
+
|
301 |
+
# Для отслеживания найденных и ненайденных пунктов
|
302 |
+
found_puncts = []
|
303 |
+
missing_puncts = []
|
304 |
+
|
305 |
+
# Собираем все чанки для документов вопроса
|
306 |
+
all_question_chunks = []
|
307 |
+
all_question_similarities = []
|
308 |
+
|
309 |
+
for filename in relevant_docs:
|
310 |
+
if not filename or pd.isna(filename):
|
311 |
+
continue
|
312 |
+
|
313 |
+
# Фильтруем чанки по имени файла
|
314 |
+
doc_chunks = chunks_df[chunks_df['doc_name'] == filename]
|
315 |
+
|
316 |
+
if doc_chunks.empty:
|
317 |
+
continue
|
318 |
+
|
319 |
+
# Индексы чанков для текущего файла
|
320 |
+
doc_chunk_indices = doc_chunks.index.tolist()
|
321 |
+
|
322 |
+
# Проверяем, что индексы чанков существуют в chunks_df
|
323 |
+
valid_indices = [idx for idx in doc_chunk_indices if idx in chunks_df.index]
|
324 |
+
|
325 |
+
# Получаем значения близости для чанков текущего файла
|
326 |
+
doc_similarities = []
|
327 |
+
for idx in valid_indices:
|
328 |
+
try:
|
329 |
+
chunk_loc = chunks_df.index.get_loc(idx)
|
330 |
+
doc_similarities.append(similarity_matrix[question_idx, chunk_loc])
|
331 |
+
except (KeyError, IndexError) as e:
|
332 |
+
print(f"Ошибка при получении индекса для чанка {idx}: {e}")
|
333 |
+
continue
|
334 |
+
|
335 |
+
# Добавляем чанки и их схожести к общему списку для вопроса
|
336 |
+
for i, idx in enumerate(valid_indices):
|
337 |
+
if i < len(doc_similarities): # проверяем, что у нас есть соответствующее значение similarity
|
338 |
+
try:
|
339 |
+
chunk_row = doc_chunks.loc[idx]
|
340 |
+
all_question_chunks.append((idx, chunk_row))
|
341 |
+
all_question_similarities.append(doc_similarities[i])
|
342 |
+
except KeyError as e:
|
343 |
+
print(f"Ошибка при доступе к строке с индексом {idx}: {e}")
|
344 |
+
|
345 |
+
# Если нет чанков для вопроса, пропускаем
|
346 |
+
if not all_question_chunks:
|
347 |
+
continue
|
348 |
+
|
349 |
+
# Сортируем все чанки по убыванию схожести и берем top_n
|
350 |
+
sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1]
|
351 |
+
top_chunks = []
|
352 |
+
top_similarities = []
|
353 |
+
|
354 |
+
# Собираем топ-N чанков и их схожести
|
355 |
+
for i in sorted_indices:
|
356 |
+
idx, chunk = all_question_chunks[i]
|
357 |
+
top_chunks.append({
|
358 |
+
'id': chunk['id'],
|
359 |
+
'doc_name': chunk['doc_name'],
|
360 |
+
'text': chunk['text']
|
361 |
+
})
|
362 |
+
top_similarities.append(all_question_similarities[i])
|
363 |
+
|
364 |
+
# Проверяем каждый пункт на наличие в топ-чанках
|
365 |
+
for i, punct in enumerate(puncts):
|
366 |
+
is_found = False
|
367 |
+
punct_text = punct['text']
|
368 |
+
punct_doc = punct['doc_name']
|
369 |
+
|
370 |
+
# Для каждого чанка из топ-N рассчитываем partial_ratio с пунктом
|
371 |
+
chunk_overlaps = []
|
372 |
+
for j, chunk in enumerate(top_chunks):
|
373 |
+
overlap = calculate_chunk_overlap(chunk['text'], punct_text)
|
374 |
+
|
375 |
+
# Если перекрытие больше порога, пункт найден
|
376 |
+
if overlap >= similarity_threshold:
|
377 |
+
is_found = True
|
378 |
+
|
379 |
+
# Сохраняем информацию о перекрытии для каждого чанка
|
380 |
+
chunk_overlaps.append({
|
381 |
+
'chunk_id': chunk['id'],
|
382 |
+
'doc_name': chunk['doc_name'],
|
383 |
+
'text': chunk['text'],
|
384 |
+
'overlap': overlap,
|
385 |
+
'similarity': float(top_similarities[j])
|
386 |
+
})
|
387 |
+
|
388 |
+
# Если пункт найден, добавляем в список найденных
|
389 |
+
if is_found:
|
390 |
+
found_puncts.append({
|
391 |
+
'index': i,
|
392 |
+
'text': punct_text,
|
393 |
+
'doc_name': punct_doc
|
394 |
+
})
|
395 |
+
else:
|
396 |
+
# Сортируем чанки по убыванию перекрытия с пунктом и берем топ-5
|
397 |
+
chunk_overlaps.sort(key=lambda x: x['overlap'], reverse=True)
|
398 |
+
top_overlaps = chunk_overlaps[:5]
|
399 |
+
|
400 |
+
missing_puncts.append({
|
401 |
+
'index': i,
|
402 |
+
'text': punct_text,
|
403 |
+
'doc_name': punct_doc,
|
404 |
+
'similar_chunks': top_overlaps
|
405 |
+
})
|
406 |
+
|
407 |
+
# Добавляем результаты для текущего вопроса
|
408 |
+
analysis_results[question_id] = {
|
409 |
+
'question_id': question_id,
|
410 |
+
'question_text': question_text,
|
411 |
+
'found_puncts_count': len(found_puncts),
|
412 |
+
'missing_puncts_count': len(missing_puncts),
|
413 |
+
'total_puncts_count': len(puncts),
|
414 |
+
'found_puncts': found_puncts,
|
415 |
+
'missing_puncts': missing_puncts
|
416 |
+
}
|
417 |
+
|
418 |
+
return analysis_results
|
419 |
+
|
420 |
+
|
421 |
+
def generate_markdown_report(analysis_results: dict, output_file: str,
|
422 |
+
words_per_chunk: int, overlap_words: int, model_name: str, top_n: int):
|
423 |
+
"""
|
424 |
+
Генерирует отчет в формате Markdown.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
analysis_results: Результаты анализа
|
428 |
+
output_file: Путь к выходному файлу
|
429 |
+
words_per_chunk: Размер чанка в словах
|
430 |
+
overlap_words: Перекрытие в словах
|
431 |
+
model_name: Название модели
|
432 |
+
top_n: Количество чанков в топе
|
433 |
+
"""
|
434 |
+
print(f"Генерация отчета в формате Markdown в {output_file}...")
|
435 |
+
|
436 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
437 |
+
# Заголовок отчета
|
438 |
+
f.write(f"# Анализ ненайденных пунктов для оптимальной конфигурации чанкинга\n\n")
|
439 |
+
|
440 |
+
# Параметры анализа
|
441 |
+
f.write("## Параметры анализа\n\n")
|
442 |
+
f.write(f"- **Модель**: {model_name}\n")
|
443 |
+
f.write(f"- **Размер чанка**: {words_per_chunk} слов\n")
|
444 |
+
f.write(f"- **Перекрытие**: {overlap_words} слов ({round(overlap_words/words_per_chunk*100, 1)}%)\n")
|
445 |
+
f.write(f"- **Количество чанков в топе**: {top_n}\n\n")
|
446 |
+
|
447 |
+
# Сводная статистика
|
448 |
+
total_questions = len(analysis_results)
|
449 |
+
total_puncts = sum(q['total_puncts_count'] for q in analysis_results.values())
|
450 |
+
total_found = sum(q['found_puncts_count'] for q in analysis_results.values())
|
451 |
+
total_missing = sum(q['missing_puncts_count'] for q in analysis_results.values())
|
452 |
+
|
453 |
+
f.write("## Сводная статистика\n\n")
|
454 |
+
f.write(f"- **Всего вопросов**: {total_questions}\n")
|
455 |
+
f.write(f"- **Всего пунктов**: {total_puncts}\n")
|
456 |
+
f.write(f"- **Найдено пунктов**: {total_found} ({round(total_found/total_puncts*100, 1)}%)\n")
|
457 |
+
f.write(f"- **Ненайдено пунктов**: {total_missing} ({round(total_missing/total_puncts*100, 1)}%)\n\n")
|
458 |
+
|
459 |
+
# Детали по каждому вопросу
|
460 |
+
f.write("## Детальный анализ по вопросам\n\n")
|
461 |
+
|
462 |
+
# Сортируем вопросы по количеству ненайденных пунктов (по убыванию)
|
463 |
+
sorted_questions = sorted(
|
464 |
+
analysis_results.values(),
|
465 |
+
key=lambda x: x['missing_puncts_count'],
|
466 |
+
reverse=True
|
467 |
+
)
|
468 |
+
|
469 |
+
for question_data in sorted_questions:
|
470 |
+
question_id = question_data['question_id']
|
471 |
+
question_text = question_data['question_text']
|
472 |
+
missing_count = question_data['missing_puncts_count']
|
473 |
+
total_count = question_data['total_puncts_count']
|
474 |
+
|
475 |
+
# Если нет ненайденных пунктов, пропускаем
|
476 |
+
if missing_count == 0:
|
477 |
+
continue
|
478 |
+
|
479 |
+
f.write(f"### Вопрос {question_id}\n\n")
|
480 |
+
f.write(f"**Текст вопроса**: {question_text}\n\n")
|
481 |
+
f.write(f"**Статистика**: найдено {question_data['found_puncts_count']} из {total_count} пунктов ")
|
482 |
+
f.write(f"({round(question_data['found_puncts_count']/total_count*100, 1)}%)\n\n")
|
483 |
+
|
484 |
+
# Детали по ненайденным пунктам
|
485 |
+
f.write("#### Ненайденные пункты\n\n")
|
486 |
+
|
487 |
+
for i, punct in enumerate(question_data['missing_puncts']):
|
488 |
+
punct_text = punct['text']
|
489 |
+
punct_doc = punct.get('doc_name', '')
|
490 |
+
similar_chunks = punct['similar_chunks']
|
491 |
+
|
492 |
+
f.write(f"##### Пункт {i+1}\n\n")
|
493 |
+
f.write(f"**Текст пункта**: {punct_text}\n\n")
|
494 |
+
if punct_doc:
|
495 |
+
f.write(f"**Документ пункта**: {punct_doc}\n\n")
|
496 |
+
f.write("**Топ-5 наиболее похожих чанков**:\n\n")
|
497 |
+
|
498 |
+
# Таблица с похожими чанками
|
499 |
+
f.write("| № | Документ | С��ожесть (с вопросом) | Перекрытие (с пунктом) | Текст чанка |\n")
|
500 |
+
f.write("|---|----------|----------|------------|------------|\n")
|
501 |
+
|
502 |
+
for j, chunk in enumerate(similar_chunks):
|
503 |
+
# Используем полный текст чанка без обрезки
|
504 |
+
chunk_text = chunk['text']
|
505 |
+
|
506 |
+
f.write(f"| {j+1} | {chunk['doc_name']} | {chunk['similarity']:.4f} | ")
|
507 |
+
f.write(f"{chunk['overlap']:.4f} | {chunk_text} |\n")
|
508 |
+
|
509 |
+
f.write("\n")
|
510 |
+
|
511 |
+
f.write("\n")
|
512 |
+
|
513 |
+
print(f"Отчет успешно сгенерирован: {output_file}")
|
514 |
+
|
515 |
+
|
516 |
+
def main():
|
517 |
+
"""
|
518 |
+
Основная функция скрипта.
|
519 |
+
"""
|
520 |
+
args = parse_args()
|
521 |
+
|
522 |
+
# Загружаем датасет с вопросами
|
523 |
+
questions_df = load_questions_dataset(args.dataset_path)
|
524 |
+
|
525 |
+
# Загружаем чанки и эмбеддинги
|
526 |
+
chunks_df, chunks_embeddings, questions_embeddings, questions_meta = load_chunks_and_embeddings(
|
527 |
+
args.output_dir, args.words_per_chunk, args.overlap_words, args.model_name
|
528 |
+
)
|
529 |
+
|
530 |
+
# Анализируем ненайденные пункты
|
531 |
+
analysis_results = analyze_missing_puncts(
|
532 |
+
questions_df, chunks_df, questions_embeddings, chunks_embeddings,
|
533 |
+
args.similarity_threshold, args.top_n
|
534 |
+
)
|
535 |
+
|
536 |
+
# Генерируем отчет в формате Markdown
|
537 |
+
output_file = os.path.join(args.output_dir, args.markdown_file)
|
538 |
+
generate_markdown_report(
|
539 |
+
analysis_results, output_file,
|
540 |
+
args.words_per_chunk, args.overlap_words, args.model_name, args.top_n
|
541 |
+
)
|
542 |
+
|
543 |
+
print(f"Анализ ненайденных пунктов завершен. Результаты сохранены в {output_file}")
|
544 |
+
|
545 |
+
|
546 |
+
if __name__ == "__main__":
|
547 |
+
main()
|
lib/extractor/scripts/combine_results.py
ADDED
@@ -0,0 +1,1352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для объединения результатов всех экспериментов в одну Excel-таблицу с форматированием.
|
4 |
+
Анализирует результаты экспериментов и создает сводную таблицу с метриками в различных разрезах.
|
5 |
+
Также строит графики через seaborn и сохраняет их в отдельную директорию.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import glob
|
10 |
+
import os
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import pandas as pd
|
14 |
+
import seaborn as sns
|
15 |
+
from openpyxl import Workbook
|
16 |
+
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
17 |
+
from openpyxl.utils import get_column_letter
|
18 |
+
from openpyxl.utils.dataframe import dataframe_to_rows
|
19 |
+
|
20 |
+
|
21 |
+
def setup_plot_directory(plots_dir: str) -> None:
|
22 |
+
"""
|
23 |
+
Создает директорию для сохранения графиков, если она не существует.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
plots_dir: Путь к директории для графиков
|
27 |
+
"""
|
28 |
+
if not os.path.exists(plots_dir):
|
29 |
+
os.makedirs(plots_dir)
|
30 |
+
print(f"Создана директория для графиков: {plots_dir}")
|
31 |
+
else:
|
32 |
+
print(f"Директория для графиков: {plots_dir}")
|
33 |
+
|
34 |
+
|
35 |
+
def parse_args():
|
36 |
+
"""Парсит аргументы командной строки."""
|
37 |
+
parser = argparse.ArgumentParser(description="Объединение результатов экспериментов в одну Excel-таблицу")
|
38 |
+
|
39 |
+
parser.add_argument("--results-dir", type=str, default="data",
|
40 |
+
help="Директория с результатами экспериментов (по умолчанию: data)")
|
41 |
+
parser.add_argument("--output-file", type=str, default="combined_results.xlsx",
|
42 |
+
help="Путь к выходному Excel-файлу (по умолчанию: combined_results.xlsx)")
|
43 |
+
parser.add_argument("--plots-dir", type=str, default="plots",
|
44 |
+
help="Директория для сохранения графиков (по умолчанию: plots)")
|
45 |
+
|
46 |
+
return parser.parse_args()
|
47 |
+
|
48 |
+
|
49 |
+
def parse_file_name(file_name: str) -> dict:
|
50 |
+
"""
|
51 |
+
Парсит имя файла и извлекает параметры эксперимента.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
file_name: Имя файла для парсинга
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Словарь с параметрами (words_per_chunk, overlap_words, model) или None при ошибке
|
58 |
+
"""
|
59 |
+
try:
|
60 |
+
# Извлекаем параметры из имени файла
|
61 |
+
parts = file_name.split('_')
|
62 |
+
if len(parts) < 4:
|
63 |
+
return None
|
64 |
+
|
65 |
+
# Ищем части с w (words) и o (overlap)
|
66 |
+
words_part = None
|
67 |
+
overlap_part = None
|
68 |
+
|
69 |
+
for part in parts:
|
70 |
+
if part.startswith('w') and part[1:].isdigit():
|
71 |
+
words_part = part[1:]
|
72 |
+
elif part.startswith('o') and part[1:].isdigit():
|
73 |
+
# Убираем потенциальную часть .csv или .xlsx из overlap_part
|
74 |
+
overlap_part = part[1:].split('.')[0]
|
75 |
+
|
76 |
+
if words_part is None or overlap_part is None:
|
77 |
+
return None
|
78 |
+
|
79 |
+
# Пытаемся извлечь имя модели из оставшейся части имени файла
|
80 |
+
model_part = file_name.split(f"_w{words_part}_o{overlap_part}_", 1)
|
81 |
+
if len(model_part) < 2:
|
82 |
+
return None
|
83 |
+
|
84 |
+
# Получаем имя модели и удаляем возможное расширение файла
|
85 |
+
model_name_parts = model_part[1].split('.')
|
86 |
+
if len(model_name_parts) > 1:
|
87 |
+
model_name_parts = model_name_parts[:-1]
|
88 |
+
|
89 |
+
model_name_parts = '_'.join(model_name_parts).split('_')
|
90 |
+
model_name = '/'.join(model_name_parts)
|
91 |
+
|
92 |
+
return {
|
93 |
+
'words_per_chunk': int(words_part),
|
94 |
+
'overlap_words': int(overlap_part),
|
95 |
+
'model': model_name,
|
96 |
+
'overlap_percentage': round(int(overlap_part) / int(words_part) * 100, 1)
|
97 |
+
}
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Ошибка при парсинге файла {file_name}: {e}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
|
103 |
+
def load_data_files(results_dir: str, pattern: str, file_type: str, load_function) -> pd.DataFrame:
|
104 |
+
"""
|
105 |
+
Общая функция для загрузки файлов данных с определенным паттерном имени.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
results_dir: Директория с результатами
|
109 |
+
pattern: Glob-паттерн для поиска файлов
|
110 |
+
file_type: Тип файлов для сообщений (напр. "результатов", "метрик")
|
111 |
+
load_function: Функция для загрузки конкретного типа файла
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
DataFrame с объединенными данными или None при ошибке
|
115 |
+
"""
|
116 |
+
print(f"Загрузка {file_type} из {results_dir}...")
|
117 |
+
|
118 |
+
# Ищем все файлы с указанным паттерном
|
119 |
+
data_files = glob.glob(os.path.join(results_dir, pattern))
|
120 |
+
|
121 |
+
if not data_files:
|
122 |
+
print(f"В директории {results_dir} не найдены файлы {file_type}")
|
123 |
+
return None
|
124 |
+
|
125 |
+
print(f"Найдено {len(data_files)} файлов {file_type}")
|
126 |
+
|
127 |
+
all_data = []
|
128 |
+
|
129 |
+
for file_path in data_files:
|
130 |
+
# Извлекаем информацию о стратегии и модели из имени файла
|
131 |
+
file_name = os.path.basename(file_path)
|
132 |
+
print(f"Обрабатываю файл: {file_name}")
|
133 |
+
|
134 |
+
# Парсим параметры из имени файла
|
135 |
+
params = parse_file_name(file_name)
|
136 |
+
|
137 |
+
if params is None:
|
138 |
+
print(f"Пропуск файла {file_name}: не удалось извлечь параметры")
|
139 |
+
continue
|
140 |
+
|
141 |
+
words_part = params['words_per_chunk']
|
142 |
+
overlap_part = params['overlap_words']
|
143 |
+
model_name = params['model']
|
144 |
+
overlap_percentage = params['overlap_percentage']
|
145 |
+
|
146 |
+
print(f" Параметры: words={words_part}, overlap={overlap_part}, model={model_name}")
|
147 |
+
|
148 |
+
try:
|
149 |
+
# Загружаем данные, используя переданную функцию
|
150 |
+
df = load_function(file_path)
|
151 |
+
|
152 |
+
# Добавляем информацию о стратегии и модели
|
153 |
+
df['model'] = model_name
|
154 |
+
df['words_per_chunk'] = words_part
|
155 |
+
df['overlap_words'] = overlap_part
|
156 |
+
df['overlap_percentage'] = overlap_percentage
|
157 |
+
|
158 |
+
all_data.append(df)
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Ошибка при обработке файла {file_path}: {e}")
|
161 |
+
|
162 |
+
if not all_data:
|
163 |
+
print(f"Не удалось загрузить ни один файл {file_type}")
|
164 |
+
return None
|
165 |
+
|
166 |
+
# Объединяем все данные
|
167 |
+
combined_data = pd.concat(all_data, ignore_index=True)
|
168 |
+
|
169 |
+
return combined_data
|
170 |
+
|
171 |
+
|
172 |
+
def load_results_files(results_dir: str) -> pd.DataFrame:
|
173 |
+
"""
|
174 |
+
Загружает все файлы результатов из указанной директории.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
results_dir: Директория с результатами
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
DataFrame с объединенными результатами
|
181 |
+
"""
|
182 |
+
# Используем общую функцию для загрузки CSV файлов
|
183 |
+
data = load_data_files(
|
184 |
+
results_dir,
|
185 |
+
"results_*.csv",
|
186 |
+
"результатов",
|
187 |
+
lambda f: pd.read_csv(f)
|
188 |
+
)
|
189 |
+
|
190 |
+
if data is None:
|
191 |
+
raise ValueError("Не удалось загрузить файлы с результатами")
|
192 |
+
|
193 |
+
return data
|
194 |
+
|
195 |
+
|
196 |
+
def load_question_metrics_files(results_dir: str) -> pd.DataFrame:
|
197 |
+
"""
|
198 |
+
Загружает все файлы с метриками по вопросам из указанной директории.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
results_dir: Директория с результатами
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
DataFrame с объединенными метриками по вопросам или None, если файлов нет
|
205 |
+
"""
|
206 |
+
# Используем общую функцию для загрузки Excel файлов
|
207 |
+
return load_data_files(
|
208 |
+
results_dir,
|
209 |
+
"question_metrics_*.xlsx",
|
210 |
+
"метрик по вопросам",
|
211 |
+
lambda f: pd.read_excel(f)
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
def prepare_summary_by_model_top_n(df: pd.DataFrame, macro_metrics: pd.DataFrame = None) -> pd.DataFrame:
|
216 |
+
"""
|
217 |
+
Подготавливает сводную таблицу по моделям и top_n значениям.
|
218 |
+
Если доступны macro метрики, они также включаются в сводную таблицу.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
df: DataFrame с объединенными результатами
|
222 |
+
macro_metrics: DataFrame с macro метриками (опционально)
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
DataFrame со сводной таблицей
|
226 |
+
"""
|
227 |
+
# Определяем группировочные колонки и метрики
|
228 |
+
group_by_columns = ['model', 'top_n']
|
229 |
+
metrics = ['text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']
|
230 |
+
|
231 |
+
# Используем общую функцию для подготовки сводки
|
232 |
+
return prepare_summary(df, group_by_columns, metrics, macro_metrics)
|
233 |
+
|
234 |
+
|
235 |
+
def prepare_summary_by_chunking_params_top_n(df: pd.DataFrame, macro_metrics: pd.DataFrame = None) -> pd.DataFrame:
|
236 |
+
"""
|
237 |
+
Подготавливает сводную таблицу по параметрам чанкинга и top_n значениям.
|
238 |
+
Если доступны macro метрики, они также включаются в сводную таблицу.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
df: DataFrame с объединенными результатами
|
242 |
+
macro_metrics: DataFrame с macro метриками (опционально)
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
DataFrame со сводной таблицей
|
246 |
+
"""
|
247 |
+
# Определяем группировочные колонки и метрики
|
248 |
+
group_by_columns = ['words_per_chunk', 'overlap_words', 'top_n']
|
249 |
+
metrics = ['text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']
|
250 |
+
|
251 |
+
# Используем общую функцию для подготовки сводки
|
252 |
+
return prepare_summary(df, group_by_columns, metrics, macro_metrics)
|
253 |
+
|
254 |
+
|
255 |
+
def prepare_summary(df: pd.DataFrame, group_by_columns: list, metrics: list, macro_metrics: pd.DataFrame = None) -> pd.DataFrame:
|
256 |
+
"""
|
257 |
+
Общая функция для подготовки сводной таблицы по указанным группировочным колонкам.
|
258 |
+
Если доступны macro метрики, они также включаются в сводную таблицу.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
df: DataFrame с объединенными результатами
|
262 |
+
group_by_columns: Колонки для группировки
|
263 |
+
metrics: Список метрик для расчета среднего
|
264 |
+
macro_metrics: DataFrame с macro метриками (опционально)
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
DataFrame со сводной таблицей
|
268 |
+
"""
|
269 |
+
# Группируем по указанным колонкам, вычисляем средние значения метрик
|
270 |
+
summary = df.groupby(group_by_columns).agg({
|
271 |
+
metric: 'mean' for metric in metrics
|
272 |
+
}).reset_index()
|
273 |
+
|
274 |
+
# Если среди группировочных колонок есть 'overlap_words' и 'words_per_chunk',
|
275 |
+
# добавляем процент перекрытия
|
276 |
+
if 'overlap_words' in group_by_columns and 'words_per_chunk' in group_by_columns:
|
277 |
+
summary['overlap_percentage'] = (summary['overlap_words'] / summary['words_per_chunk'] * 100).round(1)
|
278 |
+
|
279 |
+
# Если доступны macro метрики, объединяем их с summary
|
280 |
+
if macro_metrics is not None:
|
281 |
+
# Преобразуем метрики в macro_метрики
|
282 |
+
macro_metric_names = [f"macro_{metric}" for metric in metrics]
|
283 |
+
|
284 |
+
# Группируем macro метрики по тем же колонкам
|
285 |
+
macro_summary = macro_metrics.groupby(group_by_columns).agg({
|
286 |
+
metric: 'mean' for metric in macro_metric_names
|
287 |
+
}).reset_index()
|
288 |
+
|
289 |
+
# Если нужно, добавляем процент перекрытия для согласованности
|
290 |
+
if 'overlap_words' in group_by_columns and 'words_per_chunk' in group_by_columns:
|
291 |
+
macro_summary['overlap_percentage'] = (macro_summary['overlap_words'] / macro_summary['words_per_chunk'] * 100).round(1)
|
292 |
+
merge_on = group_by_columns + ['overlap_percentage']
|
293 |
+
else:
|
294 |
+
merge_on = group_by_columns
|
295 |
+
|
296 |
+
# Объединяем с основной сводкой
|
297 |
+
summary = pd.merge(summary, macro_summary, on=merge_on, how='left')
|
298 |
+
|
299 |
+
# Сортируем по группировочным колонкам
|
300 |
+
summary = summary.sort_values(group_by_columns)
|
301 |
+
|
302 |
+
# Округляем метрики до 4 знаков после запятой
|
303 |
+
for col in summary.columns:
|
304 |
+
if any(col.endswith(suffix) for suffix in ['precision', 'recall', 'f1']):
|
305 |
+
summary[col] = summary[col].round(4)
|
306 |
+
|
307 |
+
return summary
|
308 |
+
|
309 |
+
|
310 |
+
def prepare_best_configurations(df: pd.DataFrame, macro_metrics: pd.DataFrame = None) -> pd.DataFrame:
|
311 |
+
"""
|
312 |
+
Подготавливает таблицу с лучшими конфигурациями для каждой модели и различных top_n.
|
313 |
+
Выбирает конфигурацию только на основе macro_text_recall и text_recall (weighted),
|
314 |
+
игнорируя F1 метрики как менее важные.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
df: DataFrame с объединенными результатами
|
318 |
+
macro_metrics: DataFrame с macro метриками (опционально)
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
DataFrame с лучшими конфигурациями
|
322 |
+
"""
|
323 |
+
# Выбираем ключевые значения top_n
|
324 |
+
key_top_n = [10, 20, 50, 100]
|
325 |
+
|
326 |
+
# Определяем источник метрик и акцентируем только на recall-метриках
|
327 |
+
if macro_metrics is not None:
|
328 |
+
print("Выбор лучших конфигураций на основе macro метрик (macro_text_recall)")
|
329 |
+
metrics_source = macro_metrics
|
330 |
+
text_recall_metric = 'macro_text_recall'
|
331 |
+
doc_recall_metric = 'macro_doc_recall'
|
332 |
+
else:
|
333 |
+
print("Выбор лучших конфигураций на основе weighted метрик (text_recall)")
|
334 |
+
metrics_source = df
|
335 |
+
text_recall_metric = 'text_recall'
|
336 |
+
doc_recall_metric = 'doc_recall'
|
337 |
+
|
338 |
+
# Фильтруем только по ключевым значениям top_n
|
339 |
+
filtered_df = metrics_source[metrics_source['top_n'].isin(key_top_n)]
|
340 |
+
|
341 |
+
# Для каждой модели и top_n находим конфигурацию только с лучшим recall
|
342 |
+
best_configs = []
|
343 |
+
|
344 |
+
for model in metrics_source['model'].unique():
|
345 |
+
for top_n in key_top_n:
|
346 |
+
model_top_n_df = filtered_df[(filtered_df['model'] == model) & (filtered_df['top_n'] == top_n)]
|
347 |
+
|
348 |
+
if len(model_top_n_df) == 0:
|
349 |
+
continue
|
350 |
+
|
351 |
+
# Находим конфигурацию с лучшим text_recall
|
352 |
+
best_text_recall_idx = model_top_n_df[text_recall_metric].idxmax()
|
353 |
+
best_text_recall_config = model_top_n_df.loc[best_text_recall_idx].copy()
|
354 |
+
best_text_recall_config['metric_type'] = 'text_recall'
|
355 |
+
|
356 |
+
# Находим конфигурацию с лучшим doc_recall
|
357 |
+
best_doc_recall_idx = model_top_n_df[doc_recall_metric].idxmax()
|
358 |
+
best_doc_recall_config = model_top_n_df.loc[best_doc_recall_idx].copy()
|
359 |
+
best_doc_recall_config['metric_type'] = 'doc_recall'
|
360 |
+
|
361 |
+
best_configs.append(best_text_recall_config)
|
362 |
+
best_configs.append(best_doc_recall_config)
|
363 |
+
|
364 |
+
if not best_configs:
|
365 |
+
return pd.DataFrame()
|
366 |
+
|
367 |
+
best_configs_df = pd.DataFrame(best_configs)
|
368 |
+
|
369 |
+
# Выбираем и сортируем нужные столбцы
|
370 |
+
cols_to_keep = ['model', 'top_n', 'metric_type', 'words_per_chunk', 'overlap_words', 'overlap_percentage']
|
371 |
+
|
372 |
+
# Добавляем столбцы метрик в зависимости от того, какие доступны
|
373 |
+
if macro_metrics is not None:
|
374 |
+
# Для macro метрик сначала выбираем recall-метрики
|
375 |
+
recall_cols = [col for col in best_configs_df.columns if col.endswith('recall')]
|
376 |
+
# Затем добавляем остальные метрики
|
377 |
+
other_cols = [col for col in best_configs_df.columns if any(col.endswith(m) for m in
|
378 |
+
['precision', 'f1']) and col.startswith('macro_')]
|
379 |
+
metric_cols = recall_cols + other_cols
|
380 |
+
else:
|
381 |
+
# Для weighted метрик сначала выбираем recall-метрики
|
382 |
+
recall_cols = [col for col in best_configs_df.columns if col.endswith('recall')]
|
383 |
+
# Затем добавляем остальные метрики
|
384 |
+
other_cols = [col for col in best_configs_df.columns if any(col.endswith(m) for m in
|
385 |
+
['precision', 'f1']) and not col.startswith('macro_')]
|
386 |
+
metric_cols = recall_cols + other_cols
|
387 |
+
|
388 |
+
result = best_configs_df[cols_to_keep + metric_cols].sort_values(['model', 'top_n', 'metric_type'])
|
389 |
+
|
390 |
+
return result
|
391 |
+
|
392 |
+
|
393 |
+
def get_grouping_columns(sheet) -> dict:
|
394 |
+
"""
|
395 |
+
Определяет подходящие колонки для группировки данных на листе.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
sheet: Лист Excel
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
Словарь с данными о группировке или None
|
402 |
+
"""
|
403 |
+
# Возможные варианты группировки
|
404 |
+
grouping_possibilities = [
|
405 |
+
{'columns': ['model', 'words_per_chunk', 'overlap_words']},
|
406 |
+
{'columns': ['model']},
|
407 |
+
{'columns': ['words_per_chunk', 'overlap_words']},
|
408 |
+
{'columns': ['top_n']},
|
409 |
+
{'columns': ['model', 'top_n', 'metric_type']}
|
410 |
+
]
|
411 |
+
|
412 |
+
# Для каждого варианта группировки проверяем наличие всех колонок
|
413 |
+
for grouping in grouping_possibilities:
|
414 |
+
column_indices = {}
|
415 |
+
all_columns_present = True
|
416 |
+
|
417 |
+
for column_name in grouping['columns']:
|
418 |
+
column_idx = None
|
419 |
+
for col_idx, cell in enumerate(sheet[1], start=1):
|
420 |
+
if cell.value == column_name:
|
421 |
+
column_idx = col_idx
|
422 |
+
break
|
423 |
+
|
424 |
+
if column_idx is None:
|
425 |
+
all_columns_present = False
|
426 |
+
break
|
427 |
+
else:
|
428 |
+
column_indices[column_name] = column_idx
|
429 |
+
|
430 |
+
if all_columns_present:
|
431 |
+
return {
|
432 |
+
'columns': grouping['columns'],
|
433 |
+
'indices': column_indices
|
434 |
+
}
|
435 |
+
|
436 |
+
return None
|
437 |
+
|
438 |
+
|
439 |
+
def apply_header_formatting(sheet):
|
440 |
+
"""
|
441 |
+
Применяет форматирование к заголовкам.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
sheet: Лист Excel
|
445 |
+
"""
|
446 |
+
# Форматирование заголовков
|
447 |
+
for cell in sheet[1]:
|
448 |
+
cell.font = Font(bold=True)
|
449 |
+
cell.fill = PatternFill(start_color="D9D9D9", end_color="D9D9D9", fill_type="solid")
|
450 |
+
cell.alignment = Alignment(horizontal='center', vertical='center', wrap_text=True)
|
451 |
+
|
452 |
+
|
453 |
+
def adjust_column_width(sheet):
|
454 |
+
"""
|
455 |
+
Настраивает ширину столбцов на основе содержимого.
|
456 |
+
|
457 |
+
Args:
|
458 |
+
sheet: Лист Excel
|
459 |
+
"""
|
460 |
+
# Авторазмер столбцов
|
461 |
+
for column in sheet.columns:
|
462 |
+
max_length = 0
|
463 |
+
column_letter = get_column_letter(column[0].column)
|
464 |
+
|
465 |
+
for cell in column:
|
466 |
+
if cell.value:
|
467 |
+
try:
|
468 |
+
if len(str(cell.value)) > max_length:
|
469 |
+
max_length = len(str(cell.value))
|
470 |
+
except:
|
471 |
+
pass
|
472 |
+
|
473 |
+
adjusted_width = (max_length + 2) * 1.1
|
474 |
+
sheet.column_dimensions[column_letter].width = adjusted_width
|
475 |
+
|
476 |
+
|
477 |
+
def apply_cell_formatting(sheet):
|
478 |
+
"""
|
479 |
+
Применяет форматирование к ячейкам (границы, выравнивание и т.д.).
|
480 |
+
|
481 |
+
Args:
|
482 |
+
sheet: Лист Excel
|
483 |
+
"""
|
484 |
+
# Тонкие границы для всех ячеек
|
485 |
+
thin_border = Border(
|
486 |
+
left=Side(style='thin'),
|
487 |
+
right=Side(style='thin'),
|
488 |
+
top=Side(style='thin'),
|
489 |
+
bottom=Side(style='thin')
|
490 |
+
)
|
491 |
+
|
492 |
+
for row in sheet.iter_rows(min_row=1, max_row=sheet.max_row, min_col=1, max_col=sheet.max_column):
|
493 |
+
for cell in row:
|
494 |
+
cell.border = thin_border
|
495 |
+
|
496 |
+
# Форматирование числовых значений
|
497 |
+
numeric_columns = [
|
498 |
+
'text_precision', 'text_recall', 'text_f1',
|
499 |
+
'doc_precision', 'doc_recall', 'doc_f1',
|
500 |
+
'macro_text_precision', 'macro_text_recall', 'macro_text_f1',
|
501 |
+
'macro_doc_precision', 'macro_doc_recall', 'macro_doc_f1'
|
502 |
+
]
|
503 |
+
|
504 |
+
for col_idx, header in enumerate(sheet[1], start=1):
|
505 |
+
if header.value in numeric_columns or (header.value and str(header.value).endswith(('precision', 'recall', 'f1'))):
|
506 |
+
for row_idx in range(2, sheet.max_row + 1):
|
507 |
+
cell = sheet.cell(row=row_idx, column=col_idx)
|
508 |
+
if isinstance(cell.value, (int, float)):
|
509 |
+
cell.number_format = '0.0000'
|
510 |
+
|
511 |
+
# Выравнивание для всех ячеек
|
512 |
+
for row in sheet.iter_rows(min_row=2, max_row=sheet.max_row, min_col=1, max_col=sheet.max_column):
|
513 |
+
for cell in row:
|
514 |
+
cell.alignment = Alignment(horizontal='center', vertical='center')
|
515 |
+
|
516 |
+
|
517 |
+
def apply_group_formatting(sheet, grouping):
|
518 |
+
"""
|
519 |
+
Применяет форматирование к группам строк.
|
520 |
+
|
521 |
+
Args:
|
522 |
+
sheet: Лист Excel
|
523 |
+
grouping: Словарь с данными о группировке
|
524 |
+
"""
|
525 |
+
if not grouping or sheet.max_row <= 1:
|
526 |
+
return
|
527 |
+
|
528 |
+
# Для каждой строки проверяем изменение значений группировочных колонок
|
529 |
+
last_values = {column: None for column in grouping['columns']}
|
530 |
+
|
531 |
+
# Применяем жирную верхнюю границу к первой строке данных
|
532 |
+
for col_idx in range(1, sheet.max_column + 1):
|
533 |
+
cell = sheet.cell(row=2, column=col_idx)
|
534 |
+
cell.border = Border(
|
535 |
+
left=cell.border.left,
|
536 |
+
right=cell.border.right,
|
537 |
+
top=Side(style='thick'),
|
538 |
+
bottom=cell.border.bottom
|
539 |
+
)
|
540 |
+
|
541 |
+
for row_idx in range(2, sheet.max_row + 1):
|
542 |
+
current_values = {}
|
543 |
+
for column in grouping['columns']:
|
544 |
+
col_idx = grouping['indices'][column]
|
545 |
+
current_values[column] = sheet.cell(row=row_idx, column=col_idx).value
|
546 |
+
|
547 |
+
# Если значения изменились, добавляем жирные границы
|
548 |
+
values_changed = False
|
549 |
+
for column in grouping['columns']:
|
550 |
+
if current_values[column] != last_values[column]:
|
551 |
+
values_changed = True
|
552 |
+
break
|
553 |
+
|
554 |
+
if values_changed and row_idx > 2:
|
555 |
+
# Жирная верхняя граница для текущей строки
|
556 |
+
for col_idx in range(1, sheet.max_column + 1):
|
557 |
+
cell = sheet.cell(row=row_idx, column=col_idx)
|
558 |
+
cell.border = Border(
|
559 |
+
left=cell.border.left,
|
560 |
+
right=cell.border.right,
|
561 |
+
top=Side(style='thick'),
|
562 |
+
bottom=cell.border.bottom
|
563 |
+
)
|
564 |
+
|
565 |
+
# Жирная нижняя граница для предыдущей строки
|
566 |
+
for col_idx in range(1, sheet.max_column + 1):
|
567 |
+
cell = sheet.cell(row=row_idx-1, column=col_idx)
|
568 |
+
cell.border = Border(
|
569 |
+
left=cell.border.left,
|
570 |
+
right=cell.border.right,
|
571 |
+
top=cell.border.top,
|
572 |
+
bottom=Side(style='thick')
|
573 |
+
)
|
574 |
+
|
575 |
+
# Запоминаем текущие значения для следующей итерации
|
576 |
+
for column in grouping['columns']:
|
577 |
+
last_values[column] = current_values[column]
|
578 |
+
|
579 |
+
# Добавляем жирную нижнюю границу для последней строки
|
580 |
+
for col_idx in range(1, sheet.max_column + 1):
|
581 |
+
cell = sheet.cell(row=sheet.max_row, column=col_idx)
|
582 |
+
cell.border = Border(
|
583 |
+
left=cell.border.left,
|
584 |
+
right=cell.border.right,
|
585 |
+
top=cell.border.top,
|
586 |
+
bottom=Side(style='thick')
|
587 |
+
)
|
588 |
+
|
589 |
+
|
590 |
+
def apply_formatting(workbook: Workbook) -> None:
|
591 |
+
"""
|
592 |
+
Применяет форматирование к Excel-файлу.
|
593 |
+
Добавляет автофильтры для всех столбцов и улучшает визуальное представление.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
workbook: Workbook-объект openpyxl
|
597 |
+
"""
|
598 |
+
for sheet_name in workbook.sheetnames:
|
599 |
+
sheet = workbook[sheet_name]
|
600 |
+
|
601 |
+
# Добавляем автофильтры для всех столбцов
|
602 |
+
if sheet.max_row > 1: # Проверяем, что в листе есть данные
|
603 |
+
sheet.auto_filter.ref = sheet.dimensions
|
604 |
+
|
605 |
+
# Применяем форматирование
|
606 |
+
apply_header_formatting(sheet)
|
607 |
+
adjust_column_width(sheet)
|
608 |
+
apply_cell_formatting(sheet)
|
609 |
+
|
610 |
+
# Определяем группирующие колонки и применяем форматирование к группам
|
611 |
+
grouping = get_grouping_columns(sheet)
|
612 |
+
if grouping:
|
613 |
+
apply_group_formatting(sheet, grouping)
|
614 |
+
|
615 |
+
|
616 |
+
def create_model_comparison_plot(df: pd.DataFrame, metrics: list | str, top_n: int, plots_dir: str) -> None:
|
617 |
+
"""
|
618 |
+
Создает график сравнения моделей по указанным метрикам для заданного top_n.
|
619 |
+
|
620 |
+
Args:
|
621 |
+
df: DataFrame с данными
|
622 |
+
metrics: Список метрик или одна метрика для сравнения
|
623 |
+
top_n: Значение top_n для фильтрации
|
624 |
+
plots_dir: Директория для сохранения графиков
|
625 |
+
"""
|
626 |
+
if isinstance(metrics, str):
|
627 |
+
metrics = [metrics]
|
628 |
+
|
629 |
+
# Фильтруем данные
|
630 |
+
filtered_df = df[df['top_n'] == top_n]
|
631 |
+
|
632 |
+
if len(filtered_df) == 0:
|
633 |
+
print(f"Нет данных для top_n={top_n}")
|
634 |
+
return
|
635 |
+
|
636 |
+
# Определяем тип метрик (macro или weighted)
|
637 |
+
metrics_type = "macro" if metrics[0].startswith("macro_") else "weighted"
|
638 |
+
|
639 |
+
# Создаем фигуру с несколькими подграфиками
|
640 |
+
fig, axes = plt.subplots(1, len(metrics), figsize=(6 * len(metrics), 8))
|
641 |
+
|
642 |
+
# Если только одна метрика, преобразуем axes в список для единообразного обращения
|
643 |
+
if len(metrics) == 1:
|
644 |
+
axes = [axes]
|
645 |
+
|
646 |
+
# Для каждой метрики создаем subplot
|
647 |
+
for i, metric in enumerate(metrics):
|
648 |
+
# Группируем данные по модели
|
649 |
+
columns_to_agg = {metric: 'mean'}
|
650 |
+
model_data = filtered_df.groupby('model').agg(columns_to_agg).reset_index()
|
651 |
+
|
652 |
+
# Сортируем по значению метрики (по убыванию)
|
653 |
+
model_data = model_data.sort_values(metric, ascending=False)
|
654 |
+
|
655 |
+
# Определяем цветовую схему
|
656 |
+
palette = sns.color_palette("viridis", len(model_data))
|
657 |
+
|
658 |
+
# Строим столбчатую диаграмму на соответствующем subplot
|
659 |
+
ax = sns.barplot(x='model', y=metric, data=model_data, palette=palette, ax=axes[i])
|
660 |
+
|
661 |
+
# Добавляем значения над столбцами
|
662 |
+
for j, v in enumerate(model_data[metric]):
|
663 |
+
ax.text(j, v + 0.01, f"{v:.4f}", ha='center', fontsize=8)
|
664 |
+
|
665 |
+
# Устанавливаем заголовок и метки осей
|
666 |
+
ax.set_title(f"{metric} (top_n={top_n})", fontsize=12)
|
667 |
+
ax.set_xlabel("Модель", fontsize=10)
|
668 |
+
ax.set_ylabel(f"{metric}", fontsize=10)
|
669 |
+
|
670 |
+
# Поворачиваем подписи по оси X для лучшей читаемости
|
671 |
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=8)
|
672 |
+
|
673 |
+
# Настраиваем макет
|
674 |
+
plt.tight_layout()
|
675 |
+
|
676 |
+
# Сохраняем график
|
677 |
+
metric_names = '_'.join([m.replace('macro_', '') for m in metrics])
|
678 |
+
file_name = f"model_comparison_{metrics_type}_{metric_names}_top{top_n}.png"
|
679 |
+
plt.savefig(os.path.join(plots_dir, file_name), dpi=300)
|
680 |
+
plt.close()
|
681 |
+
|
682 |
+
print(f"Создан график сравнения моделей: {file_name}")
|
683 |
+
|
684 |
+
|
685 |
+
def create_top_n_plot(df: pd.DataFrame, models: list | str, metric: str, plots_dir: str) -> None:
|
686 |
+
"""
|
687 |
+
Создает график зависимости метрики от top_n для заданных моделей.
|
688 |
+
|
689 |
+
Args:
|
690 |
+
df: DataFrame с данными
|
691 |
+
models: Список моделей или одна модель для сравнения
|
692 |
+
metric: Название метрики
|
693 |
+
plots_dir: Директория для сохранения графиков
|
694 |
+
"""
|
695 |
+
if isinstance(models, str):
|
696 |
+
models = [models]
|
697 |
+
|
698 |
+
# Создаем фигуру
|
699 |
+
plt.figure(figsize=(12, 8))
|
700 |
+
|
701 |
+
# Определяем цветовую схему
|
702 |
+
palette = sns.color_palette("viridis", len(models))
|
703 |
+
|
704 |
+
# Ограничиваем количество моделей для читаемости
|
705 |
+
if len(models) > 5:
|
706 |
+
models = models[:5]
|
707 |
+
print("Слишком много моделей для графика, ограничиваем до 5")
|
708 |
+
|
709 |
+
# Для каждой модели строим линию
|
710 |
+
for i, model in enumerate(models):
|
711 |
+
# Находим наиболее часто используемые параметры чанкинга для этой модели
|
712 |
+
model_df = df[df['model'] == model]
|
713 |
+
|
714 |
+
if len(model_df) == 0:
|
715 |
+
print(f"Нет данных для модели {model}")
|
716 |
+
continue
|
717 |
+
|
718 |
+
# Группируем по параметрам чанкинга и подсчитываем частоту
|
719 |
+
common_configs = model_df.groupby(['words_per_chunk', 'overlap_words']).size().reset_index(name='count')
|
720 |
+
|
721 |
+
if len(common_configs) == 0:
|
722 |
+
continue
|
723 |
+
|
724 |
+
# Берем наиболее частую конфигурацию
|
725 |
+
common_config = common_configs.sort_values('count', ascending=False).iloc[0]
|
726 |
+
|
727 |
+
# Фильтруем для этой конфигурации
|
728 |
+
config_df = model_df[
|
729 |
+
(model_df['words_per_chunk'] == common_config['words_per_chunk']) &
|
730 |
+
(model_df['overlap_words'] == common_config['overlap_words'])
|
731 |
+
].sort_values('top_n')
|
732 |
+
|
733 |
+
if len(config_df) <= 1:
|
734 |
+
continue
|
735 |
+
|
736 |
+
# Строим линию
|
737 |
+
plt.plot(config_df['top_n'], config_df[metric], marker='o', linewidth=2,
|
738 |
+
label=f"{model} (w={common_config['words_per_chunk']}, o={common_config['overlap_words']})",
|
739 |
+
color=palette[i])
|
740 |
+
|
741 |
+
# Добавляем легенду, заголовок и метки осей
|
742 |
+
plt.legend(title="Модель (параметры)", fontsize=10, loc='best')
|
743 |
+
plt.title(f"Зависимость {metric} от top_n для разных моделей", fontsize=16)
|
744 |
+
plt.xlabel("top_n", fontsize=14)
|
745 |
+
plt.ylabel(metric, fontsize=14)
|
746 |
+
|
747 |
+
# Включаем сетку
|
748 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
749 |
+
|
750 |
+
# Настраиваем макет
|
751 |
+
plt.tight_layout()
|
752 |
+
|
753 |
+
# Сохраняем график
|
754 |
+
is_macro = "macro" if "macro" in metric else "weighted"
|
755 |
+
file_name = f"top_n_comparison_{is_macro}_{metric.replace('macro_', '')}.png"
|
756 |
+
plt.savefig(os.path.join(plots_dir, file_name), dpi=300)
|
757 |
+
plt.close()
|
758 |
+
|
759 |
+
print(f"Создан график зависимости от top_n: {file_name}")
|
760 |
+
|
761 |
+
|
762 |
+
def create_chunk_size_plot(df: pd.DataFrame, model: str, metrics: list | str, top_n: int, plots_dir: str) -> None:
|
763 |
+
"""
|
764 |
+
Создает график зависимости метрик от размера чанка для заданной модели и top_n.
|
765 |
+
|
766 |
+
Args:
|
767 |
+
df: DataFrame с данными
|
768 |
+
model: Название модели
|
769 |
+
metrics: Список метрик или одна метрика
|
770 |
+
top_n: Значение top_n
|
771 |
+
plots_dir: Директория для сохранения графиков
|
772 |
+
"""
|
773 |
+
if isinstance(metrics, str):
|
774 |
+
metrics = [metrics]
|
775 |
+
|
776 |
+
# Фильтруем данные
|
777 |
+
filtered_df = df[(df['model'] == model) & (df['top_n'] == top_n)]
|
778 |
+
|
779 |
+
if len(filtered_df) <= 1:
|
780 |
+
print(f"Недостаточно данных для модели {model} и top_n={top_n}")
|
781 |
+
return
|
782 |
+
|
783 |
+
# Создаем фигуру
|
784 |
+
plt.figure(figsize=(14, 8))
|
785 |
+
|
786 |
+
# Определяем цветовую схему для метрик
|
787 |
+
palette = sns.color_palette("viridis", len(metrics))
|
788 |
+
|
789 |
+
# Группируем по размеру чанка и проценту перекрытия
|
790 |
+
# Вычисляем среднее только для указанных метрик, а не для всех столбцов
|
791 |
+
columns_to_agg = {metric: 'mean' for metric in metrics}
|
792 |
+
chunk_data = filtered_df.groupby(['words_per_chunk', 'overlap_percentage']).agg(columns_to_agg).reset_index()
|
793 |
+
|
794 |
+
# Получаем уникальные значения процента перекрытия
|
795 |
+
overlap_percentages = sorted(chunk_data['overlap_percentage'].unique())
|
796 |
+
|
797 |
+
# Настраиваем маркеры и линии для разных перекрытий
|
798 |
+
markers = ['o', 's', '^', 'D', 'x', '*']
|
799 |
+
|
800 |
+
# Для каждого перекрытия строим линии с разными метриками
|
801 |
+
for i, overlap in enumerate(overlap_percentages):
|
802 |
+
subset = chunk_data[chunk_data['overlap_percentage'] == overlap].sort_values('words_per_chunk')
|
803 |
+
|
804 |
+
for j, metric in enumerate(metrics):
|
805 |
+
plt.plot(subset['words_per_chunk'], subset[metric],
|
806 |
+
marker=markers[i % len(markers)], linewidth=2,
|
807 |
+
label=f"{metric}, overlap={overlap}%",
|
808 |
+
color=palette[j])
|
809 |
+
|
810 |
+
# Добавляем легенду и заголовок
|
811 |
+
plt.legend(title="Метрика и перекрытие", fontsize=10, loc='best')
|
812 |
+
plt.title(f"Зависимость метрик от размера чанка для {model} (top_n={top_n})", fontsize=16)
|
813 |
+
plt.xlabel("Размер чанка (слов)", fontsize=14)
|
814 |
+
plt.ylabel("Значение метрики", fontsize=14)
|
815 |
+
|
816 |
+
# Включаем сетку
|
817 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
818 |
+
|
819 |
+
# Настраиваем макет
|
820 |
+
plt.tight_layout()
|
821 |
+
|
822 |
+
# Сохраняем график
|
823 |
+
metrics_type = "macro" if metrics[0].startswith("macro_") else "weighted"
|
824 |
+
model_name = model.replace('/', '_')
|
825 |
+
metric_names = '_'.join([m.replace('macro_', '') for m in metrics])
|
826 |
+
file_name = f"chunk_size_{metrics_type}_{metric_names}_{model_name}_top{top_n}.png"
|
827 |
+
plt.savefig(os.path.join(plots_dir, file_name), dpi=300)
|
828 |
+
plt.close()
|
829 |
+
|
830 |
+
print(f"Создан график зависимости от размера чанка: {file_name}")
|
831 |
+
|
832 |
+
|
833 |
+
def create_heatmap(df: pd.DataFrame, models: list | str, metric: str, top_n: int, plots_dir: str) -> None:
|
834 |
+
"""
|
835 |
+
Создает тепловые карты зависимости метрики от размера чанка и процента перекрытия
|
836 |
+
для заданных моделей.
|
837 |
+
|
838 |
+
Args:
|
839 |
+
df: DataFrame с данными
|
840 |
+
models: Список моделей или одна модель
|
841 |
+
metric: Название метрики
|
842 |
+
top_n: Значение top_n
|
843 |
+
plots_dir: Директория для сохранения графиков
|
844 |
+
"""
|
845 |
+
if isinstance(models, str):
|
846 |
+
models = [models]
|
847 |
+
|
848 |
+
# Ограничиваем количество моделей для наглядности
|
849 |
+
if len(models) > 4:
|
850 |
+
models = models[:4]
|
851 |
+
|
852 |
+
# Создаем фигуру с подграфиками
|
853 |
+
fig, axes = plt.subplots(1, len(models), figsize=(6 * len(models), 6), squeeze=False)
|
854 |
+
|
855 |
+
# Для каждой модели создаем тепловую карту
|
856 |
+
for i, model in enumerate(models):
|
857 |
+
# Фильтруем данные для указанной модели и top_n
|
858 |
+
filtered_df = df[(df['model'] == model) & (df['top_n'] == top_n)]
|
859 |
+
|
860 |
+
# Проверяем, достаточно ли данных для построения тепловой карты
|
861 |
+
chunk_sizes = filtered_df['words_per_chunk'].unique()
|
862 |
+
overlap_percentages = filtered_df['overlap_percentage'].unique()
|
863 |
+
|
864 |
+
if len(chunk_sizes) <= 1 or len(overlap_percentages) <= 1:
|
865 |
+
print(f"Недостаточно данных для построения тепловой карты для модели {model} и top_n={top_n}")
|
866 |
+
# Пропускаем этот subplot
|
867 |
+
axes[0, i].text(0.5, 0.5, f"Недостаточно данных для {model}",
|
868 |
+
horizontalalignment='center', verticalalignment='center')
|
869 |
+
axes[0, i].set_title(model)
|
870 |
+
axes[0, i].axis('off')
|
871 |
+
continue
|
872 |
+
|
873 |
+
# Создаем сводную таблицу для тепловой карты, используя только нужную метрику
|
874 |
+
# Сначала выберем только колонки для pivot_table
|
875 |
+
pivot_columns = ['words_per_chunk', 'overlap_percentage', metric]
|
876 |
+
pivot_df = filtered_df[pivot_columns].copy()
|
877 |
+
|
878 |
+
# Теперь создаем сводную таблицу
|
879 |
+
pivot_data = pivot_df.pivot_table(
|
880 |
+
index='words_per_chunk',
|
881 |
+
columns='overlap_percentage',
|
882 |
+
values=metric,
|
883 |
+
aggfunc='mean'
|
884 |
+
)
|
885 |
+
|
886 |
+
# Строим тепловую карту
|
887 |
+
sns.heatmap(pivot_data, annot=True, fmt=".4f", cmap="viridis",
|
888 |
+
linewidths=.5, annot_kws={"size": 8}, ax=axes[0, i])
|
889 |
+
|
890 |
+
# Устанавливаем заголовок и метки осей
|
891 |
+
axes[0, i].set_title(model, fontsize=12)
|
892 |
+
axes[0, i].set_xlabel("Процент перекрытия (%)", fontsize=10)
|
893 |
+
axes[0, i].set_ylabel("Размер чанка (слов)", fontsize=10)
|
894 |
+
|
895 |
+
# Добавляем общий заголовок
|
896 |
+
plt.suptitle(f"Тепловые карты {metric} для разных моделей (top_n={top_n})", fontsize=16)
|
897 |
+
|
898 |
+
# Настраиваем макет
|
899 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Оставляем место для общего заголовка
|
900 |
+
|
901 |
+
# Сохраняем график
|
902 |
+
is_macro = "macro" if "macro" in metric else "weighted"
|
903 |
+
file_name = f"heatmap_{is_macro}_{metric.replace('macro_', '')}_top{top_n}.png"
|
904 |
+
plt.savefig(os.path.join(plots_dir, file_name), dpi=300)
|
905 |
+
plt.close()
|
906 |
+
|
907 |
+
print(f"Созданы тепловые карты: {file_name}")
|
908 |
+
|
909 |
+
|
910 |
+
def find_best_combinations(df: pd.DataFrame, metrics: list | str = None) -> pd.DataFrame:
|
911 |
+
"""
|
912 |
+
Находит наилучшие комбинации параметров на основе агрегированных recall-метрик.
|
913 |
+
|
914 |
+
Args:
|
915 |
+
df: DataFrame с данными
|
916 |
+
metrics: Список метрик для анализа или None (тогда используются все recall-метрики)
|
917 |
+
|
918 |
+
Returns:
|
919 |
+
DataFrame с лучшими комбинациями параметров
|
920 |
+
"""
|
921 |
+
if metrics is None:
|
922 |
+
# По умолчанию выбираем все метрики с "recall" в названии
|
923 |
+
metrics = [col for col in df.columns if "recall" in col]
|
924 |
+
elif isinstance(metrics, str):
|
925 |
+
metrics = [metrics]
|
926 |
+
|
927 |
+
print(f"Поиск лучших комбинаций на основе метрик: {metrics}")
|
928 |
+
|
929 |
+
# Создаем новую метрику - сумму всех указанных recall-метрик
|
930 |
+
df_copy = df.copy()
|
931 |
+
df_copy['combined_recall'] = df_copy[metrics].sum(axis=1)
|
932 |
+
|
933 |
+
# Находим лучшие комбинации для различных значений top_n
|
934 |
+
best_combinations = []
|
935 |
+
|
936 |
+
for top_n in df_copy['top_n'].unique():
|
937 |
+
top_n_df = df_copy[df_copy['top_n'] == top_n]
|
938 |
+
|
939 |
+
if len(top_n_df) == 0:
|
940 |
+
continue
|
941 |
+
|
942 |
+
# Находим строку с максимальным combined_recall
|
943 |
+
best_idx = top_n_df['combined_recall'].idxmax()
|
944 |
+
best_row = top_n_df.loc[best_idx].copy()
|
945 |
+
best_row['best_for_top_n'] = top_n
|
946 |
+
|
947 |
+
best_combinations.append(best_row)
|
948 |
+
|
949 |
+
# Находим лучшие комбинации для разных моделей
|
950 |
+
for model in df_copy['model'].unique():
|
951 |
+
model_df = df_copy[df_copy['model'] == model]
|
952 |
+
|
953 |
+
if len(model_df) == 0:
|
954 |
+
continue
|
955 |
+
|
956 |
+
# Находим строку с максимальным combined_recall
|
957 |
+
best_idx = model_df['combined_recall'].idxmax()
|
958 |
+
best_row = model_df.loc[best_idx].copy()
|
959 |
+
best_row['best_for_model'] = model
|
960 |
+
|
961 |
+
best_combinations.append(best_row)
|
962 |
+
|
963 |
+
# Находим лучшие комбинации для разных размеров чанков
|
964 |
+
for chunk_size in df_copy['words_per_chunk'].unique():
|
965 |
+
chunk_df = df_copy[df_copy['words_per_chunk'] == chunk_size]
|
966 |
+
|
967 |
+
if len(chunk_df) == 0:
|
968 |
+
continue
|
969 |
+
|
970 |
+
# Находим строку с максимальным combined_recall
|
971 |
+
best_idx = chunk_df['combined_recall'].idxmax()
|
972 |
+
best_row = chunk_df.loc[best_idx].copy()
|
973 |
+
best_row['best_for_chunk_size'] = chunk_size
|
974 |
+
|
975 |
+
best_combinations.append(best_row)
|
976 |
+
|
977 |
+
# Находим абсолютно лучшую комбинацию
|
978 |
+
if len(df_copy) > 0:
|
979 |
+
best_idx = df_copy['combined_recall'].idxmax()
|
980 |
+
best_row = df_copy.loc[best_idx].copy()
|
981 |
+
best_row['absolute_best'] = True
|
982 |
+
|
983 |
+
best_combinations.append(best_row)
|
984 |
+
|
985 |
+
if not best_combinations:
|
986 |
+
return pd.DataFrame()
|
987 |
+
|
988 |
+
result = pd.DataFrame(best_combinations)
|
989 |
+
|
990 |
+
# Сортируем по combined_recall (по убыванию)
|
991 |
+
result = result.sort_values('combined_recall', ascending=False)
|
992 |
+
|
993 |
+
print(f"Найдено {len(result)} лучших комбинаций")
|
994 |
+
|
995 |
+
return result
|
996 |
+
|
997 |
+
|
998 |
+
def create_best_combinations_plot(best_df: pd.DataFrame, metrics: list | str, plots_dir: str) -> None:
|
999 |
+
"""
|
1000 |
+
Создает график сравнения лучших комбинаций параметров.
|
1001 |
+
|
1002 |
+
Args:
|
1003 |
+
best_df: DataFrame с лучшими комбинациями
|
1004 |
+
metrics: Список метрик для визуализаци��
|
1005 |
+
plots_dir: Директория для сохранения графиков
|
1006 |
+
"""
|
1007 |
+
if isinstance(metrics, str):
|
1008 |
+
metrics = [metrics]
|
1009 |
+
|
1010 |
+
if len(best_df) == 0:
|
1011 |
+
print("Нет данных для построения графика лучших комбинаций")
|
1012 |
+
return
|
1013 |
+
|
1014 |
+
# Создаем новый признак для идентификации комбинаций
|
1015 |
+
best_df['combo_label'] = best_df.apply(
|
1016 |
+
lambda row: f"{row['model']} (w={row['words_per_chunk']}, o={row['overlap_words']}, top_n={row['top_n']})",
|
1017 |
+
axis=1
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
# Берем только лучшие N комбинаций для читаемости
|
1021 |
+
max_combos = 10
|
1022 |
+
if len(best_df) > max_combos:
|
1023 |
+
plot_df = best_df.head(max_combos).copy()
|
1024 |
+
print(f"Ограничиваем график до {max_combos} лучших комбинаций")
|
1025 |
+
else:
|
1026 |
+
plot_df = best_df.copy()
|
1027 |
+
|
1028 |
+
# Создаем длинный формат данных для seaborn
|
1029 |
+
plot_data = plot_df.melt(
|
1030 |
+
id_vars=['combo_label', 'combined_recall'],
|
1031 |
+
value_vars=metrics,
|
1032 |
+
var_name='metric',
|
1033 |
+
value_name='value'
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
# Сортируем по суммарному recall (комбинации) и метрике (для группировки)
|
1037 |
+
plot_data = plot_data.sort_values(['combined_recall', 'metric'], ascending=[False, True])
|
1038 |
+
|
1039 |
+
# Создаем фигуру для графика
|
1040 |
+
plt.figure(figsize=(14, 10))
|
1041 |
+
|
1042 |
+
# Создаем bar plot
|
1043 |
+
sns.barplot(
|
1044 |
+
x='combo_label',
|
1045 |
+
y='value',
|
1046 |
+
hue='metric',
|
1047 |
+
data=plot_data,
|
1048 |
+
palette='viridis'
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
# Настраиваем оси и заголовок
|
1052 |
+
plt.title('Лучшие комбинации параметров по recall-метрикам', fontsize=16)
|
1053 |
+
plt.xlabel('Комбинация параметров', fontsize=14)
|
1054 |
+
plt.ylabel('Значение метрики', fontsize=14)
|
1055 |
+
|
1056 |
+
# Поворачиваем подписи по оси X для лучшей читаемости
|
1057 |
+
plt.xticks(rotation=45, ha='right', fontsize=10)
|
1058 |
+
|
1059 |
+
# Настраиваем легенду
|
1060 |
+
plt.legend(title='Метрика', fontsize=12)
|
1061 |
+
|
1062 |
+
# Добавляем сетку
|
1063 |
+
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
1064 |
+
|
1065 |
+
# Настраиваем макет
|
1066 |
+
plt.tight_layout()
|
1067 |
+
|
1068 |
+
# Сохраняем график
|
1069 |
+
file_name = f"best_combinations_comparison.png"
|
1070 |
+
plt.savefig(os.path.join(plots_dir, file_name), dpi=300)
|
1071 |
+
plt.close()
|
1072 |
+
|
1073 |
+
print(f"Создан график сравнения лучших комбинаций: {file_name}")
|
1074 |
+
|
1075 |
+
|
1076 |
+
def generate_plots(combined_results: pd.DataFrame, macro_metrics: pd.DataFrame, plots_dir: str) -> None:
|
1077 |
+
"""
|
1078 |
+
Генерирует набор графиков с помощью seaborn и сохраняет их в указанную директорию.
|
1079 |
+
Фокусируется в первую очередь на recall-метриках как наиболее важных.
|
1080 |
+
|
1081 |
+
Args:
|
1082 |
+
combined_results: DataFrame с объединенными результатами (weighted метрики)
|
1083 |
+
macro_metrics: DataFrame с macro метриками
|
1084 |
+
plots_dir: Директория для сохранения графиков
|
1085 |
+
"""
|
1086 |
+
# Создаем директорию для графиков, если она не существует
|
1087 |
+
setup_plot_directory(plots_dir)
|
1088 |
+
|
1089 |
+
# Настраиваем стиль для графиков
|
1090 |
+
sns.set_style("whitegrid")
|
1091 |
+
plt.rcParams['font.family'] = 'DejaVu Sans'
|
1092 |
+
|
1093 |
+
# Получаем список моделей для построения графиков
|
1094 |
+
models = combined_results['model'].unique()
|
1095 |
+
top_n_values = [10, 20, 50, 100]
|
1096 |
+
|
1097 |
+
print(f"Генерация графиков для {len(models)} моделей...")
|
1098 |
+
|
1099 |
+
# 0. Добавляем анализ наилучших комбинаций параметров
|
1100 |
+
# Определяем метрики для анализа - фокусируемся на recall
|
1101 |
+
weighted_recall_metrics = ['text_recall', 'doc_recall']
|
1102 |
+
|
1103 |
+
# Находим лучшие комбинации параметров
|
1104 |
+
best_combinations = find_best_combinations(combined_results, weighted_recall_metrics)
|
1105 |
+
|
1106 |
+
# Создаем график сравнения лучших комбинаций
|
1107 |
+
if not best_combinations.empty:
|
1108 |
+
create_best_combinations_plot(best_combinations, weighted_recall_metrics, plots_dir)
|
1109 |
+
|
1110 |
+
# Если доступны macro метрики, делаем то же самое для них
|
1111 |
+
if macro_metrics is not None:
|
1112 |
+
macro_recall_metrics = ['macro_text_recall', 'macro_doc_recall']
|
1113 |
+
macro_best_combinations = find_best_combinations(macro_metrics, macro_recall_metrics)
|
1114 |
+
|
1115 |
+
if not macro_best_combinations.empty:
|
1116 |
+
create_best_combinations_plot(macro_best_combinations, macro_recall_metrics, plots_dir)
|
1117 |
+
|
1118 |
+
# 1. Создаем графики сравнения моделей для weighted метрик
|
1119 |
+
# Фокусируемся на recall-метриках
|
1120 |
+
weighted_metrics = {
|
1121 |
+
'text': ['text_recall'], # Только text_recall
|
1122 |
+
'doc': ['doc_recall'] # Только doc_recall
|
1123 |
+
}
|
1124 |
+
|
1125 |
+
for top_n in top_n_values:
|
1126 |
+
for metrics_group, metrics in weighted_metrics.items():
|
1127 |
+
create_model_comparison_plot(combined_results, metrics, top_n, plots_dir)
|
1128 |
+
|
1129 |
+
# 2. Если доступны macro метрики, создаем графики на их основе
|
1130 |
+
if macro_metrics is not None:
|
1131 |
+
print("Создание графиков на основе macro метрик...")
|
1132 |
+
macro_metrics_groups = {
|
1133 |
+
'text': ['macro_text_recall'], # Только macro_text_recall
|
1134 |
+
'doc': ['macro_doc_recall'] # Только macro_doc_recall
|
1135 |
+
}
|
1136 |
+
|
1137 |
+
for top_n in top_n_values:
|
1138 |
+
for metrics_group, metrics in macro_metrics_groups.items():
|
1139 |
+
create_model_comparison_plot(macro_metrics, metrics, top_n, plots_dir)
|
1140 |
+
|
1141 |
+
# 3. Создаем графики зависимости от top_n
|
1142 |
+
for metrics_type, df in [("weighted", combined_results), ("macro", macro_metrics)]:
|
1143 |
+
if df is None:
|
1144 |
+
continue
|
1145 |
+
|
1146 |
+
metrics_to_plot = []
|
1147 |
+
if metrics_type == "weighted":
|
1148 |
+
metrics_to_plot = ['text_recall', 'doc_recall'] # Только recall-метрики
|
1149 |
+
else:
|
1150 |
+
metrics_to_plot = ['macro_text_recall', 'macro_doc_recall'] # Только macro recall-метрики
|
1151 |
+
|
1152 |
+
for metric in metrics_to_plot:
|
1153 |
+
create_top_n_plot(df, models, metric, plots_dir)
|
1154 |
+
|
1155 |
+
# 4. Для каждой модели создаем графики по размеру чанка
|
1156 |
+
for model in models:
|
1157 |
+
# Выбираем 2 значения top_n для анализа
|
1158 |
+
for top_n in [20, 50]:
|
1159 |
+
# Создаем графики с recall-метриками
|
1160 |
+
weighted_metrics_to_combine = ['text_recall']
|
1161 |
+
create_chunk_size_plot(combined_results, model, weighted_metrics_to_combine, top_n, plots_dir)
|
1162 |
+
|
1163 |
+
doc_metrics_to_combine = ['doc_recall']
|
1164 |
+
create_chunk_size_plot(combined_results, model, doc_metrics_to_combine, top_n, plots_dir)
|
1165 |
+
|
1166 |
+
# Если есть macro метрики, создаем соответствующие графики
|
1167 |
+
if macro_metrics is not None:
|
1168 |
+
macro_metrics_to_combine = ['macro_text_recall']
|
1169 |
+
create_chunk_size_plot(macro_metrics, model, macro_metrics_to_combine, top_n, plots_dir)
|
1170 |
+
|
1171 |
+
macro_doc_metrics_to_combine = ['macro_doc_recall']
|
1172 |
+
create_chunk_size_plot(macro_metrics, model, macro_doc_metrics_to_combine, top_n, plots_dir)
|
1173 |
+
|
1174 |
+
# 5. Создаем тепловые карты для моделей
|
1175 |
+
for top_n in [20, 50]:
|
1176 |
+
for metric_prefix in ["", "macro_"]:
|
1177 |
+
for metric_type in ["text_recall", "doc_recall"]:
|
1178 |
+
metric = f"{metric_prefix}{metric_type}"
|
1179 |
+
# Используем соответствующий DataFrame
|
1180 |
+
if metric_prefix and macro_metrics is None:
|
1181 |
+
continue
|
1182 |
+
df_to_use = macro_metrics if metric_prefix else combined_results
|
1183 |
+
create_heatmap(df_to_use, models, metric, top_n, plots_dir)
|
1184 |
+
|
1185 |
+
print(f"Создание графиков завершено в директории {plots_dir}")
|
1186 |
+
|
1187 |
+
|
1188 |
+
def print_best_combinations(best_df: pd.DataFrame) -> None:
|
1189 |
+
"""
|
1190 |
+
Выводит информацию о лучших комбинациях параметров.
|
1191 |
+
|
1192 |
+
Args:
|
1193 |
+
best_df: DataFrame с лучшими комбинациями
|
1194 |
+
"""
|
1195 |
+
if best_df.empty:
|
1196 |
+
print("Не найдено лучших комбинаций")
|
1197 |
+
return
|
1198 |
+
|
1199 |
+
print("\n=== ЛУЧШИЕ КОМБИНАЦИИ ПАРАМЕТРОВ ===")
|
1200 |
+
|
1201 |
+
# Выводим абсолютно лучшую комбинацию, если она есть
|
1202 |
+
absolute_best = best_df[best_df.get('absolute_best', False) == True]
|
1203 |
+
if not absolute_best.empty:
|
1204 |
+
row = absolute_best.iloc[0]
|
1205 |
+
print(f"\nАБСОЛЮТНО ЛУЧШАЯ КОМБИНАЦИЯ:")
|
1206 |
+
print(f" Модель: {row['model']}")
|
1207 |
+
print(f" Размер чанка: {row['words_per_chunk']} слов")
|
1208 |
+
print(f" Перекрытие: {row['overlap_words']} слов ({row['overlap_percentage']}%)")
|
1209 |
+
print(f" top_n: {row['top_n']}")
|
1210 |
+
|
1211 |
+
# Выводим значения метрик
|
1212 |
+
recall_metrics = [col for col in best_df.columns if 'recall' in col and col != 'combined_recall']
|
1213 |
+
for metric in recall_metrics:
|
1214 |
+
print(f" {metric}: {row[metric]:.4f}")
|
1215 |
+
|
1216 |
+
print("\n=== ТОП-5 ЛУЧШИХ КОМБИНАЦИЙ ===")
|
1217 |
+
for i, row in best_df.head(5).iterrows():
|
1218 |
+
print(f"\n#{i+1}: {row['model']}, w={row['words_per_chunk']}, o={row['overlap_words']}, top_n={row['top_n']}")
|
1219 |
+
|
1220 |
+
# Выводим значения метрик
|
1221 |
+
recall_metrics = [col for col in best_df.columns if 'recall' in col and col != 'combined_recall']
|
1222 |
+
for metric in recall_metrics:
|
1223 |
+
print(f" {metric}: {row[metric]:.4f}")
|
1224 |
+
|
1225 |
+
print("\n=======================================")
|
1226 |
+
|
1227 |
+
|
1228 |
+
def create_combined_excel(combined_results: pd.DataFrame, question_metrics: pd.DataFrame,
|
1229 |
+
macro_metrics: pd.DataFrame = None, output_file: str = "combined_results.xlsx") -> None:
|
1230 |
+
"""
|
1231 |
+
Создает Excel-файл с несколькими листами, содержащими различные срезы данных.
|
1232 |
+
Добавляет автофильтры и применяет форматирование.
|
1233 |
+
|
1234 |
+
Args:
|
1235 |
+
combined_results: DataFrame с объединенными результатами
|
1236 |
+
question_metrics: DataFrame с метриками по вопросам
|
1237 |
+
macro_metrics: DataFrame с macro метриками (опционально)
|
1238 |
+
output_file: Путь к выходному Excel-файлу
|
1239 |
+
"""
|
1240 |
+
print(f"Создание Excel-файла {output_file}...")
|
1241 |
+
|
1242 |
+
# Создаем новый Excel-файл
|
1243 |
+
workbook = Workbook()
|
1244 |
+
|
1245 |
+
# Удаляем стандартный лист
|
1246 |
+
default_sheet = workbook.active
|
1247 |
+
workbook.remove(default_sheet)
|
1248 |
+
|
1249 |
+
# Подготавливаем данные для различных листов
|
1250 |
+
sheets_data = {
|
1251 |
+
"Исходные данные": combined_results,
|
1252 |
+
"Сводка по моделям": prepare_summary_by_model_top_n(combined_results, macro_metrics),
|
1253 |
+
"Сводка по чанкингу": prepare_summary_by_chunking_params_top_n(combined_results, macro_metrics),
|
1254 |
+
"Лучшие конфигурации": prepare_best_configurations(combined_results, macro_metrics)
|
1255 |
+
}
|
1256 |
+
|
1257 |
+
# Если есть метрики по вопросам, добавляем лист с ними
|
1258 |
+
if question_metrics is not None:
|
1259 |
+
sheets_data["Метрики по вопросам"] = question_metrics
|
1260 |
+
|
1261 |
+
# Если есть macro метрики, добавляем лист с ними
|
1262 |
+
if macro_metrics is not None:
|
1263 |
+
sheets_data["Macro метрики"] = macro_metrics
|
1264 |
+
|
1265 |
+
# Создаем листы и добавляем данные
|
1266 |
+
for sheet_name, data in sheets_data.items():
|
1267 |
+
if data is not None and not data.empty:
|
1268 |
+
sheet = workbook.create_sheet(title=sheet_name)
|
1269 |
+
for r in dataframe_to_rows(data, index=False, header=True):
|
1270 |
+
sheet.append(r)
|
1271 |
+
|
1272 |
+
# Применяем форматирование
|
1273 |
+
apply_formatting(workbook)
|
1274 |
+
|
1275 |
+
# Сохраняем файл
|
1276 |
+
workbook.save(output_file)
|
1277 |
+
print(f"Excel-файл создан: {output_file}")
|
1278 |
+
|
1279 |
+
|
1280 |
+
def calculate_macro_metrics(question_metrics: pd.DataFrame) -> pd.DataFrame:
|
1281 |
+
"""
|
1282 |
+
Вычисляет macro метрики на основе результатов по вопросам.
|
1283 |
+
|
1284 |
+
Args:
|
1285 |
+
question_metrics: DataFrame с метриками по вопросам
|
1286 |
+
|
1287 |
+
Returns:
|
1288 |
+
DataFrame с macro метриками
|
1289 |
+
"""
|
1290 |
+
if question_metrics is None:
|
1291 |
+
return None
|
1292 |
+
|
1293 |
+
print("Вычисление macro метрик на основе метрик по вопросам...")
|
1294 |
+
|
1295 |
+
# Группируем по конфигурации (модель, параметры чанкинга, top_n)
|
1296 |
+
grouped_metrics = question_metrics.groupby(['model', 'words_per_chunk', 'overlap_words', 'top_n'])
|
1297 |
+
|
1298 |
+
# Для каждой группы вычисляем среднее значение метрик (macro)
|
1299 |
+
macro_metrics = grouped_metrics.agg({
|
1300 |
+
'text_precision': 'mean', # Macro precision = среднее precision по всем вопросам
|
1301 |
+
'text_recall': 'mean', # Macro recall = среднее recall по всем вопросам
|
1302 |
+
'text_f1': 'mean', # Macro F1 = среднее F1 по всем вопросам
|
1303 |
+
'doc_precision': 'mean',
|
1304 |
+
'doc_recall': 'mean',
|
1305 |
+
'doc_f1': 'mean'
|
1306 |
+
}).reset_index()
|
1307 |
+
|
1308 |
+
# Добавляем префикс "macro_" к названиям метрик для ясности
|
1309 |
+
for col in ['text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']:
|
1310 |
+
macro_metrics.rename(columns={col: f'macro_{col}'}, inplace=True)
|
1311 |
+
|
1312 |
+
# Добавляем процент перекрытия
|
1313 |
+
macro_metrics['overlap_percentage'] = (macro_metrics['overlap_words'] / macro_metrics['words_per_chunk'] * 100).round(1)
|
1314 |
+
|
1315 |
+
print(f"Вычислено {len(macro_metrics)} набо��ов macro метрик")
|
1316 |
+
|
1317 |
+
return macro_metrics
|
1318 |
+
|
1319 |
+
|
1320 |
+
def main():
|
1321 |
+
"""Основная функция скрипта."""
|
1322 |
+
args = parse_args()
|
1323 |
+
|
1324 |
+
# Загружаем результаты из CSV-файлов
|
1325 |
+
combined_results = load_results_files(args.results_dir)
|
1326 |
+
|
1327 |
+
# Загружаем метрики по вопросам (если есть)
|
1328 |
+
question_metrics = load_question_metrics_files(args.results_dir)
|
1329 |
+
|
1330 |
+
# Вычисляем macro метрики на основе метрик по вопросам
|
1331 |
+
macro_metrics = calculate_macro_metrics(question_metrics)
|
1332 |
+
|
1333 |
+
# Находим лучшие комбинации параметров
|
1334 |
+
best_combinations_weighted = find_best_combinations(combined_results, ['text_recall', 'doc_recall'])
|
1335 |
+
print_best_combinations(best_combinations_weighted)
|
1336 |
+
|
1337 |
+
if macro_metrics is not None:
|
1338 |
+
best_combinations_macro = find_best_combinations(macro_metrics, ['macro_text_recall', 'macro_doc_recall'])
|
1339 |
+
print_best_combinations(best_combinations_macro)
|
1340 |
+
|
1341 |
+
# Создаем объединенный Excel-файл с данными
|
1342 |
+
create_combined_excel(combined_results, question_metrics, macro_metrics, args.output_file)
|
1343 |
+
|
1344 |
+
# Генерируем графики с помощью seaborn
|
1345 |
+
print(f"Генерация графиков и сохранение их в директорию: {args.plots_dir}")
|
1346 |
+
generate_plots(combined_results, macro_metrics, args.plots_dir)
|
1347 |
+
|
1348 |
+
print("Готово! Результаты сохранены в Excel и графики созданы.")
|
1349 |
+
|
1350 |
+
|
1351 |
+
if __name__ == "__main__":
|
1352 |
+
main()
|
lib/extractor/scripts/debug_question_chunks.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для отладки и анализа чанков, найденных для конкретного вопроса.
|
4 |
+
Показывает, какие чанки находятся, какие пункты ожидаются и значения метрик нечеткого сравнения.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from difflib import SequenceMatcher
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
17 |
+
|
18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
19 |
+
|
20 |
+
|
21 |
+
# Константы для настройки
|
22 |
+
DATA_FOLDER = "data/docs" # Путь к папке с документами
|
23 |
+
MODEL_NAME = "intfloat/e5-base" # Название модели для векторизации
|
24 |
+
DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами
|
25 |
+
OUTPUT_DIR = "data" # Директория для сохранения результатов
|
26 |
+
TOP_N_VALUES = [5, 10, 20, 30, 50, 100] # Значения N для анализа
|
27 |
+
THRESHOLD = 0.6
|
28 |
+
|
29 |
+
|
30 |
+
def parse_args():
|
31 |
+
"""
|
32 |
+
Парсит аргументы командной строки.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Аргументы командной строки
|
36 |
+
"""
|
37 |
+
parser = argparse.ArgumentParser(description="Скрипт для отладки чанкинга на конкретном вопросе")
|
38 |
+
|
39 |
+
parser.add_argument("--data-folder", type=str, default=DATA_FOLDER,
|
40 |
+
help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})")
|
41 |
+
parser.add_argument("--model-name", type=str, default=MODEL_NAME,
|
42 |
+
help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})")
|
43 |
+
parser.add_argument("--dataset-path", type=str, default=DATASET_PATH,
|
44 |
+
help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})")
|
45 |
+
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
|
46 |
+
help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})")
|
47 |
+
parser.add_argument("--question-id", type=int, required=True,
|
48 |
+
help="ID вопроса для отладки")
|
49 |
+
parser.add_argument("--top-n", type=int, default=20,
|
50 |
+
help="Количество чанков в топе для отладки (по умолчанию: 20)")
|
51 |
+
parser.add_argument("--words-per-chunk", type=int, default=50,
|
52 |
+
help="Количество слов в чанке для fixed_size стратегии (по умолчанию: 50)")
|
53 |
+
parser.add_argument("--overlap-words", type=int, default=25,
|
54 |
+
help="Количество слов перекрытия для fixed_size стратегии (по умолчанию: 25)")
|
55 |
+
|
56 |
+
return parser.parse_args()
|
57 |
+
|
58 |
+
|
59 |
+
def load_questions_dataset(file_path: str) -> pd.DataFrame:
|
60 |
+
"""
|
61 |
+
Загружает датасет с вопросами из Excel-файла.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
file_path: Путь к Excel-файлу
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
DataFrame с вопросами и пунктами
|
68 |
+
"""
|
69 |
+
print(f"Загрузка датасета из {file_path}...")
|
70 |
+
|
71 |
+
df = pd.read_excel(file_path)
|
72 |
+
print(f"Загружен датасет со столбцами: {df.columns.tolist()}")
|
73 |
+
|
74 |
+
# Преобразуем NaN в пустые строки для текстовых полей
|
75 |
+
text_columns = ['question', 'text', 'item_type']
|
76 |
+
for col in text_columns:
|
77 |
+
if col in df.columns:
|
78 |
+
df[col] = df[col].fillna('')
|
79 |
+
|
80 |
+
return df
|
81 |
+
|
82 |
+
|
83 |
+
def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]:
|
84 |
+
"""
|
85 |
+
Загружает эмбеддинги и соответствующие данные из файлов.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
filename: Базовое имя файла
|
89 |
+
output_dir: Директория, где хранятся файлы
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены
|
93 |
+
"""
|
94 |
+
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
|
95 |
+
data_path = os.path.join(output_dir, f"{filename}_data.csv")
|
96 |
+
|
97 |
+
if os.path.exists(embeddings_path) and os.path.exists(data_path):
|
98 |
+
print(f"Загрузка данных из {embeddings_path} и {data_path}...")
|
99 |
+
embeddings = np.load(embeddings_path)
|
100 |
+
data = pd.read_csv(data_path)
|
101 |
+
return embeddings, data
|
102 |
+
|
103 |
+
print(f"Ошибка: файлы {embeddings_path} и {data_path} не найдены.")
|
104 |
+
print("Сначала запустите скрип�� evaluate_chunking.py для создания эмбеддингов.")
|
105 |
+
sys.exit(1)
|
106 |
+
|
107 |
+
|
108 |
+
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
|
109 |
+
"""
|
110 |
+
Рассчитывает степень перекрытия между чанком и пунктом.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
chunk_text: Текст чанка
|
114 |
+
punct_text: Текст пункта
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
Коэффициент перекрытия от 0 до 1
|
118 |
+
"""
|
119 |
+
# Если чанк входит в пункт, возвращаем 1.0 (полное вхождение)
|
120 |
+
if chunk_text in punct_text:
|
121 |
+
return 1.0
|
122 |
+
|
123 |
+
# Если пункт входит в чанк, возвращаем соотношение длин
|
124 |
+
if punct_text in chunk_text:
|
125 |
+
return len(punct_text) / len(chunk_text)
|
126 |
+
|
127 |
+
# Используем SequenceMatcher для нечеткого сравнения
|
128 |
+
matcher = SequenceMatcher(None, chunk_text, punct_text)
|
129 |
+
|
130 |
+
# Находим наибольшую общую подстроку
|
131 |
+
match = matcher.find_longest_match(0, len(chunk_text), 0, len(punct_text))
|
132 |
+
|
133 |
+
# Если совпадений нет
|
134 |
+
if match.size == 0:
|
135 |
+
return 0.0
|
136 |
+
|
137 |
+
# Возвращаем соотношение длины совпадения к минимальной длине
|
138 |
+
return match.size / min(len(chunk_text), len(punct_text))
|
139 |
+
|
140 |
+
|
141 |
+
def format_text_for_display(text: str, max_length: int = 100) -> str:
|
142 |
+
"""
|
143 |
+
Форматирует текст для отображения, обрезая его при необходимости.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
text: Исходный текст
|
147 |
+
max_length: Максимальная длина для отображения
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
Отформатированный текст
|
151 |
+
"""
|
152 |
+
if len(text) <= max_length:
|
153 |
+
return text
|
154 |
+
return text[:max_length] + "..."
|
155 |
+
|
156 |
+
|
157 |
+
def analyze_question(
|
158 |
+
question_id: int,
|
159 |
+
questions_df: pd.DataFrame,
|
160 |
+
chunks_df: pd.DataFrame,
|
161 |
+
question_embeddings: np.ndarray,
|
162 |
+
chunk_embeddings: np.ndarray,
|
163 |
+
question_id_to_idx: dict,
|
164 |
+
top_n: int
|
165 |
+
) -> dict:
|
166 |
+
"""
|
167 |
+
Анализирует конкретный вопрос и его релевантные чанки.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
question_id: ID вопроса для анализа
|
171 |
+
questions_df: DataFrame с вопросами
|
172 |
+
chunks_df: DataFrame с чанками
|
173 |
+
question_embeddings: Эмбеддинги вопросов
|
174 |
+
chunk_embeddings: Эмбеддинги чанков
|
175 |
+
question_id_to_idx: Словарь соответствия ID вопроса и его индекса
|
176 |
+
top_n: Количество чанков в топе
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
Словарь с результатами анализа
|
180 |
+
"""
|
181 |
+
# Проверяем, есть ли вопрос с таким ID
|
182 |
+
if question_id not in question_id_to_idx:
|
183 |
+
print(f"Ошибка: вопрос с ID {question_id} не найден в данных")
|
184 |
+
sys.exit(1)
|
185 |
+
|
186 |
+
# Получаем строки для выбранного вопроса
|
187 |
+
question_rows = questions_df[questions_df['id'] == question_id]
|
188 |
+
if len(question_rows) == 0:
|
189 |
+
print(f"Ошибка: вопрос с ID {question_id} не найден в исходном датасете")
|
190 |
+
sys.exit(1)
|
191 |
+
|
192 |
+
# Получаем текст вопроса и его индекс в массиве эмбеддингов
|
193 |
+
question_text = question_rows['question'].iloc[0]
|
194 |
+
question_idx = question_id_to_idx[question_id]
|
195 |
+
|
196 |
+
# Получаем ожидаемые пункты для вопроса
|
197 |
+
expected_puncts = question_rows['text'].tolist()
|
198 |
+
|
199 |
+
# Вычисляем косинусную близость между вопросом и всеми чанками
|
200 |
+
similarity = cosine_similarity([question_embeddings[question_idx]], chunk_embeddings)[0]
|
201 |
+
|
202 |
+
# Получаем связанные документы, если есть
|
203 |
+
related_docs = []
|
204 |
+
if 'filename' in question_rows.columns:
|
205 |
+
related_docs = question_rows['filename'].unique().tolist()
|
206 |
+
related_docs = [doc for doc in related_docs if doc and not pd.isna(doc)]
|
207 |
+
|
208 |
+
# Результаты для всех документов
|
209 |
+
all_results = []
|
210 |
+
|
211 |
+
# Обрабатываем каждый связанный документ
|
212 |
+
if related_docs:
|
213 |
+
for doc_name in related_docs:
|
214 |
+
# Фильтруем чанки по имени документа
|
215 |
+
doc_chunks = chunks_df[chunks_df['doc_name'] == doc_name]
|
216 |
+
if doc_chunks.empty:
|
217 |
+
continue
|
218 |
+
|
219 |
+
# Индексы чанков для документа
|
220 |
+
doc_chunk_indices = doc_chunks.index.tolist()
|
221 |
+
|
222 |
+
# Получаем значения близости для чанков документа
|
223 |
+
doc_similarities = [similarity[chunks_df.index.get_loc(idx)] for idx in doc_chunk_indices]
|
224 |
+
|
225 |
+
# Создаем словарь индекс -> схожесть
|
226 |
+
similarity_dict = {idx: sim for idx, sim in zip(doc_chunk_indices, doc_similarities)}
|
227 |
+
|
228 |
+
# Сортируем индексы по убыванию похожести
|
229 |
+
sorted_indices = sorted(similarity_dict.keys(), key=lambda x: similarity_dict[x], reverse=True)
|
230 |
+
|
231 |
+
# Берем топ-N
|
232 |
+
top_indices = sorted_indices[:min(top_n, len(sorted_indices))]
|
233 |
+
|
234 |
+
# Получаем топ-N чанков
|
235 |
+
top_chunks = chunks_df.iloc[top_indices]
|
236 |
+
|
237 |
+
# Формируем результаты для документа
|
238 |
+
doc_results = {
|
239 |
+
'doc_name': doc_name,
|
240 |
+
'top_chunks': []
|
241 |
+
}
|
242 |
+
|
243 |
+
# Для каждого чанка
|
244 |
+
for idx, chunk in top_chunks.iterrows():
|
245 |
+
# Вычисляем перекрытие с каждым пунктом
|
246 |
+
overlaps = []
|
247 |
+
for punct in expected_puncts:
|
248 |
+
overlap = calculate_chunk_overlap(chunk['text'], punct)
|
249 |
+
overlaps.append({
|
250 |
+
'punct': format_text_for_display(punct),
|
251 |
+
'overlap': overlap
|
252 |
+
})
|
253 |
+
|
254 |
+
# Находим максимальное перекрытие
|
255 |
+
max_overlap = max(overlaps, key=lambda x: x['overlap']) if overlaps else {'overlap': 0}
|
256 |
+
|
257 |
+
# Добавляем в результаты
|
258 |
+
doc_results['top_chunks'].append({
|
259 |
+
'chunk_id': chunk['id'],
|
260 |
+
'chunk_text': format_text_for_display(chunk['text']),
|
261 |
+
'similarity': similarity_dict[idx],
|
262 |
+
'overlaps': overlaps,
|
263 |
+
'max_overlap': max_overlap['overlap'],
|
264 |
+
'is_relevant': max_overlap['overlap'] >= THRESHOLD # Используем порог 0.7
|
265 |
+
})
|
266 |
+
|
267 |
+
all_results.append(doc_results)
|
268 |
+
else:
|
269 |
+
# Если нет связанных документов, анализируем чанки из всех документов
|
270 |
+
# Получаем индексы для топ-N чанков по близости
|
271 |
+
top_indices = np.argsort(similarity)[-top_n:][::-1]
|
272 |
+
|
273 |
+
# Получаем топ-N чанков
|
274 |
+
top_chunks = chunks_df.iloc[top_indices]
|
275 |
+
|
276 |
+
# Группируем чанки по документам
|
277 |
+
doc_groups = top_chunks.groupby('doc_name')
|
278 |
+
|
279 |
+
for doc_name, group in doc_groups:
|
280 |
+
doc_results = {
|
281 |
+
'doc_name': doc_name,
|
282 |
+
'top_chunks': []
|
283 |
+
}
|
284 |
+
|
285 |
+
for idx, chunk in group.iterrows():
|
286 |
+
# Вычисляем перекрытие с каждым пунктом
|
287 |
+
overlaps = []
|
288 |
+
for punct in expected_puncts:
|
289 |
+
overlap = calculate_chunk_overlap(chunk['text'], punct)
|
290 |
+
overlaps.append({
|
291 |
+
'punct': format_text_for_display(punct),
|
292 |
+
'overlap': overlap
|
293 |
+
})
|
294 |
+
|
295 |
+
# Находим максимальное перекрытие
|
296 |
+
max_overlap = max(overlaps, key=lambda x: x['overlap']) if overlaps else {'overlap': 0}
|
297 |
+
|
298 |
+
# Добавляем в результаты
|
299 |
+
doc_results['top_chunks'].append({
|
300 |
+
'chunk_id': chunk['id'],
|
301 |
+
'chunk_text': format_text_for_display(chunk['text']),
|
302 |
+
'similarity': similarity[chunks_df.index.get_loc(idx)],
|
303 |
+
'overlaps': overlaps,
|
304 |
+
'max_overlap': max_overlap['overlap'],
|
305 |
+
'is_relevant': max_overlap['overlap'] >= THRESHOLD # Используем порог 0.7
|
306 |
+
})
|
307 |
+
|
308 |
+
all_results.append(doc_results)
|
309 |
+
|
310 |
+
# Формируем общие результаты для вопроса
|
311 |
+
results = {
|
312 |
+
'question_id': question_id,
|
313 |
+
'question_text': question_text,
|
314 |
+
'expected_puncts': [format_text_for_display(punct) for punct in expected_puncts],
|
315 |
+
'related_docs': related_docs,
|
316 |
+
'results_by_doc': all_results
|
317 |
+
}
|
318 |
+
|
319 |
+
return results
|
320 |
+
|
321 |
+
|
322 |
+
def main():
|
323 |
+
"""
|
324 |
+
Основная функция скрипта.
|
325 |
+
"""
|
326 |
+
args = parse_args()
|
327 |
+
|
328 |
+
# Загружаем датасет с вопросами
|
329 |
+
questions_df = load_questions_dataset(args.dataset_path)
|
330 |
+
|
331 |
+
# Формируем уникальное имя для сохраненных файлов на основе параметров стратегии и модел��
|
332 |
+
strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}"
|
333 |
+
chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}"
|
334 |
+
questions_filename = f"questions_{args.model_name.replace('/', '_')}"
|
335 |
+
|
336 |
+
# Загружаем сохраненные эмбеддинги и данные
|
337 |
+
chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir)
|
338 |
+
question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir)
|
339 |
+
|
340 |
+
# Создаем словарь соответствия id вопроса и его индекса в эмбеддингах
|
341 |
+
question_id_to_idx = {
|
342 |
+
int(row['id']): i
|
343 |
+
for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows())
|
344 |
+
}
|
345 |
+
|
346 |
+
# Анализируем выбранный вопрос для указанного top_n
|
347 |
+
results = analyze_question(
|
348 |
+
args.question_id,
|
349 |
+
questions_df,
|
350 |
+
chunks_df,
|
351 |
+
question_embeddings,
|
352 |
+
chunk_embeddings,
|
353 |
+
question_id_to_idx,
|
354 |
+
args.top_n
|
355 |
+
)
|
356 |
+
|
357 |
+
# Сохраняем результаты в JSON файл
|
358 |
+
output_filename = f"debug_question_{args.question_id}_top{args.top_n}.json"
|
359 |
+
output_path = os.path.join(args.output_dir, output_filename)
|
360 |
+
|
361 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
362 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
363 |
+
|
364 |
+
print(f"Результаты сохранены в {output_path}")
|
365 |
+
|
366 |
+
# Выводим краткую информацию
|
367 |
+
print(f"\nАнализ вопроса ID {args.question_id}: {results['question_text']}")
|
368 |
+
print(f"Ожидаемые пункты: {len(results['expected_puncts'])}")
|
369 |
+
print(f"Связанные документы: {results['related_docs']}")
|
370 |
+
|
371 |
+
# Статистика релевантности
|
372 |
+
relevant_chunks = 0
|
373 |
+
total_chunks = 0
|
374 |
+
|
375 |
+
for doc_result in results['results_by_doc']:
|
376 |
+
doc_relevant = sum(1 for chunk in doc_result['top_chunks'] if chunk['is_relevant'])
|
377 |
+
doc_total = len(doc_result['top_chunks'])
|
378 |
+
|
379 |
+
print(f"\nДокумент: {doc_result['doc_name']}")
|
380 |
+
print(f"Релевантных чанков: {doc_relevant} из {doc_total} ({doc_relevant/doc_total*100:.1f}%)")
|
381 |
+
|
382 |
+
relevant_chunks += doc_relevant
|
383 |
+
total_chunks += doc_total
|
384 |
+
|
385 |
+
if total_chunks > 0:
|
386 |
+
print(f"\nОбщая точность: {relevant_chunks/total_chunks*100:.1f}%")
|
387 |
+
else:
|
388 |
+
print("\nНе найдено чанков для анализа")
|
389 |
+
|
390 |
+
|
391 |
+
if __name__ == "__main__":
|
392 |
+
main()
|
lib/extractor/scripts/evaluate_chunking.py
ADDED
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для оценки качества различных стратегий чанкинга.
|
4 |
+
Сравнивает стратегии на основе релевантности чанков к вопросам.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
import torch
|
16 |
+
from fuzzywuzzy import fuzz
|
17 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
18 |
+
from tqdm import tqdm
|
19 |
+
from transformers import AutoModel, AutoTokenizer
|
20 |
+
|
21 |
+
# Константы для настройки
|
22 |
+
DATA_FOLDER = "data/docs" # Путь к папке с документами
|
23 |
+
MODEL_NAME = "intfloat/e5-base" # Название модели для векторизации
|
24 |
+
DATASET_PATH = "data/dataset.xlsx" # Путь к Excel-датасету с вопросами
|
25 |
+
BATCH_SIZE = 8 # Размер батча для векторизации
|
26 |
+
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu" # Устройство для вычислений
|
27 |
+
SIMILARITY_THRESHOLD = 0.7 # Порог для нечеткого сравнения
|
28 |
+
OUTPUT_DIR = "data" # Директория для сохранения результатов
|
29 |
+
TOP_CHUNKS_DIR = "data/top_chunks" # Директория для сохранения топ-чанков
|
30 |
+
TOP_N_VALUES = [5, 10, 20, 30, 50, 70, 100] # Значения N для оценки
|
31 |
+
|
32 |
+
# Параметры стратегий чанкинга
|
33 |
+
FIXED_SIZE_CONFIG = {
|
34 |
+
"words_per_chunk": 50, # Количество слов в чанке
|
35 |
+
"overlap_words": 25 # Количество слов перекрытия
|
36 |
+
}
|
37 |
+
|
38 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
39 |
+
from ntr_fileparser import UniversalParser
|
40 |
+
|
41 |
+
from ntr_text_fragmentation import Destructurer
|
42 |
+
|
43 |
+
|
44 |
+
def _average_pool(
|
45 |
+
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
46 |
+
) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
Расчёт усредненного эмбеддинга по всем токенам
|
49 |
+
|
50 |
+
Args:
|
51 |
+
last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой
|
52 |
+
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size)
|
56 |
+
"""
|
57 |
+
last_hidden = last_hidden_states.masked_fill(
|
58 |
+
~attention_mask[..., None].bool(), 0.0
|
59 |
+
)
|
60 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
61 |
+
|
62 |
+
|
63 |
+
def parse_args():
|
64 |
+
"""
|
65 |
+
Парсит аргументы командной строки.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Аргументы командной строки
|
69 |
+
"""
|
70 |
+
parser = argparse.ArgumentParser(description="Скрипт для оценки качества чанкинга")
|
71 |
+
|
72 |
+
parser.add_argument("--data-folder", type=str, default=DATA_FOLDER,
|
73 |
+
help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})")
|
74 |
+
parser.add_argument("--model-name", type=str, default=MODEL_NAME,
|
75 |
+
help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})")
|
76 |
+
parser.add_argument("--dataset-path", type=str, default=DATASET_PATH,
|
77 |
+
help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})")
|
78 |
+
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
|
79 |
+
help=f"Размер батча для векторизации (по умолчанию: {BATCH_SIZE})")
|
80 |
+
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
|
81 |
+
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
|
82 |
+
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
|
83 |
+
help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})")
|
84 |
+
parser.add_argument("--force-recompute", action="store_true",
|
85 |
+
help="Принудительно пересчитать эмбеддинги, игнорируя сохраненные")
|
86 |
+
parser.add_argument("--use-sentence-transformers", action="store_true",
|
87 |
+
help="Использовать библиотеку sentence_transformers для извлечения эмбеддингов (для FRIDA и других моделей)")
|
88 |
+
parser.add_argument("--device", type=str, default=DEVICE,
|
89 |
+
help=f"Устройст��о для вычислений (по умолчанию: {DEVICE})")
|
90 |
+
|
91 |
+
# Параметры для fixed_size стратегии
|
92 |
+
parser.add_argument("--words-per-chunk", type=int, default=FIXED_SIZE_CONFIG["words_per_chunk"],
|
93 |
+
help=f"Количество слов в чанке для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['words_per_chunk']})")
|
94 |
+
parser.add_argument("--overlap-words", type=int, default=FIXED_SIZE_CONFIG["overlap_words"],
|
95 |
+
help=f"Количество слов перекрытия для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['overlap_words']})")
|
96 |
+
|
97 |
+
return parser.parse_args()
|
98 |
+
|
99 |
+
|
100 |
+
def read_documents(folder_path: str) -> dict:
|
101 |
+
"""
|
102 |
+
Читает все документы из указанной папки.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
folder_path: Путь к папке с документами
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
Словарь {имя_файла: parsed_document}
|
109 |
+
"""
|
110 |
+
print(f"Чтение документов из {folder_path}...")
|
111 |
+
parser = UniversalParser()
|
112 |
+
documents = {}
|
113 |
+
|
114 |
+
for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"):
|
115 |
+
try:
|
116 |
+
doc_name = file_path.stem
|
117 |
+
documents[doc_name] = parser.parse_by_path(str(file_path))
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Ошибка при чтении файла {file_path}: {e}")
|
120 |
+
|
121 |
+
return documents
|
122 |
+
|
123 |
+
|
124 |
+
def process_documents(documents: dict, fixed_size_config: dict) -> pd.DataFrame:
|
125 |
+
"""
|
126 |
+
Обрабатывает документы со стратегией fixed_size для чанкинга.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
documents: Словарь с распарсенными документами
|
130 |
+
fixed_size_config: Конфигурация для fixed_size стратегии
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
DataFrame с чанками
|
134 |
+
"""
|
135 |
+
print("Обработка документов стратегией fixed_size...")
|
136 |
+
|
137 |
+
all_data = []
|
138 |
+
|
139 |
+
for doc_name, document in tqdm(documents.items(), desc="Применение стратегии fixed_size"):
|
140 |
+
# Стратегия fixed_size для чанкинга
|
141 |
+
destructurer = Destructurer(document)
|
142 |
+
destructurer.configure('fixed_size',
|
143 |
+
words_per_chunk=fixed_size_config["words_per_chunk"],
|
144 |
+
overlap_words=fixed_size_config["overlap_words"])
|
145 |
+
fixed_size_entities, _ = destructurer.destructure()
|
146 |
+
|
147 |
+
# Обрабатываем только сущности для поиска
|
148 |
+
for entity in fixed_size_entities:
|
149 |
+
if hasattr(entity, 'use_in_search') and entity.use_in_search:
|
150 |
+
entity_data = {
|
151 |
+
'id': str(entity.id),
|
152 |
+
'doc_name': doc_name,
|
153 |
+
'name': entity.name,
|
154 |
+
'text': entity.text,
|
155 |
+
'type': entity.type,
|
156 |
+
'strategy': 'fixed_size',
|
157 |
+
'metadata': json.dumps(entity.metadata, ensure_ascii=False)
|
158 |
+
}
|
159 |
+
all_data.append(entity_data)
|
160 |
+
|
161 |
+
# Создаем DataFrame
|
162 |
+
df = pd.DataFrame(all_data)
|
163 |
+
|
164 |
+
# Фильтруем по типу, исключая Document
|
165 |
+
df = df[df['type'] != 'Document']
|
166 |
+
|
167 |
+
return df
|
168 |
+
|
169 |
+
|
170 |
+
def load_questions_dataset(file_path: str) -> pd.DataFrame:
|
171 |
+
"""
|
172 |
+
Загружает датасет с вопросами из Excel-файла.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
file_path: Путь к Excel-файлу
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
DataFrame с вопросами и пунктами
|
179 |
+
"""
|
180 |
+
print(f"Загрузка датасета из {file_path}...")
|
181 |
+
|
182 |
+
df = pd.read_excel(file_path)
|
183 |
+
print(f"Загружен датасет со столбцами: {df.columns.tolist()}")
|
184 |
+
|
185 |
+
# Преобразуем NaN в пустые строки для текстовых полей
|
186 |
+
text_columns = ['question', 'text', 'item_type']
|
187 |
+
for col in text_columns:
|
188 |
+
if col in df.columns:
|
189 |
+
df[col] = df[col].fillna('')
|
190 |
+
|
191 |
+
return df
|
192 |
+
|
193 |
+
|
194 |
+
def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool = False, device: str = DEVICE):
|
195 |
+
"""
|
196 |
+
Инициализирует модель и токенизатор.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
model_name: Название предобученной модели
|
200 |
+
use_sentence_transformers: Использовать ли библиотеку sentence_transformers
|
201 |
+
device: Устройство для вычислений
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Кортеж (модель, токенизатор) или объект SentenceTransformer
|
205 |
+
"""
|
206 |
+
print(f"Загрузка модели {model_name} на устройство {device}...")
|
207 |
+
|
208 |
+
if use_sentence_transformers:
|
209 |
+
try:
|
210 |
+
from sentence_transformers import SentenceTransformer
|
211 |
+
model = SentenceTransformer(model_name, device=device)
|
212 |
+
return model, None
|
213 |
+
except ImportError:
|
214 |
+
print("Библиотека sentence_transformers не установлена. Установите её с помощью pip install sentence-transformers")
|
215 |
+
raise
|
216 |
+
else:
|
217 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
218 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
219 |
+
model.eval()
|
220 |
+
|
221 |
+
return model, tokenizer
|
222 |
+
|
223 |
+
|
224 |
+
def get_embeddings(texts: list[str], model, tokenizer=None, batch_size: int = BATCH_SIZE, use_sentence_transformers: bool = False, device: str = DEVICE) -> np.ndarray:
|
225 |
+
"""
|
226 |
+
Получает эмбеддинги для списка текстов с использованием average pooling или sentence_transformers.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
texts: Список текстов
|
230 |
+
model: Модель для векторизации или SentenceTransformer
|
231 |
+
tokenizer: Токенизатор (None для sentence_transformers)
|
232 |
+
batch_size: Размер батча
|
233 |
+
use_sentence_transformers: Использовать ли библиотеку sentence_transformers
|
234 |
+
device: Устройство для вычислений
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
Массив эмбеддингов
|
238 |
+
"""
|
239 |
+
if use_sentence_transformers:
|
240 |
+
# Используем sentence_transformers для получения эмбеддингов
|
241 |
+
all_embeddings = []
|
242 |
+
|
243 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов (sentence_transformers)"):
|
244 |
+
batch_texts = texts[i:i+batch_size]
|
245 |
+
|
246 |
+
# Получаем эмбеддинги с помощью sentence_transformers
|
247 |
+
embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=False)
|
248 |
+
all_embeddings.append(embeddings)
|
249 |
+
|
250 |
+
return np.vstack(all_embeddings)
|
251 |
+
else:
|
252 |
+
# Используем стандартный подход с average pooling
|
253 |
+
all_embeddings = []
|
254 |
+
|
255 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов"):
|
256 |
+
batch_texts = texts[i:i+batch_size]
|
257 |
+
|
258 |
+
# Токенизация с обрезкой и padding
|
259 |
+
encoding = tokenizer(
|
260 |
+
batch_texts,
|
261 |
+
padding=True,
|
262 |
+
truncation=True,
|
263 |
+
max_length=512,
|
264 |
+
return_tensors="pt"
|
265 |
+
).to(device)
|
266 |
+
|
267 |
+
# Получаем эмбеддинги с average pooling
|
268 |
+
with torch.no_grad():
|
269 |
+
outputs = model(**encoding)
|
270 |
+
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
|
271 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
272 |
+
|
273 |
+
return np.vstack(all_embeddings)
|
274 |
+
|
275 |
+
|
276 |
+
def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
|
277 |
+
"""
|
278 |
+
Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
chunk_text: Текст чанка
|
282 |
+
punct_text: Текст пункта
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
Коэффициент перекрытия от 0 до 1
|
286 |
+
"""
|
287 |
+
# Если чанк входит в пункт, возвращаем 1.0 (полное вхождение)
|
288 |
+
if chunk_text in punct_text:
|
289 |
+
return 1.0
|
290 |
+
|
291 |
+
# Если пункт входит в чанк, возвращаем соотношение длин
|
292 |
+
if punct_text in chunk_text:
|
293 |
+
return len(punct_text) / len(chunk_text)
|
294 |
+
|
295 |
+
# Используем partial_ratio из fuzzywuzzy, который лучше обрабатывает
|
296 |
+
# случаи, когда один текст является подстрокой другого, даже с небольшими различиями
|
297 |
+
partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0
|
298 |
+
|
299 |
+
return partial_ratio_score
|
300 |
+
|
301 |
+
|
302 |
+
def save_embeddings_and_data(embeddings: np.ndarray, data: pd.DataFrame, filename: str, output_dir: str):
|
303 |
+
"""
|
304 |
+
Сохраняет эмбеддинги и соответствующие данные в файлы.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
embeddings: Массив эмбеддингов
|
308 |
+
data: DataFrame с данными
|
309 |
+
filename: Базовое имя файла
|
310 |
+
output_dir: Директория для сохранения
|
311 |
+
"""
|
312 |
+
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
|
313 |
+
data_path = os.path.join(output_dir, f"{filename}_data.csv")
|
314 |
+
|
315 |
+
# Сохраняем эмбеддинги
|
316 |
+
np.save(embeddings_path, embeddings)
|
317 |
+
print(f"Эмбеддинги сохранены в {embeddings_path}")
|
318 |
+
|
319 |
+
# Сохраняем данные
|
320 |
+
data.to_csv(data_path, index=False)
|
321 |
+
print(f"Данные сохранены в {data_path}")
|
322 |
+
|
323 |
+
|
324 |
+
def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]:
|
325 |
+
"""
|
326 |
+
Загружает эмбеддинги и соответствующие данные из файлов.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
filename: Базовое имя файла
|
330 |
+
output_dir: Директория, где хранятся файлы
|
331 |
+
|
332 |
+
Returns:
|
333 |
+
Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены
|
334 |
+
"""
|
335 |
+
embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
|
336 |
+
data_path = os.path.join(output_dir, f"{filename}_data.csv")
|
337 |
+
|
338 |
+
if os.path.exists(embeddings_path) and os.path.exists(data_path):
|
339 |
+
print(f"Загрузка данных из {embeddings_path} и {data_path}...")
|
340 |
+
embeddings = np.load(embeddings_path)
|
341 |
+
data = pd.read_csv(data_path)
|
342 |
+
return embeddings, data
|
343 |
+
|
344 |
+
return None, None
|
345 |
+
|
346 |
+
|
347 |
+
def save_top_chunks_for_question(
|
348 |
+
question_id: int,
|
349 |
+
question_text: str,
|
350 |
+
question_puncts: list[str],
|
351 |
+
top_chunks: pd.DataFrame,
|
352 |
+
similarities: dict,
|
353 |
+
overlap_data: list,
|
354 |
+
output_dir: str
|
355 |
+
):
|
356 |
+
"""
|
357 |
+
Сохраняет топ-чанки для конкретного вопроса в JSON-файл.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
question_id: ID вопроса
|
361 |
+
question_text: Текст вопроса
|
362 |
+
question_puncts: Список пунктов, относящихся к вопросу
|
363 |
+
top_chunks: DataFrame с топ-чанками
|
364 |
+
similarities: Словарь с косинусными схожестями для чанков
|
365 |
+
overlap_data: Данные о перекрытии чанков с пунктами
|
366 |
+
output_dir: Директория для сохранения
|
367 |
+
"""
|
368 |
+
# Подготавливаем результаты для сохранения
|
369 |
+
chunks_data = []
|
370 |
+
|
371 |
+
for i, (idx, chunk) in enumerate(top_chunks.iterrows()):
|
372 |
+
# Получаем данные о перекрытии для текущего чанка
|
373 |
+
chunk_overlaps = overlap_data[i] if i < len(overlap_data) else []
|
374 |
+
|
375 |
+
# Преобразуем numpy типы в стандартные типы Python
|
376 |
+
similarity = float(similarities.get(idx, 0.0))
|
377 |
+
|
378 |
+
# Формируем данные чанка
|
379 |
+
chunk_data = {
|
380 |
+
'chunk_id': chunk['id'],
|
381 |
+
'doc_name': chunk['doc_name'],
|
382 |
+
'text': chunk['text'],
|
383 |
+
'similarity': similarity,
|
384 |
+
'overlaps': chunk_overlaps
|
385 |
+
}
|
386 |
+
chunks_data.append(chunk_data)
|
387 |
+
|
388 |
+
# Преобразуем numpy.int64 в int для question_id
|
389 |
+
question_id = int(question_id)
|
390 |
+
|
391 |
+
# Формируем общий результат
|
392 |
+
result = {
|
393 |
+
'question_id': question_id,
|
394 |
+
'question_text': question_text,
|
395 |
+
'puncts': question_puncts,
|
396 |
+
'chunks': chunks_data
|
397 |
+
}
|
398 |
+
|
399 |
+
# Создаем имя файла
|
400 |
+
filename = f"question_{question_id}_top_chunks.json"
|
401 |
+
filepath = os.path.join(output_dir, filename)
|
402 |
+
|
403 |
+
# Класс для сериализации numpy типов
|
404 |
+
class NumpyEncoder(json.JSONEncoder):
|
405 |
+
def default(self, obj):
|
406 |
+
if isinstance(obj, np.integer):
|
407 |
+
return int(obj)
|
408 |
+
if isinstance(obj, np.floating):
|
409 |
+
return float(obj)
|
410 |
+
if isinstance(obj, np.ndarray):
|
411 |
+
return obj.tolist()
|
412 |
+
return super().default(obj)
|
413 |
+
|
414 |
+
# Сохраняем в JSON с кастомным энкодером
|
415 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
416 |
+
json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)
|
417 |
+
|
418 |
+
print(f"Топ-чанки для вопроса {question_id} сохранены в {filepath}")
|
419 |
+
|
420 |
+
|
421 |
+
def evaluate_for_top_n_with_mapping(
|
422 |
+
questions_df: pd.DataFrame,
|
423 |
+
chunks_df: pd.DataFrame,
|
424 |
+
question_embeddings: np.ndarray,
|
425 |
+
chunk_embeddings: np.ndarray,
|
426 |
+
question_id_to_idx: dict,
|
427 |
+
top_n: int,
|
428 |
+
similarity_threshold: float,
|
429 |
+
top_chunks_dir: str = None
|
430 |
+
) -> tuple[dict[str, float], pd.DataFrame]:
|
431 |
+
"""
|
432 |
+
Оценивает качество чанкинга для заданного значения top_n с использованием маппинга id -> индекс.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
questions_df: DataFrame с вопросами и релевантными пунктами (исходный датасет)
|
436 |
+
chunks_df: DataFrame с чанками
|
437 |
+
question_embeddings: Эмбеддинги вопросов
|
438 |
+
chunk_embeddings: Эмбеддинги чанков
|
439 |
+
question_id_to_idx: Словарь соответствия id вопроса и его индекса в массиве эмбеддингов
|
440 |
+
top_n: Количество чанков в топе ��ля каждого вопроса
|
441 |
+
similarity_threshold: Порог для нечеткого сравнения
|
442 |
+
top_chunks_dir: Директория для сохранения топ-чанков (если None, то не сохраняем)
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
Кортеж (словарь с усредненными метриками, DataFrame с метриками по отдельным вопросам)
|
446 |
+
"""
|
447 |
+
print(f"Оценка для top-{top_n}...")
|
448 |
+
|
449 |
+
# Вычисляем косинусную близость между вопросами и чанками
|
450 |
+
similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
|
451 |
+
|
452 |
+
# Счетчики для метрик на основе текста
|
453 |
+
total_puncts = 0
|
454 |
+
found_puncts = 0
|
455 |
+
total_chunks = 0
|
456 |
+
relevant_chunks = 0
|
457 |
+
|
458 |
+
# Счетчики для метрик на основе документов
|
459 |
+
total_docs_required = 0
|
460 |
+
found_relevant_docs = 0
|
461 |
+
total_docs_found = 0
|
462 |
+
|
463 |
+
# Для сохранения метрик по отдельным вопросам
|
464 |
+
question_metrics = []
|
465 |
+
|
466 |
+
# Выводим информацию о столбцах для отладки
|
467 |
+
print(f"Столбцы в исходном датасете: {questions_df.columns.tolist()}")
|
468 |
+
|
469 |
+
# Группируем вопросы по id (у нас 20 уникальных вопросов)
|
470 |
+
for question_id in tqdm(questions_df['id'].unique(), desc=f"Оценка top-{top_n}"):
|
471 |
+
# Получаем строки для текущего вопроса из исходного датасета
|
472 |
+
question_rows = questions_df[questions_df['id'] == question_id]
|
473 |
+
|
474 |
+
# Проверяем, есть ли вопрос с таким id в нашем маппинге
|
475 |
+
if question_id not in question_id_to_idx:
|
476 |
+
print(f"Предупреждение: вопрос с id {question_id} отсутствует в маппинге")
|
477 |
+
continue
|
478 |
+
|
479 |
+
# Если нет строк с таким id, пропускаем
|
480 |
+
if len(question_rows) == 0:
|
481 |
+
continue
|
482 |
+
|
483 |
+
# Получаем индекс вопроса в массиве эмбеддингов
|
484 |
+
question_idx = question_id_to_idx[question_id]
|
485 |
+
|
486 |
+
# Получаем текст вопроса
|
487 |
+
question_text = question_rows['question'].iloc[0]
|
488 |
+
|
489 |
+
# Получаем все пункты для этого вопроса
|
490 |
+
puncts = question_rows['text'].tolist()
|
491 |
+
question_total_puncts = len(puncts)
|
492 |
+
total_puncts += question_total_puncts
|
493 |
+
|
494 |
+
# Получаем связанные документы
|
495 |
+
relevant_docs = []
|
496 |
+
if 'filename' in question_rows.columns:
|
497 |
+
relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)]
|
498 |
+
question_total_docs_required = len(relevant_docs)
|
499 |
+
total_docs_required += question_total_docs_required
|
500 |
+
print(f"Найдено {question_total_docs_required} документов для вопроса {question_id}")
|
501 |
+
else:
|
502 |
+
print(f"Столбец 'filename' отсутствует. Используем все документы.")
|
503 |
+
relevant_docs = chunks_df['doc_name'].unique().tolist()
|
504 |
+
question_total_docs_required = len(relevant_docs)
|
505 |
+
total_docs_required += question_total_docs_required
|
506 |
+
|
507 |
+
# Если для вопроса нет релевантных документов, пропускаем
|
508 |
+
if not relevant_docs:
|
509 |
+
print(f"Для вопроса {question_id} нет связанных документов")
|
510 |
+
continue
|
511 |
+
|
512 |
+
# Флаги для отслеживания найденных пунктов
|
513 |
+
punct_found = [False] * question_total_puncts
|
514 |
+
|
515 |
+
# Для отслеживания найденных документов
|
516 |
+
docs_found_for_question = set()
|
517 |
+
|
518 |
+
# Для хранения всех чанков вопроса для ограничения top_n
|
519 |
+
all_question_chunks = []
|
520 |
+
all_question_similarities = []
|
521 |
+
|
522 |
+
# Собираем чанки для всех документов по этому вопросу
|
523 |
+
for filename in relevant_docs:
|
524 |
+
if not filename or pd.isna(filename):
|
525 |
+
continue
|
526 |
+
|
527 |
+
# Фильтруем чанки по имени файла
|
528 |
+
doc_chunks = chunks_df[chunks_df['doc_name'] == filename]
|
529 |
+
|
530 |
+
if doc_chunks.empty:
|
531 |
+
print(f"Предупреждение: документ {filename} не содержит чанков")
|
532 |
+
continue
|
533 |
+
|
534 |
+
# Индексы чанков для текущего файла
|
535 |
+
doc_chunk_indices = doc_chunks.index.tolist()
|
536 |
+
|
537 |
+
# Получаем значения близости для чанков текущего файла
|
538 |
+
doc_similarities = [
|
539 |
+
similarity_matrix[question_idx, chunks_df.index.get_loc(idx)]
|
540 |
+
for idx in doc_chunk_indices
|
541 |
+
]
|
542 |
+
|
543 |
+
# Добавляем чанки и их схожести к общему списку для вопроса
|
544 |
+
for i, idx in enumerate(doc_chunk_indices):
|
545 |
+
all_question_chunks.append((idx, doc_chunks.iloc[doc_chunks.index.get_indexer([idx])[0]]))
|
546 |
+
all_question_similarities.append(doc_similarities[i])
|
547 |
+
|
548 |
+
# Сортируем все чанки по убыванию схожести и берем top_n
|
549 |
+
sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1]
|
550 |
+
top_chunks_indices = [all_question_chunks[i][0] for i in sorted_indices]
|
551 |
+
top_chunks = [all_question_chunks[i][1] for i in sorted_indices]
|
552 |
+
|
553 |
+
# Увеличиваем счетчик общего числа чанков
|
554 |
+
question_total_chunks = len(top_chunks)
|
555 |
+
total_chunks += question_total_chunks
|
556 |
+
|
557 |
+
# Для сохранения данных топ-чанков
|
558 |
+
all_top_chunks = pd.DataFrame([chunk for chunk in top_chunks])
|
559 |
+
all_chunk_similarities = {idx: all_question_similarities[i] for i, idx in enumerate([all_question_chunks[j][0] for j in sorted_indices])}
|
560 |
+
all_chunk_overlaps = []
|
561 |
+
|
562 |
+
# Для каждого чанка проверяем его релевантность к пунктам
|
563 |
+
question_relevant_chunks = 0
|
564 |
+
|
565 |
+
for i, chunk in enumerate(top_chunks):
|
566 |
+
is_relevant = False
|
567 |
+
chunk_overlaps = []
|
568 |
+
|
569 |
+
# Добавляем документ в найденные
|
570 |
+
docs_found_for_question.add(chunk['doc_name'])
|
571 |
+
|
572 |
+
# Проверяем перекрытие с каждым пунктом
|
573 |
+
for j, punct in enumerate(puncts):
|
574 |
+
overlap = calculate_chunk_overlap(chunk['text'], punct)
|
575 |
+
|
576 |
+
# Если нужно сохранить топ-чанки и top_n == 20
|
577 |
+
if top_chunks_dir and top_n == 20:
|
578 |
+
chunk_overlaps.append({
|
579 |
+
'punct_index': j,
|
580 |
+
'punct_text': punct[:100] + '...' if len(punct) > 100 else punct,
|
581 |
+
'overlap': overlap
|
582 |
+
})
|
583 |
+
|
584 |
+
# Если перекрытие больше порога, чанк релевантен
|
585 |
+
if overlap >= similarity_threshold:
|
586 |
+
is_relevant = True
|
587 |
+
punct_found[j] = True
|
588 |
+
|
589 |
+
if is_relevant:
|
590 |
+
question_relevant_chunks += 1
|
591 |
+
|
592 |
+
# Если нужно сохранить топ-чанки и top_n == 20
|
593 |
+
if top_chunks_dir and top_n == 20:
|
594 |
+
all_chunk_overlaps.append(chunk_overlaps)
|
595 |
+
|
596 |
+
# Если нужно сохранить топ-чанки и top_n == 20
|
597 |
+
if top_chunks_dir and top_n == 20 and not all_top_chunks.empty:
|
598 |
+
save_top_chunks_for_question(
|
599 |
+
question_id,
|
600 |
+
question_text,
|
601 |
+
puncts,
|
602 |
+
all_top_chunks,
|
603 |
+
all_chunk_similarities,
|
604 |
+
all_chunk_overlaps,
|
605 |
+
top_chunks_dir
|
606 |
+
)
|
607 |
+
|
608 |
+
# Подсчитываем метрики для текущего вопроса
|
609 |
+
question_found_puncts = sum(punct_found)
|
610 |
+
found_puncts += question_found_puncts
|
611 |
+
|
612 |
+
relevant_chunks += question_relevant_chunks
|
613 |
+
|
614 |
+
# Обновляем метрики для документов
|
615 |
+
question_found_relevant_docs = sum(1 for doc in docs_found_for_question if doc in relevant_docs)
|
616 |
+
found_relevant_docs += question_found_relevant_docs
|
617 |
+
question_total_docs_found = len(docs_found_for_question)
|
618 |
+
total_docs_found += question_total_docs_found
|
619 |
+
|
620 |
+
# Вычисляем метрики для текущего вопроса
|
621 |
+
question_text_precision = question_relevant_chunks / question_total_chunks if question_total_chunks > 0 else 0
|
622 |
+
question_text_recall = question_found_puncts / question_total_puncts if question_total_puncts > 0 else 0
|
623 |
+
question_text_f1 = 2 * question_text_precision * question_text_recall / (question_text_precision + question_text_recall) if question_text_precision + question_text_recall > 0 else 0
|
624 |
+
|
625 |
+
question_doc_precision = question_found_relevant_docs / question_total_docs_found if question_total_docs_found > 0 else 0
|
626 |
+
question_doc_recall = question_found_relevant_docs / question_total_docs_required if question_total_docs_required > 0 else 0
|
627 |
+
question_doc_f1 = 2 * question_doc_precision * question_doc_recall / (question_doc_precision + question_doc_recall) if question_doc_precision + question_doc_recall > 0 else 0
|
628 |
+
|
629 |
+
# Сохраняем метрики вопроса
|
630 |
+
question_metrics.append({
|
631 |
+
'question_id': question_id,
|
632 |
+
'question_text': question_text,
|
633 |
+
'top_n': top_n,
|
634 |
+
'text_precision': question_text_precision,
|
635 |
+
'text_recall': question_text_recall,
|
636 |
+
'text_f1': question_text_f1,
|
637 |
+
'doc_precision': question_doc_precision,
|
638 |
+
'doc_recall': question_doc_recall,
|
639 |
+
'doc_f1': question_doc_f1,
|
640 |
+
'found_puncts': question_found_puncts,
|
641 |
+
'total_puncts': question_total_puncts,
|
642 |
+
'relevant_chunks': question_relevant_chunks,
|
643 |
+
'total_chunks': question_total_chunks,
|
644 |
+
'found_relevant_docs': question_found_relevant_docs,
|
645 |
+
'total_docs_required': question_total_docs_required,
|
646 |
+
'total_docs_found': question_total_docs_found
|
647 |
+
})
|
648 |
+
|
649 |
+
# Вычисляем метрики для текста
|
650 |
+
text_precision = relevant_chunks / total_chunks if total_chunks > 0 else 0
|
651 |
+
text_recall = found_puncts / total_puncts if total_puncts > 0 else 0
|
652 |
+
text_f1 = 2 * text_precision * text_recall / (text_precision + text_recall) if text_precision + text_recall > 0 else 0
|
653 |
+
|
654 |
+
# Вычисляем метрики для документов
|
655 |
+
doc_precision = found_relevant_docs / total_docs_found if total_docs_found > 0 else 0
|
656 |
+
doc_recall = found_relevant_docs / total_docs_required if total_docs_required > 0 else 0
|
657 |
+
doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0
|
658 |
+
|
659 |
+
aggregated_metrics = {
|
660 |
+
'top_n': top_n,
|
661 |
+
'text_precision': text_precision,
|
662 |
+
'text_recall': text_recall,
|
663 |
+
'text_f1': text_f1,
|
664 |
+
'doc_precision': doc_precision,
|
665 |
+
'doc_recall': doc_recall,
|
666 |
+
'doc_f1': doc_f1,
|
667 |
+
'found_puncts': found_puncts,
|
668 |
+
'total_puncts': total_puncts,
|
669 |
+
'relevant_chunks': relevant_chunks,
|
670 |
+
'total_chunks': total_chunks,
|
671 |
+
'found_relevant_docs': found_relevant_docs,
|
672 |
+
'total_docs_required': total_docs_required,
|
673 |
+
'total_docs_found': total_docs_found
|
674 |
+
}
|
675 |
+
|
676 |
+
return aggregated_metrics, pd.DataFrame(question_metrics)
|
677 |
+
|
678 |
+
|
679 |
+
def main():
|
680 |
+
"""
|
681 |
+
Основная функция скрипта.
|
682 |
+
"""
|
683 |
+
args = parse_args()
|
684 |
+
|
685 |
+
# Устанавливаем устройство из аргументов
|
686 |
+
device = args.device
|
687 |
+
|
688 |
+
# Создаем выходной каталог, если его нет
|
689 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
690 |
+
|
691 |
+
# Создаем директорию для топ-чанков
|
692 |
+
top_chunks_dir = os.path.join(args.output_dir, "top_chunks")
|
693 |
+
os.makedirs(top_chunks_dir, exist_ok=True)
|
694 |
+
|
695 |
+
# Загружаем датасет с вопросами
|
696 |
+
questions_df = load_questions_dataset(args.dataset_path)
|
697 |
+
|
698 |
+
# Формируем уникальное имя для сохраняемых файлов на основе параметров стратегии и модели
|
699 |
+
strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}"
|
700 |
+
chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}"
|
701 |
+
questions_filename = f"questions_{args.model_name.replace('/', '_')}"
|
702 |
+
|
703 |
+
# Пытаемся загрузить сохраненные эмбеддинги и данные
|
704 |
+
chunk_embeddings, chunks_df = None, None
|
705 |
+
question_embeddings, questions_df_with_embeddings = None, None
|
706 |
+
|
707 |
+
if not args.force_recompute:
|
708 |
+
chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir)
|
709 |
+
question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir)
|
710 |
+
|
711 |
+
# Если не удалось загрузить данные или включен режим принудительного пересчета
|
712 |
+
if chunk_embeddings is None or chunks_df is None:
|
713 |
+
# Читаем и обрабатываем документы
|
714 |
+
documents = read_documents(args.data_folder)
|
715 |
+
|
716 |
+
# Формируем конфигурацию для стратегии fixed_size
|
717 |
+
fixed_size_config = {
|
718 |
+
"words_per_chunk": args.words_per_chunk,
|
719 |
+
"overlap_words": args.overlap_words
|
720 |
+
}
|
721 |
+
|
722 |
+
# Получаем DataFrame с чанками
|
723 |
+
chunks_df = process_documents(documents, fixed_size_config)
|
724 |
+
|
725 |
+
# Настраиваем модель и токенизатор
|
726 |
+
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
|
727 |
+
|
728 |
+
# Получаем эмбеддинги для чанков
|
729 |
+
chunk_embeddings = get_embeddings(chunks_df['text'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
|
730 |
+
|
731 |
+
# Сохраняем эмбеддинги и данные
|
732 |
+
save_embeddings_and_data(chunk_embeddings, chunks_df, chunks_filename, args.output_dir)
|
733 |
+
|
734 |
+
# Если не удалось загрузить эмбеддинги вопросов или включен режим принудительного пересчета
|
735 |
+
if question_embeddings is None or questions_df_with_embeddings is None:
|
736 |
+
# Получаем уникальные вопросы (по id)
|
737 |
+
unique_questions = questions_df.drop_duplicates(subset=['id'])[['id', 'question']]
|
738 |
+
|
739 |
+
# Настраиваем модель и токенизатор (если еще не настроены)
|
740 |
+
if 'model' not in locals() or 'tokenizer' not in locals():
|
741 |
+
model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
|
742 |
+
|
743 |
+
# Получаем эмбеддинги для вопросов
|
744 |
+
question_embeddings = get_embeddings(unique_questions['question'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
|
745 |
+
|
746 |
+
# Сохраняем эмбеддинги и данные
|
747 |
+
save_embeddings_and_data(question_embeddings, unique_questions, questions_filename, args.output_dir)
|
748 |
+
|
749 |
+
# Устанавливаем questions_df_with_embeddings для дальнейшего использования
|
750 |
+
questions_df_with_embeddings = unique_questions
|
751 |
+
|
752 |
+
# Создаем словарь соответствия id вопроса и его индекса в эмбеддингах
|
753 |
+
question_id_to_idx = {
|
754 |
+
row['id']: i
|
755 |
+
for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows())
|
756 |
+
}
|
757 |
+
|
758 |
+
# Оцениваем стратегию чанкинга для разных значений top_n
|
759 |
+
aggregated_results = []
|
760 |
+
all_question_metrics = []
|
761 |
+
|
762 |
+
for top_n in TOP_N_VALUES:
|
763 |
+
metrics, question_metrics = evaluate_for_top_n_with_mapping(
|
764 |
+
questions_df, # Исходный датасет с связью между вопросами и документами
|
765 |
+
chunks_df, # Датасет с чанками
|
766 |
+
question_embeddings, # Эмбеддинги вопросов
|
767 |
+
chunk_embeddings, # Эмбеддинги чанков
|
768 |
+
question_id_to_idx, # Маппинг id вопроса к индексу в эмбеддингах
|
769 |
+
top_n, # Количество чанков в топе
|
770 |
+
args.similarity_threshold, # Порог для определения перекрытия
|
771 |
+
top_chunks_dir if top_n == 20 else None # Сохраняем топ-чанки только для top_n=20
|
772 |
+
)
|
773 |
+
aggregated_results.append(metrics)
|
774 |
+
all_question_metrics.append(question_metrics)
|
775 |
+
|
776 |
+
# Объединяем все метрики по вопросам
|
777 |
+
all_question_metrics_df = pd.concat(all_question_metrics)
|
778 |
+
|
779 |
+
# Создаем DataFrame с агрегированными результатами
|
780 |
+
aggregated_results_df = pd.DataFrame(aggregated_results)
|
781 |
+
|
782 |
+
# Сохраняем результаты
|
783 |
+
results_filename = f"results_{strategy_config_str}_{args.model_name.replace('/', '_')}.csv"
|
784 |
+
results_path = os.path.join(args.output_dir, results_filename)
|
785 |
+
aggregated_results_df.to_csv(results_path, index=False)
|
786 |
+
|
787 |
+
# Сохраняем метрики по вопросам
|
788 |
+
question_metrics_filename = f"question_metrics_{strategy_config_str}_{args.model_name.replace('/', '_')}.xlsx"
|
789 |
+
question_metrics_path = os.path.join(args.output_dir, question_metrics_filename)
|
790 |
+
all_question_metrics_df.to_excel(question_metrics_path, index=False)
|
791 |
+
|
792 |
+
print(f"\nРезультаты сохранены в {results_path}")
|
793 |
+
print(f"Метрики по вопросам сохранены в {question_metrics_path}")
|
794 |
+
print(f"Топ-20 чанков для каждого вопроса сохранены в {top_chunks_dir}")
|
795 |
+
print("\nМетрики для различных значений top_n:")
|
796 |
+
print(aggregated_results_df[['top_n', 'text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']])
|
797 |
+
|
798 |
+
|
799 |
+
if __name__ == "__main__":
|
800 |
+
main()
|
lib/extractor/scripts/plot_macro_metrics.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для построения специализированных графиков на основе макрометрик из Excel-файла.
|
4 |
+
Строит несколько типов графиков:
|
5 |
+
1. Зависимость macro_text_recall от top_N для разных моделей при фиксированных параметрах чанкинга
|
6 |
+
2. Зависимость macro_text_recall от top_N для разных подходов к чанкингу при фиксированных моделях
|
7 |
+
3. Зависимость macro_text_recall от подхода к чанкингу для разных моделей при фиксированных top_N
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import pandas as pd
|
14 |
+
import seaborn as sns
|
15 |
+
|
16 |
+
# Константы
|
17 |
+
EXCEL_FILE_PATH = "../../Белагропромбанк/test_vectors/combined_results.xlsx"
|
18 |
+
PLOTS_DIR = "../../Белагропромбанк/test_vectors/plots"
|
19 |
+
|
20 |
+
# Настройки для графиков
|
21 |
+
plt.rcParams['font.family'] = 'DejaVu Sans'
|
22 |
+
sns.set_style("whitegrid")
|
23 |
+
FIGSIZE = (14, 10)
|
24 |
+
DPI = 300
|
25 |
+
|
26 |
+
|
27 |
+
def setup_plots_directory(plots_dir: str) -> None:
|
28 |
+
"""
|
29 |
+
Создает директорию для сохранения графиков, если она не существует.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
plots_dir: Путь к директории для графиков
|
33 |
+
"""
|
34 |
+
if not os.path.exists(plots_dir):
|
35 |
+
os.makedirs(plots_dir)
|
36 |
+
print(f"Создана директория для графиков: {plots_dir}")
|
37 |
+
else:
|
38 |
+
print(f"Использование существующей директории для графиков: {plots_dir}")
|
39 |
+
|
40 |
+
|
41 |
+
def load_macro_metrics(excel_path: str) -> pd.DataFrame:
|
42 |
+
"""
|
43 |
+
Загружает макрометрики из Excel-файла.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
excel_path: Путь к Excel-файлу с данными
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
DataFrame с макрометриками
|
50 |
+
"""
|
51 |
+
try:
|
52 |
+
df = pd.read_excel(excel_path, sheet_name="Macro метрики")
|
53 |
+
print(f"Загружены данные из {excel_path}, лист 'Macro метрики'")
|
54 |
+
print(f"Количество строк: {len(df)}")
|
55 |
+
return df
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Ошибка при загрузке данных: {e}")
|
58 |
+
raise
|
59 |
+
|
60 |
+
|
61 |
+
def plot_top_n_vs_recall_by_model(df: pd.DataFrame, plots_dir: str) -> None:
|
62 |
+
"""
|
63 |
+
Строит графики зависимости macro_text_recall от top_N для разных моделей
|
64 |
+
при фиксированных параметрах чанкинга (50/25 и 200/75).
|
65 |
+
|
66 |
+
Args:
|
67 |
+
df: DataFrame с данными
|
68 |
+
plots_dir: Директория для сохранения графиков
|
69 |
+
"""
|
70 |
+
# Фиксированные параметры чанкинга
|
71 |
+
chunking_params = [
|
72 |
+
{"words": 50, "overlap": 25, "title": "Чанкинг 50/25"},
|
73 |
+
{"words": 200, "overlap": 75, "title": "Чанкинг 200/75"}
|
74 |
+
]
|
75 |
+
|
76 |
+
# Создаем субплоты: 1 строка, 2 столбца
|
77 |
+
fig, axes = plt.subplots(1, 2, figsize=FIGSIZE, sharey=True)
|
78 |
+
|
79 |
+
for i, params in enumerate(chunking_params):
|
80 |
+
# Фильтруем данные для текущих параметров чанкинга
|
81 |
+
filtered_df = df[
|
82 |
+
(df['words_per_chunk'] == params['words']) &
|
83 |
+
(df['overlap_words'] == params['overlap'])
|
84 |
+
]
|
85 |
+
|
86 |
+
if len(filtered_df) == 0:
|
87 |
+
print(f"Предупреждение: нет данных для чанкинга {params['words']}/{params['overlap']}")
|
88 |
+
axes[i].text(0.5, 0.5, f"Нет данных для чанкинга {params['words']}/{params['overlap']}",
|
89 |
+
ha='center', va='center', fontsize=12)
|
90 |
+
axes[i].set_title(params['title'])
|
91 |
+
continue
|
92 |
+
|
93 |
+
# Находим уникальные модели
|
94 |
+
models = filtered_df['model'].unique()
|
95 |
+
|
96 |
+
# Создаем палитру цветов
|
97 |
+
palette = sns.color_palette("viridis", len(models))
|
98 |
+
|
99 |
+
# Строим график для каждой модели
|
100 |
+
for j, model in enumerate(models):
|
101 |
+
model_df = filtered_df[filtered_df['model'] == model].sort_values('top_n')
|
102 |
+
|
103 |
+
if len(model_df) <= 1:
|
104 |
+
print(f"Предупреждение: недостаточно данных для модели {model} при чанкинге {params['words']}/{params['overlap']}")
|
105 |
+
continue
|
106 |
+
|
107 |
+
# Строим ломаную линию
|
108 |
+
axes[i].plot(model_df['top_n'], model_df['macro_text_recall'],
|
109 |
+
marker='o', linestyle='-', linewidth=2,
|
110 |
+
label=model, color=palette[j])
|
111 |
+
|
112 |
+
# Настраиваем оси и заголовок
|
113 |
+
axes[i].set_title(params['title'], fontsize=14)
|
114 |
+
axes[i].set_xlabel('top_N', fontsize=12)
|
115 |
+
if i == 0:
|
116 |
+
axes[i].set_ylabel('macro_text_recall', fontsize=12)
|
117 |
+
|
118 |
+
# Добавляем сетку
|
119 |
+
axes[i].grid(True, linestyle='--', alpha=0.7)
|
120 |
+
|
121 |
+
# Добавляем легенду
|
122 |
+
axes[i].legend(title="Модель", fontsize=10, loc='best')
|
123 |
+
|
124 |
+
# Общий заголовок
|
125 |
+
plt.suptitle('Зависимость macro_text_recall от top_N для разных моделей', fontsize=16)
|
126 |
+
|
127 |
+
# Настраиваем макет
|
128 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
129 |
+
|
130 |
+
# Сохраняем график
|
131 |
+
file_path = os.path.join(plots_dir, "top_n_vs_recall_by_model.png")
|
132 |
+
plt.savefig(file_path, dpi=DPI)
|
133 |
+
plt.close()
|
134 |
+
|
135 |
+
print(f"Создан график: {file_path}")
|
136 |
+
|
137 |
+
|
138 |
+
def plot_top_n_vs_recall_by_chunking(df: pd.DataFrame, plots_dir: str) -> None:
|
139 |
+
"""
|
140 |
+
Строит графики зависимости macro_text_recall от top_N для разных параметров чанкинга
|
141 |
+
при фиксированных моделях (bge и frida).
|
142 |
+
|
143 |
+
Args:
|
144 |
+
df: DataFrame с данными
|
145 |
+
plots_dir: Директория для сохранения графиков
|
146 |
+
"""
|
147 |
+
# Фиксированные модели
|
148 |
+
models = ["BAAI/bge", "frida"]
|
149 |
+
|
150 |
+
# Создаем субплоты: 1 строка, 2 столбца
|
151 |
+
fig, axes = plt.subplots(1, 2, figsize=FIGSIZE, sharey=True)
|
152 |
+
|
153 |
+
for i, model_name in enumerate(models):
|
154 |
+
# Находим все строки с моделями, содержащими указанное название
|
155 |
+
model_df = df[df['model'].str.contains(model_name, case=False)]
|
156 |
+
|
157 |
+
if len(model_df) == 0:
|
158 |
+
print(f"Предупреждение: нет данных для модели {model_name}")
|
159 |
+
axes[i].text(0.5, 0.5, f"Нет данных для модели {model_name}",
|
160 |
+
ha='center', va='center', fontsize=12)
|
161 |
+
axes[i].set_title(f"Модель: {model_name}")
|
162 |
+
continue
|
163 |
+
|
164 |
+
# Находим уникальные комбинации параметров чанкинга
|
165 |
+
chunking_combinations = model_df.drop_duplicates(['words_per_chunk', 'overlap_words'])[['words_per_chunk', 'overlap_words']]
|
166 |
+
|
167 |
+
# Ограничиваем количество комбинаций до 7 для читаемости
|
168 |
+
if len(chunking_combinations) > 7:
|
169 |
+
print(f"Предупреждение: слишком много комбинаций чанкинга для модели {model_name}, ограничиваем до 7")
|
170 |
+
chunking_combinations = chunking_combinations.head(7)
|
171 |
+
|
172 |
+
# Создаем палитру цветов
|
173 |
+
palette = sns.color_palette("viridis", len(chunking_combinations))
|
174 |
+
|
175 |
+
# Строим график для каждой комбинации параметров чанкинга
|
176 |
+
for j, (_, row) in enumerate(chunking_combinations.iterrows()):
|
177 |
+
words = row['words_per_chunk']
|
178 |
+
overlap = row['overlap_words']
|
179 |
+
|
180 |
+
# Фильтруем данные для текущей модели и параметров чанкинга
|
181 |
+
chunking_df = model_df[
|
182 |
+
(model_df['words_per_chunk'] == words) &
|
183 |
+
(model_df['overlap_words'] == overlap)
|
184 |
+
].sort_values('top_n')
|
185 |
+
|
186 |
+
if len(chunking_df) <= 1:
|
187 |
+
print(f"Предупреждение: недостаточно данных для модели {model_name} с чанкингом {words}/{overlap}")
|
188 |
+
continue
|
189 |
+
|
190 |
+
# Строим ломаную линию
|
191 |
+
axes[i].plot(chunking_df['top_n'], chunking_df['macro_text_recall'],
|
192 |
+
marker='o', linestyle='-', linewidth=2,
|
193 |
+
label=f"w={words}, o={overlap}", color=palette[j])
|
194 |
+
|
195 |
+
# Настраиваем оси и заголовок
|
196 |
+
axes[i].set_title(f"Модель: {model_name}", fontsize=14)
|
197 |
+
axes[i].set_xlabel('top_N', fontsize=12)
|
198 |
+
if i == 0:
|
199 |
+
axes[i].set_ylabel('macro_text_recall', fontsize=12)
|
200 |
+
|
201 |
+
# Добавляем сетку
|
202 |
+
axes[i].grid(True, linestyle='--', alpha=0.7)
|
203 |
+
|
204 |
+
# Добавляем легенду
|
205 |
+
axes[i].legend(title="Чанкинг", fontsize=10, loc='best')
|
206 |
+
|
207 |
+
# Общий заголовок
|
208 |
+
plt.suptitle('Зависимость macro_text_recall от top_N для разных параметров чанкинга', fontsize=16)
|
209 |
+
|
210 |
+
# Настраиваем макет
|
211 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
212 |
+
|
213 |
+
# Сохраняем график
|
214 |
+
file_path = os.path.join(plots_dir, "top_n_vs_recall_by_chunking.png")
|
215 |
+
plt.savefig(file_path, dpi=DPI)
|
216 |
+
plt.close()
|
217 |
+
|
218 |
+
print(f"Создан график: {file_path}")
|
219 |
+
|
220 |
+
|
221 |
+
def plot_chunking_vs_recall_by_model(df: pd.DataFrame, plots_dir: str) -> None:
|
222 |
+
"""
|
223 |
+
Строит графики зависимости macro_text_recall от подхода к чанкингу
|
224 |
+
для разных моделей при фиксированных top_N (5, 20, 100).
|
225 |
+
|
226 |
+
Args:
|
227 |
+
df: DataFrame с данными
|
228 |
+
plots_dir: Директория для сохранения графиков
|
229 |
+
"""
|
230 |
+
# Фиксированные значения top_N
|
231 |
+
top_n_values = [5, 20, 100]
|
232 |
+
|
233 |
+
# Создаем субплоты: 1 строка, 3 столбца
|
234 |
+
fig, axes = plt.subplots(1, 3, figsize=FIGSIZE, sharey=True)
|
235 |
+
|
236 |
+
# Создаем порядок чанкинга - сортируем по возрастанию размера и оверлапа
|
237 |
+
chunking_order = df.drop_duplicates(['words_per_chunk', 'overlap_words'])[['words_per_chunk', 'overlap_words']]
|
238 |
+
chunking_order = chunking_order.sort_values(['words_per_chunk', 'overlap_words'])
|
239 |
+
|
240 |
+
# Создаем словарь для маппинга комбинаций чанкинга на индексы
|
241 |
+
chunking_labels = [f"{row['words_per_chunk']}/{row['overlap_words']}" for _, row in chunking_order.iterrows()]
|
242 |
+
chunking_map = {f"{row['words_per_chunk']}/{row['overlap_words']}": i for i, (_, row) in enumerate(chunking_order.iterrows())}
|
243 |
+
|
244 |
+
for i, top_n in enumerate(top_n_values):
|
245 |
+
# Фильтруем данные для текущего top_N
|
246 |
+
top_n_df = df[df['top_n'] == top_n]
|
247 |
+
|
248 |
+
if len(top_n_df) == 0:
|
249 |
+
print(f"Предупреждение: нет данных для top_N={top_n}")
|
250 |
+
axes[i].text(0.5, 0.5, f"Нет данных для top_N={top_n}",
|
251 |
+
ha='center', va='center', fontsize=12)
|
252 |
+
axes[i].set_title(f"top_N={top_n}")
|
253 |
+
continue
|
254 |
+
|
255 |
+
# Находим уникальные модели
|
256 |
+
models = top_n_df['model'].unique()
|
257 |
+
|
258 |
+
# Ограничиваем количество моделей до 5 для читаемости
|
259 |
+
if len(models) > 5:
|
260 |
+
print(f"Предупреждение: слишком много моделей для top_N={top_n}, ограничиваем до 5")
|
261 |
+
models = models[:5]
|
262 |
+
|
263 |
+
# Создаем палитру цветов
|
264 |
+
palette = sns.color_palette("viridis", len(models))
|
265 |
+
|
266 |
+
# Строим график для каждой модели
|
267 |
+
for j, model in enumerate(models):
|
268 |
+
model_df = top_n_df[top_n_df['model'] == model].copy()
|
269 |
+
|
270 |
+
if len(model_df) <= 1:
|
271 |
+
print(f"Предупреждение: недостаточно данных для модели {model} при top_N={top_n}")
|
272 |
+
continue
|
273 |
+
|
274 |
+
# Создаем новую колонку с индексом чанкинга для сортировки
|
275 |
+
model_df['chunking_index'] = model_df.apply(
|
276 |
+
lambda row: chunking_map.get(f"{row['words_per_chunk']}/{row['overlap_words']}", -1),
|
277 |
+
axis=1
|
278 |
+
)
|
279 |
+
|
280 |
+
# Отбрасываем строки с неизвестными комбинациями чанкинга
|
281 |
+
model_df = model_df[model_df['chunking_index'] >= 0]
|
282 |
+
|
283 |
+
if len(model_df) <= 1:
|
284 |
+
continue
|
285 |
+
|
286 |
+
# Сортируем по индексу чанкинга
|
287 |
+
model_df = model_df.sort_values('chunking_index')
|
288 |
+
|
289 |
+
# Создаем список индексов и значений для графика
|
290 |
+
x_indices = model_df['chunking_index'].tolist()
|
291 |
+
y_values = model_df['macro_text_recall'].tolist()
|
292 |
+
|
293 |
+
# Строим ломаную линию
|
294 |
+
axes[i].plot(x_indices, y_values, marker='o', linestyle='-', linewidth=2,
|
295 |
+
label=model, color=palette[j])
|
296 |
+
|
297 |
+
# Настраиваем оси и заголовок
|
298 |
+
axes[i].set_title(f"top_N={top_n}", fontsize=14)
|
299 |
+
axes[i].set_xlabel('Подход к чанкингу', fontsize=12)
|
300 |
+
if i == 0:
|
301 |
+
axes[i].set_ylabel('macro_text_recall', fontsize=12)
|
302 |
+
|
303 |
+
# Устанавливаем метки на оси X (подходы к чанкингу)
|
304 |
+
axes[i].set_xticks(range(len(chunking_labels)))
|
305 |
+
axes[i].set_xticklabels(chunking_labels, rotation=45, ha='right', fontsize=10)
|
306 |
+
|
307 |
+
# Добавляем сетку
|
308 |
+
axes[i].grid(True, linestyle='--', alpha=0.7)
|
309 |
+
|
310 |
+
# Добавляем легенду
|
311 |
+
axes[i].legend(title="Модель", fontsize=10, loc='best')
|
312 |
+
|
313 |
+
# Общий заголовок
|
314 |
+
plt.suptitle('Зависимость macro_text_recall от подхода к чанкингу для разных моделей', fontsize=16)
|
315 |
+
|
316 |
+
# Настраиваем макет
|
317 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
318 |
+
|
319 |
+
# Сохраняем график
|
320 |
+
file_path = os.path.join(plots_dir, "chunking_vs_recall_by_model.png")
|
321 |
+
plt.savefig(file_path, dpi=DPI)
|
322 |
+
plt.close()
|
323 |
+
|
324 |
+
print(f"Создан график: {file_path}")
|
325 |
+
|
326 |
+
|
327 |
+
def main():
|
328 |
+
"""Основная функция скрипта."""
|
329 |
+
# Создаем директорию для графиков
|
330 |
+
setup_plots_directory(PLOTS_DIR)
|
331 |
+
|
332 |
+
# Загружаем данные
|
333 |
+
try:
|
334 |
+
macro_metrics = load_macro_metrics(EXCEL_FILE_PATH)
|
335 |
+
except Exception as e:
|
336 |
+
print(f"Критическая ошибка: {e}")
|
337 |
+
return
|
338 |
+
|
339 |
+
# Строим графики
|
340 |
+
plot_top_n_vs_recall_by_model(macro_metrics, PLOTS_DIR)
|
341 |
+
plot_top_n_vs_recall_by_chunking(macro_metrics, PLOTS_DIR)
|
342 |
+
plot_chunking_vs_recall_by_model(macro_metrics, PLOTS_DIR)
|
343 |
+
|
344 |
+
print("Готово! Все графики созданы.")
|
345 |
+
|
346 |
+
|
347 |
+
if __name__ == "__main__":
|
348 |
+
main()
|
lib/extractor/scripts/prepare_dataset.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для подготовки датасета с вопросами и текстами пунктов/приложений.
|
4 |
+
Преобразует исходный датасет, содержащий списки пунктов, в расширенный датасет,
|
5 |
+
где каждому пункту/приложению соответствует отдельная строка.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import sys
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, Dict
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from ntr_text_fragmentation import Destructurer
|
17 |
+
|
18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
19 |
+
from ntr_fileparser import UniversalParser
|
20 |
+
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
"""
|
24 |
+
Парсит аргументы командной строки.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Аргументы командной строки
|
28 |
+
"""
|
29 |
+
parser = argparse.ArgumentParser(description="Подготовка датасета с текстами пунктов")
|
30 |
+
|
31 |
+
parser.add_argument('--input-dataset', type=str, default='data/dataset.xlsx',
|
32 |
+
help='Путь к исходному датасету (Excel-файл)')
|
33 |
+
parser.add_argument('--output-dataset', type=str, default='data/dataset_with_texts.xlsx',
|
34 |
+
help='Путь для сохранения подготовленного датасета (Excel-файл)')
|
35 |
+
parser.add_argument('--data-folder', type=str, default='data/docs',
|
36 |
+
help='Путь к папке с документами')
|
37 |
+
parser.add_argument('--debug', action='store_true',
|
38 |
+
help='Включить режим отладки с дополнительным выводом информации')
|
39 |
+
|
40 |
+
return parser.parse_args()
|
41 |
+
|
42 |
+
|
43 |
+
def load_dataset(file_path: str, debug: bool = False) -> pd.DataFrame:
|
44 |
+
"""
|
45 |
+
Загружает исходный датасет с вопросами.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
file_path: Путь к Excel-файлу
|
49 |
+
debug: Режим отладки
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
DataFrame с вопросами
|
53 |
+
"""
|
54 |
+
print(f"Загрузка исходного датасета из {file_path}...")
|
55 |
+
|
56 |
+
df = pd.read_excel(file_path)
|
57 |
+
|
58 |
+
# Преобразуем строковые списки в настоящие списки
|
59 |
+
for col in ['puncts', 'appendices']:
|
60 |
+
if col in df.columns:
|
61 |
+
df[col] = df[col].apply(lambda x:
|
62 |
+
eval(x) if isinstance(x, str) and x.strip()
|
63 |
+
else ([] if pd.isna(x) else x))
|
64 |
+
|
65 |
+
# Вывод отладочной информации о форматах пунктов/приложений
|
66 |
+
if debug:
|
67 |
+
all_puncts = set()
|
68 |
+
all_appendices = set()
|
69 |
+
|
70 |
+
for _, row in df.iterrows():
|
71 |
+
if 'puncts' in row and row['puncts']:
|
72 |
+
all_puncts.update(row['puncts'])
|
73 |
+
if 'appendices' in row and row['appendices']:
|
74 |
+
all_appendices.update(row['appendices'])
|
75 |
+
|
76 |
+
print(f"\nУникальные форматы пунктов в датасете ({len(all_puncts)}):")
|
77 |
+
for i, p in enumerate(sorted(all_puncts)):
|
78 |
+
if i < 20 or i > len(all_puncts) - 20:
|
79 |
+
print(f" - {repr(p)}")
|
80 |
+
elif i == 20:
|
81 |
+
print(" ... (пропущено)")
|
82 |
+
|
83 |
+
print(f"\nУникальные форматы приложений в датасете ({len(all_appendices)}):")
|
84 |
+
for app in sorted(all_appendices):
|
85 |
+
print(f" - {repr(app)}")
|
86 |
+
|
87 |
+
print(f"Загружено {len(df)} вопросов")
|
88 |
+
return df
|
89 |
+
|
90 |
+
|
91 |
+
def read_documents(folder_path: str) -> Dict[str, Any]:
|
92 |
+
"""
|
93 |
+
Читает все документы из указанной папки.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
folder_path: Путь к папке с документами
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Словарь {имя_файла: parsed_document}
|
100 |
+
"""
|
101 |
+
print(f"Чтение документов из {folder_path}...")
|
102 |
+
parser = UniversalParser()
|
103 |
+
documents = {}
|
104 |
+
|
105 |
+
for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"):
|
106 |
+
try:
|
107 |
+
doc_name = file_path.stem
|
108 |
+
documents[doc_name] = parser.parse_by_path(str(file_path))
|
109 |
+
except Exception as e:
|
110 |
+
print(f"Ошибка при чтении файла {file_path}: {e}")
|
111 |
+
|
112 |
+
print(f"Прочитано {len(documents)} документов")
|
113 |
+
return documents
|
114 |
+
|
115 |
+
|
116 |
+
def normalize_punct_format(punct: str) -> str:
|
117 |
+
"""
|
118 |
+
Нормализует формат номера пункта для единообразного сравнения.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
punct: Номер пункта
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Нормализованный номер пункта
|
125 |
+
"""
|
126 |
+
# Убираем пробелы
|
127 |
+
punct = punct.strip()
|
128 |
+
|
129 |
+
# Убираем завершающую точку, если она есть
|
130 |
+
if punct.endswith('.'):
|
131 |
+
punct = punct[:-1]
|
132 |
+
|
133 |
+
return punct
|
134 |
+
|
135 |
+
|
136 |
+
def normalize_appendix_format(appendix: str) -> str:
|
137 |
+
"""
|
138 |
+
Нормализует формат номера приложения для единообразного сравнения.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
appendix: Номер приложения
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Нормализованный номер приложения
|
145 |
+
"""
|
146 |
+
# Убираем пробелы
|
147 |
+
appendix = appendix.strip()
|
148 |
+
|
149 |
+
# Обработка форматов с дефисом (например, "14-1")
|
150 |
+
if "-" in appendix:
|
151 |
+
return appendix
|
152 |
+
|
153 |
+
return appendix
|
154 |
+
|
155 |
+
|
156 |
+
def find_matching_key(search_key, available_keys, item_type='punct', debug_mode=False):
|
157 |
+
"""
|
158 |
+
Ищет наиболее подходящий ключ среди доступных ключей с учетом типа элемента
|
159 |
+
|
160 |
+
Args:
|
161 |
+
search_key: Ключ для поиска
|
162 |
+
available_keys: Доступные ключи
|
163 |
+
item_type: Тип элемента ('punct' или 'appendix')
|
164 |
+
debug_mode: Режим отладки
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Найденный ключ или None
|
168 |
+
"""
|
169 |
+
if not available_keys:
|
170 |
+
return None
|
171 |
+
|
172 |
+
# Нормализуем ключ в зависимости от типа элемента
|
173 |
+
if item_type == 'punct':
|
174 |
+
normalized_search_key = normalize_punct_format(search_key)
|
175 |
+
else: # appendix
|
176 |
+
normalized_search_key = normalize_appendix_format(search_key)
|
177 |
+
|
178 |
+
# Проверяем прямое совпадение ключей
|
179 |
+
for key in available_keys:
|
180 |
+
if item_type == 'punct':
|
181 |
+
normalized_key = normalize_punct_format(key)
|
182 |
+
else: # appendix
|
183 |
+
normalized_key = normalize_appendix_format(key)
|
184 |
+
|
185 |
+
if normalized_key == normalized_search_key:
|
186 |
+
if debug_mode:
|
187 |
+
print(f"Найдено прямое совпадение для {item_type} {search_key} -> {key}")
|
188 |
+
return key
|
189 |
+
|
190 |
+
# Если прямого совпадения нет, проверяем "мягкое" совпадение
|
191 |
+
# Только для пунктов, не для приложений
|
192 |
+
if item_type == 'punct':
|
193 |
+
for key in available_keys:
|
194 |
+
normalized_key = normalize_punct_format(key)
|
195 |
+
|
196 |
+
# Если ключ содержит "/", это подпункт приложения, его не следует сопоставлять с обычным пунктом
|
197 |
+
if '/' in key and '/' not in search_key:
|
198 |
+
continue
|
199 |
+
|
200 |
+
# Проверяем совпадение конца номера (например, "1.2" и "1.2.")
|
201 |
+
if normalized_key.rstrip('.') == normalized_search_key.rstrip('.'):
|
202 |
+
if debug_mode:
|
203 |
+
print(f"Найдено мягкое совпадение для {search_key} -> {key}")
|
204 |
+
return key
|
205 |
+
|
206 |
+
return None
|
207 |
+
|
208 |
+
|
209 |
+
def extract_item_texts(documents, debug_mode=False):
|
210 |
+
"""
|
211 |
+
Извлекает тексты пунктов и приложений из документов.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
documents: Словарь с распарсенными документами {doc_name: document}
|
215 |
+
debug_mode: Включать ли режим отладки
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
Словарь с текстами пунктов и приложений, организованный по названиям документов
|
219 |
+
"""
|
220 |
+
print("Извлечение текстов пунктов и приложений...")
|
221 |
+
|
222 |
+
item_texts = {}
|
223 |
+
all_extracted_items = set()
|
224 |
+
all_extracted_appendices = set()
|
225 |
+
|
226 |
+
for doc_name, document in tqdm(documents.items(), desc="Применение стратегии numbered_items"):
|
227 |
+
# Используем стратегию numbered_items с режимом отладки
|
228 |
+
destructurer = Destructurer(document)
|
229 |
+
destructurer.configure('numbered_items', debug_mode=debug_mode)
|
230 |
+
entities, _ = destructurer.destructure()
|
231 |
+
|
232 |
+
# Инициализируем структуру для документа, если она еще не создана
|
233 |
+
if doc_name not in item_texts:
|
234 |
+
item_texts[doc_name] = {
|
235 |
+
'puncts': {}, # Для пунктов основного текста
|
236 |
+
'appendices': {} # Для приложений
|
237 |
+
}
|
238 |
+
|
239 |
+
for entity in entities:
|
240 |
+
# Пропускаем сущность документа
|
241 |
+
if entity.type == "Document":
|
242 |
+
continue
|
243 |
+
|
244 |
+
# Работаем только с чанками для поиска
|
245 |
+
if hasattr(entity, 'use_in_search') and entity.use_in_search:
|
246 |
+
metadata = entity.metadata
|
247 |
+
text = entity.text
|
248 |
+
|
249 |
+
# Для пунктов
|
250 |
+
if 'item_number' in metadata:
|
251 |
+
item_number = metadata['item_number']
|
252 |
+
|
253 |
+
# Проверяем, является ли пункт подпунктом приложения
|
254 |
+
if 'appendix_number' in metadata:
|
255 |
+
# Это подпункт приложения
|
256 |
+
appendix_number = metadata['appendix_number']
|
257 |
+
|
258 |
+
# Создаем структуру для приложения, если ее еще нет
|
259 |
+
if appendix_number not in item_texts[doc_name]['appendices']:
|
260 |
+
item_texts[doc_name]['appendices'][appendix_number] = {
|
261 |
+
'main_text': '', # Основной текст приложения
|
262 |
+
'subpuncts': {} # Подпункты приложения
|
263 |
+
}
|
264 |
+
|
265 |
+
# Добавляем подпункт в словарь подпунктов
|
266 |
+
item_texts[doc_name]['appendices'][appendix_number]['subpuncts'][item_number] = text
|
267 |
+
|
268 |
+
if debug_mode:
|
269 |
+
print(f"Извлечен подпункт {item_number} приложения {appendix_number} из {doc_name}")
|
270 |
+
|
271 |
+
all_extracted_items.add(item_number)
|
272 |
+
else:
|
273 |
+
# Обычный пункт
|
274 |
+
item_texts[doc_name]['puncts'][item_number] = text
|
275 |
+
|
276 |
+
if debug_mode:
|
277 |
+
print(f"Извлечен пункт {item_number} из {doc_name}")
|
278 |
+
|
279 |
+
all_extracted_items.add(item_number)
|
280 |
+
|
281 |
+
# Для приложений
|
282 |
+
elif 'appendix_number' in metadata and 'item_number' not in metadata:
|
283 |
+
appendix_number = metadata['appendix_number']
|
284 |
+
|
285 |
+
# Создаем структуру для приложения, если ее еще нет
|
286 |
+
if appendix_number not in item_texts[doc_name]['appendices']:
|
287 |
+
item_texts[doc_name]['appendices'][appendix_number] = {
|
288 |
+
'main_text': text, # Основной текст приложения
|
289 |
+
'subpuncts': {} # Подпункты приложения
|
290 |
+
}
|
291 |
+
else:
|
292 |
+
# Если приложение уже существует, обновляем основной текст
|
293 |
+
item_texts[doc_name]['appendices'][appendix_number]['main_text'] = text
|
294 |
+
|
295 |
+
if debug_mode:
|
296 |
+
print(f"Извлечено приложение {appendix_number} из {doc_name}")
|
297 |
+
|
298 |
+
all_extracted_appendices.add(appendix_number)
|
299 |
+
|
300 |
+
# Выводим статистику, если включен режим отладки
|
301 |
+
if debug_mode:
|
302 |
+
print(f"\nВсего извлечено уникальных пунктов: {len(all_extracted_items)}")
|
303 |
+
print(f"Примеры форматов пунктов: {', '.join(sorted(list(all_extracted_items))[:20])}")
|
304 |
+
|
305 |
+
print(f"\nВсего извлечено уникальных приложений: {len(all_extracted_appendices)}")
|
306 |
+
print(f"Форматы приложений: {', '.join(sorted(list(all_extracted_appendices)))}")
|
307 |
+
|
308 |
+
# Подсчитываем общее количество пунктов и приложений
|
309 |
+
total_puncts = sum(len(doc_data['puncts']) for doc_data in item_texts.values())
|
310 |
+
total_appendices = sum(len(doc_data['appendices']) for doc_data in item_texts.values())
|
311 |
+
|
312 |
+
print(f"Извлечено {total_puncts} пунктов и {total_appendices} приложений из {len(item_texts)} документов")
|
313 |
+
|
314 |
+
return item_texts
|
315 |
+
|
316 |
+
|
317 |
+
def is_subpunct(parent_punct: str, possible_subpunct: str) -> bool:
|
318 |
+
"""
|
319 |
+
Проверяет, является ли пункт подпунктом другого пункта.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
parent_punct: Родительский пункт (например, "14")
|
323 |
+
possible_subpunct: Возможный подпункт (например, "14.1")
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
True, если possible_subpunct является подпунктом parent_punct
|
327 |
+
"""
|
328 |
+
# Нормализуем пункты
|
329 |
+
parent = normalize_punct_format(parent_punct)
|
330 |
+
child = normalize_punct_format(possible_subpunct)
|
331 |
+
|
332 |
+
# Проверяем, начинается ли child с parent и после него идет точка или другой разделитель
|
333 |
+
if child.startswith(parent):
|
334 |
+
# Если длины равны, это тот же самый пункт
|
335 |
+
if len(child) == len(parent):
|
336 |
+
return False
|
337 |
+
|
338 |
+
# Проверяем символ после parent - должна быть точка (дефис исключен, т.к. это разные пункты)
|
339 |
+
next_char = child[len(parent)]
|
340 |
+
return next_char in ['.']
|
341 |
+
|
342 |
+
return False
|
343 |
+
|
344 |
+
|
345 |
+
def collect_subpuncts(punct: str, all_puncts: dict) -> dict:
|
346 |
+
"""
|
347 |
+
Собирает все подпункты для указанного пункта.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
punct: Пункт, для которого нужно найти подпункты (например, "14")
|
351 |
+
all_puncts: Словарь всех пунктов {punct: text}
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
Словарь {punct: text} с пунктом и всеми его подпунктами
|
355 |
+
"""
|
356 |
+
result = {}
|
357 |
+
normalized_punct = normalize_punct_format(punct)
|
358 |
+
|
359 |
+
# Добавляем сам пункт, если он существует
|
360 |
+
if normalized_punct in all_puncts:
|
361 |
+
result[normalized_punct] = all_puncts[normalized_punct]
|
362 |
+
|
363 |
+
# Ищем подпункты
|
364 |
+
for possible_subpunct in all_puncts.keys():
|
365 |
+
if is_subpunct(normalized_punct, possible_subpunct):
|
366 |
+
result[possible_subpunct] = all_puncts[possible_subpunct]
|
367 |
+
|
368 |
+
return result
|
369 |
+
|
370 |
+
|
371 |
+
def prepare_expanded_dataset(df, item_texts, output_path, debug_mode=False):
|
372 |
+
"""
|
373 |
+
Подготавливает расширенный датасет, добавляя тексты пунктов и приложений.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
df: Исходный датасет
|
377 |
+
item_texts: Словарь с текстами пунктов и приложений
|
378 |
+
output_path: Путь для сохранения расширенного датасета
|
379 |
+
debug_mode: Включать ли режим отладки
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
Датафрейм с расширенным датасетом
|
383 |
+
"""
|
384 |
+
rows = []
|
385 |
+
skipped_items = 0
|
386 |
+
total_items = 0
|
387 |
+
|
388 |
+
for _, row in df.iterrows():
|
389 |
+
question_id = row['id']
|
390 |
+
question = row['question']
|
391 |
+
filepath = row.get('filepath', '')
|
392 |
+
|
393 |
+
# Получаем имя файла без пути
|
394 |
+
doc_name = Path(filepath).stem if filepath else ''
|
395 |
+
|
396 |
+
# Пропускаем, если файл не найден
|
397 |
+
if not doc_name or doc_name not in item_texts:
|
398 |
+
if debug_mode and doc_name:
|
399 |
+
print(f"Документ {doc_name} не найден в извлеченных данных")
|
400 |
+
continue
|
401 |
+
|
402 |
+
# Обрабатываем пункты
|
403 |
+
puncts = row.get('puncts', [])
|
404 |
+
if isinstance(puncts, str) and puncts.strip():
|
405 |
+
# Преобразуем строковое представление в список
|
406 |
+
try:
|
407 |
+
puncts = eval(puncts)
|
408 |
+
except:
|
409 |
+
puncts = []
|
410 |
+
|
411 |
+
if not isinstance(puncts, list):
|
412 |
+
puncts = []
|
413 |
+
|
414 |
+
for punct in puncts:
|
415 |
+
total_items += 1
|
416 |
+
|
417 |
+
if debug_mode:
|
418 |
+
print(f"\nОбработка пункта {punct} для вопроса {question_id} из {doc_name}")
|
419 |
+
|
420 |
+
# Ищем соответствующий пункт в документе
|
421 |
+
available_keys = list(item_texts[doc_name]['puncts'].keys())
|
422 |
+
matching_key = find_matching_key(punct, available_keys, 'punct', debug_mode)
|
423 |
+
|
424 |
+
if matching_key:
|
425 |
+
# Сохраняем основной текст пункта
|
426 |
+
item_text = item_texts[doc_name]['puncts'][matching_key]
|
427 |
+
|
428 |
+
# Список всех включенных ключей (для отслеживания что было приконкатенировано)
|
429 |
+
matched_keys = [matching_key]
|
430 |
+
|
431 |
+
# Ищем все подпункты для этого пункта
|
432 |
+
subpuncts = {}
|
433 |
+
for key in available_keys:
|
434 |
+
if is_subpunct(matching_key, key):
|
435 |
+
subpuncts[key] = item_texts[doc_name]['puncts'][key]
|
436 |
+
matched_keys.append(key)
|
437 |
+
|
438 |
+
# Если есть подпункты, добавляем их к основному тексту
|
439 |
+
if subpuncts:
|
440 |
+
# Сортируем подпункты по номеру
|
441 |
+
sorted_subpuncts = sorted(subpuncts.items(), key=lambda x: x[0])
|
442 |
+
|
443 |
+
# Добавляем разделитель и все подпункты
|
444 |
+
combined_text = item_text
|
445 |
+
for key, subtext in sorted_subpuncts:
|
446 |
+
combined_text += f"\n\n{key} {subtext}"
|
447 |
+
|
448 |
+
item_text = combined_text
|
449 |
+
|
450 |
+
# Добавляем строку с пунктом и его подпунктами
|
451 |
+
rows.append({
|
452 |
+
'id': question_id,
|
453 |
+
'question': question,
|
454 |
+
'filename': doc_name,
|
455 |
+
'text': item_text,
|
456 |
+
'item_type': 'punct',
|
457 |
+
'item_id': punct,
|
458 |
+
'matching_keys': ", ".join(matched_keys)
|
459 |
+
})
|
460 |
+
|
461 |
+
if debug_mode:
|
462 |
+
print(f"Добавлен пункт {matching_key} для {question_id} с {len(matched_keys)} ключами")
|
463 |
+
if len(matched_keys) > 1:
|
464 |
+
print(f" Включены ключи: {', '.join(matched_keys)}")
|
465 |
+
else:
|
466 |
+
skipped_items += 1
|
467 |
+
if debug_mode:
|
468 |
+
print(f"Не найден соответствующий пункт для {punct} в {doc_name}")
|
469 |
+
|
470 |
+
# Обрабатываем приложения
|
471 |
+
appendices = row.get('appendices', [])
|
472 |
+
if isinstance(appendices, str) and appendices.strip():
|
473 |
+
# Преобразуем строковое представление в список
|
474 |
+
try:
|
475 |
+
appendices = eval(appendices)
|
476 |
+
except:
|
477 |
+
appendices = []
|
478 |
+
|
479 |
+
if not isinstance(appendices, list):
|
480 |
+
appendices = []
|
481 |
+
|
482 |
+
for appendix in appendices:
|
483 |
+
total_items += 1
|
484 |
+
|
485 |
+
if debug_mode:
|
486 |
+
print(f"\nОбработка приложения {appendix} для вопроса {question_id} из {doc_name}")
|
487 |
+
|
488 |
+
# Ищем соответствующее приложение в документе
|
489 |
+
available_keys = list(item_texts[doc_name]['appendices'].keys())
|
490 |
+
matching_key = find_matching_key(appendix, available_keys, 'appendix', debug_mode)
|
491 |
+
|
492 |
+
if matching_key:
|
493 |
+
appendix_content = item_texts[doc_name]['appendices'][matching_key]
|
494 |
+
|
495 |
+
# Список всех включенных ключей (для отслеживания что было приконкатенировано)
|
496 |
+
matched_keys = [matching_key]
|
497 |
+
|
498 |
+
# Формируем полный текст приложения, включая все подпункты
|
499 |
+
if isinstance(appendix_content, dict):
|
500 |
+
# Начинаем с основного текста
|
501 |
+
full_text = appendix_content.get('main_text', '')
|
502 |
+
|
503 |
+
# Добавляем все подпункты в отсортированном порядке
|
504 |
+
if 'subpuncts' in appendix_content and appendix_content['subpuncts']:
|
505 |
+
subpuncts = appendix_content['subpuncts']
|
506 |
+
sorted_subpuncts = sorted(subpuncts.items(), key=lambda x: x[0])
|
507 |
+
|
508 |
+
# Добавляем разделитель, если есть основной текст
|
509 |
+
if full_text:
|
510 |
+
full_text += "\n\n"
|
511 |
+
|
512 |
+
# Добавляем все подпункты
|
513 |
+
for i, (key, subtext) in enumerate(sorted_subpuncts):
|
514 |
+
matched_keys.append(f"{matching_key}/{key}")
|
515 |
+
if i > 0:
|
516 |
+
full_text += "\n\n"
|
517 |
+
full_text += f"{key} {subtext}"
|
518 |
+
else:
|
519 |
+
# Если приложение просто строка
|
520 |
+
full_text = appendix_content
|
521 |
+
|
522 |
+
# Добавляем строку с приложением
|
523 |
+
rows.append({
|
524 |
+
'id': question_id,
|
525 |
+
'question': question,
|
526 |
+
'filename': doc_name,
|
527 |
+
'text': full_text,
|
528 |
+
'item_type': 'appendix',
|
529 |
+
'item_id': appendix,
|
530 |
+
'matching_keys': ", ".join(matched_keys)
|
531 |
+
})
|
532 |
+
|
533 |
+
if debug_mode:
|
534 |
+
print(f"Добавлено приложение {matching_key} для {question_id} с {len(matched_keys)} ключами")
|
535 |
+
if len(matched_keys) > 1:
|
536 |
+
print(f" Включены ключи: {', '.join(matched_keys)}")
|
537 |
+
else:
|
538 |
+
skipped_items += 1
|
539 |
+
if debug_mode:
|
540 |
+
print(f"Не найдено соответствующее п��иложение для {appendix} в {doc_name}")
|
541 |
+
|
542 |
+
extended_df = pd.DataFrame(rows)
|
543 |
+
|
544 |
+
# Сохраняем расширенный датасет
|
545 |
+
extended_df.to_excel(output_path, index=False)
|
546 |
+
|
547 |
+
print(f"Расширенный датасет сохранен в {output_path}")
|
548 |
+
print(f"Всего обработано элементов: {total_items}")
|
549 |
+
print(f"Всего элементов в расширенном датасете: {len(extended_df)}")
|
550 |
+
print(f"Пропущено элементов из-за отсутствия соответствия: {skipped_items}")
|
551 |
+
|
552 |
+
return extended_df
|
553 |
+
|
554 |
+
|
555 |
+
def main():
|
556 |
+
# Парсим аргументы командной строки
|
557 |
+
args = parse_args()
|
558 |
+
|
559 |
+
# Определяем режим отладки
|
560 |
+
debug = args.debug
|
561 |
+
|
562 |
+
# Загружаем исходный датасет
|
563 |
+
df = load_dataset(args.input_dataset, debug)
|
564 |
+
|
565 |
+
# Читаем документы
|
566 |
+
documents = read_documents(args.data_folder)
|
567 |
+
|
568 |
+
# Извлекаем тексты пунктов и приложений
|
569 |
+
item_texts = extract_item_texts(documents, debug)
|
570 |
+
|
571 |
+
# Подготавливаем расширенный датасет
|
572 |
+
expanded_df = prepare_expanded_dataset(df, item_texts, args.output_dataset, debug)
|
573 |
+
|
574 |
+
print("Готово!")
|
575 |
+
|
576 |
+
|
577 |
+
if __name__ == "__main__":
|
578 |
+
main()
|
lib/extractor/scripts/run_chunking_experiments.sh
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Скрипт для запуска экспериментов по оценке качества чанкинга с разными моделями и параметрами
|
3 |
+
|
4 |
+
# Директории и пути по умолчанию
|
5 |
+
DATA_FOLDER="data/docs"
|
6 |
+
DATASET_PATH="data/dataset.xlsx"
|
7 |
+
OUTPUT_DIR="data"
|
8 |
+
LOG_DIR="logs"
|
9 |
+
SIMILARITY_THRESHOLD=0.7
|
10 |
+
DEVICE="cuda:1"
|
11 |
+
|
12 |
+
# Создаем директории, если они не существуют
|
13 |
+
mkdir -p "$OUTPUT_DIR"
|
14 |
+
mkdir -p "$LOG_DIR"
|
15 |
+
|
16 |
+
# Список моделей для тестирования
|
17 |
+
MODELS=(
|
18 |
+
"intfloat/e5-base"
|
19 |
+
"intfloat/e5-large"
|
20 |
+
"BAAI/bge-m3"
|
21 |
+
"deepvk/USER-bge-m3"
|
22 |
+
"ai-forever/FRIDA"
|
23 |
+
)
|
24 |
+
|
25 |
+
# Параметры чанкинга (отсортированы в запрошенном порядке)
|
26 |
+
# Формат: [слов_в_чанке]:[нахлест]:[описание]
|
27 |
+
CHUNKING_PARAMS=(
|
28 |
+
"50:25:Маленький чанкинг с нахлёстом 50%"
|
29 |
+
"50:0:Маленький чанкинг без нахлёста"
|
30 |
+
"20:10:Очень мелкий чанкинг с нахлёстом 50%"
|
31 |
+
"100:0:Средний чанкинг без нахлёста"
|
32 |
+
"100:25:Средний чанкинг с нахлёстом 25%"
|
33 |
+
"150:50:Крупный чанкинг с нахлёстом 33%"
|
34 |
+
"200:75:Очень крупный чанкинг с нахлёстом 37.5%"
|
35 |
+
)
|
36 |
+
|
37 |
+
# Функция для запуска одного эксперимента
|
38 |
+
run_experiment() {
|
39 |
+
local model="$1"
|
40 |
+
local words="$2"
|
41 |
+
local overlap="$3"
|
42 |
+
local description="$4"
|
43 |
+
|
44 |
+
# Заменяем слеши в имени модели на подчеркивания для имен файлов
|
45 |
+
local model_safe_name=$(echo "$model" | tr '/' '_')
|
46 |
+
|
47 |
+
# Формируем имя файла результатов
|
48 |
+
local results_filename="results_fixed_size_w${words}_o${overlap}_${model_safe_name}.csv"
|
49 |
+
local results_path="${OUTPUT_DIR}/${results_filename}"
|
50 |
+
|
51 |
+
# Формируем имя файла лога
|
52 |
+
local timestamp=$(date +"%Y%m%d_%H%M%S")
|
53 |
+
local log_filename="log_${model_safe_name}_w${words}_o${overlap}_${timestamp}.txt"
|
54 |
+
local log_path="${LOG_DIR}/${log_filename}"
|
55 |
+
|
56 |
+
echo "=============================================================================="
|
57 |
+
echo "Запуск эксперимента:"
|
58 |
+
echo " Модель: $model"
|
59 |
+
echo " Чанкинг: $description (words=$words, overlap=$overlap)"
|
60 |
+
echo " Устройство: $DEVICE"
|
61 |
+
echo " Результаты будут сохранены в: $results_path"
|
62 |
+
echo " Лог: $log_path"
|
63 |
+
echo "=============================================================================="
|
64 |
+
|
65 |
+
# Базовая команда запуска
|
66 |
+
local cmd="python scripts/evaluate_chunking.py \
|
67 |
+
--data-folder \"$DATA_FOLDER\" \
|
68 |
+
--model-name \"$model\" \
|
69 |
+
--dataset-path \"$DATASET_PATH\" \
|
70 |
+
--output-dir \"$OUTPUT_DIR\" \
|
71 |
+
--words-per-chunk $words \
|
72 |
+
--overlap-words $overlap \
|
73 |
+
--similarity-threshold $SIMILARITY_THRESHOLD \
|
74 |
+
--device $DEVICE \
|
75 |
+
--force-recompute"
|
76 |
+
|
77 |
+
# Специальная обработка для модели ai-forever/FRIDA
|
78 |
+
if [[ "$model" == "ai-forever/FRIDA" ]]; then
|
79 |
+
cmd="$cmd --use-sentence-transformers"
|
80 |
+
fi
|
81 |
+
|
82 |
+
# Записываем информацию о запуске в лог
|
83 |
+
echo "Эксперимент запущен в: $(date)" > "$log_path"
|
84 |
+
echo "Команда: $cmd" >> "$log_path"
|
85 |
+
echo "" >> "$log_path"
|
86 |
+
|
87 |
+
# Записываем время начала
|
88 |
+
start_time=$(date +%s)
|
89 |
+
|
90 |
+
# Запускаем команду и записываем вывод в лог
|
91 |
+
eval "$cmd" 2>&1 | tee -a "$log_path"
|
92 |
+
exit_code=${PIPESTATUS[0]}
|
93 |
+
|
94 |
+
# Записываем время окончания
|
95 |
+
end_time=$(date +%s)
|
96 |
+
duration=$((end_time - start_time))
|
97 |
+
duration_min=$(echo "scale=2; $duration/60" | bc)
|
98 |
+
|
99 |
+
# Добавляем информацию о завершении в лог
|
100 |
+
echo "" >> "$log_path"
|
101 |
+
echo "Эксперимент завершен в: $(date)" >> "$log_path"
|
102 |
+
echo "Длительность: $duration секунд ($duration_min минут)" >> "$log_path"
|
103 |
+
echo "Код возврата: $exit_code" >> "$log_path"
|
104 |
+
|
105 |
+
if [ $exit_code -eq 0 ]; then
|
106 |
+
echo "Эксперимент успешно завершен за $duration_min минут"
|
107 |
+
else
|
108 |
+
echo "Эксперимент завершился с ошибкой (код $exit_code)"
|
109 |
+
fi
|
110 |
+
}
|
111 |
+
|
112 |
+
# Основная функция
|
113 |
+
main() {
|
114 |
+
local total_experiments=$((${#MODELS[@]} * ${#CHUNKING_PARAMS[@]}))
|
115 |
+
local completed_experiments=0
|
116 |
+
|
117 |
+
echo "Запуск $total_experiments экспериментов..."
|
118 |
+
|
119 |
+
# Засекаем время начала всех экспериментов
|
120 |
+
local start_time_all=$(date +%s)
|
121 |
+
|
122 |
+
# Сначала перебираем все параметры чанкинга
|
123 |
+
for chunking_param in "${CHUNKING_PARAMS[@]}"; do
|
124 |
+
# Разбиваем строку параметров на составляющие
|
125 |
+
IFS=':' read -r words overlap description <<< "$chunking_param"
|
126 |
+
|
127 |
+
echo -e "\n=== Стратегия чанкинга: $description (words=$words, overlap=$overlap) ===\n"
|
128 |
+
|
129 |
+
# Затем перебираем все модели для текущей стратегии чанкинга
|
130 |
+
for model in "${MODELS[@]}"; do
|
131 |
+
# Запускаем эксперимент
|
132 |
+
run_experiment "$model" "$words" "$overlap" "$description"
|
133 |
+
|
134 |
+
# Увеличиваем счетчик завершенных экспериментов
|
135 |
+
completed_experiments=$((completed_experiments + 1))
|
136 |
+
remaining_experiments=$((total_experiments - completed_experiments))
|
137 |
+
|
138 |
+
if [ $remaining_experiments -gt 0 ]; then
|
139 |
+
echo "Завершено $completed_experiments/$total_experiments экспериментов. Осталось: $remaining_experiments"
|
140 |
+
fi
|
141 |
+
done
|
142 |
+
done
|
143 |
+
|
144 |
+
# Рассчитываем общее время выполнения
|
145 |
+
local end_time_all=$(date +%s)
|
146 |
+
local total_duration=$((end_time_all - start_time_all))
|
147 |
+
local total_duration_min=$(echo "scale=2; $total_duration/60" | bc)
|
148 |
+
|
149 |
+
echo ""
|
150 |
+
echo "Все эксперименты завершены за $total_duration_min минут"
|
151 |
+
echo "Результаты сохранены в $OUTPUT_DIR"
|
152 |
+
echo "Логи сохранены в $LOG_DIR"
|
153 |
+
}
|
154 |
+
|
155 |
+
# Запускаем основную функцию
|
156 |
+
main
|
lib/extractor/scripts/run_experiments.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для запуска экспериментов по оценке качества чанкинга с разными моделями и параметрами.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
# Конфигурация моделей
|
14 |
+
MODELS = [
|
15 |
+
"intfloat/e5-base",
|
16 |
+
"intfloat/e5-large",
|
17 |
+
"BAAI/bge-m3",
|
18 |
+
"deepvk/USER-bge-m3",
|
19 |
+
"ai-forever/FRIDA"
|
20 |
+
]
|
21 |
+
|
22 |
+
# Параметры чанкинга (отсортированы в запрошенном порядке)
|
23 |
+
CHUNKING_PARAMS = [
|
24 |
+
{"words": 50, "overlap": 25, "description": "Маленький чанкинг с нахлёстом 50%"},
|
25 |
+
{"words": 50, "overlap": 0, "description": "Маленький чанкинг без нахлёста"},
|
26 |
+
{"words": 20, "overlap": 10, "description": "Очень мелкий чанкинг с нахлёстом 50%"},
|
27 |
+
{"words": 100, "overlap": 0, "description": "Средний чанкинг без нахлёста"},
|
28 |
+
{"words": 100, "overlap": 25, "description": "Средний чанкинг с нахлёстом 25%"},
|
29 |
+
{"words": 150, "overlap": 50, "description": "Крупный чанкинг с нахлёстом 33%"},
|
30 |
+
{"words": 200, "overlap": 75, "description": "Очень крупный чанкинг с нахлёстом 37.5%"}
|
31 |
+
]
|
32 |
+
|
33 |
+
# Значение порога для нечеткого сравнения
|
34 |
+
SIMILARITY_THRESHOLD = 0.7
|
35 |
+
|
36 |
+
|
37 |
+
def parse_args():
|
38 |
+
"""Парсит аргументы командной строки."""
|
39 |
+
parser = argparse.ArgumentParser(description="Запуск экспериментов для оценки качества чанкинга")
|
40 |
+
|
41 |
+
parser.add_argument("--data-folder", type=str, default="data/docs",
|
42 |
+
help="Путь к папке с документами (по умолчанию: data/docs)")
|
43 |
+
parser.add_argument("--dataset-path", type=str, default="data/dataset.xlsx",
|
44 |
+
help="Путь к Excel-датасету с вопросами (по умолчанию: data/dataset.xlsx)")
|
45 |
+
parser.add_argument("--output-dir", type=str, default="data",
|
46 |
+
help="Директория для сохранения результатов (по умолчанию: data)")
|
47 |
+
parser.add_argument("--log-dir", type=str, default="logs",
|
48 |
+
help="Директория для сохранения логов (по умолчанию: logs)")
|
49 |
+
parser.add_argument("--skip-existing", action="store_true",
|
50 |
+
help="Пропускать эксперименты, если файлы результатов уже существуют")
|
51 |
+
parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
|
52 |
+
help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
|
53 |
+
parser.add_argument("--model", type=str, default=None,
|
54 |
+
help="Запустить эксперимент только для указанной модели")
|
55 |
+
parser.add_argument("--chunking-index", type=int, default=None,
|
56 |
+
help="Запустить эксперимент только для указанного индекса конфигурации чанкинга (0-6)")
|
57 |
+
parser.add_argument("--device", type=str, default="cuda:1",
|
58 |
+
help="Устройство для вычислений (по умолчанию: cuda:1)")
|
59 |
+
|
60 |
+
return parser.parse_args()
|
61 |
+
|
62 |
+
|
63 |
+
def run_experiment(model_name, chunking_params, args):
|
64 |
+
"""
|
65 |
+
Запускает эксперимент с определенной моделью и параметрами чанкинга.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
model_name: Название модели
|
69 |
+
chunking_params: Словарь с параметрами чанкинга
|
70 |
+
args: Аргументы командной строки
|
71 |
+
"""
|
72 |
+
words = chunking_params["words"]
|
73 |
+
overlap = chunking_params["overlap"]
|
74 |
+
description = chunking_params["description"]
|
75 |
+
|
76 |
+
# Формируем имя файла результатов
|
77 |
+
results_filename = f"results_fixed_size_w{words}_o{overlap}_{model_name.replace('/', '_')}.csv"
|
78 |
+
results_path = os.path.join(args.output_dir, results_filename)
|
79 |
+
|
80 |
+
# Проверяем, существует ли файл результатов
|
81 |
+
if args.skip_existing and os.path.exists(results_path):
|
82 |
+
print(f"Пропуск: {results_path} уже существует")
|
83 |
+
return
|
84 |
+
|
85 |
+
# Создаем директорию для логов, если она не существует
|
86 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
87 |
+
|
88 |
+
# Формируем имя файла лога
|
89 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
90 |
+
log_filename = f"log_{model_name.replace('/', '_')}_w{words}_o{overlap}_{timestamp}.txt"
|
91 |
+
log_path = os.path.join(args.log_dir, log_filename)
|
92 |
+
|
93 |
+
# Используем тот же интерпретатор Python, что и текущий скрипт
|
94 |
+
python_executable = sys.executable
|
95 |
+
|
96 |
+
# Запускаем скрипт evaluate_chunking.py с нужными параметрами
|
97 |
+
cmd = [
|
98 |
+
python_executable, "scripts/evaluate_chunking.py",
|
99 |
+
"--data-folder", args.data_folder,
|
100 |
+
"--model-name", model_name,
|
101 |
+
"--dataset-path", args.dataset_path,
|
102 |
+
"--output-dir", args.output_dir,
|
103 |
+
"--words-per-chunk", str(words),
|
104 |
+
"--overlap-words", str(overlap),
|
105 |
+
"--similarity-threshold", str(args.similarity_threshold),
|
106 |
+
"--device", args.device,
|
107 |
+
"--force-recompute" # Принудительно пересчитываем эмбеддинги
|
108 |
+
]
|
109 |
+
|
110 |
+
# Специальная обработка для модели ai-forever/FRIDA
|
111 |
+
if model_name == "ai-forever/FRIDA":
|
112 |
+
cmd.append("--use-sentence-transformers") # Добавляем флаг для использования sentence_transformers
|
113 |
+
|
114 |
+
print(f"\n{'='*80}")
|
115 |
+
print(f"Запуск эксперимента:")
|
116 |
+
print(f" Интерпретатор Python: {python_executable}")
|
117 |
+
print(f" Модель: {model_name}")
|
118 |
+
print(f" Чанкинг: {description} (words={words}, overlap={overlap})")
|
119 |
+
print(f" Порог для нечеткого сравнения: {args.similarity_threshold}")
|
120 |
+
print(f" Устройство: {args.device}")
|
121 |
+
print(f" Результаты будут сохранены в: {results_path}")
|
122 |
+
print(f" Лог: {log_path}")
|
123 |
+
print(f"{'='*80}\n")
|
124 |
+
|
125 |
+
# Запись информации в лог
|
126 |
+
with open(log_path, "w", encoding="utf-8") as log_file:
|
127 |
+
log_file.write(f"Эксперимент запущен в: {datetime.now()}\n")
|
128 |
+
log_file.write(f"Интерпретатор Python: {python_executable}\n")
|
129 |
+
log_file.write(f"Команда: {' '.join(cmd)}\n\n")
|
130 |
+
|
131 |
+
start_time = time.time()
|
132 |
+
|
133 |
+
# Запускаем процесс и перенаправляем вывод в файл лога
|
134 |
+
process = subprocess.Popen(
|
135 |
+
cmd,
|
136 |
+
stdout=subprocess.PIPE,
|
137 |
+
stderr=subprocess.STDOUT,
|
138 |
+
text=True,
|
139 |
+
bufsize=1 # Построчная буферизация
|
140 |
+
)
|
141 |
+
|
142 |
+
# Читаем вывод процесса
|
143 |
+
for line in process.stdout:
|
144 |
+
print(line, end="") # Выводим в консоль
|
145 |
+
log_file.write(line) # Записываем в файл лога
|
146 |
+
|
147 |
+
# Ждем завершения процесса
|
148 |
+
process.wait()
|
149 |
+
|
150 |
+
end_time = time.time()
|
151 |
+
duration = end_time - start_time
|
152 |
+
|
153 |
+
# Записываем информацию о завершении
|
154 |
+
log_file.write(f"\nЭксперимент завершен в: {datetime.now()}\n")
|
155 |
+
log_file.write(f"Длительность: {duration:.2f} секунд ({duration/60:.2f} минут)\n")
|
156 |
+
log_file.write(f"Код возврата: {process.returncode}\n")
|
157 |
+
|
158 |
+
if process.returncode == 0:
|
159 |
+
print(f"Эксперимент успешно завершен за {duration/60:.2f} минут")
|
160 |
+
else:
|
161 |
+
print(f"Эксперимент завершился с ошибкой (код {process.returncode})")
|
162 |
+
|
163 |
+
|
164 |
+
def main():
|
165 |
+
"""Основная функция скрипта."""
|
166 |
+
args = parse_args()
|
167 |
+
|
168 |
+
# Создаем output_dir, если он не существует
|
169 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
170 |
+
|
171 |
+
# Получаем список моделей для запуска
|
172 |
+
models_to_run = [args.model] if args.model else MODELS
|
173 |
+
|
174 |
+
# Получаем список конфигураций чанкинга для запуска
|
175 |
+
chunking_configs = [CHUNKING_PARAMS[args.chunking_index]] if args.chunking_index is not None else CHUNKING_PARAMS
|
176 |
+
|
177 |
+
start_time_all = time.time()
|
178 |
+
total_experiments = len(models_to_run) * len(chunking_configs)
|
179 |
+
completed_experiments = 0
|
180 |
+
|
181 |
+
print(f"Запуск {total_experiments} экспериментов...")
|
182 |
+
|
183 |
+
# Изменен порядок: сначала идём по стратегиям, затем по моделям
|
184 |
+
for chunking_config in chunking_configs:
|
185 |
+
print(f"\n=== Стратегия чанкинга: {chunking_config['description']} (words={chunking_config['words']}, overlap={chunking_config['overlap']}) ===\n")
|
186 |
+
|
187 |
+
for model in models_to_run:
|
188 |
+
# Запускаем эксперимент
|
189 |
+
run_experiment(model, chunking_config, args)
|
190 |
+
|
191 |
+
completed_experiments += 1
|
192 |
+
remaining_experiments = total_experiments - completed_experiments
|
193 |
+
|
194 |
+
if remaining_experiments > 0:
|
195 |
+
print(f"Завершено {completed_experiments}/{total_experiments} экспериментов. Осталось: {remaining_experiments}")
|
196 |
+
|
197 |
+
end_time_all = time.time()
|
198 |
+
total_duration = end_time_all - start_time_all
|
199 |
+
|
200 |
+
print(f"\nВсе эксперименты завершены за {total_duration/60:.2f} минут")
|
201 |
+
print(f"Результаты сохранены в {args.output_dir}")
|
202 |
+
print(f"Логи сохранены в {args.log_dir}")
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
main()
|
lib/extractor/scripts/search_api.py
ADDED
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для поиска по векторизованным документам через API.
|
4 |
+
|
5 |
+
Этот скрипт:
|
6 |
+
1. Считывает все документы из заданной папки с помощью UniversalParser
|
7 |
+
2. Чанкит каждый документ через Destructurer с fixed_size-стратегией
|
8 |
+
3. Векторизует поле in_search_text через BGE-модель
|
9 |
+
4. Поднимает FastAPI с двумя эндпоинтами:
|
10 |
+
- /search/entities - возвращает найденные сущности списком словарей
|
11 |
+
- /search/text - возвращает полноценный собранный текст
|
12 |
+
"""
|
13 |
+
|
14 |
+
import logging
|
15 |
+
import os
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Dict, List, Optional
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import pandas as pd
|
21 |
+
import torch
|
22 |
+
import uvicorn
|
23 |
+
from fastapi import FastAPI, Query
|
24 |
+
from ntr_fileparser import UniversalParser
|
25 |
+
from pydantic import BaseModel
|
26 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
27 |
+
from transformers import AutoModel, AutoTokenizer
|
28 |
+
|
29 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import \
|
30 |
+
FixedSizeChunkingStrategy
|
31 |
+
from ntr_text_fragmentation.core.destructurer import Destructurer
|
32 |
+
from ntr_text_fragmentation.core.entity_repository import \
|
33 |
+
InMemoryEntityRepository
|
34 |
+
from ntr_text_fragmentation.core.injection_builder import InjectionBuilder
|
35 |
+
from ntr_text_fragmentation.models.linker_entity import LinkerEntity
|
36 |
+
|
37 |
+
# Константы
|
38 |
+
DOCS_FOLDER = "../data/docs" # Путь к папке с документами
|
39 |
+
MODEL_NAME = "BAAI/bge-m3" # Название модели для векторизации
|
40 |
+
BATCH_SIZE = 16 # Размер батча для векторизации
|
41 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # Устройство для вычислений
|
42 |
+
MAX_ENTITIES = 100 # Максимальное количество возвращаемых сущностей
|
43 |
+
WORDS_PER_CHUNK = 50 # Количество слов в чанке для fixed_size стратегии
|
44 |
+
OVERLAP_WORDS = 25 # Количество слов перекрытия для fixed_size стратегии
|
45 |
+
|
46 |
+
# Пути к кэшированным файлам
|
47 |
+
CACHE_DIR = "../data/cache" # Путь к папке с кэшированными данными
|
48 |
+
ENTITIES_CSV = os.path.join(CACHE_DIR, "entities.csv") # Путь к CSV с сущностями
|
49 |
+
EMBEDDINGS_NPY = os.path.join(CACHE_DIR, "embeddings.npy") # Путь к массиву эмбеддингов
|
50 |
+
|
51 |
+
# Инициализация FastAPI
|
52 |
+
app = FastAPI(title="Документный поиск API",
|
53 |
+
description="API для поиска по векторизованным документам")
|
54 |
+
|
55 |
+
# Глобальные переменные для хранения данных
|
56 |
+
entities_df = None
|
57 |
+
entity_embeddings = None
|
58 |
+
model = None
|
59 |
+
tokenizer = None
|
60 |
+
entity_repository = None
|
61 |
+
injection_builder = None
|
62 |
+
|
63 |
+
|
64 |
+
class EntityResponse(BaseModel):
|
65 |
+
"""Модель ответа для сущностей."""
|
66 |
+
id: str
|
67 |
+
name: str
|
68 |
+
text: str
|
69 |
+
type: str
|
70 |
+
score: float
|
71 |
+
doc_name: Optional[str] = None
|
72 |
+
metadata: Optional[Dict] = None
|
73 |
+
|
74 |
+
|
75 |
+
class TextResponse(BaseModel):
|
76 |
+
"""Модель ответа для собранного текста."""
|
77 |
+
text: str
|
78 |
+
entities_count: int
|
79 |
+
|
80 |
+
|
81 |
+
class TextsResponse(BaseModel):
|
82 |
+
"""Модель ответа для списка текстов."""
|
83 |
+
texts: List[str]
|
84 |
+
entities_count: int
|
85 |
+
|
86 |
+
|
87 |
+
def setup_logging() -> None:
|
88 |
+
"""Настройка логгирования."""
|
89 |
+
logging.basicConfig(
|
90 |
+
level=logging.INFO,
|
91 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def load_documents(folder_path: str) -> Dict:
|
96 |
+
"""
|
97 |
+
Загружает все документы из указанной папки.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
folder_path: Путь к папке с документами
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Словарь {имя_файла: parsed_document}
|
104 |
+
"""
|
105 |
+
logging.info(f"Чтение документов из {folder_path}...")
|
106 |
+
parser = UniversalParser()
|
107 |
+
documents = {}
|
108 |
+
|
109 |
+
# Проверка существования папки
|
110 |
+
if not os.path.exists(folder_path):
|
111 |
+
logging.error(f"Папка {folder_path} не существует!")
|
112 |
+
return {}
|
113 |
+
|
114 |
+
for file_path in Path(folder_path).glob("**/*.docx"):
|
115 |
+
try:
|
116 |
+
doc_name = file_path.stem
|
117 |
+
logging.info(f"Обработка документа: {doc_name}")
|
118 |
+
documents[doc_name] = parser.parse_by_path(str(file_path))
|
119 |
+
except Exception as e:
|
120 |
+
logging.error(f"Ошибка при чтении файла {file_path}: {e}")
|
121 |
+
|
122 |
+
logging.info(f"Загружено {len(documents)} документов.")
|
123 |
+
return documents
|
124 |
+
|
125 |
+
|
126 |
+
def process_documents(documents: Dict) -> List[LinkerEntity]:
|
127 |
+
"""
|
128 |
+
Обрабатывает документы, применяя fixed_size стратегию чанкинга.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
documents: Словарь с распарсенными документами
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Список сущностей из всех документов
|
135 |
+
"""
|
136 |
+
logging.info("Применение fixed_size стратегии чанкинга ко всем документам...")
|
137 |
+
|
138 |
+
all_entities = []
|
139 |
+
|
140 |
+
for doc_name, document in documents.items():
|
141 |
+
try:
|
142 |
+
# Создаем Destructurer с fixed_size стратегией
|
143 |
+
destructurer = Destructurer(
|
144 |
+
document,
|
145 |
+
strategy_name="fixed_size",
|
146 |
+
words_per_chunk=WORDS_PER_CHUNK,
|
147 |
+
overlap_words=OVERLAP_WORDS
|
148 |
+
)
|
149 |
+
|
150 |
+
# Получаем сущности
|
151 |
+
doc_entities = destructurer.destructure()
|
152 |
+
|
153 |
+
# Добавляем имя документа в метаданные всех сущностей
|
154 |
+
for entity in doc_entities:
|
155 |
+
if not hasattr(entity, 'metadata') or entity.metadata is None:
|
156 |
+
entity.metadata = {}
|
157 |
+
entity.metadata['doc_name'] = doc_name
|
158 |
+
|
159 |
+
all_entities.extend(doc_entities)
|
160 |
+
logging.info(f"Документ {doc_name}: получено {len(doc_entities)} сущностей")
|
161 |
+
|
162 |
+
except Exception as e:
|
163 |
+
logging.error(f"Ошибка при обработке документа {doc_name}: {e}")
|
164 |
+
|
165 |
+
logging.info(f"Всего получено {len(all_entities)} сущностей из всех документов")
|
166 |
+
return all_entities
|
167 |
+
|
168 |
+
|
169 |
+
def entities_to_dataframe(entities: List[LinkerEntity]) -> pd.DataFrame:
|
170 |
+
"""
|
171 |
+
Преобразует список сущностей в DataFrame для удобной работы.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
entities: Список сущностей
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
DataFrame с данными сущностей
|
178 |
+
"""
|
179 |
+
data = []
|
180 |
+
|
181 |
+
for entity in entities:
|
182 |
+
# Получаем имя документа из метаданных
|
183 |
+
doc_name = entity.metadata.get('doc_name', '') if hasattr(entity, 'metadata') and entity.metadata else ''
|
184 |
+
|
185 |
+
# Базовые поля для всех типов сущностей
|
186 |
+
entity_dict = {
|
187 |
+
"id": str(entity.id),
|
188 |
+
"type": entity.type,
|
189 |
+
"name": entity.name,
|
190 |
+
"text": entity.text,
|
191 |
+
"in_search_text": entity.in_search_text,
|
192 |
+
"doc_name": doc_name,
|
193 |
+
"source_id": entity.source_id if hasattr(entity, 'source_id') else None,
|
194 |
+
"target_id": entity.target_id if hasattr(entity, 'target_id') else None,
|
195 |
+
"metadata": entity.metadata if hasattr(entity, 'metadata') else {},
|
196 |
+
}
|
197 |
+
|
198 |
+
data.append(entity_dict)
|
199 |
+
|
200 |
+
df = pd.DataFrame(data)
|
201 |
+
return df
|
202 |
+
|
203 |
+
|
204 |
+
def setup_model_and_tokenizer():
|
205 |
+
"""
|
206 |
+
Инициализирует модель и токенизатор для векторизации.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
Кортеж (модель, токенизатор)
|
210 |
+
"""
|
211 |
+
global model, tokenizer
|
212 |
+
|
213 |
+
logging.info(f"Загрузка модели {MODEL_NAME} на устройство {DEVICE}...")
|
214 |
+
|
215 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
216 |
+
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
217 |
+
model.eval()
|
218 |
+
|
219 |
+
return model, tokenizer
|
220 |
+
|
221 |
+
|
222 |
+
def _average_pool(
|
223 |
+
last_hidden_states: torch.Tensor,
|
224 |
+
attention_mask: torch.Tensor
|
225 |
+
) -> torch.Tensor:
|
226 |
+
"""
|
227 |
+
Расчёт усредненного эмбеддинга по всем токенам
|
228 |
+
|
229 |
+
Args:
|
230 |
+
last_hidden_states: Матрица эмбеддингов отдельных токенов
|
231 |
+
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
Усредненный эмбеддинг
|
235 |
+
"""
|
236 |
+
last_hidden = last_hidden_states.masked_fill(
|
237 |
+
~attention_mask[..., None].bool(), 0.0
|
238 |
+
)
|
239 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
240 |
+
|
241 |
+
|
242 |
+
def get_embeddings(texts: List[str]) -> np.ndarray:
|
243 |
+
"""
|
244 |
+
Получает эмбеддинги для списка текстов.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
texts: Список текстов для векторизации
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Массив эмбеддингов
|
251 |
+
"""
|
252 |
+
global model, tokenizer
|
253 |
+
|
254 |
+
# Проверяем, что модель и токенизатор инициализированы
|
255 |
+
if model is None or tokenizer is None:
|
256 |
+
model, tokenizer = setup_model_and_tokenizer()
|
257 |
+
|
258 |
+
all_embeddings = []
|
259 |
+
|
260 |
+
for i in range(0, len(texts), BATCH_SIZE):
|
261 |
+
batch_texts = texts[i:i+BATCH_SIZE]
|
262 |
+
|
263 |
+
# Фильтруем None и пустые строки
|
264 |
+
batch_texts = [text for text in batch_texts if text]
|
265 |
+
|
266 |
+
if not batch_texts:
|
267 |
+
continue
|
268 |
+
|
269 |
+
# Токенизация с обрезкой и padding
|
270 |
+
encoding = tokenizer(
|
271 |
+
batch_texts,
|
272 |
+
padding=True,
|
273 |
+
truncation=True,
|
274 |
+
max_length=512,
|
275 |
+
return_tensors="pt"
|
276 |
+
).to(DEVICE)
|
277 |
+
|
278 |
+
# Получаем эмбеддинги с average pooling
|
279 |
+
with torch.no_grad():
|
280 |
+
outputs = model(**encoding)
|
281 |
+
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
|
282 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
283 |
+
|
284 |
+
if not all_embeddings:
|
285 |
+
return np.array([])
|
286 |
+
|
287 |
+
return np.vstack(all_embeddings)
|
288 |
+
|
289 |
+
|
290 |
+
def init_entity_repository_and_builder(entities: List[LinkerEntity]):
|
291 |
+
"""
|
292 |
+
Инициализирует хранилище сущностей и сборщик инъекций.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
entities: Список сущностей
|
296 |
+
"""
|
297 |
+
global entity_repository, injection_builder
|
298 |
+
|
299 |
+
# Создаем хранилище сущностей
|
300 |
+
entity_repository = InMemoryEntityRepository(entities)
|
301 |
+
|
302 |
+
# Добавляем метод get_entity_by_id в InMemoryEntityRepository
|
303 |
+
# Это временное решение, в идеале нужно добавить этот метод в сам класс
|
304 |
+
def get_entity_by_id(self, entity_id):
|
305 |
+
"""Получает сущность по ID"""
|
306 |
+
for entity in self.entities:
|
307 |
+
if str(entity.id) == entity_id:
|
308 |
+
return entity
|
309 |
+
return None
|
310 |
+
|
311 |
+
# Добавляем метод в класс
|
312 |
+
InMemoryEntityRepository.get_entity_by_id = get_entity_by_id
|
313 |
+
|
314 |
+
# Создаем сборщик инъекций
|
315 |
+
injection_builder = InjectionBuilder(repository=entity_repository)
|
316 |
+
|
317 |
+
# Регистрируем стратегию
|
318 |
+
injection_builder.register_strategy("fixed_size", FixedSizeChunkingStrategy)
|
319 |
+
|
320 |
+
|
321 |
+
def search_entities(query: str, top_n: int = MAX_ENTITIES) -> List[Dict]:
|
322 |
+
"""
|
323 |
+
Ищет сущности по запросу на основе косинусной близости.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
query: Поисковый запрос
|
327 |
+
top_n: Максимальное количество возвращаемых сущностей
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
Список найденных сущностей с их скорами
|
331 |
+
"""
|
332 |
+
global entities_df, entity_embeddings
|
333 |
+
|
334 |
+
# Проверяем наличие данных
|
335 |
+
if entities_df is None or entity_embeddings is None:
|
336 |
+
logging.error("Данные не инициализированы. Запустите сначала prepare_data().")
|
337 |
+
return []
|
338 |
+
|
339 |
+
# Векторизуем запрос
|
340 |
+
query_embedding = get_embeddings([query])
|
341 |
+
|
342 |
+
if query_embedding.size == 0:
|
343 |
+
return []
|
344 |
+
|
345 |
+
# Считаем косинусную близость
|
346 |
+
similarities = cosine_similarity(query_embedding, entity_embeddings)[0]
|
347 |
+
|
348 |
+
# Получаем индексы топ-N сущностей
|
349 |
+
top_indices = np.argsort(similarities)[-top_n:][::-1]
|
350 |
+
|
351 |
+
# Фильтруем сущности, которые используются для поиска
|
352 |
+
search_df = entities_df.copy()
|
353 |
+
search_df = search_df[search_df['in_search_text'].notna()]
|
354 |
+
|
355 |
+
# Если после фильтрации нет данных, возвращаем пустой список
|
356 |
+
if search_df.empty:
|
357 |
+
return []
|
358 |
+
|
359 |
+
# Получаем топ-N сущностей
|
360 |
+
results = []
|
361 |
+
|
362 |
+
for idx in top_indices:
|
363 |
+
if idx >= len(search_df):
|
364 |
+
continue
|
365 |
+
|
366 |
+
entity = search_df.iloc[idx]
|
367 |
+
similarity = similarities[idx]
|
368 |
+
|
369 |
+
# Создаем результат
|
370 |
+
result = {
|
371 |
+
"id": entity["id"],
|
372 |
+
"name": entity["name"],
|
373 |
+
"text": entity["text"],
|
374 |
+
"type": entity["type"],
|
375 |
+
"score": float(similarity),
|
376 |
+
"doc_name": entity["doc_name"],
|
377 |
+
"metadata": entity["metadata"]
|
378 |
+
}
|
379 |
+
|
380 |
+
results.append(result)
|
381 |
+
|
382 |
+
return results
|
383 |
+
|
384 |
+
|
385 |
+
@app.get("/search/entities", response_model=List[EntityResponse])
|
386 |
+
async def api_search_entities(
|
387 |
+
query: str = Query(..., description="Поисковый запрос"),
|
388 |
+
limit: int = Query(MAX_ENTITIES, description="Максимальное количество результатов")
|
389 |
+
):
|
390 |
+
"""
|
391 |
+
Эндпоинт для поиска сущностей по запросу.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
query: Поисковый запрос
|
395 |
+
limit: Максимальное количество результатов
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
Список найденных сущностей
|
399 |
+
"""
|
400 |
+
results = search_entities(query, limit)
|
401 |
+
return results
|
402 |
+
|
403 |
+
|
404 |
+
@app.get("/search/text", response_model=TextResponse)
|
405 |
+
async def api_search_text(
|
406 |
+
query: str = Query(..., description="Поисковый запрос"),
|
407 |
+
limit: int = Query(MAX_ENTITIES, description="Максимальное количество учитываемых сущностей")
|
408 |
+
):
|
409 |
+
"""
|
410 |
+
Эндпоинт для поиска и сборки полного текста по запросу.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
query: Поисковый запрос
|
414 |
+
limit: Максимальное количество учитываемых сущностей
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
Собранный текст и количество использованных сущностей
|
418 |
+
"""
|
419 |
+
global injection_builder
|
420 |
+
|
421 |
+
# Проверяем наличие сборщика инъекций
|
422 |
+
if injection_builder is None:
|
423 |
+
logging.error("Сборщик инъекций не инициализирован.")
|
424 |
+
return {"text": "", "entities_count": 0}
|
425 |
+
|
426 |
+
# Получаем найденные сущности
|
427 |
+
entity_results = search_entities(query, limit)
|
428 |
+
|
429 |
+
if not entity_results:
|
430 |
+
return {"text": "", "entities_count": 0}
|
431 |
+
|
432 |
+
# Получаем список ID сущностей
|
433 |
+
entity_ids = [str(result["id"]) for result in entity_results]
|
434 |
+
|
435 |
+
# Собираем текст, используя напрямую ID
|
436 |
+
try:
|
437 |
+
assembled_text = injection_builder.build(entity_ids)
|
438 |
+
print('Всё ок прошло вроде бы')
|
439 |
+
return {"text": assembled_text, "entities_count": len(entity_ids)}
|
440 |
+
except ImportError as e:
|
441 |
+
# Обработка ошибки импорта модулей для работы с изображениями
|
442 |
+
logging.error(f"Ошибка импорта при сборке текста: {e}")
|
443 |
+
# Альтернативная сборка текста без использования injection_builder
|
444 |
+
simple_text = "\n\n".join([result["text"] for result in entity_results if result.get("text")])
|
445 |
+
return {"text": simple_text, "entities_count": len(entity_ids)}
|
446 |
+
except Exception as e:
|
447 |
+
logging.error(f"Ошибка при сборке текста: {e}")
|
448 |
+
return {"text": "", "entities_count": 0}
|
449 |
+
|
450 |
+
|
451 |
+
@app.get("/search/texts", response_model=TextsResponse)
|
452 |
+
async def api_search_texts(
|
453 |
+
query: str = Query(..., description="Поисковый запрос"),
|
454 |
+
limit: int = Query(MAX_ENTITIES, description="Максимальное количество результатов")
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Эндпоинт для поиска списка текстов сущностей по запросу.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
query: Поисковый запрос
|
461 |
+
limit: Максимальное количество результатов
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
Список текстов найденных сущностей и их количество
|
465 |
+
"""
|
466 |
+
# Получаем найденные сущности
|
467 |
+
entity_results = search_entities(query, limit)
|
468 |
+
|
469 |
+
if not entity_results:
|
470 |
+
return {"texts": [], "entities_count": 0}
|
471 |
+
|
472 |
+
# Извлекаем тексты из результатов
|
473 |
+
texts = [result["text"] for result in entity_results if result.get("text")]
|
474 |
+
|
475 |
+
return {"texts": texts, "entities_count": len(texts)}
|
476 |
+
|
477 |
+
|
478 |
+
@app.get("/search/text_test", response_model=TextResponse)
|
479 |
+
async def api_search_text_test(
|
480 |
+
query: str = Query(..., description="Поисковый запрос"),
|
481 |
+
limit: int = Query(MAX_ENTITIES, description="Максимальное количество учитываемых сущностей")
|
482 |
+
):
|
483 |
+
"""
|
484 |
+
Тестовый эндпоинт для поиска и сборки текста с использованием подхода из test_chunking_visualization.py.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
query: Поисковый запрос
|
488 |
+
limit: Максимальное количество учитываемых сущностей
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
Собранный текст и количество использованных сущностей
|
492 |
+
"""
|
493 |
+
global entity_repository, injection_builder
|
494 |
+
|
495 |
+
# Проверяем наличие репозитория и сборщика инъекций
|
496 |
+
if entity_repository is None or injection_builder is None:
|
497 |
+
logging.error("Репозиторий или сборщик инъекций не инициализированы.")
|
498 |
+
return {"text": "", "entities_count": 0}
|
499 |
+
|
500 |
+
# Получаем найденные сущности
|
501 |
+
entity_results = search_entities(query, limit)
|
502 |
+
|
503 |
+
if not entity_results:
|
504 |
+
return {"text": "", "entities_count": 0}
|
505 |
+
|
506 |
+
try:
|
507 |
+
# Получаем объекты сущностей из репозитория по ID
|
508 |
+
entity_ids = [result["id"] for result in entity_results]
|
509 |
+
entities = []
|
510 |
+
|
511 |
+
for entity_id in entity_ids:
|
512 |
+
entity = entity_repository.get_entity_by_id(entity_id)
|
513 |
+
if entity:
|
514 |
+
entities.append(entity)
|
515 |
+
|
516 |
+
logging.info(f"Найдено {len(entities)} объектов сущностей по ID")
|
517 |
+
|
518 |
+
if not entities:
|
519 |
+
logging.error("Не удалось найти сущности в репозитории")
|
520 |
+
# Собираем простой текст из результатов поиска
|
521 |
+
simple_text = "\n\n".join([result["text"] for result in entity_results if result.get("text")])
|
522 |
+
return {"text": simple_text, "entities_count": len(entity_results)}
|
523 |
+
|
524 |
+
# Собираем текст, как в test_chunking_visualization.py
|
525 |
+
assembled_text = injection_builder.build(entities) # Передаем сами объекты
|
526 |
+
|
527 |
+
return {"text": assembled_text, "entities_count": len(entities)}
|
528 |
+
except Exception as e:
|
529 |
+
logging.error(f"Ошибка при сборке текста: {e}", exc_info=True)
|
530 |
+
# Запасной вариант - просто соединяем тексты
|
531 |
+
fallback_text = "\n\n".join([result["text"] for result in entity_results if result.get("text")])
|
532 |
+
return {"text": fallback_text, "entities_count": len(entity_results)}
|
533 |
+
|
534 |
+
|
535 |
+
def save_entities_to_csv(entities: List[LinkerEntity], csv_path: str) -> None:
|
536 |
+
"""
|
537 |
+
Сохраняет сущности в CSV файл.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
entities: Список сущностей
|
541 |
+
csv_path: Путь для сохранения CSV файла
|
542 |
+
"""
|
543 |
+
logging.info(f"Сохранение {len(entities)} сущностей в {csv_path}")
|
544 |
+
|
545 |
+
# Создаем директорию, если она не существует
|
546 |
+
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
|
547 |
+
|
548 |
+
# Преобразуем сущности в DataFrame и сохраняем
|
549 |
+
df = entities_to_dataframe(entities)
|
550 |
+
df.to_csv(csv_path, index=False)
|
551 |
+
|
552 |
+
logging.info(f"Сохранено {len(entities)} сущностей в {csv_path}")
|
553 |
+
|
554 |
+
|
555 |
+
def load_entities_from_csv(csv_path: str) -> List[LinkerEntity]:
|
556 |
+
"""
|
557 |
+
Загружает сущности из CSV файла.
|
558 |
+
|
559 |
+
Args:
|
560 |
+
csv_path: Путь к CSV файлу
|
561 |
+
|
562 |
+
Returns:
|
563 |
+
Список сущностей
|
564 |
+
"""
|
565 |
+
logging.info(f"Загрузка сущностей из {csv_path}")
|
566 |
+
|
567 |
+
if not os.path.exists(csv_path):
|
568 |
+
logging.error(f"Файл {csv_path} не найден")
|
569 |
+
return []
|
570 |
+
|
571 |
+
df = pd.read_csv(csv_path)
|
572 |
+
entities = []
|
573 |
+
|
574 |
+
for _, row in df.iterrows():
|
575 |
+
# Обработка метаданных
|
576 |
+
metadata = row.get("metadata", {})
|
577 |
+
if isinstance(metadata, str):
|
578 |
+
try:
|
579 |
+
metadata = eval(metadata) if metadata and not pd.isna(metadata) else {}
|
580 |
+
except:
|
581 |
+
metadata = {}
|
582 |
+
|
583 |
+
# Общие поля для всех типов сущностей
|
584 |
+
common_args = {
|
585 |
+
"id": row["id"],
|
586 |
+
"name": row["name"] if not pd.isna(row.get("name", "")) else "",
|
587 |
+
"text": row["text"] if not pd.isna(row.get("text", "")) else "",
|
588 |
+
"metadata": metadata,
|
589 |
+
"type": row["type"],
|
590 |
+
}
|
591 |
+
|
592 |
+
# Добавляем in_search_text, если он есть
|
593 |
+
if "in_search_text" in row and not pd.isna(row["in_search_text"]):
|
594 |
+
common_args["in_search_text"] = row["in_search_text"]
|
595 |
+
|
596 |
+
# Добавляем поля связи, если они есть
|
597 |
+
if "source_id" in row and not pd.isna(row["source_id"]):
|
598 |
+
common_args["source_id"] = row["source_id"]
|
599 |
+
common_args["target_id"] = row["target_id"]
|
600 |
+
if "number_in_relation" in row and not pd.isna(row["number_in_relation"]):
|
601 |
+
common_args["number_in_relation"] = int(row["number_in_relation"])
|
602 |
+
|
603 |
+
entity = LinkerEntity(**common_args)
|
604 |
+
entities.append(entity)
|
605 |
+
|
606 |
+
logging.info(f"Загружено {len(entities)} сущностей из {csv_path}")
|
607 |
+
return entities
|
608 |
+
|
609 |
+
|
610 |
+
def save_embeddings(embeddings: np.ndarray, file_path: str) -> None:
|
611 |
+
"""
|
612 |
+
Сохраняет эмбеддинги в numpy файл.
|
613 |
+
|
614 |
+
Args:
|
615 |
+
embeddings: Массив эмбеддингов
|
616 |
+
file_path: Путь для сохранения файла
|
617 |
+
"""
|
618 |
+
logging.info(f"Сохранение эмбеддингов размером {embeddings.shape} в {file_path}")
|
619 |
+
|
620 |
+
# Создаем директорию, если она не существует
|
621 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
622 |
+
|
623 |
+
# Сохраняем эмбеддинги
|
624 |
+
np.save(file_path, embeddings)
|
625 |
+
|
626 |
+
logging.info(f"Эмбеддинги сохранены в {file_path}")
|
627 |
+
|
628 |
+
|
629 |
+
def load_embeddings(file_path: str) -> np.ndarray:
|
630 |
+
"""
|
631 |
+
Загружает эмбеддинги из numpy файла.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
file_path: Путь к файлу
|
635 |
+
|
636 |
+
Returns:
|
637 |
+
Массив эмбеддингов
|
638 |
+
"""
|
639 |
+
logging.info(f"Загрузка эмбеддингов из {file_path}")
|
640 |
+
|
641 |
+
if not os.path.exists(file_path):
|
642 |
+
logging.error(f"Файл {file_path} не найден")
|
643 |
+
return np.array([])
|
644 |
+
|
645 |
+
embeddings = np.load(file_path)
|
646 |
+
|
647 |
+
logging.info(f"Загружены эмбеддинги размером {embeddings.shape}")
|
648 |
+
return embeddings
|
649 |
+
|
650 |
+
|
651 |
+
def prepare_data():
|
652 |
+
"""
|
653 |
+
Подготавливает все необходимые данные для API.
|
654 |
+
"""
|
655 |
+
global entities_df, entity_embeddings, entity_repository, injection_builder
|
656 |
+
|
657 |
+
# Проверяем наличие кэшированных данных
|
658 |
+
cache_exists = os.path.exists(ENTITIES_CSV) and os.path.exists(EMBEDDINGS_NPY)
|
659 |
+
|
660 |
+
if cache_exists:
|
661 |
+
logging.info("Найдены кэшированные данные, загружаем их")
|
662 |
+
|
663 |
+
# Загружаем сущности из CSV
|
664 |
+
entities = load_entities_from_csv(ENTITIES_CSV)
|
665 |
+
|
666 |
+
if not entities:
|
667 |
+
logging.error("Не удалось загрузить сущности из кэша, генерируем заново")
|
668 |
+
cache_exists = False
|
669 |
+
else:
|
670 |
+
# Преобразуем сущности в DataFrame
|
671 |
+
entities_df = entities_to_dataframe(entities)
|
672 |
+
|
673 |
+
# Загружаем эмбеддинги
|
674 |
+
entity_embeddings = load_embeddings(EMBEDDINGS_NPY)
|
675 |
+
|
676 |
+
if entity_embeddings.size == 0:
|
677 |
+
logging.error("Не удалось загрузить эмбеддинги из кэша, генерируем заново")
|
678 |
+
cache_exists = False
|
679 |
+
else:
|
680 |
+
# Инициализируем хранилище и сборщик
|
681 |
+
init_entity_repository_and_builder(entities)
|
682 |
+
logging.info("Данные успешно загружены из кэша")
|
683 |
+
|
684 |
+
# Если кэшированных данных нет или их не удалось загрузить, генерируем заново
|
685 |
+
if not cache_exists:
|
686 |
+
logging.info("Кэшированные данные не найдены или не могут быть загружены, обрабатываем документы")
|
687 |
+
|
688 |
+
# Загружаем и обрабатываем документы
|
689 |
+
documents = load_documents(DOCS_FOLDER)
|
690 |
+
|
691 |
+
if not documents:
|
692 |
+
logging.error(f"Не найдено документов в папке {DOCS_FOLDER}")
|
693 |
+
return
|
694 |
+
|
695 |
+
# Получаем сущности из всех документов
|
696 |
+
all_entities = process_documents(documents)
|
697 |
+
|
698 |
+
if not all_entities:
|
699 |
+
logging.error("Не получено сущностей из документов")
|
700 |
+
return
|
701 |
+
|
702 |
+
# Преобразуем сущности в DataFrame
|
703 |
+
entities_df = entities_to_dataframe(all_entities)
|
704 |
+
|
705 |
+
# Инициализируем хранилище и сборщик
|
706 |
+
init_entity_repository_and_builder(all_entities)
|
707 |
+
|
708 |
+
# Фильтруем только сущности для поиска
|
709 |
+
search_df = entities_df[entities_df['in_search_text'].notna()]
|
710 |
+
|
711 |
+
if search_df.empty:
|
712 |
+
logging.error("Нет сущностей для поиска с in_search_text")
|
713 |
+
return
|
714 |
+
|
715 |
+
# Векторизуем тексты сущностей
|
716 |
+
search_texts = search_df['in_search_text'].tolist()
|
717 |
+
entity_embeddings = get_embeddings(search_texts)
|
718 |
+
|
719 |
+
logging.info(f"Подготовлено {len(search_df)} сущностей для поиска")
|
720 |
+
logging.info(f"Размер эмбеддингов: {entity_embeddings.shape}")
|
721 |
+
|
722 |
+
# Сохраняем данные в кэш для последующего использования
|
723 |
+
save_entities_to_csv(all_entities, ENTITIES_CSV)
|
724 |
+
save_embeddings(entity_embeddings, EMBEDDINGS_NPY)
|
725 |
+
logging.info("Данные сохранены в кэш для последующего использования")
|
726 |
+
|
727 |
+
# Вывод итоговой информации (независимо от источника данных)
|
728 |
+
logging.info(f"Подготовка данных завершена. Готово к использованию {entity_embeddings.shape[0]} сущностей")
|
729 |
+
|
730 |
+
|
731 |
+
@app.on_event("startup")
|
732 |
+
async def startup_event():
|
733 |
+
"""Запускается при старте приложения."""
|
734 |
+
setup_logging()
|
735 |
+
prepare_data()
|
736 |
+
|
737 |
+
|
738 |
+
def main():
|
739 |
+
"""Основная функция для запуска скрипта вручную."""
|
740 |
+
setup_logging()
|
741 |
+
prepare_data()
|
742 |
+
|
743 |
+
# Запуск Uvicorn сервера
|
744 |
+
uvicorn.run(app, host="0.0.0.0", port=8017)
|
745 |
+
|
746 |
+
|
747 |
+
if __name__ == "__main__":
|
748 |
+
main()
|
lib/extractor/scripts/test_chunking_visualization.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для визуального тестирования процесса чанкинга и сборки документа.
|
4 |
+
|
5 |
+
Этот скрипт:
|
6 |
+
1. Считывает test_input/test.docx с помощью UniversalParser
|
7 |
+
2. Чанкит документ через Destructurer с fixed_size-стратегией
|
8 |
+
3. Сохраняет результат чанкинга в test_output/test.csv
|
9 |
+
4. Выбирает 20-30 случайных чанков из CSV
|
10 |
+
5. Создает InjectionBuilder с InMemoryEntityRepository
|
11 |
+
6. Собирает текст из выбранных чанков
|
12 |
+
7. Сохраняет результат в test_output/test_builded.txt
|
13 |
+
"""
|
14 |
+
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
import random
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
import pandas as pd
|
22 |
+
from ntr_fileparser import UniversalParser
|
23 |
+
|
24 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import \
|
25 |
+
FixedSizeChunkingStrategy
|
26 |
+
from ntr_text_fragmentation.core.destructurer import Destructurer
|
27 |
+
from ntr_text_fragmentation.core.entity_repository import \
|
28 |
+
InMemoryEntityRepository
|
29 |
+
from ntr_text_fragmentation.core.injection_builder import InjectionBuilder
|
30 |
+
from ntr_text_fragmentation.models.linker_entity import LinkerEntity
|
31 |
+
|
32 |
+
|
33 |
+
def setup_logging() -> None:
|
34 |
+
"""Настройка логгирования."""
|
35 |
+
logging.basicConfig(
|
36 |
+
level=logging.INFO,
|
37 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def ensure_directories() -> None:
|
42 |
+
"""Проверка наличия необходимых директорий."""
|
43 |
+
for directory in ["test_input", "test_output"]:
|
44 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
45 |
+
|
46 |
+
|
47 |
+
def save_entities_to_csv(entities: List[LinkerEntity], csv_path: str) -> None:
|
48 |
+
"""
|
49 |
+
Сохраняет сущности в CSV файл.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
entities: Список сущностей
|
53 |
+
csv_path: Путь для сохранения CSV файла
|
54 |
+
"""
|
55 |
+
data = []
|
56 |
+
for entity in entities:
|
57 |
+
# Базовые поля для всех типов сущностей
|
58 |
+
entity_dict = {
|
59 |
+
"id": str(entity.id),
|
60 |
+
"type": entity.type,
|
61 |
+
"name": entity.name,
|
62 |
+
"text": entity.text,
|
63 |
+
"metadata": str(entity.metadata),
|
64 |
+
"in_search_text": entity.in_search_text,
|
65 |
+
"source_id": entity.source_id,
|
66 |
+
"target_id": entity.target_id,
|
67 |
+
"number_in_relation": entity.number_in_relation,
|
68 |
+
}
|
69 |
+
|
70 |
+
data.append(entity_dict)
|
71 |
+
|
72 |
+
df = pd.DataFrame(data)
|
73 |
+
df.to_csv(csv_path, index=False)
|
74 |
+
logging.info(f"Сохранено {len(entities)} сущностей в {csv_path}")
|
75 |
+
|
76 |
+
|
77 |
+
def load_entities_from_csv(csv_path: str) -> List[LinkerEntity]:
|
78 |
+
"""
|
79 |
+
Загружает сущности из CSV файла.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
csv_path: Путь к CSV файлу
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Список сущностей
|
86 |
+
"""
|
87 |
+
df = pd.read_csv(csv_path)
|
88 |
+
entities = []
|
89 |
+
|
90 |
+
for _, row in df.iterrows():
|
91 |
+
# Обработка метаданных
|
92 |
+
metadata_str = row.get("metadata", "{}")
|
93 |
+
try:
|
94 |
+
metadata = (
|
95 |
+
eval(metadata_str) if metadata_str and not pd.isna(metadata_str) else {}
|
96 |
+
)
|
97 |
+
except:
|
98 |
+
metadata = {}
|
99 |
+
|
100 |
+
# Общие поля для всех типов сущностей
|
101 |
+
common_args = {
|
102 |
+
"id": row["id"],
|
103 |
+
"name": row["name"] if not pd.isna(row.get("name", "")) else "",
|
104 |
+
"text": row["text"] if not pd.isna(row.get("text", "")) else "",
|
105 |
+
"metadata": metadata,
|
106 |
+
"in_search_text": row["in_search_text"],
|
107 |
+
"type": row["type"],
|
108 |
+
}
|
109 |
+
|
110 |
+
# Добавляем поля связи, если они есть
|
111 |
+
if not pd.isna(row.get("source_id", "")):
|
112 |
+
common_args["source_id"] = row["source_id"]
|
113 |
+
common_args["target_id"] = row["target_id"]
|
114 |
+
if not pd.isna(row.get("number_in_relation", "")):
|
115 |
+
common_args["number_in_relation"] = int(row["number_in_relation"])
|
116 |
+
|
117 |
+
entity = LinkerEntity(**common_args)
|
118 |
+
entities.append(entity)
|
119 |
+
|
120 |
+
logging.info(f"Загружено {len(entities)} сущностей из {csv_path}")
|
121 |
+
return entities
|
122 |
+
|
123 |
+
|
124 |
+
def main() -> None:
|
125 |
+
"""Основная функция скрипта."""
|
126 |
+
setup_logging()
|
127 |
+
ensure_directories()
|
128 |
+
|
129 |
+
# Пути к файлам
|
130 |
+
input_doc_path = "test_input/test.docx"
|
131 |
+
output_csv_path = "test_output/test.csv"
|
132 |
+
output_text_path = "test_output/test_builded.txt"
|
133 |
+
|
134 |
+
# Проверка наличия входного файла
|
135 |
+
if not os.path.exists(input_doc_path):
|
136 |
+
logging.error(f"Файл {input_doc_path} не найден!")
|
137 |
+
return
|
138 |
+
|
139 |
+
logging.info(f"Парсинг документа {input_doc_path}")
|
140 |
+
|
141 |
+
try:
|
142 |
+
# Шаг 1: Парсинг документа дважды, как если бы это были два разных документа
|
143 |
+
parser = UniversalParser()
|
144 |
+
document1 = parser.parse_by_path(input_doc_path)
|
145 |
+
document2 = parser.parse_by_path(input_doc_path)
|
146 |
+
|
147 |
+
# Меняем название второго документа, чтобы отличить его
|
148 |
+
document2.name = document2.name + "_copy" if document2.name else "copy_doc"
|
149 |
+
|
150 |
+
# Шаг 2: Чанкинг обоих документов с использованием fixed_size-стратегии
|
151 |
+
all_entities = []
|
152 |
+
|
153 |
+
# Обработка первого документа
|
154 |
+
destructurer1 = Destructurer(
|
155 |
+
document1, strategy_name="fixed_size", words_per_chunk=50, overlap_words=25
|
156 |
+
)
|
157 |
+
logging.info("Начало процесса чанкинга первого документа")
|
158 |
+
entities1 = destructurer1.destructure()
|
159 |
+
|
160 |
+
# Добавляем метаданные о документе к каждой сущности
|
161 |
+
for entity in entities1:
|
162 |
+
if not hasattr(entity, 'metadata') or entity.metadata is None:
|
163 |
+
entity.metadata = {}
|
164 |
+
entity.metadata['doc_name'] = "document1"
|
165 |
+
|
166 |
+
logging.info(f"Получено {len(entities1)} сущностей из первого документа")
|
167 |
+
all_entities.extend(entities1)
|
168 |
+
|
169 |
+
# Обработка второго документа
|
170 |
+
destructurer2 = Destructurer(
|
171 |
+
document2, strategy_name="fixed_size", words_per_chunk=50, overlap_words=25
|
172 |
+
)
|
173 |
+
logging.info("Начало процесса чанкинга второго документа")
|
174 |
+
entities2 = destructurer2.destructure()
|
175 |
+
|
176 |
+
# Добавляем метаданные о документе к каждой сущности
|
177 |
+
for entity in entities2:
|
178 |
+
if not hasattr(entity, 'metadata') or entity.metadata is None:
|
179 |
+
entity.metadata = {}
|
180 |
+
entity.metadata['doc_name'] = "document2"
|
181 |
+
|
182 |
+
logging.info(f"Получено {len(entities2)} сущностей из второго документа")
|
183 |
+
all_entities.extend(entities2)
|
184 |
+
|
185 |
+
logging.info(f"Всего получено {len(all_entities)} сущностей из обоих документов")
|
186 |
+
|
187 |
+
# Шаг 3: Сохранение результатов чанкинга в CSV
|
188 |
+
save_entities_to_csv(all_entities, output_csv_path)
|
189 |
+
|
190 |
+
# Шаг 4: Загрузка сущностей из CSV и выбор случайных чанков
|
191 |
+
loaded_entities = load_entities_from_csv(output_csv_path)
|
192 |
+
|
193 |
+
# Фильтрация только чанков
|
194 |
+
chunks = [e for e in loaded_entities if e.in_search_text is not None]
|
195 |
+
|
196 |
+
# Выбор случайных чанков (от 20 до 30)
|
197 |
+
num_chunks_to_select = min(random.randint(20, 30), len(chunks))
|
198 |
+
selected_chunks = random.sample(chunks, num_chunks_to_select)
|
199 |
+
|
200 |
+
logging.info(f"Выбрано {len(selected_chunks)} случайных чанков для сборки")
|
201 |
+
|
202 |
+
# Дополнительная статистика по документам
|
203 |
+
doc1_chunks = [c for c in selected_chunks if hasattr(c, 'metadata') and c.metadata.get('doc_name') == "document1"]
|
204 |
+
doc2_chunks = [c for c in selected_chunks if hasattr(c, 'metadata') and c.metadata.get('doc_name') == "document2"]
|
205 |
+
logging.info(f"Из них {len(doc1_chunks)} чанков из первого документа и {len(doc2_chunks)} из второго")
|
206 |
+
|
207 |
+
# Шаг 5: Создание InjectionBuilder с InMemoryEntityRepository
|
208 |
+
repository = InMemoryEntityRepository(loaded_entities)
|
209 |
+
builder = InjectionBuilder(repository=repository)
|
210 |
+
|
211 |
+
# Регистрация стратегии
|
212 |
+
builder.register_strategy("fixed_size", FixedSizeChunkingStrategy)
|
213 |
+
|
214 |
+
# Шаг 6: Сборка текста из выбранных чанков
|
215 |
+
logging.info("Начало сборки текста из выбранных чанков")
|
216 |
+
assembled_text = builder.build(selected_chunks)
|
217 |
+
|
218 |
+
# Шаг 7: Сохранение результата в файл
|
219 |
+
with open(output_text_path, "w", encoding="utf-8") as f:
|
220 |
+
f.write(assembled_text)
|
221 |
+
|
222 |
+
logging.info(f"Результат сборки сохранен в {output_text_path}")
|
223 |
+
|
224 |
+
# Вывод статистики
|
225 |
+
logging.info(f"Общее количество сущностей: {len(loaded_entities)}")
|
226 |
+
logging.info(f"Количество чанков: {len(chunks)}")
|
227 |
+
logging.info(f"Выбрано для сборки: {len(selected_chunks)}")
|
228 |
+
logging.info(f"Длина собранного текста: {len(assembled_text)} символов")
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
logging.error(f"П��оизошла ошибка: {e}", exc_info=True)
|
232 |
+
|
233 |
+
|
234 |
+
if __name__ == "__main__":
|
235 |
+
main()
|
lib/extractor/tests/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Пакет с тестами для ntr_text_fragmentation.
|
3 |
+
"""
|
lib/extractor/tests/chunking/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Тесты для компонентов чанкинга.
|
3 |
+
"""
|