Spaces:
Sleeping
Sleeping
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- common/configuration.py +18 -190
- common/dependencies.py +4 -3
- components/dbo/chunk_repository.py +130 -195
- components/embedding_extraction.py +26 -11
- components/nmd/faiss_vector_search.py +11 -14
- components/services/dataset.py +14 -26
- components/services/entity.py +95 -53
- config_dev.yaml +14 -55
- lib/extractor/ntr_text_fragmentation/__init__.py +13 -8
- lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py +3 -1
- lib/extractor/ntr_text_fragmentation/additors/tables/models/__init__.py +10 -0
- lib/extractor/ntr_text_fragmentation/additors/tables/models/subtable_entity.py +61 -0
- lib/extractor/ntr_text_fragmentation/additors/tables/models/table_entity.py +66 -0
- lib/extractor/ntr_text_fragmentation/additors/tables/models/table_row_entity.py +57 -0
- lib/extractor/ntr_text_fragmentation/additors/tables/table_processor.py +178 -0
- lib/extractor/ntr_text_fragmentation/additors/tables_processor.py +27 -87
- lib/extractor/ntr_text_fragmentation/chunking/__init__.py +13 -1
- lib/extractor/ntr_text_fragmentation/chunking/chunking_registry.py +40 -0
- lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py +77 -64
- lib/extractor/ntr_text_fragmentation/chunking/models/__init__.py +7 -0
- lib/extractor/ntr_text_fragmentation/chunking/models/chunk.py +11 -0
- lib/extractor/ntr_text_fragmentation/chunking/models/custom_chunk.py +12 -0
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py +5 -1
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py +52 -101
- lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py +233 -484
- lib/extractor/ntr_text_fragmentation/chunking/text_to_text_base.py +45 -0
- lib/extractor/ntr_text_fragmentation/core/__init__.py +5 -3
- lib/extractor/ntr_text_fragmentation/core/extractor.py +207 -0
- lib/extractor/ntr_text_fragmentation/core/injection_builder.py +129 -364
- lib/extractor/ntr_text_fragmentation/integrations/__init__.py +3 -2
- lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/__init__.py +6 -0
- lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/sqlalchemy_repository.py +455 -0
- lib/extractor/ntr_text_fragmentation/models/__init__.py +7 -6
- lib/extractor/ntr_text_fragmentation/models/document.py +28 -14
- lib/extractor/ntr_text_fragmentation/models/linker_entity.py +96 -64
- lib/extractor/ntr_text_fragmentation/repositories/__init__.py +8 -0
- lib/extractor/ntr_text_fragmentation/repositories/entity_repository.py +106 -0
- lib/extractor/ntr_text_fragmentation/repositories/in_memory_repository.py +337 -0
- lib/extractor/scripts/test_chunking.py +352 -0
- lib/extractor/tests/chunking/test_chunking_registry.py +122 -0
- lib/extractor/tests/chunking/test_fixed_size_chunking.py +346 -325
- lib/extractor/tests/conftest.py +0 -44
- lib/extractor/tests/core/test_extractor.py +265 -0
- lib/extractor/tests/core/test_in_memory_repository.py +433 -0
- lib/extractor/tests/core/test_injection_builder.py +412 -0
- lib/extractor/tests/custom_entity.py +0 -102
- lib/extractor/tests/models/test_linker_entity.py +251 -0
- routes/dataset.py +1 -2
- routes/entity.py +99 -118
- routes/llm.py +29 -20
common/configuration.py
CHANGED
@@ -1,221 +1,48 @@
|
|
1 |
"""This module includes classes to define configurations."""
|
2 |
|
3 |
-
from typing import Any, Dict,
|
4 |
|
5 |
from pyaml_env import parse_config
|
6 |
-
from pydantic import BaseModel
|
7 |
|
8 |
|
9 |
-
class
|
10 |
-
query: str
|
11 |
-
query_abbreviation: str
|
12 |
-
abbreviations_replaced: Optional[List] = None
|
13 |
-
userName: Optional[str] = None
|
14 |
-
|
15 |
-
|
16 |
-
class SemanticChunk(BaseModel):
|
17 |
-
index_answer: int
|
18 |
-
doc_name: str
|
19 |
-
title: str
|
20 |
-
text_answer: str
|
21 |
-
# doc_number: str # TODO Потом поменять название переменной на doc_id везде с чем это будет связанно
|
22 |
-
other_info: List
|
23 |
-
start_index_paragraph: int
|
24 |
-
|
25 |
-
|
26 |
-
class FilterChunks(BaseModel):
|
27 |
-
id: str
|
28 |
-
filename: str
|
29 |
-
title: str
|
30 |
-
chunks: List[SemanticChunk]
|
31 |
-
|
32 |
-
|
33 |
-
class BusinessProcess(BaseModel):
|
34 |
-
production_activities_section: Optional[str]
|
35 |
-
processes_name: Optional[str]
|
36 |
-
level_process: Optional[str]
|
37 |
-
|
38 |
-
|
39 |
-
class Lead(BaseModel):
|
40 |
-
person: Optional[str]
|
41 |
-
leads: Optional[str]
|
42 |
-
|
43 |
-
|
44 |
-
class Subordinate(BaseModel):
|
45 |
-
person_name: Optional[str]
|
46 |
-
position: Optional[str]
|
47 |
-
|
48 |
-
|
49 |
-
class OrganizationalStructure(BaseModel):
|
50 |
-
position: Optional[str] = None
|
51 |
-
leads: Optional[List[Lead]] = None
|
52 |
-
subordinates: Optional[Subordinate] = None
|
53 |
-
|
54 |
-
|
55 |
-
class RocksNN(BaseModel):
|
56 |
-
division: Optional[str]
|
57 |
-
company_name: Optional[str]
|
58 |
-
|
59 |
-
|
60 |
-
class RocksNNSearch(BaseModel):
|
61 |
-
division: Optional[str]
|
62 |
-
company_name: Optional[List]
|
63 |
-
|
64 |
-
|
65 |
-
class SegmentationSearch(BaseModel):
|
66 |
-
segmentation_model: Optional[str]
|
67 |
-
company_name: Optional[List]
|
68 |
-
|
69 |
-
|
70 |
-
class Group(BaseModel):
|
71 |
-
group_name: Optional[str]
|
72 |
-
position_in_group: Optional[str]
|
73 |
-
block: Optional[str]
|
74 |
-
|
75 |
-
|
76 |
-
class GroupComposition(BaseModel):
|
77 |
-
person_name: Optional[str]
|
78 |
-
position_in_group: Optional[str]
|
79 |
-
|
80 |
-
|
81 |
-
class SearchGroupComposition(BaseModel):
|
82 |
-
group_name: Optional[str]
|
83 |
-
group_composition: Optional[List[GroupComposition]]
|
84 |
-
|
85 |
-
|
86 |
-
class PeopleChunks(BaseModel):
|
87 |
-
business_processes: Optional[List[BusinessProcess]] = None
|
88 |
-
organizatinal_structure: Optional[List[OrganizationalStructure]] = None
|
89 |
-
business_curator: Optional[List[RocksNN]] = None
|
90 |
-
groups: Optional[List[Group]] = None
|
91 |
-
person_name: str
|
92 |
-
|
93 |
-
|
94 |
-
class SummaryChunks(BaseModel):
|
95 |
-
doc_chunks: Optional[List[FilterChunks]] = None
|
96 |
-
people_search: Optional[List[PeopleChunks]] = None
|
97 |
-
groups_search: Optional[SearchGroupComposition] = None
|
98 |
-
rocks_nn_search: Optional[RocksNNSearch] = None
|
99 |
-
segmentation_search: Optional[SegmentationSearch] = None
|
100 |
-
query_type: str = '[3]'
|
101 |
-
|
102 |
-
|
103 |
-
class ElasticConfiguration:
|
104 |
-
def __init__(self, config_data):
|
105 |
-
self.es_host = str(config_data['es_host'])
|
106 |
-
self.es_port = int(config_data['es_port'])
|
107 |
-
self.use_elastic = bool(config_data['use_elastic'])
|
108 |
-
self.people_path = str(config_data['people_path'])
|
109 |
-
|
110 |
-
|
111 |
-
class FaissDataConfiguration:
|
112 |
def __init__(self, config_data):
|
113 |
-
self.
|
114 |
-
self.
|
115 |
-
self.
|
116 |
-
|
117 |
-
|
118 |
-
class ChunksElasticSearchConfiguration:
|
119 |
-
def __init__(self, config_data):
|
120 |
-
self.use_chunks_search = bool(config_data['use_chunks_search'])
|
121 |
-
self.index_name = str(config_data['index_name'])
|
122 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
123 |
-
|
124 |
-
|
125 |
-
class PeopleSearchConfiguration:
|
126 |
-
def __init__(self, config_data):
|
127 |
-
self.use_people_search = bool(config_data['use_people_search'])
|
128 |
-
self.index_name = str(config_data['index_name'])
|
129 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
130 |
-
|
131 |
-
|
132 |
-
class VectorSearchConfiguration:
|
133 |
-
def __init__(self, config_data):
|
134 |
-
self.use_vector_search = bool(config_data['use_vector_search'])
|
135 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
136 |
-
|
137 |
-
|
138 |
-
class GroupsSearchConfiguration:
|
139 |
-
def __init__(self, config_data):
|
140 |
-
self.use_groups_search = bool(config_data['use_groups_search'])
|
141 |
-
self.index_name = str(config_data['index_name'])
|
142 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
143 |
-
|
144 |
-
|
145 |
-
class RocksNNSearchConfiguration:
|
146 |
-
def __init__(self, config_data):
|
147 |
-
self.use_rocks_nn_search = bool(config_data['use_rocks_nn_search'])
|
148 |
-
self.index_name = str(config_data['index_name'])
|
149 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
150 |
-
|
151 |
-
|
152 |
-
class AbbreviationSearchConfiguration:
|
153 |
-
def __init__(self, config_data):
|
154 |
-
self.use_abbreviation_search = bool(config_data['use_abbreviation_search'])
|
155 |
-
self.index_name = str(config_data['index_name'])
|
156 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
157 |
-
|
158 |
-
|
159 |
-
class SegmentationSearchConfiguration:
|
160 |
-
def __init__(self, config_data):
|
161 |
-
self.use_segmentation_search = bool(config_data['use_segmentation_search'])
|
162 |
-
self.index_name = str(config_data['index_name'])
|
163 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
164 |
|
165 |
|
166 |
class SearchConfiguration:
|
167 |
def __init__(self, config_data):
|
168 |
-
self.
|
169 |
-
self.
|
170 |
-
|
171 |
-
)
|
172 |
-
self.
|
173 |
-
config_data['chunks_elastic_search']
|
174 |
-
)
|
175 |
-
self.groups_elastic_search = GroupsSearchConfiguration(
|
176 |
-
config_data['groups_elastic_search']
|
177 |
-
)
|
178 |
-
self.rocks_nn_elastic_search = RocksNNSearchConfiguration(
|
179 |
-
config_data['rocks_nn_elastic_search']
|
180 |
-
)
|
181 |
-
self.segmentation_elastic_search = SegmentationSearchConfiguration(
|
182 |
-
config_data['segmentation_elastic_search']
|
183 |
-
)
|
184 |
-
self.stop_index_names = list(config_data['stop_index_names'])
|
185 |
-
self.abbreviation_search = AbbreviationSearchConfiguration(
|
186 |
-
config_data['abbreviation_search']
|
187 |
-
)
|
188 |
self.use_qe = bool(config_data['use_qe'])
|
189 |
|
190 |
|
191 |
class FilesConfiguration:
|
192 |
def __init__(self, config_data):
|
193 |
self.empty_start = bool(config_data['empty_start'])
|
194 |
-
self.regulations_path = str(config_data['regulations_path'])
|
195 |
-
self.default_regulations_path = str(config_data['default_regulations_path'])
|
196 |
self.documents_path = str(config_data['documents_path'])
|
197 |
|
198 |
|
199 |
-
class RankingConfiguration:
|
200 |
-
def __init__(self, config_data):
|
201 |
-
self.use_ranging = bool(config_data['use_ranging'])
|
202 |
-
self.alpha = float(config_data['alpha'])
|
203 |
-
self.beta = float(config_data['beta'])
|
204 |
-
self.k_neighbors = int(config_data['k_neighbors'])
|
205 |
-
|
206 |
-
|
207 |
class DataBaseConfiguration:
|
208 |
def __init__(self, config_data):
|
209 |
-
self.
|
210 |
-
self.faiss = FaissDataConfiguration(config_data['faiss'])
|
211 |
self.search = SearchConfiguration(config_data['search'])
|
212 |
self.files = FilesConfiguration(config_data['files'])
|
213 |
-
self.ranker = RankingConfiguration(config_data['ranging'])
|
214 |
|
215 |
|
216 |
class LLMConfiguration:
|
217 |
def __init__(self, config_data):
|
218 |
-
self.base_url =
|
|
|
|
|
|
|
|
|
219 |
self.api_key_env = (
|
220 |
str(config_data['api_key_env'])
|
221 |
if config_data['api_key_env'] not in ("", "null", "None")
|
@@ -235,6 +62,7 @@ class CommonConfiguration:
|
|
235 |
def __init__(self, config_data):
|
236 |
self.log_file_path = str(config_data['log_file_path'])
|
237 |
self.log_sql_path = str(config_data['log_sql_path'])
|
|
|
238 |
|
239 |
|
240 |
class Configuration:
|
|
|
1 |
"""This module includes classes to define configurations."""
|
2 |
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
|
5 |
from pyaml_env import parse_config
|
|
|
6 |
|
7 |
|
8 |
+
class EntitiesExtractorConfiguration:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def __init__(self, config_data):
|
10 |
+
self.strategy_name = str(config_data['strategy_name'])
|
11 |
+
self.strategy_params: dict = config_data['strategy_params']
|
12 |
+
self.process_tables = bool(config_data['process_tables'])
|
13 |
+
self.neighbors_max_distance = int(config_data['neighbors_max_distance'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
class SearchConfiguration:
|
17 |
def __init__(self, config_data):
|
18 |
+
self.use_vector_search = bool(config_data['use_vector_search'])
|
19 |
+
self.vectorizer_path = str(config_data['vectorizer_path'])
|
20 |
+
self.device = str(config_data['device'])
|
21 |
+
self.max_entities_per_message = int(config_data['max_entities_per_message'])
|
22 |
+
self.max_entities_per_dialogue = int(config_data['max_entities_per_dialogue'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
self.use_qe = bool(config_data['use_qe'])
|
24 |
|
25 |
|
26 |
class FilesConfiguration:
|
27 |
def __init__(self, config_data):
|
28 |
self.empty_start = bool(config_data['empty_start'])
|
|
|
|
|
29 |
self.documents_path = str(config_data['documents_path'])
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class DataBaseConfiguration:
|
33 |
def __init__(self, config_data):
|
34 |
+
self.entities = EntitiesExtractorConfiguration(config_data['entities'])
|
|
|
35 |
self.search = SearchConfiguration(config_data['search'])
|
36 |
self.files = FilesConfiguration(config_data['files'])
|
|
|
37 |
|
38 |
|
39 |
class LLMConfiguration:
|
40 |
def __init__(self, config_data):
|
41 |
+
self.base_url = (
|
42 |
+
str(config_data['base_url'])
|
43 |
+
if config_data['base_url'] not in ("", "null", "None")
|
44 |
+
else None
|
45 |
+
)
|
46 |
self.api_key_env = (
|
47 |
str(config_data['api_key_env'])
|
48 |
if config_data['api_key_env'] not in ("", "null", "None")
|
|
|
62 |
def __init__(self, config_data):
|
63 |
self.log_file_path = str(config_data['log_file_path'])
|
64 |
self.log_sql_path = str(config_data['log_sql_path'])
|
65 |
+
self.log_level = str(config_data['log_level'])
|
66 |
|
67 |
|
68 |
class Configuration:
|
common/dependencies.py
CHANGED
@@ -37,12 +37,13 @@ def get_embedding_extractor(
|
|
37 |
config: Annotated[Configuration, Depends(get_config)],
|
38 |
) -> EmbeddingExtractor:
|
39 |
return EmbeddingExtractor(
|
40 |
-
config.db_config.
|
41 |
-
config.db_config.
|
42 |
)
|
43 |
|
44 |
|
45 |
-
def get_chunk_repository(db: Annotated[
|
|
|
46 |
return ChunkRepository(db)
|
47 |
|
48 |
|
|
|
37 |
config: Annotated[Configuration, Depends(get_config)],
|
38 |
) -> EmbeddingExtractor:
|
39 |
return EmbeddingExtractor(
|
40 |
+
config.db_config.search.vectorizer_path,
|
41 |
+
config.db_config.search.device,
|
42 |
)
|
43 |
|
44 |
|
45 |
+
def get_chunk_repository(db: Annotated[sessionmaker, Depends(get_db)]) -> ChunkRepository:
|
46 |
+
"""Получение репозитория чанков через DI."""
|
47 |
return ChunkRepository(db)
|
48 |
|
49 |
|
components/dbo/chunk_repository.py
CHANGED
@@ -1,249 +1,184 @@
|
|
|
|
1 |
from uuid import UUID
|
2 |
|
3 |
import numpy as np
|
4 |
from ntr_text_fragmentation import LinkerEntity
|
5 |
-
from ntr_text_fragmentation.integrations import
|
6 |
-
|
7 |
-
from sqlalchemy
|
|
|
8 |
|
9 |
from components.dbo.models.entity import EntityModel
|
10 |
|
|
|
|
|
11 |
|
12 |
class ChunkRepository(SQLAlchemyEntityRepository):
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
16 |
def _entity_model_class(self):
|
|
|
17 |
return EntityModel
|
18 |
|
19 |
-
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel):
|
20 |
"""
|
21 |
-
Преобразует
|
22 |
-
|
|
|
23 |
Args:
|
24 |
-
db_entity: Сущность из базы
|
25 |
-
|
26 |
Returns:
|
27 |
-
LinkerEntity
|
28 |
"""
|
29 |
-
#
|
30 |
-
|
31 |
-
|
|
|
32 |
name=db_entity.name,
|
33 |
text=db_entity.text,
|
34 |
-
type=db_entity.entity_type,
|
35 |
in_search_text=db_entity.in_search_text,
|
36 |
-
metadata=db_entity.metadata_json,
|
37 |
-
source_id=UUID(db_entity.source_id) if db_entity.source_id else None,
|
38 |
-
target_id=UUID(db_entity.target_id) if db_entity.target_id else None,
|
39 |
number_in_relation=db_entity.number_in_relation,
|
|
|
|
|
40 |
)
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def add_entities(
|
44 |
self,
|
45 |
entities: list[LinkerEntity],
|
46 |
dataset_id: int,
|
47 |
-
embeddings: dict[str, np.ndarray],
|
48 |
):
|
49 |
"""
|
50 |
-
Добавляет
|
51 |
-
|
52 |
Args:
|
53 |
-
entities: Список сущностей для
|
54 |
-
dataset_id: ID
|
55 |
-
embeddings: Словарь эмбеддингов {
|
56 |
"""
|
|
|
57 |
with self.db() as session:
|
|
|
58 |
for entity in entities:
|
59 |
# Преобразуем UUID в строку для хранения в базе
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
number_in_relation=entity.number_in_relation,
|
78 |
-
chunk_index=getattr(entity, "chunk_index", None), # Добавляем chunk_index
|
79 |
-
dataset_id=dataset_id,
|
80 |
-
embedding=embedding,
|
81 |
-
)
|
82 |
)
|
|
|
83 |
|
|
|
84 |
session.commit()
|
85 |
|
86 |
def get_searching_entities(
|
87 |
self,
|
88 |
dataset_id: int,
|
89 |
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
|
90 |
-
with self.db() as session:
|
91 |
-
models = (
|
92 |
-
session.query(EntityModel)
|
93 |
-
.filter(EntityModel.in_search_text is not None)
|
94 |
-
.filter(EntityModel.dataset_id == dataset_id)
|
95 |
-
.all()
|
96 |
-
)
|
97 |
-
return (
|
98 |
-
[self._map_db_entity_to_linker_entity(model) for model in models],
|
99 |
-
[model.embedding for model in models],
|
100 |
-
)
|
101 |
-
|
102 |
-
def get_chunks_by_ids(
|
103 |
-
self,
|
104 |
-
chunk_ids: list[str],
|
105 |
-
) -> list[LinkerEntity]:
|
106 |
"""
|
107 |
-
|
108 |
-
|
|
|
109 |
Args:
|
110 |
-
|
111 |
-
|
112 |
Returns:
|
113 |
-
|
|
|
114 |
"""
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
118 |
with self.db() as session:
|
119 |
-
|
120 |
-
|
121 |
-
.
|
122 |
-
.
|
123 |
)
|
124 |
-
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
def get_neighboring_chunks(self, chunk_ids: list[UUID], max_distance: int = 1) -> list[LinkerEntity]:
|
151 |
"""
|
152 |
-
|
153 |
-
|
154 |
Args:
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
Returns:
|
159 |
-
|
160 |
"""
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
# Преобразуем UUID в строки
|
165 |
-
str_chunk_ids = [str(chunk_id) for chunk_id in chunk_ids]
|
166 |
-
|
167 |
with self.db() as session:
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
select(entity_model).where(
|
174 |
-
and_(
|
175 |
-
entity_model.uuid.in_(str_chunk_ids),
|
176 |
-
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
|
177 |
-
)
|
178 |
-
)
|
179 |
-
).scalars().all()
|
180 |
-
|
181 |
-
if not chunks:
|
182 |
-
return []
|
183 |
-
|
184 |
-
# Находим документы для чанков через связи
|
185 |
-
doc_ids = set()
|
186 |
-
chunk_indices = {}
|
187 |
-
|
188 |
-
for chunk in chunks:
|
189 |
-
chunk_indices[chunk.uuid] = chunk.chunk_index
|
190 |
-
|
191 |
-
# Находим связь от документа к чанку
|
192 |
-
links = session.execute(
|
193 |
-
select(entity_model).where(
|
194 |
-
and_(
|
195 |
-
entity_model.target_id == chunk.uuid,
|
196 |
-
entity_model.name == "document_to_chunk"
|
197 |
-
)
|
198 |
-
)
|
199 |
-
).scalars().all()
|
200 |
-
|
201 |
-
for link in links:
|
202 |
-
doc_ids.add(link.source_id)
|
203 |
-
|
204 |
-
if not doc_ids or not any(idx is not None for idx in chunk_indices.values()):
|
205 |
-
return []
|
206 |
-
|
207 |
-
# Для каждого документа находим все его чанки
|
208 |
-
for doc_id in doc_ids:
|
209 |
-
# Находим все связи от документа к чанкам
|
210 |
-
links = session.execute(
|
211 |
-
select(entity_model).where(
|
212 |
-
and_(
|
213 |
-
entity_model.source_id == doc_id,
|
214 |
-
entity_model.name == "document_to_chunk"
|
215 |
-
)
|
216 |
-
)
|
217 |
-
).scalars().all()
|
218 |
-
|
219 |
-
doc_chunk_ids = [link.target_id for link in links]
|
220 |
-
|
221 |
-
# Получаем все чанки документа
|
222 |
-
doc_chunks = session.execute(
|
223 |
-
select(entity_model).where(
|
224 |
-
and_(
|
225 |
-
entity_model.uuid.in_(doc_chunk_ids),
|
226 |
-
entity_model.entity_type == "Chunk" # Используем entity_type вместо type
|
227 |
-
)
|
228 |
-
)
|
229 |
-
).scalars().all()
|
230 |
-
|
231 |
-
# Для каждого чанка в документе проверяем, является ли он соседом
|
232 |
-
for doc_chunk in doc_chunks:
|
233 |
-
if doc_chunk.uuid in str_chunk_ids:
|
234 |
-
continue
|
235 |
-
|
236 |
-
if doc_chunk.chunk_index is None:
|
237 |
-
continue
|
238 |
-
|
239 |
-
# Проверяем, является ли чанк соседом какого-либо из исходных чанков
|
240 |
-
is_neighbor = False
|
241 |
-
for orig_chunk_id, orig_index in chunk_indices.items():
|
242 |
-
if orig_index is not None and abs(doc_chunk.chunk_index - orig_index) <= max_distance:
|
243 |
-
is_neighbor = True
|
244 |
-
break
|
245 |
-
|
246 |
-
if is_neighbor:
|
247 |
-
result.append(self._map_db_entity_to_linker_entity(doc_chunk))
|
248 |
-
|
249 |
-
return result
|
|
|
1 |
+
import logging
|
2 |
from uuid import UUID
|
3 |
|
4 |
import numpy as np
|
5 |
from ntr_text_fragmentation import LinkerEntity
|
6 |
+
from ntr_text_fragmentation.integrations.sqlalchemy import \
|
7 |
+
SQLAlchemyEntityRepository
|
8 |
+
from sqlalchemy import func, select
|
9 |
+
from sqlalchemy.orm import Session, sessionmaker
|
10 |
|
11 |
from components.dbo.models.entity import EntityModel
|
12 |
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
|
16 |
class ChunkRepository(SQLAlchemyEntityRepository):
|
17 |
+
"""
|
18 |
+
Репозиторий для работы с сущностями (чанками, документами, связями),
|
19 |
+
хранящимися в базе данных с использованием SQL Alchemy.
|
20 |
+
Наследуется от SQLAlchemyEntityRepository, предоставляя конкретную реализацию
|
21 |
+
для модели EntityModel.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, db_session_factory: sessionmaker[Session]):
|
25 |
+
"""
|
26 |
+
Инициализация репозитория.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
db_session_factory: Фабрика сессий SQLAlchemy.
|
30 |
+
"""
|
31 |
+
super().__init__(db_session_factory)
|
32 |
|
33 |
+
@property
|
34 |
def _entity_model_class(self):
|
35 |
+
"""Возвращает класс модели SQLAlchemy."""
|
36 |
return EntityModel
|
37 |
|
38 |
+
def _map_db_entity_to_linker_entity(self, db_entity: EntityModel) -> LinkerEntity:
|
39 |
"""
|
40 |
+
Преобразует объект EntityModel из базы данных в объект LinkerEntity
|
41 |
+
или его соответствующий подкласс.
|
42 |
+
|
43 |
Args:
|
44 |
+
db_entity: Сущность EntityModel из базы данных.
|
45 |
+
|
46 |
Returns:
|
47 |
+
Объект LinkerEntity или его подкласс.
|
48 |
"""
|
49 |
+
# Создаем базовый LinkerEntity со всеми данными из БД
|
50 |
+
# Преобразуем строковые UUID обратно в объекты UUID
|
51 |
+
base_data = LinkerEntity(
|
52 |
+
id=UUID(db_entity.uuid),
|
53 |
name=db_entity.name,
|
54 |
text=db_entity.text,
|
|
|
55 |
in_search_text=db_entity.in_search_text,
|
56 |
+
metadata=db_entity.metadata_json or {},
|
57 |
+
source_id=UUID(db_entity.source_id) if db_entity.source_id else None,
|
58 |
+
target_id=UUID(db_entity.target_id) if db_entity.target_id else None,
|
59 |
number_in_relation=db_entity.number_in_relation,
|
60 |
+
type=db_entity.entity_type,
|
61 |
+
groupper=db_entity.entity_type,
|
62 |
)
|
63 |
+
|
64 |
+
# Используем LinkerEntity._deserialize для получения объекта нужного типа
|
65 |
+
# на основе поля 'type', взятого из db_entity.entity_type
|
66 |
+
try:
|
67 |
+
deserialized_entity = base_data.deserialize()
|
68 |
+
return deserialized_entity
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(
|
71 |
+
f"Error deserializing entity {base_data.id} of type {base_data.type}: {e}"
|
72 |
+
)
|
73 |
+
return base_data
|
74 |
|
75 |
def add_entities(
|
76 |
self,
|
77 |
entities: list[LinkerEntity],
|
78 |
dataset_id: int,
|
79 |
+
embeddings: dict[str, np.ndarray] | None = None,
|
80 |
):
|
81 |
"""
|
82 |
+
Добавляет список сущностей LinkerEntity в базу данных.
|
83 |
+
|
84 |
Args:
|
85 |
+
entities: Список сущностей LinkerEntity для добавления.
|
86 |
+
dataset_id: ID датасета, к которому принадлежат сущности.
|
87 |
+
embeddings: Словарь эмбеддингов {entity_id_str: embedding}, где entity_id_str - строка UUID.
|
88 |
"""
|
89 |
+
embeddings = embeddings or {}
|
90 |
with self.db() as session:
|
91 |
+
db_entities_to_add = []
|
92 |
for entity in entities:
|
93 |
# Преобразуем UUID в строку для хранения в базе
|
94 |
+
entity_id_str = str(entity.id)
|
95 |
+
embedding = embeddings.get(entity_id_str)
|
96 |
+
|
97 |
+
db_entity = EntityModel(
|
98 |
+
uuid=entity_id_str,
|
99 |
+
name=entity.name,
|
100 |
+
text=entity.text,
|
101 |
+
entity_type=entity.type,
|
102 |
+
in_search_text=entity.in_search_text,
|
103 |
+
metadata_json=(
|
104 |
+
entity.metadata if isinstance(entity.metadata, dict) else {}
|
105 |
+
),
|
106 |
+
source_id=str(entity.source_id) if entity.source_id else None,
|
107 |
+
target_id=str(entity.target_id) if entity.target_id else None,
|
108 |
+
number_in_relation=entity.number_in_relation,
|
109 |
+
dataset_id=dataset_id,
|
110 |
+
embedding=embedding,
|
|
|
|
|
|
|
|
|
|
|
111 |
)
|
112 |
+
db_entities_to_add.append(db_entity)
|
113 |
|
114 |
+
session.add_all(db_entities_to_add)
|
115 |
session.commit()
|
116 |
|
117 |
def get_searching_entities(
|
118 |
self,
|
119 |
dataset_id: int,
|
120 |
) -> tuple[list[LinkerEntity], list[np.ndarray]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
"""
|
122 |
+
Получает сущности из указанного датасета, которые имеют текст для поиска
|
123 |
+
(in_search_text не None), вместе с их эмбеддингами.
|
124 |
+
|
125 |
Args:
|
126 |
+
dataset_id: ID датасета.
|
127 |
+
|
128 |
Returns:
|
129 |
+
Кортеж из двух списков: список LinkerEntity и список их эмбеддингов (numpy array).
|
130 |
+
Порядок эмбеддингов соответствует порядку сущностей.
|
131 |
"""
|
132 |
+
entity_model = self._entity_model_class
|
133 |
+
linker_entities = []
|
134 |
+
embeddings_list = []
|
135 |
+
|
136 |
with self.db() as session:
|
137 |
+
stmt = select(entity_model).where(
|
138 |
+
entity_model.in_search_text.isnot(None),
|
139 |
+
entity_model.dataset_id == dataset_id,
|
140 |
+
entity_model.embedding.isnot(None)
|
141 |
)
|
142 |
+
db_models = session.execute(stmt).scalars().all()
|
143 |
|
144 |
+
# Переносим цикл внутрь сессии
|
145 |
+
for model in db_models:
|
146 |
+
# Теперь маппинг происходит при активной сессии
|
147 |
+
linker_entity = self._map_db_entity_to_linker_entity(model)
|
148 |
+
linker_entities.append(linker_entity)
|
149 |
+
|
150 |
+
# Извлекаем эмбеддинг.
|
151 |
+
# _map_db_entity_to_linker_entity может поместить его в метаданные.
|
152 |
+
embedding = linker_entity.metadata.get('_embedding')
|
153 |
+
if embedding is None and hasattr(model, 'embedding'): # Fallback
|
154 |
+
embedding = model.embedding # Доступ к model.embedding тоже должен быть внутри сессии
|
155 |
+
|
156 |
+
if embedding is not None:
|
157 |
+
embeddings_list.append(embedding)
|
158 |
+
else:
|
159 |
+
# Обработка случая отсутствия эмбеддинга
|
160 |
+
print(f"Warning: Entity {model.uuid} has in_search_text but no embedding.")
|
161 |
+
linker_entities.pop()
|
162 |
+
|
163 |
+
# Возвращаем результаты после закрытия сессии
|
164 |
+
return linker_entities, embeddings_list
|
165 |
+
|
166 |
+
def count_entities_by_dataset_id(self, dataset_id: int) -> int:
|
|
|
|
|
167 |
"""
|
168 |
+
Подсчитывает общее количество сущностей для указанного датасета.
|
169 |
+
|
170 |
Args:
|
171 |
+
dataset_id: ID датасета.
|
172 |
+
|
|
|
173 |
Returns:
|
174 |
+
Общее количество сущностей в датасете.
|
175 |
"""
|
176 |
+
entity_model = self._entity_model_class
|
177 |
+
id_column = self._get_id_column() # Получаем колонку ID (uuid или id)
|
178 |
+
|
|
|
|
|
|
|
179 |
with self.db() as session:
|
180 |
+
stmt = select(func.count(id_column)).where(
|
181 |
+
entity_model.dataset_id == dataset_id
|
182 |
+
)
|
183 |
+
count = session.execute(stmt).scalar_one()
|
184 |
+
return count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
components/embedding_extraction.py
CHANGED
@@ -5,15 +5,23 @@ import numpy as np
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import DataLoader
|
8 |
-
from transformers import (
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
from common.decorators import singleton
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
|
|
17 |
@singleton
|
18 |
class EmbeddingExtractor:
|
19 |
"""Класс обрабатывает текст вопроса и возвращает embedding"""
|
@@ -26,7 +34,7 @@ class EmbeddingExtractor:
|
|
26 |
do_normalization: bool = True,
|
27 |
max_len: int = 510,
|
28 |
model: XLMRobertaModel = None,
|
29 |
-
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None
|
30 |
):
|
31 |
"""
|
32 |
Класс, соединяющий в себе модель, токенизатор и параметры векторизации.
|
@@ -46,18 +54,25 @@ class EmbeddingExtractor:
|
|
46 |
device = torch.device(device)
|
47 |
|
48 |
self.device = device
|
49 |
-
|
50 |
# Инициализация модели
|
51 |
if model is not None and tokenizer is not None:
|
52 |
self.tokenizer = tokenizer
|
53 |
self.model = model
|
54 |
elif model_id is not None:
|
55 |
-
print(
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
)
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
print('EmbeddingExtractor: model loaded')
|
62 |
self.model.eval()
|
63 |
self.model.share_memory()
|
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from torch.utils.data import DataLoader
|
8 |
+
from transformers import (
|
9 |
+
AutoModel,
|
10 |
+
AutoTokenizer,
|
11 |
+
BatchEncoding,
|
12 |
+
XLMRobertaModel,
|
13 |
+
PreTrainedTokenizer,
|
14 |
+
PreTrainedTokenizerFast,
|
15 |
+
)
|
16 |
+
from transformers.modeling_outputs import (
|
17 |
+
BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput,
|
18 |
+
)
|
19 |
|
20 |
from common.decorators import singleton
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
+
|
25 |
@singleton
|
26 |
class EmbeddingExtractor:
|
27 |
"""Класс обрабатывает текст вопроса и возвращает embedding"""
|
|
|
34 |
do_normalization: bool = True,
|
35 |
max_len: int = 510,
|
36 |
model: XLMRobertaModel = None,
|
37 |
+
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None,
|
38 |
):
|
39 |
"""
|
40 |
Класс, соединяющий в себе модель, токенизатор и параметры векторизации.
|
|
|
54 |
device = torch.device(device)
|
55 |
|
56 |
self.device = device
|
57 |
+
|
58 |
# Инициализация модели
|
59 |
if model is not None and tokenizer is not None:
|
60 |
self.tokenizer = tokenizer
|
61 |
self.model = model
|
62 |
elif model_id is not None:
|
63 |
+
print(
|
64 |
+
'EmbeddingExtractor: model loading '
|
65 |
+
+ model_id
|
66 |
+
+ ' to '
|
67 |
+
+ str(self.device)
|
68 |
)
|
69 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
70 |
+
model_id, local_files_only=True
|
71 |
+
)
|
72 |
+
self.model: XLMRobertaModel = AutoModel.from_pretrained(
|
73 |
+
model_id, local_files_only=True
|
74 |
+
).to(self.device)
|
75 |
+
|
76 |
print('EmbeddingExtractor: model loaded')
|
77 |
self.model.eval()
|
78 |
self.model.share_memory()
|
components/nmd/faiss_vector_search.py
CHANGED
@@ -3,7 +3,6 @@ import logging
|
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
|
6 |
-
from common.configuration import DataBaseConfiguration
|
7 |
from common.constants import DO_NORMALIZATION
|
8 |
from components.embedding_extraction import EmbeddingExtractor
|
9 |
|
@@ -12,23 +11,16 @@ logger = logging.getLogger(__name__)
|
|
12 |
|
13 |
class FaissVectorSearch:
|
14 |
def __init__(
|
15 |
-
self,
|
16 |
-
model: EmbeddingExtractor,
|
17 |
ids_to_embeddings: dict[str, np.ndarray],
|
18 |
-
config: DataBaseConfiguration,
|
19 |
):
|
20 |
self.model = model
|
21 |
-
self.config = config
|
22 |
-
self.path_to_metadata = config.faiss.path_to_metadata
|
23 |
-
if self.config.ranker.use_ranging:
|
24 |
-
self.k_neighbors = config.ranker.k_neighbors
|
25 |
-
else:
|
26 |
-
self.k_neighbors = config.search.vector_search.k_neighbors
|
27 |
self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
|
28 |
self.__create_index(ids_to_embeddings)
|
29 |
|
30 |
def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
|
31 |
-
"""
|
32 |
if len(ids_to_embeddings) == 0:
|
33 |
self.index = None
|
34 |
return
|
@@ -37,12 +29,17 @@ class FaissVectorSearch:
|
|
37 |
self.index = faiss.IndexFlatIP(dim)
|
38 |
self.index.add(embeddings)
|
39 |
|
40 |
-
def search_vectors(
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
Поиск векторов в индексе.
|
43 |
-
|
44 |
Args:
|
45 |
query: Строка, запрос для поиска.
|
|
|
46 |
|
47 |
Returns:
|
48 |
tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
|
@@ -54,6 +51,6 @@ class FaissVectorSearch:
|
|
54 |
if self.index is None:
|
55 |
return (np.array([]), np.array([]), np.array([]))
|
56 |
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
|
57 |
-
similarities, indexes = self.index.search(query_embeds,
|
58 |
ids = [self.index_to_id[index] for index in indexes[0]]
|
59 |
return query_embeds, similarities[0], np.array(ids)
|
|
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
|
|
|
6 |
from common.constants import DO_NORMALIZATION
|
7 |
from components.embedding_extraction import EmbeddingExtractor
|
8 |
|
|
|
11 |
|
12 |
class FaissVectorSearch:
|
13 |
def __init__(
|
14 |
+
self,
|
15 |
+
model: EmbeddingExtractor,
|
16 |
ids_to_embeddings: dict[str, np.ndarray],
|
|
|
17 |
):
|
18 |
self.model = model
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
|
20 |
self.__create_index(ids_to_embeddings)
|
21 |
|
22 |
def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
|
23 |
+
"""Создает индекс для векторного поиска."""
|
24 |
if len(ids_to_embeddings) == 0:
|
25 |
self.index = None
|
26 |
return
|
|
|
29 |
self.index = faiss.IndexFlatIP(dim)
|
30 |
self.index.add(embeddings)
|
31 |
|
32 |
+
def search_vectors(
|
33 |
+
self,
|
34 |
+
query: str,
|
35 |
+
max_entities: int = 100,
|
36 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
37 |
"""
|
38 |
Поиск векторов в индексе.
|
39 |
+
|
40 |
Args:
|
41 |
query: Строка, запрос для поиска.
|
42 |
+
max_entities: Максимальное количество найденных сущностей.
|
43 |
|
44 |
Returns:
|
45 |
tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
|
|
|
51 |
if self.index is None:
|
52 |
return (np.array([]), np.array([]), np.array([]))
|
53 |
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
|
54 |
+
similarities, indexes = self.index.search(query_embeds, max_entities)
|
55 |
ids = [self.index_to_id[index] for index in indexes[0]]
|
56 |
return query_embeds, similarities[0], np.array(ids)
|
components/services/dataset.py
CHANGED
@@ -6,9 +6,9 @@ import zipfile
|
|
6 |
from datetime import datetime
|
7 |
from pathlib import Path
|
8 |
|
9 |
-
import pandas as pd
|
10 |
import torch
|
11 |
from fastapi import BackgroundTasks, HTTPException, UploadFile
|
|
|
12 |
from ntr_fileparser import ParsedDocument, UniversalParser
|
13 |
from sqlalchemy.orm import Session
|
14 |
|
@@ -34,9 +34,9 @@ class DatasetService:
|
|
34 |
"""
|
35 |
|
36 |
def __init__(
|
37 |
-
self,
|
38 |
entity_service: EntityService,
|
39 |
-
config: Configuration,
|
40 |
db: Session,
|
41 |
) -> None:
|
42 |
"""
|
@@ -52,7 +52,6 @@ class DatasetService:
|
|
52 |
self.config = config
|
53 |
self.parser = UniversalParser()
|
54 |
self.entity_service = entity_service
|
55 |
-
self.regulations_path = Path(config.db_config.files.regulations_path)
|
56 |
self.documents_path = Path(config.db_config.files.documents_path)
|
57 |
self.tmp_path = Path(os.environ.get("APP_TMP_PATH", '.'))
|
58 |
logger.info("DatasetService initialized")
|
@@ -214,7 +213,8 @@ class DatasetService:
|
|
214 |
raise HTTPException(
|
215 |
status_code=403, detail='Active dataset cannot be deleted'
|
216 |
)
|
217 |
-
|
|
|
218 |
session.delete(dataset)
|
219 |
session.commit()
|
220 |
|
@@ -222,6 +222,7 @@ class DatasetService:
|
|
222 |
"""
|
223 |
Метод для выполнения в отдельном процессе.
|
224 |
"""
|
|
|
225 |
try:
|
226 |
with self.db() as session:
|
227 |
dataset = (
|
@@ -244,6 +245,7 @@ class DatasetService:
|
|
244 |
active_dataset.is_active = False
|
245 |
|
246 |
session.commit()
|
|
|
247 |
except Exception as e:
|
248 |
logger.error(f"Error applying draft: {e}")
|
249 |
raise
|
@@ -326,7 +328,7 @@ class DatasetService:
|
|
326 |
logger.info(f"Uploading ZIP file {file.filename}")
|
327 |
self.raise_if_processing()
|
328 |
|
329 |
-
file_location = Path(self.tmp_path / 'tmp
|
330 |
logger.debug(f"Saving uploaded file to {file_location}")
|
331 |
file_location.parent.mkdir(parents=True, exist_ok=True)
|
332 |
with open(file_location, 'wb') as f:
|
@@ -338,7 +340,6 @@ class DatasetService:
|
|
338 |
dataset = self.create_dataset_from_directory(
|
339 |
is_default=False,
|
340 |
directory_with_documents=file_location.parent,
|
341 |
-
directory_with_ready_dataset=None,
|
342 |
)
|
343 |
|
344 |
file_location.unlink()
|
@@ -386,8 +387,10 @@ class DatasetService:
|
|
386 |
|
387 |
TMP_PATH.touch()
|
388 |
|
389 |
-
documents: list[Document] = [
|
390 |
-
|
|
|
|
|
391 |
for document in documents:
|
392 |
path = self.documents_path / f'{document.id}.DOCX'
|
393 |
parsed = self.parser.parse_by_path(str(path))
|
@@ -396,16 +399,12 @@ class DatasetService:
|
|
396 |
logger.warning(f"Failed to parse document {document.id}")
|
397 |
continue
|
398 |
|
399 |
-
# Используем EntityService для обработки документа с callback
|
400 |
self.entity_service.process_document(
|
401 |
parsed,
|
402 |
dataset.id,
|
403 |
progress_callback=progress_callback,
|
404 |
-
words_per_chunk=50,
|
405 |
-
overlap_words=25,
|
406 |
-
respect_sentence_boundaries=True,
|
407 |
)
|
408 |
-
|
409 |
TMP_PATH.unlink()
|
410 |
|
411 |
def raise_if_processing(self) -> None:
|
@@ -422,7 +421,6 @@ class DatasetService:
|
|
422 |
self,
|
423 |
is_default: bool,
|
424 |
directory_with_documents: Path,
|
425 |
-
directory_with_ready_dataset: Path | None = None,
|
426 |
) -> Dataset:
|
427 |
"""
|
428 |
Создать датасет из директории с xml-документами.
|
@@ -446,7 +444,7 @@ class DatasetService:
|
|
446 |
|
447 |
dataset = Dataset(
|
448 |
name=name,
|
449 |
-
is_draft=True
|
450 |
is_active=True if is_default else False,
|
451 |
)
|
452 |
session.add(dataset)
|
@@ -465,16 +463,6 @@ class DatasetService:
|
|
465 |
|
466 |
session.flush()
|
467 |
|
468 |
-
if directory_with_ready_dataset is not None:
|
469 |
-
shutil.move(
|
470 |
-
directory_with_ready_dataset,
|
471 |
-
self.regulations_path / str(dataset.id),
|
472 |
-
)
|
473 |
-
|
474 |
-
logger.info(
|
475 |
-
f"Moved ready dataset to {self.regulations_path / str(dataset.id)}"
|
476 |
-
)
|
477 |
-
|
478 |
self.documents_path.mkdir(parents=True, exist_ok=True)
|
479 |
|
480 |
for document in documents:
|
|
|
6 |
from datetime import datetime
|
7 |
from pathlib import Path
|
8 |
|
|
|
9 |
import torch
|
10 |
from fastapi import BackgroundTasks, HTTPException, UploadFile
|
11 |
+
from components.dbo.models.entity import EntityModel
|
12 |
from ntr_fileparser import ParsedDocument, UniversalParser
|
13 |
from sqlalchemy.orm import Session
|
14 |
|
|
|
34 |
"""
|
35 |
|
36 |
def __init__(
|
37 |
+
self,
|
38 |
entity_service: EntityService,
|
39 |
+
config: Configuration,
|
40 |
db: Session,
|
41 |
) -> None:
|
42 |
"""
|
|
|
52 |
self.config = config
|
53 |
self.parser = UniversalParser()
|
54 |
self.entity_service = entity_service
|
|
|
55 |
self.documents_path = Path(config.db_config.files.documents_path)
|
56 |
self.tmp_path = Path(os.environ.get("APP_TMP_PATH", '.'))
|
57 |
logger.info("DatasetService initialized")
|
|
|
213 |
raise HTTPException(
|
214 |
status_code=403, detail='Active dataset cannot be deleted'
|
215 |
)
|
216 |
+
|
217 |
+
session.query(EntityModel).filter(EntityModel.dataset_id == dataset_id).delete()
|
218 |
session.delete(dataset)
|
219 |
session.commit()
|
220 |
|
|
|
222 |
"""
|
223 |
Метод для выполнения в отдельном процессе.
|
224 |
"""
|
225 |
+
logger.info(f"apply_draft_task started")
|
226 |
try:
|
227 |
with self.db() as session:
|
228 |
dataset = (
|
|
|
245 |
active_dataset.is_active = False
|
246 |
|
247 |
session.commit()
|
248 |
+
logger.info(f"apply_draft_task finished")
|
249 |
except Exception as e:
|
250 |
logger.error(f"Error applying draft: {e}")
|
251 |
raise
|
|
|
328 |
logger.info(f"Uploading ZIP file {file.filename}")
|
329 |
self.raise_if_processing()
|
330 |
|
331 |
+
file_location = Path(self.tmp_path / 'tmp' / 'tmp.zip')
|
332 |
logger.debug(f"Saving uploaded file to {file_location}")
|
333 |
file_location.parent.mkdir(parents=True, exist_ok=True)
|
334 |
with open(file_location, 'wb') as f:
|
|
|
340 |
dataset = self.create_dataset_from_directory(
|
341 |
is_default=False,
|
342 |
directory_with_documents=file_location.parent,
|
|
|
343 |
)
|
344 |
|
345 |
file_location.unlink()
|
|
|
387 |
|
388 |
TMP_PATH.touch()
|
389 |
|
390 |
+
documents: list[Document] = [
|
391 |
+
doc_dataset_link.document for doc_dataset_link in dataset.documents
|
392 |
+
]
|
393 |
+
|
394 |
for document in documents:
|
395 |
path = self.documents_path / f'{document.id}.DOCX'
|
396 |
parsed = self.parser.parse_by_path(str(path))
|
|
|
399 |
logger.warning(f"Failed to parse document {document.id}")
|
400 |
continue
|
401 |
|
|
|
402 |
self.entity_service.process_document(
|
403 |
parsed,
|
404 |
dataset.id,
|
405 |
progress_callback=progress_callback,
|
|
|
|
|
|
|
406 |
)
|
407 |
+
|
408 |
TMP_PATH.unlink()
|
409 |
|
410 |
def raise_if_processing(self) -> None:
|
|
|
421 |
self,
|
422 |
is_default: bool,
|
423 |
directory_with_documents: Path,
|
|
|
424 |
) -> Dataset:
|
425 |
"""
|
426 |
Создать датасет из директории с xml-документами.
|
|
|
444 |
|
445 |
dataset = Dataset(
|
446 |
name=name,
|
447 |
+
is_draft=True,
|
448 |
is_active=True if is_default else False,
|
449 |
)
|
450 |
session.add(dataset)
|
|
|
463 |
|
464 |
session.flush()
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
self.documents_path.mkdir(parents=True, exist_ok=True)
|
467 |
|
468 |
for document in documents:
|
components/services/entity.py
CHANGED
@@ -2,9 +2,10 @@ import logging
|
|
2 |
from typing import Callable, Optional
|
3 |
from uuid import UUID
|
4 |
|
5 |
-
import numpy as np
|
6 |
from ntr_fileparser import ParsedDocument
|
7 |
-
from ntr_text_fragmentation import
|
|
|
|
|
8 |
|
9 |
from common.configuration import Configuration
|
10 |
from components.dbo.chunk_repository import ChunkRepository
|
@@ -39,6 +40,16 @@ class EntityService:
|
|
39 |
self.chunk_repository = chunk_repository
|
40 |
self.faiss_search = None # Инициализируется при необходимости
|
41 |
self.current_dataset_id = None # Текущий dataset_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def _ensure_faiss_initialized(self, dataset_id: int) -> None:
|
44 |
"""
|
@@ -50,7 +61,9 @@ class EntityService:
|
|
50 |
# Если индекс не инициализирован или датасет изменился
|
51 |
if self.faiss_search is None or self.current_dataset_id != dataset_id:
|
52 |
logger.info(f'Initializing FAISS for dataset {dataset_id}')
|
53 |
-
entities, embeddings = self.chunk_repository.get_searching_entities(
|
|
|
|
|
54 |
if entities:
|
55 |
# Создаем словарь только из не-None эмбеддингов
|
56 |
embeddings_dict = {
|
@@ -62,12 +75,15 @@ class EntityService:
|
|
62 |
self.faiss_search = FaissVectorSearch(
|
63 |
self.vectorizer,
|
64 |
embeddings_dict,
|
65 |
-
self.config.db_config,
|
66 |
)
|
67 |
self.current_dataset_id = dataset_id
|
68 |
-
logger.info(
|
|
|
|
|
69 |
else:
|
70 |
-
logger.warning(
|
|
|
|
|
71 |
self.faiss_search = None
|
72 |
self.current_dataset_id = None
|
73 |
else:
|
@@ -80,7 +96,6 @@ class EntityService:
|
|
80 |
document: ParsedDocument,
|
81 |
dataset_id: int,
|
82 |
progress_callback: Optional[Callable] = None,
|
83 |
-
**destructurer_kwargs,
|
84 |
) -> None:
|
85 |
"""
|
86 |
Обработка документа: разбиение на чанки и сохранение в базу.
|
@@ -89,49 +104,33 @@ class EntityService:
|
|
89 |
document: Документ для обработки
|
90 |
dataset_id: ID датасета
|
91 |
progress_callback: Функция для отслеживания прогресса
|
92 |
-
**destructurer_kwargs: Дополнительные параметры для Destructurer
|
93 |
"""
|
94 |
logger.info(f"Processing document {document.name} for dataset {dataset_id}")
|
95 |
-
|
96 |
-
# Создаем деструктуризатор с параметрами по умолчанию
|
97 |
-
destructurer = Destructurer(
|
98 |
-
document,
|
99 |
-
strategy_name="fixed_size",
|
100 |
-
process_tables=True,
|
101 |
-
**{
|
102 |
-
"words_per_chunk": 50,
|
103 |
-
"overlap_words": 25,
|
104 |
-
"respect_sentence_boundaries": True,
|
105 |
-
**destructurer_kwargs,
|
106 |
-
}
|
107 |
-
)
|
108 |
-
|
109 |
# Получаем сущности
|
110 |
-
entities =
|
111 |
-
|
112 |
# Фильтруем сущности для поиска
|
113 |
-
filtering_entities = [
|
|
|
|
|
114 |
filtering_texts = [entity.in_search_text for entity in filtering_entities]
|
115 |
-
|
116 |
# Получаем эмбеддинги с поддержкой callback
|
117 |
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback)
|
118 |
embeddings_dict = {
|
119 |
str(entity.id): embedding # Прео��разуем UUID в строку для ключа
|
120 |
for entity, embedding in zip(filtering_entities, embeddings)
|
121 |
}
|
122 |
-
|
123 |
# Сохраняем в базу
|
124 |
self.chunk_repository.add_entities(entities, dataset_id, embeddings_dict)
|
125 |
-
|
126 |
-
# Переинициализируем FAISS индекс, если это текущий датасет
|
127 |
-
if self.current_dataset_id == dataset_id:
|
128 |
-
self._ensure_faiss_initialized(dataset_id)
|
129 |
-
|
130 |
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}")
|
131 |
|
132 |
def build_text(
|
133 |
self,
|
134 |
-
entities: list[
|
135 |
chunk_scores: Optional[list[float]] = None,
|
136 |
include_tables: bool = True,
|
137 |
max_documents: Optional[int] = None,
|
@@ -140,7 +139,7 @@ class EntityService:
|
|
140 |
Сборка текста из сущностей.
|
141 |
|
142 |
Args:
|
143 |
-
entities: Список сущностей
|
144 |
chunk_scores: Список весов чанков
|
145 |
include_tables: Флаг включения таблиц
|
146 |
max_documents: Максимальное количество документов
|
@@ -148,18 +147,23 @@ class EntityService:
|
|
148 |
Returns:
|
149 |
Собранный текст
|
150 |
"""
|
|
|
|
|
151 |
logger.info(f"Building text for {len(entities)} entities")
|
152 |
if chunk_scores is not None:
|
153 |
-
chunk_scores = {
|
|
|
|
|
154 |
builder = InjectionBuilder(self.chunk_repository)
|
155 |
return builder.build(
|
156 |
-
|
157 |
-
|
158 |
include_tables=include_tables,
|
|
|
159 |
max_documents=max_documents,
|
160 |
)
|
161 |
|
162 |
-
def
|
163 |
self,
|
164 |
query: str,
|
165 |
dataset_id: int,
|
@@ -185,26 +189,64 @@ class EntityService:
|
|
185 |
|
186 |
# Выполняем поиск
|
187 |
return self.faiss_search.search_vectors(query)
|
188 |
-
|
189 |
-
def
|
190 |
self,
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
194 |
"""
|
195 |
-
|
196 |
|
197 |
Args:
|
198 |
-
|
199 |
-
|
|
|
200 |
|
201 |
Returns:
|
202 |
-
|
|
|
|
|
|
|
203 |
"""
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from typing import Callable, Optional
|
3 |
from uuid import UUID
|
4 |
|
|
|
5 |
from ntr_fileparser import ParsedDocument
|
6 |
+
from ntr_text_fragmentation import (EntitiesExtractor, InjectionBuilder,
|
7 |
+
LinkerEntity)
|
8 |
+
import numpy as np
|
9 |
|
10 |
from common.configuration import Configuration
|
11 |
from components.dbo.chunk_repository import ChunkRepository
|
|
|
40 |
self.chunk_repository = chunk_repository
|
41 |
self.faiss_search = None # Инициализируется при необходимости
|
42 |
self.current_dataset_id = None # Текущий dataset_id
|
43 |
+
|
44 |
+
self.neighbors_max_distance = config.db_config.entities.neighbors_max_distance
|
45 |
+
self.max_entities_per_message = config.db_config.search.max_entities_per_message
|
46 |
+
self.max_entities_per_dialogue = config.db_config.search.max_entities_per_dialogue
|
47 |
+
|
48 |
+
self.entities_extractor = EntitiesExtractor(
|
49 |
+
strategy_name=config.db_config.entities.strategy_name,
|
50 |
+
strategy_params=config.db_config.entities.strategy_params,
|
51 |
+
process_tables=config.db_config.entities.process_tables,
|
52 |
+
)
|
53 |
|
54 |
def _ensure_faiss_initialized(self, dataset_id: int) -> None:
|
55 |
"""
|
|
|
61 |
# Если индекс не инициализирован или датасет изменился
|
62 |
if self.faiss_search is None or self.current_dataset_id != dataset_id:
|
63 |
logger.info(f'Initializing FAISS for dataset {dataset_id}')
|
64 |
+
entities, embeddings = self.chunk_repository.get_searching_entities(
|
65 |
+
dataset_id
|
66 |
+
)
|
67 |
if entities:
|
68 |
# Создаем словарь только из не-None эмбеддингов
|
69 |
embeddings_dict = {
|
|
|
75 |
self.faiss_search = FaissVectorSearch(
|
76 |
self.vectorizer,
|
77 |
embeddings_dict,
|
|
|
78 |
)
|
79 |
self.current_dataset_id = dataset_id
|
80 |
+
logger.info(
|
81 |
+
f'FAISS initialized for dataset {dataset_id} with {len(embeddings_dict)} embeddings'
|
82 |
+
)
|
83 |
else:
|
84 |
+
logger.warning(
|
85 |
+
f'No valid embeddings found for dataset {dataset_id}'
|
86 |
+
)
|
87 |
self.faiss_search = None
|
88 |
self.current_dataset_id = None
|
89 |
else:
|
|
|
96 |
document: ParsedDocument,
|
97 |
dataset_id: int,
|
98 |
progress_callback: Optional[Callable] = None,
|
|
|
99 |
) -> None:
|
100 |
"""
|
101 |
Обработка документа: разбиение на чанки и сохранение в базу.
|
|
|
104 |
document: Документ для обработки
|
105 |
dataset_id: ID датасета
|
106 |
progress_callback: Функция для отслеживания прогресса
|
|
|
107 |
"""
|
108 |
logger.info(f"Processing document {document.name} for dataset {dataset_id}")
|
109 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# Получаем сущности
|
111 |
+
entities = self.entities_extractor.extract(document)
|
112 |
+
|
113 |
# Фильтруем сущности для поиска
|
114 |
+
filtering_entities = [
|
115 |
+
entity for entity in entities if entity.in_search_text is not None
|
116 |
+
]
|
117 |
filtering_texts = [entity.in_search_text for entity in filtering_entities]
|
118 |
+
|
119 |
# Получаем эмбеддинги с поддержкой callback
|
120 |
embeddings = self.vectorizer.vectorize(filtering_texts, progress_callback)
|
121 |
embeddings_dict = {
|
122 |
str(entity.id): embedding # Прео��разуем UUID в строку для ключа
|
123 |
for entity, embedding in zip(filtering_entities, embeddings)
|
124 |
}
|
125 |
+
|
126 |
# Сохраняем в базу
|
127 |
self.chunk_repository.add_entities(entities, dataset_id, embeddings_dict)
|
128 |
+
|
|
|
|
|
|
|
|
|
129 |
logger.info(f"Added {len(entities)} entities to dataset {dataset_id}")
|
130 |
|
131 |
def build_text(
|
132 |
self,
|
133 |
+
entities: list[str],
|
134 |
chunk_scores: Optional[list[float]] = None,
|
135 |
include_tables: bool = True,
|
136 |
max_documents: Optional[int] = None,
|
|
|
139 |
Сборка текста из сущностей.
|
140 |
|
141 |
Args:
|
142 |
+
entities: Список идентификаторов сущностей
|
143 |
chunk_scores: Список весов чанков
|
144 |
include_tables: Флаг включения таблиц
|
145 |
max_documents: Максимальное количество документов
|
|
|
147 |
Returns:
|
148 |
Собранный текст
|
149 |
"""
|
150 |
+
entities = [UUID(entity) for entity in entities]
|
151 |
+
entities = self.chunk_repository.get_entities_by_ids(entities)
|
152 |
logger.info(f"Building text for {len(entities)} entities")
|
153 |
if chunk_scores is not None:
|
154 |
+
chunk_scores = {
|
155 |
+
entity.id: score for entity, score in zip(entities, chunk_scores)
|
156 |
+
}
|
157 |
builder = InjectionBuilder(self.chunk_repository)
|
158 |
return builder.build(
|
159 |
+
entities,
|
160 |
+
scores=chunk_scores,
|
161 |
include_tables=include_tables,
|
162 |
+
neighbors_max_distance=self.neighbors_max_distance,
|
163 |
max_documents=max_documents,
|
164 |
)
|
165 |
|
166 |
+
def search_similar_old(
|
167 |
self,
|
168 |
query: str,
|
169 |
dataset_id: int,
|
|
|
189 |
|
190 |
# Выполняем поиск
|
191 |
return self.faiss_search.search_vectors(query)
|
192 |
+
|
193 |
+
def search_similar(
|
194 |
self,
|
195 |
+
query: str,
|
196 |
+
dataset_id: int,
|
197 |
+
previous_entities: list[list[str]] = None,
|
198 |
+
) -> tuple[list[list[str]], list[str], list[float]]:
|
199 |
"""
|
200 |
+
Поиск похожих сущностей.
|
201 |
|
202 |
Args:
|
203 |
+
query: Текст запроса
|
204 |
+
dataset_id: ID датасета
|
205 |
+
previous_entities: Список идентификаторов сущностей, которые уже были найдены
|
206 |
|
207 |
Returns:
|
208 |
+
tuple[list[list[str]], list[str], list[float]]:
|
209 |
+
- Перефильтрованный список идентификаторов сущностей из прошлых запросов
|
210 |
+
- Список идентификаторов найденных сущностей
|
211 |
+
- Скоры найденных сущностей
|
212 |
"""
|
213 |
+
self._ensure_faiss_initialized(dataset_id)
|
214 |
+
|
215 |
+
if self.faiss_search is None:
|
216 |
+
return previous_entities, [], []
|
217 |
+
|
218 |
+
if sum(len(entities) for entities in previous_entities) < self.max_entities_per_dialogue - self.max_entities_per_message:
|
219 |
+
_, scores, ids = self.faiss_search.search_vectors(query, self.max_entities_per_message)
|
220 |
+
try:
|
221 |
+
scores = scores.tolist()
|
222 |
+
ids = ids.tolist()
|
223 |
+
except:
|
224 |
+
scores = list(scores)
|
225 |
+
ids = list(ids)
|
226 |
+
return previous_entities, ids, scores
|
227 |
+
|
228 |
+
if previous_entities:
|
229 |
+
_, scores, ids = self.faiss_search.search_vectors(query, self.max_entities_per_dialogue)
|
230 |
+
scores = scores.tolist()
|
231 |
+
ids = ids.tolist()
|
232 |
+
|
233 |
+
print(ids)
|
234 |
+
|
235 |
+
previous_entities_ids = [[entity for entity in sublist if entity in ids] for sublist in previous_entities]
|
236 |
+
previous_entities_flat = [entity for sublist in previous_entities_ids for entity in sublist]
|
237 |
+
new_entities = []
|
238 |
+
new_scores = []
|
239 |
+
for id_, score in zip(ids, scores):
|
240 |
+
if id_ not in previous_entities_flat:
|
241 |
+
new_entities.append(id_)
|
242 |
+
new_scores.append(score)
|
243 |
+
if len(new_entities) >= self.max_entities_per_message:
|
244 |
+
break
|
245 |
|
246 |
+
return previous_entities, new_entities, new_scores
|
247 |
+
|
248 |
+
else:
|
249 |
+
_, scores, ids = self.faiss_search.search_vectors(query, self.max_entities_per_dialogue)
|
250 |
+
scores = scores.tolist()
|
251 |
+
ids = ids.tolist()
|
252 |
+
return [], ids, scores
|
config_dev.yaml
CHANGED
@@ -1,69 +1,28 @@
|
|
1 |
common:
|
2 |
log_file_path: !ENV ${LOG_FILE_PATH:/data/logs/common.log}
|
3 |
log_sql_path: !ENV ${SQLALCHEMY_DATABASE_URL:sqlite:////data/logs.db}
|
|
|
4 |
|
5 |
bd:
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
es_port: !ENV ${ELASTIC_PORT:9200}
|
15 |
-
people_path: /data/person_card
|
16 |
-
|
17 |
-
ranging:
|
18 |
-
use_ranging: false
|
19 |
-
alpha: 0.35
|
20 |
-
beta: -0.15
|
21 |
-
k_neighbors: 100
|
22 |
|
23 |
search:
|
24 |
use_qe: true
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
people_elastic_search:
|
31 |
-
use_people_search: false
|
32 |
-
index_name: 'people_search'
|
33 |
-
k_neighbors: 10
|
34 |
-
|
35 |
-
chunks_elastic_search:
|
36 |
-
use_chunks_search: true
|
37 |
-
index_name: 'nmd_full_text'
|
38 |
-
k_neighbors: 5
|
39 |
-
|
40 |
-
groups_elastic_search:
|
41 |
-
use_groups_search: false
|
42 |
-
index_name: 'group_search_elastic_nn'
|
43 |
-
k_neighbors: 1
|
44 |
-
|
45 |
-
rocks_nn_elastic_search:
|
46 |
-
use_rocks_nn_search: false
|
47 |
-
index_name: 'rocks_nn_search_elastic'
|
48 |
-
k_neighbors: 1
|
49 |
-
|
50 |
-
segmentation_elastic_search:
|
51 |
-
use_segmentation_search: false
|
52 |
-
index_name: 'segmentation_search_elastic'
|
53 |
-
k_neighbors: 1
|
54 |
-
|
55 |
-
# Если поиск будет не по чанкам, то добавить название ключа из функции search_answer словаря answer!!!
|
56 |
-
stop_index_names: ['people_answer', 'groups_answer', 'rocks_nn_answer', 'segmentation_answer']
|
57 |
-
|
58 |
-
abbreviation_search:
|
59 |
-
use_abbreviation_search: true
|
60 |
-
index_name: 'nmd_abbreviation_elastic'
|
61 |
-
k_neighbors: 10
|
62 |
|
63 |
files:
|
64 |
empty_start: true
|
65 |
-
regulations_path: /data/regulation_datasets
|
66 |
-
default_regulations_path: /data/regulation_datasets/default
|
67 |
documents_path: /data/documents
|
68 |
|
69 |
llm:
|
|
|
1 |
common:
|
2 |
log_file_path: !ENV ${LOG_FILE_PATH:/data/logs/common.log}
|
3 |
log_sql_path: !ENV ${SQLALCHEMY_DATABASE_URL:sqlite:////data/logs.db}
|
4 |
+
log_level: !ENV ${LOG_LEVEL:INFO}
|
5 |
|
6 |
bd:
|
7 |
+
entities:
|
8 |
+
strategy_name: !ENV ${ENTITIES_STRATEGY_NAME:fixed_size}
|
9 |
+
strategy_params:
|
10 |
+
words_per_chunk: 50
|
11 |
+
overlap_words: 25
|
12 |
+
respect_sentence_boundaries: true
|
13 |
+
process_tables: true
|
14 |
+
neighbors_max_distance: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
search:
|
17 |
use_qe: true
|
18 |
+
use_vector_search: true
|
19 |
+
vectorizer_path: !ENV ${EMBEDDING_MODEL_PATH:BAAI/bge-m3}
|
20 |
+
device: !ENV ${DEVICE:cuda}
|
21 |
+
max_entities_per_message: 75
|
22 |
+
max_entities_per_dialogue: 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
files:
|
25 |
empty_start: true
|
|
|
|
|
26 |
documents_path: /data/documents
|
27 |
|
28 |
llm:
|
lib/extractor/ntr_text_fragmentation/__init__.py
CHANGED
@@ -2,18 +2,23 @@
|
|
2 |
Модуль извлечения и сборки документов.
|
3 |
"""
|
4 |
|
5 |
-
from .core.
|
6 |
-
from .
|
7 |
from .core.injection_builder import InjectionBuilder
|
8 |
-
from .
|
|
|
|
|
9 |
|
10 |
__all__ = [
|
11 |
-
"
|
12 |
-
"InjectionBuilder",
|
13 |
-
"EntityRepository",
|
14 |
"InMemoryEntityRepository",
|
15 |
"LinkerEntity",
|
16 |
-
"
|
|
|
|
|
17 |
"DocumentAsEntity",
|
18 |
"integrations",
|
19 |
-
|
|
|
|
2 |
Модуль извлечения и сборки документов.
|
3 |
"""
|
4 |
|
5 |
+
from .core.extractor import EntitiesExtractor
|
6 |
+
from .repositories.entity_repository import EntityRepository
|
7 |
from .core.injection_builder import InjectionBuilder
|
8 |
+
from .repositories import InMemoryEntityRepository
|
9 |
+
from .models import DocumentAsEntity, LinkerEntity, Link, Entity, register_entity
|
10 |
+
from .chunking import FIXED_SIZE
|
11 |
|
12 |
__all__ = [
|
13 |
+
"EntitiesExtractor",
|
14 |
+
"InjectionBuilder",
|
15 |
+
"EntityRepository",
|
16 |
"InMemoryEntityRepository",
|
17 |
"LinkerEntity",
|
18 |
+
"Entity",
|
19 |
+
"Link",
|
20 |
+
"register_entity",
|
21 |
"DocumentAsEntity",
|
22 |
"integrations",
|
23 |
+
"FIXED_SIZE",
|
24 |
+
]
|
lib/extractor/ntr_text_fragmentation/additors/tables/__init__.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
-
from .table_entity import TableEntity
|
|
|
2 |
|
3 |
__all__ = [
|
4 |
'TableEntity',
|
|
|
5 |
]
|
|
|
1 |
+
from .models.table_entity import TableEntity
|
2 |
+
from .table_processor import TableProcessor
|
3 |
|
4 |
__all__ = [
|
5 |
'TableEntity',
|
6 |
+
'TableProcessor',
|
7 |
]
|
lib/extractor/ntr_text_fragmentation/additors/tables/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .table_entity import TableEntity
|
2 |
+
from .subtable_entity import SubTableEntity
|
3 |
+
from .table_row_entity import TableRowEntity
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'TableEntity',
|
8 |
+
'SubTableEntity',
|
9 |
+
'TableRowEntity',
|
10 |
+
]
|
lib/extractor/ntr_text_fragmentation/additors/tables/models/subtable_entity.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from ....models import Entity, register_entity
|
4 |
+
|
5 |
+
|
6 |
+
@register_entity
|
7 |
+
@dataclass
|
8 |
+
class SubTableEntity(Entity):
|
9 |
+
"""
|
10 |
+
Сущность подтаблицы из документа.
|
11 |
+
|
12 |
+
Расширяет основную сущность LinkerEntity, добавляя информацию о таблице.
|
13 |
+
"""
|
14 |
+
|
15 |
+
header: list[str] | None = None
|
16 |
+
title: str | None = None
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def _deserialize_to_me(cls, data: Entity) -> 'SubTableEntity':
|
20 |
+
"""
|
21 |
+
Десериализует SubTableEntity из объекта Entity.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
data (Entity): Объект Entity для десериализации.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
SubTableEntity: Новый экземпляр SubTableEntity с данными из Entity.
|
28 |
+
|
29 |
+
Raises:
|
30 |
+
TypeError: Если data не является экземпляром Entity.
|
31 |
+
"""
|
32 |
+
if not isinstance(data, Entity):
|
33 |
+
raise TypeError(f"Ожидался Entity, получен {type(data)}")
|
34 |
+
|
35 |
+
# Пытаемся получить из полей объекта, если это уже SubTableEntity или его потомок
|
36 |
+
header = getattr(data, 'header', None)
|
37 |
+
title = getattr(data, 'title', None)
|
38 |
+
|
39 |
+
# Если не нашли в полях, ищем в метаданных
|
40 |
+
metadata = data.metadata or {}
|
41 |
+
if header is None:
|
42 |
+
header = metadata.get('_header')
|
43 |
+
if title is None:
|
44 |
+
title = metadata.get('_title')
|
45 |
+
|
46 |
+
# Переписываем блок return, чтобы точно включить groupper
|
47 |
+
return cls(
|
48 |
+
id=data.id,
|
49 |
+
name=data.name,
|
50 |
+
text=data.text,
|
51 |
+
in_search_text=data.in_search_text,
|
52 |
+
metadata={k: v for k, v in metadata.items() if not k.startswith('_')}, # Очищаем метаданные
|
53 |
+
source_id=data.source_id,
|
54 |
+
target_id=data.target_id,
|
55 |
+
number_in_relation=data.number_in_relation,
|
56 |
+
groupper=data.groupper, # Убеждаемся, что groupper здесь
|
57 |
+
type=cls.__name__, # Используем имя класса для типа
|
58 |
+
# Специфичные поля
|
59 |
+
header=header,
|
60 |
+
title=title
|
61 |
+
)
|
lib/extractor/ntr_text_fragmentation/additors/tables/models/table_entity.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from ....models import Entity, register_entity
|
4 |
+
|
5 |
+
|
6 |
+
@register_entity
|
7 |
+
@dataclass
|
8 |
+
class TableEntity(Entity):
|
9 |
+
"""
|
10 |
+
Сущность таблицы из документа.
|
11 |
+
|
12 |
+
Расширяет основную сущность LinkerEntity, добавляя информацию о таблице.
|
13 |
+
"""
|
14 |
+
|
15 |
+
title: str | None = None
|
16 |
+
header: list[str] | None = None
|
17 |
+
note: str | None = None
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def _deserialize_to_me(cls, data: Entity) -> 'TableEntity':
|
21 |
+
"""
|
22 |
+
Десериализует TableEntity из объекта Entity.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
data (Entity): Объект Entity для десериализации.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
TableEntity: Новый экземпляр TableEntity с данными из Entity.
|
29 |
+
|
30 |
+
Raises:
|
31 |
+
TypeError: Если data не является экземпляром Entity.
|
32 |
+
"""
|
33 |
+
if not isinstance(data, Entity):
|
34 |
+
raise TypeError(f"Ожидался Entity, получен {type(data)}")
|
35 |
+
|
36 |
+
# Пытаемся получить из полей объекта, если это уже TableEntity или его потомок
|
37 |
+
title = getattr(data, 'title', None)
|
38 |
+
header = getattr(data, 'header', None)
|
39 |
+
note = getattr(data, 'note', None)
|
40 |
+
|
41 |
+
# Если не нашли в полях, ищем в метаданных
|
42 |
+
metadata = data.metadata or {}
|
43 |
+
if title is None:
|
44 |
+
title = metadata.get('_title')
|
45 |
+
if header is None:
|
46 |
+
header = metadata.get('_header')
|
47 |
+
if note is None:
|
48 |
+
note = metadata.get('_note')
|
49 |
+
|
50 |
+
# Переписываем блок return, чтобы точно включить groupper
|
51 |
+
return cls(
|
52 |
+
id=data.id,
|
53 |
+
name=data.name,
|
54 |
+
text=data.text,
|
55 |
+
in_search_text=data.in_search_text,
|
56 |
+
metadata={k: v for k, v in metadata.items() if not k.startswith('_')}, # Очищаем метаданные
|
57 |
+
source_id=data.source_id,
|
58 |
+
target_id=data.target_id,
|
59 |
+
number_in_relation=data.number_in_relation,
|
60 |
+
groupper=data.groupper, # Убеждаемся, что groupper здесь
|
61 |
+
type=cls.__name__, # Используем имя класса для типа
|
62 |
+
# Специфичные поля
|
63 |
+
title=title,
|
64 |
+
header=header,
|
65 |
+
note=note
|
66 |
+
)
|
lib/extractor/ntr_text_fragmentation/additors/tables/models/table_row_entity.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
from ....models import Entity, register_entity
|
4 |
+
|
5 |
+
|
6 |
+
@register_entity
|
7 |
+
@dataclass
|
8 |
+
class TableRowEntity(Entity):
|
9 |
+
"""
|
10 |
+
Сущность строки таблицы из документа.
|
11 |
+
|
12 |
+
Расширяет основную сущность LinkerEntity, добавляя информацию о строке таблицы.
|
13 |
+
"""
|
14 |
+
|
15 |
+
cells: list[str] = field(default_factory=list)
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def _deserialize_to_me(cls, data: Entity) -> "TableRowEntity":
|
19 |
+
"""
|
20 |
+
Десериализует TableRowEntity из объекта Entity.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
data (Entity): Объект Entity для десериализации.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
TableRowEntity: Новый экземпляр TableRowEntity с данными из Entity.
|
27 |
+
|
28 |
+
Raises:
|
29 |
+
TypeError: Если data не является экземпляром Entity.
|
30 |
+
"""
|
31 |
+
if not isinstance(data, Entity):
|
32 |
+
raise TypeError(f"Ожидался Entity, получен {type(data)}")
|
33 |
+
|
34 |
+
# Пытаемся получить из полей объекта, если это уже TableRowEntity или его потомок
|
35 |
+
cells = getattr(data, 'cells', None)
|
36 |
+
|
37 |
+
# Если не нашли в полях, ищем в метаданных
|
38 |
+
metadata = data.metadata or {}
|
39 |
+
if cells is None:
|
40 |
+
cells = metadata.get('_cells', [])
|
41 |
+
|
42 |
+
# Переписываем блок return, чтобы точно включить groupper
|
43 |
+
return cls(
|
44 |
+
id=data.id,
|
45 |
+
name=data.name,
|
46 |
+
text=data.text,
|
47 |
+
in_search_text=data.in_search_text,
|
48 |
+
metadata={k: v for k, v in metadata.items() if not k.startswith('_')}, # Очищаем метаданные
|
49 |
+
source_id=data.source_id,
|
50 |
+
target_id=data.target_id,
|
51 |
+
number_in_relation=data.number_in_relation,
|
52 |
+
groupper=data.groupper, # Убеждаемся, что groupper здесь
|
53 |
+
type=cls.__name__, # Используем имя класса для типа
|
54 |
+
# Специфичные поля
|
55 |
+
cells=cells
|
56 |
+
)
|
57 |
+
|
lib/extractor/ntr_text_fragmentation/additors/tables/table_processor.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ntr_fileparser import ParsedRow, ParsedSubtable, ParsedTable
|
2 |
+
|
3 |
+
from ...models import LinkerEntity
|
4 |
+
from ...repositories.entity_repository import EntityRepository, GroupedEntities
|
5 |
+
from .models import SubTableEntity, TableEntity, TableRowEntity
|
6 |
+
|
7 |
+
|
8 |
+
class TableProcessor:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def extract(
|
13 |
+
self,
|
14 |
+
table: ParsedTable,
|
15 |
+
doc_entity: LinkerEntity,
|
16 |
+
) -> list[LinkerEntity]:
|
17 |
+
"""
|
18 |
+
Извлекает таблицу из документа и создает для нее сущность, а также сущности для всех строк таблицы.
|
19 |
+
"""
|
20 |
+
entities = []
|
21 |
+
table_entity = self._create_table_entity(table, doc_entity)
|
22 |
+
|
23 |
+
entities.append(table_entity)
|
24 |
+
|
25 |
+
for i, subtable in enumerate(table.subtables):
|
26 |
+
index_in_relation = i + 1
|
27 |
+
subtable_entity = self._create_subtable_entity(
|
28 |
+
subtable,
|
29 |
+
table_entity,
|
30 |
+
index_in_relation,
|
31 |
+
)
|
32 |
+
|
33 |
+
entities.append(subtable_entity)
|
34 |
+
|
35 |
+
for j, row in enumerate(subtable.rows):
|
36 |
+
index_in_relation = j + 1
|
37 |
+
row_entity = self._create_row_entity(
|
38 |
+
row,
|
39 |
+
subtable_entity,
|
40 |
+
row.to_string(),
|
41 |
+
index_in_relation,
|
42 |
+
)
|
43 |
+
|
44 |
+
entities.append(row_entity)
|
45 |
+
|
46 |
+
return entities
|
47 |
+
|
48 |
+
def _create_table_entity(
|
49 |
+
self,
|
50 |
+
table: ParsedTable,
|
51 |
+
doc_entity: LinkerEntity,
|
52 |
+
) -> TableEntity:
|
53 |
+
entity = TableEntity(
|
54 |
+
name=table.title or 'NonameTable',
|
55 |
+
text=table.title or '',
|
56 |
+
title=table.title,
|
57 |
+
header=self._create_header(table),
|
58 |
+
note=table.note,
|
59 |
+
groupper=f'Table_{doc_entity.id}',
|
60 |
+
number_in_relation=table.index_in_document,
|
61 |
+
)
|
62 |
+
entity.owner_id = doc_entity.id
|
63 |
+
|
64 |
+
return entity
|
65 |
+
|
66 |
+
def _create_header(self, table: ParsedTable) -> list[str] | None:
|
67 |
+
if len(table.headers) == 0:
|
68 |
+
return None
|
69 |
+
|
70 |
+
rows = table.headers
|
71 |
+
header: list[list[str]] = [[] for _ in range(len(rows[0].cells))]
|
72 |
+
for row in rows:
|
73 |
+
for i, cell in enumerate(row.cells):
|
74 |
+
header[i].append(cell)
|
75 |
+
result = [" > ".join(column) for column in header]
|
76 |
+
|
77 |
+
return result
|
78 |
+
|
79 |
+
def _create_subtable_entity(
|
80 |
+
self,
|
81 |
+
subtable: ParsedSubtable,
|
82 |
+
table_entity: TableEntity,
|
83 |
+
number_in_relation: int,
|
84 |
+
) -> SubTableEntity:
|
85 |
+
header = None
|
86 |
+
if subtable.header:
|
87 |
+
header = subtable.header.cells
|
88 |
+
entity = SubTableEntity(
|
89 |
+
name=subtable.title or 'NonameSubTable',
|
90 |
+
text=subtable.title or '',
|
91 |
+
title=subtable.title,
|
92 |
+
header=header,
|
93 |
+
groupper=f'SubTable_{table_entity.id}',
|
94 |
+
number_in_relation=number_in_relation,
|
95 |
+
)
|
96 |
+
entity.owner_id = table_entity.id
|
97 |
+
return entity
|
98 |
+
|
99 |
+
def _create_row_entity(
|
100 |
+
self,
|
101 |
+
row: ParsedRow,
|
102 |
+
subtable_entity: SubTableEntity,
|
103 |
+
in_search_text: str,
|
104 |
+
number_in_relation: int,
|
105 |
+
) -> TableRowEntity:
|
106 |
+
entity = TableRowEntity(
|
107 |
+
name=f'{row.index}',
|
108 |
+
text='',
|
109 |
+
cells=row.cells,
|
110 |
+
in_search_text=in_search_text,
|
111 |
+
groupper=f'Row_{subtable_entity.id}',
|
112 |
+
number_in_relation=number_in_relation,
|
113 |
+
)
|
114 |
+
entity.owner_id = subtable_entity.id
|
115 |
+
return entity
|
116 |
+
|
117 |
+
def build(
|
118 |
+
self,
|
119 |
+
repository: EntityRepository,
|
120 |
+
group: GroupedEntities[TableEntity],
|
121 |
+
) -> str:
|
122 |
+
"""
|
123 |
+
Собирает текст таблицы из списка сущностей.
|
124 |
+
"""
|
125 |
+
table = group.composer
|
126 |
+
entities = group.entities
|
127 |
+
|
128 |
+
subtable_grouped: list[GroupedEntities[SubTableEntity]] = (
|
129 |
+
repository.group_entities_hierarchically(
|
130 |
+
entities=entities,
|
131 |
+
root_type=SubTableEntity,
|
132 |
+
sort=True,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
|
136 |
+
result = ""
|
137 |
+
|
138 |
+
if table.title:
|
139 |
+
result += f"#### {table.title}\n"
|
140 |
+
else:
|
141 |
+
result += f"#### Таблица {table.number_in_relation}\n"
|
142 |
+
|
143 |
+
table_header = table.header
|
144 |
+
|
145 |
+
for subtable_group in subtable_grouped:
|
146 |
+
subtable = subtable_group.composer
|
147 |
+
subtable_header = subtable.header
|
148 |
+
rows = [
|
149 |
+
row
|
150 |
+
for row in subtable_group.entities
|
151 |
+
if isinstance(row, TableRowEntity)
|
152 |
+
]
|
153 |
+
if subtable.title:
|
154 |
+
result += f"##### {subtable.title}\n"
|
155 |
+
for row in rows:
|
156 |
+
result += self._prepare_row(
|
157 |
+
row,
|
158 |
+
subtable_header or table_header,
|
159 |
+
)
|
160 |
+
|
161 |
+
if table.note:
|
162 |
+
result += f"**Примечание:** {table.note}\n"
|
163 |
+
|
164 |
+
return result
|
165 |
+
|
166 |
+
def _prepare_row(
|
167 |
+
self,
|
168 |
+
row: TableRowEntity,
|
169 |
+
header: list[str] | None = None,
|
170 |
+
) -> str:
|
171 |
+
row_name = f'Строка {row.number_in_relation}'
|
172 |
+
if header is None:
|
173 |
+
cells = "\n".join([f"- - {cell}" for cell in row.cells])
|
174 |
+
else:
|
175 |
+
normalized_header = [h.replace('\n', '') for h in header]
|
176 |
+
cells = "\n".join([f" - **{normalized_header[i]}**: {row.cells[i]}".replace('\n', '\n -') for i in range(len(header))])
|
177 |
+
|
178 |
+
return f"- {row_name}\n{cells}\n"
|
lib/extractor/ntr_text_fragmentation/additors/tables_processor.py
CHANGED
@@ -2,12 +2,11 @@
|
|
2 |
Процессор таблиц из документа.
|
3 |
"""
|
4 |
|
5 |
-
from uuid import uuid4
|
6 |
-
|
7 |
from ntr_fileparser import ParsedDocument
|
8 |
|
9 |
from ..models import LinkerEntity
|
10 |
-
from .tables import TableEntity
|
|
|
11 |
|
12 |
|
13 |
class TablesProcessor:
|
@@ -17,101 +16,42 @@ class TablesProcessor:
|
|
17 |
|
18 |
def __init__(self):
|
19 |
"""Инициализация процессора таблиц."""
|
20 |
-
|
21 |
|
22 |
-
def
|
23 |
self,
|
24 |
document: ParsedDocument,
|
25 |
doc_entity: LinkerEntity,
|
26 |
) -> list[LinkerEntity]:
|
27 |
-
"""
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
doc_entity: Сущность документа для связи с таблицами
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
Список сущностей TableEntity и связей
|
36 |
-
"""
|
37 |
-
if not document.tables:
|
38 |
-
return []
|
39 |
-
|
40 |
-
table_entities = []
|
41 |
-
links = []
|
42 |
-
|
43 |
-
rows = '\n\n'.join([table.to_string() for table in document.tables]).split(
|
44 |
-
'\n\n'
|
45 |
-
)
|
46 |
-
|
47 |
-
# Обрабатываем каждую таблицу
|
48 |
-
for idx, row in enumerate(rows):
|
49 |
-
# Создаем сущность таблицы
|
50 |
-
table_entity = self._create_table_entity(
|
51 |
-
table_text=row,
|
52 |
-
table_index=idx,
|
53 |
-
doc_name=doc_entity.name,
|
54 |
-
)
|
55 |
-
|
56 |
-
# Создаем связь между документом и таблицей
|
57 |
-
link = self._create_link(doc_entity, table_entity, idx)
|
58 |
|
59 |
-
|
60 |
-
links.append(link)
|
61 |
-
|
62 |
-
# Возвращаем список таблиц и связей
|
63 |
-
return table_entities + links
|
64 |
-
|
65 |
-
def _create_table_entity(
|
66 |
self,
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
) -> TableEntity:
|
71 |
"""
|
72 |
-
|
73 |
-
|
74 |
-
Args:
|
75 |
-
table_text: Текст таблицы
|
76 |
-
table_index: Индекс таблицы в документе
|
77 |
-
doc_name: Имя документа
|
78 |
-
|
79 |
-
Returns:
|
80 |
-
Сущность TableEntity
|
81 |
"""
|
82 |
-
entity_name = f"{doc_name}_table_{table_index}"
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
type=TableEntity.__name__,
|
91 |
-
table_index=table_index,
|
92 |
)
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
"""
|
98 |
-
Создает связь между документом и таблицей.
|
99 |
-
|
100 |
-
Args:
|
101 |
-
doc_entity: Сущность документа
|
102 |
-
table_entity: Сущность таблицы
|
103 |
-
index: Индекс таблицы в документе
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
"""
|
108 |
-
return LinkerEntity(
|
109 |
-
id=uuid4(),
|
110 |
-
name="document_to_table",
|
111 |
-
text="",
|
112 |
-
metadata={},
|
113 |
-
source_id=doc_entity.id,
|
114 |
-
target_id=table_entity.id,
|
115 |
-
number_in_relation=index,
|
116 |
-
type="Link",
|
117 |
)
|
|
|
|
|
|
2 |
Процессор таблиц из документа.
|
3 |
"""
|
4 |
|
|
|
|
|
5 |
from ntr_fileparser import ParsedDocument
|
6 |
|
7 |
from ..models import LinkerEntity
|
8 |
+
from .tables import TableProcessor, TableEntity
|
9 |
+
from ..repositories import EntityRepository, GroupedEntities
|
10 |
|
11 |
|
12 |
class TablesProcessor:
|
|
|
16 |
|
17 |
def __init__(self):
|
18 |
"""Инициализация процессора таблиц."""
|
19 |
+
self.table_processor = TableProcessor()
|
20 |
|
21 |
+
def extract(
|
22 |
self,
|
23 |
document: ParsedDocument,
|
24 |
doc_entity: LinkerEntity,
|
25 |
) -> list[LinkerEntity]:
|
26 |
+
"""Извлекает таблицы из документа и создает для них сущности."""
|
27 |
+
entities = []
|
28 |
+
for table in document.tables:
|
29 |
+
entities.extend(self.table_processor.extract(table, doc_entity))
|
30 |
+
return entities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
def build(
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
self,
|
34 |
+
repository: EntityRepository,
|
35 |
+
entities: list[LinkerEntity],
|
36 |
+
) -> str:
|
|
|
37 |
"""
|
38 |
+
Собирает текст таблиц из списка сущностей.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
"""
|
|
|
40 |
|
41 |
+
groups: list[GroupedEntities[TableEntity]] = (
|
42 |
+
repository.group_entities_hierarchically(
|
43 |
+
entities=entities,
|
44 |
+
root_type=TableEntity,
|
45 |
+
sort=True,
|
46 |
+
)
|
|
|
|
|
47 |
)
|
48 |
|
49 |
+
groups = sorted(
|
50 |
+
groups, key=lambda x: x.composer.number_in_relation,
|
51 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
result = "\n\n".join(
|
54 |
+
self.table_processor.build(repository, group) for group in groups
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
+
|
57 |
+
return result
|
lib/extractor/ntr_text_fragmentation/chunking/__init__.py
CHANGED
@@ -3,9 +3,21 @@
|
|
3 |
"""
|
4 |
|
5 |
from .chunking_strategy import ChunkingStrategy
|
6 |
-
from .specific_strategies import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
__all__ = [
|
9 |
"ChunkingStrategy",
|
|
|
10 |
"FixedSizeChunkingStrategy",
|
|
|
|
|
|
|
|
|
11 |
]
|
|
|
3 |
"""
|
4 |
|
5 |
from .chunking_strategy import ChunkingStrategy
|
6 |
+
from .specific_strategies import (
|
7 |
+
FixedSizeChunk,
|
8 |
+
FixedSizeChunkingStrategy,
|
9 |
+
FIXED_SIZE,
|
10 |
+
)
|
11 |
+
from .text_to_text_base import TextToTextBaseStrategy
|
12 |
+
|
13 |
+
from .chunking_registry import register_chunking_strategy, chunking_registry
|
14 |
|
15 |
__all__ = [
|
16 |
"ChunkingStrategy",
|
17 |
+
"FixedSizeChunk",
|
18 |
"FixedSizeChunkingStrategy",
|
19 |
+
"FIXED_SIZE",
|
20 |
+
"TextToTextBaseStrategy",
|
21 |
+
"register_chunking_strategy",
|
22 |
+
"chunking_registry",
|
23 |
]
|
lib/extractor/ntr_text_fragmentation/chunking/chunking_registry.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..chunking.chunking_strategy import ChunkingStrategy
|
2 |
+
|
3 |
+
|
4 |
+
class _ChunkingRegistry:
|
5 |
+
def __init__(self):
|
6 |
+
self._chunking_strategies: dict[str, ChunkingStrategy] = {}
|
7 |
+
|
8 |
+
def register(self, name: str, strategy: ChunkingStrategy):
|
9 |
+
self._chunking_strategies[name] = strategy
|
10 |
+
|
11 |
+
def get(self, name: str) -> ChunkingStrategy:
|
12 |
+
return self._chunking_strategies[name]
|
13 |
+
|
14 |
+
def get_names(self) -> list[str]:
|
15 |
+
return list(self._chunking_strategies.keys())
|
16 |
+
|
17 |
+
def __len__(self) -> int:
|
18 |
+
return len(self._chunking_strategies)
|
19 |
+
|
20 |
+
def __contains__(self, name: str | ChunkingStrategy) -> bool:
|
21 |
+
if isinstance(name, ChunkingStrategy):
|
22 |
+
return name in self._chunking_strategies.values()
|
23 |
+
return name in self._chunking_strategies
|
24 |
+
|
25 |
+
def __dict__(self) -> dict:
|
26 |
+
return self._chunking_strategies
|
27 |
+
|
28 |
+
def __getitem__(self, name: str) -> ChunkingStrategy:
|
29 |
+
return self._chunking_strategies[name]
|
30 |
+
|
31 |
+
|
32 |
+
chunking_registry = _ChunkingRegistry()
|
33 |
+
|
34 |
+
|
35 |
+
def register_chunking_strategy(name: str | None = None):
|
36 |
+
def decorator(cls: type[ChunkingStrategy]) -> type[ChunkingStrategy]:
|
37 |
+
chunking_registry.register(name or cls.__name__, cls)
|
38 |
+
return cls
|
39 |
+
|
40 |
+
return decorator
|
lib/extractor/ntr_text_fragmentation/chunking/chunking_strategy.py
CHANGED
@@ -1,86 +1,99 @@
|
|
1 |
"""
|
2 |
-
|
3 |
"""
|
4 |
|
|
|
5 |
from abc import ABC, abstractmethod
|
6 |
|
7 |
from ntr_fileparser import ParsedDocument
|
8 |
|
9 |
-
from ..models import
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
class ChunkingStrategy(ABC):
|
13 |
-
"""
|
14 |
-
|
15 |
-
"""
|
16 |
-
|
17 |
@abstractmethod
|
18 |
-
def chunk(
|
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
Разбивает документ на чанки в соответствии со стратегией.
|
21 |
-
|
22 |
Args:
|
23 |
-
document: ParsedDocument для извлечения текста
|
24 |
-
doc_entity:
|
25 |
-
|
26 |
-
|
27 |
Returns:
|
28 |
-
|
29 |
"""
|
30 |
raise NotImplementedError("Стратегия чанкинга должна реализовать метод chunk")
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
"""
|
34 |
-
Собирает
|
35 |
-
|
36 |
-
Базовая реализация сортирует чанки по chunk_index и объединяет их тексты,
|
37 |
-
сохраняя структуру параграфов и избегая дублирования текста.
|
38 |
-
|
39 |
Args:
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
Returns:
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
# Отбираем только чанки
|
53 |
-
valid_chunks = [c for c in chunks if isinstance(c, Chunk)]
|
54 |
-
|
55 |
-
# Сортируем чанки по chunk_index
|
56 |
-
sorted_chunks = sorted(valid_chunks, key=lambda c: c.chunk_index or 0)
|
57 |
-
|
58 |
-
# Собираем текст документа с учетом структуры параграфов
|
59 |
-
result_text = ""
|
60 |
-
|
61 |
-
for chunk in sorted_chunks:
|
62 |
-
# Получаем текст чанка (предпочитаем text, а не in_search_text для избежания дублирования)
|
63 |
-
chunk_text = chunk.text if hasattr(chunk, 'text') and chunk.text else ""
|
64 |
-
|
65 |
-
# Добавляем текст чанка с сохранением структуры параграфов
|
66 |
-
if result_text and result_text[-1] != "\n" and chunk_text and chunk_text[0] != "\n":
|
67 |
-
result_text += " "
|
68 |
-
result_text += chunk_text
|
69 |
-
|
70 |
-
# Пост-обработка результата
|
71 |
-
# Заменяем множественные переносы строк на одиночные
|
72 |
-
result_text = re.sub(r'\n+', '\n', result_text)
|
73 |
-
|
74 |
-
# Заменяем множественные пробелы на одиночные
|
75 |
-
result_text = re.sub(r' +', ' ', result_text)
|
76 |
-
|
77 |
-
# Убираем пробелы перед переносами строк
|
78 |
-
result_text = re.sub(r' +\n', '\n', result_text)
|
79 |
-
|
80 |
-
# Убираем пробелы после переносов строк
|
81 |
-
result_text = re.sub(r'\n +', '\n', result_text)
|
82 |
-
|
83 |
-
# Убираем лишние переносы строк в начале и конце текста
|
84 |
-
result_text = result_text.strip()
|
85 |
-
|
86 |
-
return result_text
|
|
|
1 |
"""
|
2 |
+
Абстрактный базовый класс для стратегий чанкинга.
|
3 |
"""
|
4 |
|
5 |
+
import logging
|
6 |
from abc import ABC, abstractmethod
|
7 |
|
8 |
from ntr_fileparser import ParsedDocument
|
9 |
|
10 |
+
from ..models import DocumentAsEntity, LinkerEntity
|
11 |
+
from ..repositories import EntityRepository
|
12 |
+
from .models import Chunk
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
|
16 |
|
17 |
class ChunkingStrategy(ABC):
|
18 |
+
"""Абстрактный класс для стратегий чанкинга."""
|
19 |
+
|
|
|
|
|
20 |
@abstractmethod
|
21 |
+
def chunk(
|
22 |
+
self,
|
23 |
+
document: ParsedDocument,
|
24 |
+
doc_entity: DocumentAsEntity,
|
25 |
+
) -> list[LinkerEntity]:
|
26 |
"""
|
27 |
Разбивает документ на чанки в соответствии со стратегией.
|
28 |
+
|
29 |
Args:
|
30 |
+
document: ParsedDocument для извлечения текста и структуры.
|
31 |
+
doc_entity: Сущность документа-владельца, к которой будут привязаны чанки.
|
32 |
+
|
|
|
33 |
Returns:
|
34 |
+
Список сущностей (чанки)
|
35 |
"""
|
36 |
raise NotImplementedError("Стратегия чанкинга должна реализовать метод chunk")
|
37 |
|
38 |
+
@classmethod
|
39 |
+
def dechunk(
|
40 |
+
cls,
|
41 |
+
repository: EntityRepository,
|
42 |
+
filtered_entities: list[LinkerEntity],
|
43 |
+
) -> str:
|
44 |
"""
|
45 |
+
Собирает текст из отфильтрованных чанков к одному документу.
|
46 |
+
|
|
|
|
|
|
|
47 |
Args:
|
48 |
+
repository: Репозиторий (может понадобиться для получения доп. информации,
|
49 |
+
хотя в текущей реализации не используется).
|
50 |
+
filtered_entities: Список отфильтрованных сущностей (чанков),
|
51 |
+
относящихся к одному документу.
|
52 |
+
|
53 |
Returns:
|
54 |
+
Собранный текст из чанков.
|
55 |
+
"""
|
56 |
+
chunks = [e for e in filtered_entities if isinstance(e, Chunk)]
|
57 |
+
chunks.sort(key=lambda x: x.number_in_relation)
|
58 |
+
|
59 |
+
groups: list[list[Chunk]] = []
|
60 |
+
for chunk in chunks:
|
61 |
+
if len(groups) == 0:
|
62 |
+
groups.append([chunk])
|
63 |
+
continue
|
64 |
+
|
65 |
+
last_chunk = groups[-1][-1]
|
66 |
+
if chunk.number_in_relation == last_chunk.number_in_relation + 1:
|
67 |
+
groups[-1].append(chunk)
|
68 |
+
else:
|
69 |
+
groups.append([chunk])
|
70 |
+
|
71 |
+
result = ""
|
72 |
+
previous_last_index = 0
|
73 |
+
for group in groups:
|
74 |
+
if previous_last_index is not None:
|
75 |
+
missing_chunks = group[0].number_in_relation - previous_last_index - 1
|
76 |
+
missing_string = f'\n_<...Пропущено {missing_chunks} фрагментов...>_\n'
|
77 |
+
else:
|
78 |
+
missing_string = '\n_<...>_\n'
|
79 |
+
result += missing_string + cls._build_sequenced_chunks(repository, group)
|
80 |
+
previous_last_index = group[-1].number_in_relation
|
81 |
+
|
82 |
+
return result.strip()
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def _build_sequenced_chunks(
|
86 |
+
cls,
|
87 |
+
repository: EntityRepository,
|
88 |
+
group: list[Chunk],
|
89 |
+
) -> str:
|
90 |
+
"""
|
91 |
+
Строит текст для последовательных чанко��.
|
92 |
+
Стоит переопределить в конкретной стратегии, если она предполагает сложную логику
|
93 |
"""
|
94 |
+
return " ".join([cls._build_chunk(chunk) for chunk in group])
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def _build_chunk(cls, chunk: Chunk) -> str:
|
98 |
+
"""Строит текст для одного чанка."""
|
99 |
+
return chunk.text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/extractor/ntr_text_fragmentation/chunking/models/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .chunk import Chunk
|
2 |
+
from .custom_chunk import CustomChunk
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"Chunk",
|
6 |
+
"CustomChunk",
|
7 |
+
]
|
lib/extractor/ntr_text_fragmentation/chunking/models/chunk.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from ...models.linker_entity import Entity, register_entity
|
4 |
+
|
5 |
+
|
6 |
+
@register_entity
|
7 |
+
@dataclass
|
8 |
+
class Chunk(Entity):
|
9 |
+
"""
|
10 |
+
Чанк документа.
|
11 |
+
"""
|
lib/extractor/ntr_text_fragmentation/chunking/models/custom_chunk.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from ...models.linker_entity import Entity, register_entity
|
4 |
+
|
5 |
+
|
6 |
+
@register_entity
|
7 |
+
@dataclass
|
8 |
+
class CustomChunk(Entity):
|
9 |
+
"""
|
10 |
+
Чанк документа, полученный в результате применения пользовательской стратегии
|
11 |
+
чанкинга.
|
12 |
+
"""
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/__init__.py
CHANGED
@@ -3,9 +3,13 @@
|
|
3 |
"""
|
4 |
|
5 |
from .fixed_size import FixedSizeChunk
|
6 |
-
from .fixed_size_chunking import
|
|
|
|
|
|
|
7 |
|
8 |
__all__ = [
|
9 |
"FixedSizeChunk",
|
10 |
"FixedSizeChunkingStrategy",
|
|
|
11 |
]
|
|
|
3 |
"""
|
4 |
|
5 |
from .fixed_size import FixedSizeChunk
|
6 |
+
from .fixed_size_chunking import (
|
7 |
+
FixedSizeChunkingStrategy,
|
8 |
+
FIXED_SIZE,
|
9 |
+
)
|
10 |
|
11 |
__all__ = [
|
12 |
"FixedSizeChunk",
|
13 |
"FixedSizeChunkingStrategy",
|
14 |
+
"FIXED_SIZE",
|
15 |
]
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size/fixed_size_chunk.py
CHANGED
@@ -2,11 +2,11 @@
|
|
2 |
Класс для представления чанка фиксированного размера.
|
3 |
"""
|
4 |
|
5 |
-
from dataclasses import dataclass
|
6 |
from typing import Any
|
7 |
|
8 |
-
from ....models
|
9 |
-
from
|
10 |
|
11 |
|
12 |
@register_entity
|
@@ -15,21 +15,14 @@ class FixedSizeChunk(Chunk):
|
|
15 |
"""
|
16 |
Представляет чанк фиксированного размера.
|
17 |
|
18 |
-
Расширяет базовый класс Chunk дополнительными полями, связанными с
|
19 |
-
|
20 |
-
границ предложений.
|
21 |
"""
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
right_sentence_part: str = "" # Часть предложения справа от text
|
28 |
-
overlap_left: str = "" # Нахлест слева (без учета границ предложений)
|
29 |
-
overlap_right: str = "" # Нахлест справа (без учета границ предложений)
|
30 |
-
|
31 |
-
# Метаданные для дополнительной информации
|
32 |
-
metadata: dict[str, Any] = field(default_factory=dict)
|
33 |
|
34 |
def __str__(self) -> str:
|
35 |
"""
|
@@ -38,106 +31,64 @@ class FixedSizeChunk(Chunk):
|
|
38 |
Returns:
|
39 |
Строка с информацией о чанке.
|
40 |
"""
|
|
|
|
|
|
|
41 |
return (
|
42 |
f"FixedSizeChunk(id={self.id}, chunk_index={self.chunk_index}, "
|
43 |
-
f"tokens={self.token_count}, "
|
44 |
-
f"text='{self.text[:30]}{'...' if len(self.text) > 30 else ''}'"
|
45 |
-
f")"
|
46 |
)
|
47 |
|
48 |
-
def get_adjacent_chunks_indices(self, max_distance: int = 1) -> list[int]:
|
49 |
-
"""
|
50 |
-
Возвращает индексы соседних чанков в пределах указанного расстояния.
|
51 |
-
|
52 |
-
Args:
|
53 |
-
max_distance: Максимальное расстояние от текущего чанка
|
54 |
-
|
55 |
-
Returns:
|
56 |
-
Список индексов соседних чанков
|
57 |
-
"""
|
58 |
-
indices = []
|
59 |
-
for i in range(1, max_distance + 1):
|
60 |
-
# Добавляем предыдущие чанки
|
61 |
-
if self.chunk_index - i >= 0:
|
62 |
-
indices.append(self.chunk_index - i)
|
63 |
-
# Добавляем следующие чанки
|
64 |
-
indices.append(self.chunk_index + i)
|
65 |
-
|
66 |
-
return sorted(indices)
|
67 |
-
|
68 |
@classmethod
|
69 |
-
def
|
70 |
"""
|
71 |
-
Десериализует FixedSizeChunk из объекта LinkerEntity.
|
|
|
|
|
|
|
|
|
72 |
|
73 |
Args:
|
74 |
-
|
75 |
|
76 |
Returns:
|
77 |
-
|
78 |
-
"""
|
79 |
-
metadata = entity.metadata or {}
|
80 |
-
|
81 |
-
# Извлекаем параметры из метаданных
|
82 |
-
# Сначала проверяем в метаданных под ключом _chunk_index
|
83 |
-
chunk_index = metadata.get('_chunk_index')
|
84 |
-
if chunk_index is None:
|
85 |
-
# Затем пробуем получить как атрибут объекта
|
86 |
-
chunk_index = getattr(entity, 'chunk_index', None)
|
87 |
-
if chunk_index is None:
|
88 |
-
# Если и там нет, пробуем обычный поиск по метаданным
|
89 |
-
chunk_index = metadata.get('chunk_index')
|
90 |
-
|
91 |
-
# Преобразуем к int, если значение найдено
|
92 |
-
if chunk_index is not None:
|
93 |
-
try:
|
94 |
-
chunk_index = int(chunk_index)
|
95 |
-
except (ValueError, TypeError):
|
96 |
-
chunk_index = None
|
97 |
-
|
98 |
-
start_token = metadata.get('start_token', 0)
|
99 |
-
end_token = metadata.get('end_token', 0)
|
100 |
-
token_count = metadata.get(
|
101 |
-
'_token_count', metadata.get('token_count', end_token - start_token + 1)
|
102 |
-
)
|
103 |
-
|
104 |
-
# Извлекаем параметры для границ предложений и нахлестов
|
105 |
-
# Сначала ищем в метаданных с префиксом _
|
106 |
-
left_sentence_part = metadata.get('_left_sentence_part')
|
107 |
-
if left_sentence_part is None:
|
108 |
-
# Затем пробуем получить как атрибут объекта
|
109 |
-
left_sentence_part = getattr(entity, 'left_sentence_part', '')
|
110 |
-
|
111 |
-
right_sentence_part = metadata.get('_right_sentence_part')
|
112 |
-
if right_sentence_part is None:
|
113 |
-
right_sentence_part = getattr(entity, 'right_sentence_part', '')
|
114 |
-
|
115 |
-
overlap_left = metadata.get('_overlap_left')
|
116 |
-
if overlap_left is None:
|
117 |
-
overlap_left = getattr(entity, 'overlap_left', '')
|
118 |
-
|
119 |
-
overlap_right = metadata.get('_overlap_right')
|
120 |
-
if overlap_right is None:
|
121 |
-
overlap_right = getattr(entity, 'overlap_right', '')
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
# Создаем чистые метаданные без служебных полей
|
124 |
clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')}
|
125 |
|
126 |
# Создаем и возвращаем новый экземпляр FixedSizeChunk
|
127 |
return cls(
|
128 |
-
id=
|
129 |
-
name=
|
130 |
-
text=
|
131 |
-
in_search_text=
|
132 |
metadata=clean_metadata,
|
133 |
-
source_id=
|
134 |
-
target_id=
|
135 |
-
number_in_relation=
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
)
|
|
|
2 |
Класс для представления чанка фиксированного размера.
|
3 |
"""
|
4 |
|
5 |
+
from dataclasses import dataclass
|
6 |
from typing import Any
|
7 |
|
8 |
+
from ....models import Entity, LinkerEntity, register_entity
|
9 |
+
from ...models.chunk import Chunk
|
10 |
|
11 |
|
12 |
@register_entity
|
|
|
15 |
"""
|
16 |
Представляет чанк фиксированного размера.
|
17 |
|
18 |
+
Расширяет базовый класс Chunk дополнительными полями, связанными с токенами,
|
19 |
+
границами предложений и перекрытиями.
|
|
|
20 |
"""
|
21 |
|
22 |
+
left_sentence_part: str | None = None
|
23 |
+
right_sentence_part: str | None = None
|
24 |
+
overlap_left: str | None = None
|
25 |
+
overlap_right: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def __str__(self) -> str:
|
28 |
"""
|
|
|
31 |
Returns:
|
32 |
Строка с информацией о чанке.
|
33 |
"""
|
34 |
+
text_preview = (
|
35 |
+
f"{self.text[:30]}..." if self.text and len(self.text) > 30 else self.text
|
36 |
+
)
|
37 |
return (
|
38 |
f"FixedSizeChunk(id={self.id}, chunk_index={self.chunk_index}, "
|
39 |
+
f"tokens={self.token_count}, text='{text_preview}')"
|
|
|
|
|
40 |
)
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
@classmethod
|
43 |
+
def _deserialize_to_me(cls, data: Entity) -> "FixedSizeChunk":
|
44 |
"""
|
45 |
+
Десериализует FixedSizeChunk из объекта Entity (LinkerEntity).
|
46 |
+
|
47 |
+
Использует паттерн: сначала ищет поле в атрибутах объекта `data`
|
48 |
+
(на случай, если он уже частично десериализован или является подклассом),
|
49 |
+
затем ищет поле в `data.metadata` с префиксом '_'.
|
50 |
|
51 |
Args:
|
52 |
+
data: Объект Entity (LinkerEntity) для десериализации.
|
53 |
|
54 |
Returns:
|
55 |
+
Новый экземпляр FixedSizeChunk с данными из Entity.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
Raises:
|
58 |
+
TypeError: Если data не является экземпляром LinkerEntity или его подкласса.
|
59 |
+
"""
|
60 |
+
if not isinstance(data, LinkerEntity):
|
61 |
+
raise TypeError(
|
62 |
+
f"Ожидался LinkerEntity или его подкласс, получен {type(data)}"
|
63 |
+
)
|
64 |
+
|
65 |
+
metadata = data.metadata or {}
|
66 |
+
|
67 |
+
# Извлечение специфичных полей с использованием паттерна getattr/metadata.get
|
68 |
+
def get_field(field_name: str, default: Any = None) -> Any:
|
69 |
+
value = getattr(data, field_name, None)
|
70 |
+
if value is None:
|
71 |
+
value = metadata.get(f"_{field_name}", default)
|
72 |
+
return value
|
73 |
+
|
74 |
# Создаем чистые метаданные без служебных полей
|
75 |
clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')}
|
76 |
|
77 |
# Создаем и возвращаем новый экземпляр FixedSizeChunk
|
78 |
return cls(
|
79 |
+
id=data.id,
|
80 |
+
name=data.name,
|
81 |
+
text=data.text,
|
82 |
+
in_search_text=data.in_search_text,
|
83 |
metadata=clean_metadata,
|
84 |
+
source_id=data.source_id,
|
85 |
+
target_id=data.target_id, # owner_id
|
86 |
+
number_in_relation=data.number_in_relation,
|
87 |
+
groupper=data.groupper,
|
88 |
+
type=cls.__name__, # Устанавливаем конкретный тип
|
89 |
+
# Специфичные поля FixedSizeChunk
|
90 |
+
left_sentence_part=get_field('left_sentence_part', ""),
|
91 |
+
right_sentence_part=get_field('right_sentence_part', ""),
|
92 |
+
overlap_left=get_field('overlap_left', ""),
|
93 |
+
overlap_right=get_field('overlap_right', ""),
|
94 |
)
|
lib/extractor/ntr_text_fragmentation/chunking/specific_strategies/fixed_size_chunking.py
CHANGED
@@ -2,46 +2,38 @@
|
|
2 |
Стратегия чанкинга фиксированного размера.
|
3 |
"""
|
4 |
|
|
|
5 |
import re
|
6 |
-
from
|
7 |
from uuid import uuid4
|
8 |
|
9 |
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
10 |
|
11 |
-
from ...chunking.chunking_strategy import ChunkingStrategy
|
12 |
from ...models import DocumentAsEntity, LinkerEntity
|
|
|
|
|
|
|
|
|
13 |
from .fixed_size.fixed_size_chunk import FixedSizeChunk
|
14 |
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
class _FixedSizeChunkingStrategyParams(NamedTuple):
|
19 |
-
words_per_chunk: int = 50
|
20 |
-
overlap_words: int = 25
|
21 |
-
respect_sentence_boundaries: bool = True
|
22 |
|
23 |
|
|
|
24 |
class FixedSizeChunkingStrategy(ChunkingStrategy):
|
25 |
"""
|
26 |
-
Стратегия чанкинга, разбивающая текст на чанки фиксированного
|
27 |
-
|
28 |
-
Преимущества:
|
29 |
-
- Простое и предсказуемое разбиение
|
30 |
-
- Равные по размеру чанки
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
- Не учитывает смысловую структуру текста
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
- В поле `in_search_text` хранится текст с нахлестом (для улучшения векторизации)
|
39 |
"""
|
40 |
|
41 |
-
|
42 |
-
description = (
|
43 |
-
"Стратегия чанкинга, разбивающая текст на чанки фиксированного размера."
|
44 |
-
)
|
45 |
|
46 |
def __init__(
|
47 |
self,
|
@@ -50,519 +42,276 @@ class FixedSizeChunkingStrategy(ChunkingStrategy):
|
|
50 |
respect_sentence_boundaries: bool = True,
|
51 |
):
|
52 |
"""
|
53 |
-
Инициализация
|
54 |
|
55 |
Args:
|
56 |
-
words_per_chunk:
|
57 |
-
overlap_words: Количество слов перекрытия между
|
58 |
-
respect_sentence_boundaries:
|
|
|
59 |
"""
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
overlap_words
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def chunk(
|
68 |
-
self,
|
69 |
-
document: ParsedDocument | str,
|
70 |
-
doc_entity: DocumentAsEntity | None = None,
|
71 |
) -> list[LinkerEntity]:
|
72 |
"""
|
73 |
-
Разбивает документ на чанки
|
74 |
|
75 |
Args:
|
76 |
-
document: Документ для
|
77 |
-
doc_entity: Сущность
|
78 |
|
79 |
Returns:
|
80 |
-
Список
|
81 |
"""
|
82 |
-
|
83 |
-
|
84 |
|
85 |
-
|
86 |
-
|
87 |
return []
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
#
|
101 |
-
chunk_text = self._prepare_chunk_text(words,
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
)
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
115 |
)
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
return [doc_entity] + chunks + links
|
122 |
-
|
123 |
-
def _find_nearest_sentence_boundary(
|
124 |
-
self, text: str, position: int
|
125 |
-
) -> tuple[int, str, str]:
|
126 |
-
"""
|
127 |
-
Находит ближайшую границу предложения к указанной позиции.
|
128 |
-
|
129 |
-
Args:
|
130 |
-
text: Полный текст для поиска границ
|
131 |
-
position: Позиция, для которой ищем ближайшую границу
|
132 |
-
|
133 |
-
Returns:
|
134 |
-
tuple из (позиция границы, левая часть текста, правая часть текста)
|
135 |
-
"""
|
136 |
-
# Регулярное выражение для поиска конца предложения
|
137 |
-
sentence_end_pattern = r'[.!?](?:\s|$)'
|
138 |
-
|
139 |
-
# Ищем все совпадения в тексте
|
140 |
-
matches = list(re.finditer(sentence_end_pattern, text))
|
141 |
-
|
142 |
-
if not matches:
|
143 |
-
# Если совпадений нет, возвращаем исходную позицию
|
144 |
-
return position, text[:position], text[position:]
|
145 |
-
|
146 |
-
# Находим ближайшую границу предложения
|
147 |
-
nearest_pos = position
|
148 |
-
min_distance = float('inf')
|
149 |
-
|
150 |
-
for match in matches:
|
151 |
-
end_pos = match.end()
|
152 |
-
distance = abs(end_pos - position)
|
153 |
-
|
154 |
-
if distance < min_distance:
|
155 |
-
min_distance = distance
|
156 |
-
nearest_pos = end_pos
|
157 |
-
|
158 |
-
# Возвращаем позицию и соответствующие части текста
|
159 |
-
return nearest_pos, text[:nearest_pos], text[nearest_pos:]
|
160 |
-
|
161 |
-
def _find_sentence_boundary(self, text: str, is_left_boundary: bool) -> str:
|
162 |
-
"""
|
163 |
-
Находит часть текста на границе предложения.
|
164 |
-
|
165 |
-
Args:
|
166 |
-
text: Текст для обработки
|
167 |
-
is_left_boundary: True для левой границы, False для правой
|
168 |
-
|
169 |
-
Returns:
|
170 |
-
Часть предложения на границе
|
171 |
-
"""
|
172 |
-
# Регулярное выражение для поиска конца предложения
|
173 |
-
sentence_end_pattern = r'[.!?](?:\s|$)'
|
174 |
-
matches = list(re.finditer(sentence_end_pattern, text))
|
175 |
-
|
176 |
-
if not matches:
|
177 |
-
return text
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
-
def
|
189 |
-
self,
|
190 |
-
filtered_chunks: list[LinkerEntity],
|
191 |
-
repository: 'EntityRepository' = None, # type: ignore
|
192 |
) -> str:
|
193 |
-
"""
|
194 |
-
|
195 |
-
|
196 |
-
Args:
|
197 |
-
filtered_chunks: Список отфильтрованных чанков
|
198 |
-
repository: Репозиторий сущностей для получения дополнительной информации (может быть None)
|
199 |
-
|
200 |
-
Returns:
|
201 |
-
Восстановленный текст документа
|
202 |
-
"""
|
203 |
-
if not filtered_chunks:
|
204 |
return ""
|
205 |
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
# Группируем последовательные чанки
|
221 |
-
current_group = []
|
222 |
-
groups = []
|
223 |
-
|
224 |
-
for i, chunk in enumerate(sorted_chunks):
|
225 |
-
current_index = chunk.chunk_index or 0
|
226 |
-
|
227 |
-
# Если первый чанк или продолжение последовательности
|
228 |
-
if i == 0 or current_index == (sorted_chunks[i - 1].chunk_index or 0) + 1:
|
229 |
-
current_group.append(chunk)
|
230 |
-
else:
|
231 |
-
# Закрываем текущую группу и начинаем новую
|
232 |
-
if current_group:
|
233 |
-
groups.append(current_group)
|
234 |
-
current_group = [chunk]
|
235 |
-
|
236 |
-
# Добавляем последнюю группу
|
237 |
-
if current_group:
|
238 |
-
groups.append(current_group)
|
239 |
-
|
240 |
-
# Обрабатываем каждую группу
|
241 |
-
for group_index, group in enumerate(groups):
|
242 |
-
# Добавляем многоточие между непоследовательными группами
|
243 |
-
if group_index > 0:
|
244 |
-
result_text += "\n\n...\n\n"
|
245 |
-
|
246 |
-
# Обрабатываем группу соседних чанков
|
247 |
-
group_text = ""
|
248 |
-
|
249 |
-
# Добавляем левую недостающую часть к первому чанку группы
|
250 |
-
first_chunk = group[0]
|
251 |
-
|
252 |
-
# Добавляем левую часть предложения к первому чанку группы
|
253 |
-
if (
|
254 |
-
hasattr(first_chunk, 'left_sentence_part')
|
255 |
-
and first_chunk.left_sentence_part
|
256 |
-
):
|
257 |
-
group_text += first_chunk.left_sentence_part
|
258 |
-
|
259 |
-
# Добавляем текст всех чанков группы
|
260 |
-
for i, chunk in enumerate(group):
|
261 |
-
current_text = chunk.text.strip() if hasattr(chunk, 'text') else ""
|
262 |
-
if not current_text:
|
263 |
-
continue
|
264 |
-
|
265 |
-
# Проверяем, нужно ли добавить пробел между предыдущим текстом и текущим чанком
|
266 |
-
if group_text:
|
267 |
-
# Если текущий чанк начинается с новой строки, не добавляем пробел
|
268 |
-
if current_text.startswith("\n"):
|
269 |
-
pass
|
270 |
-
# Если предыдущий текст заканчивается переносом строки, также не добавляем пробел
|
271 |
-
elif group_text.endswith("\n"):
|
272 |
-
pass
|
273 |
-
# Если предыдущий текст заканчивается знаком препинания без пробела, добавляем пробел
|
274 |
-
elif group_text.rstrip()[-1] not in [
|
275 |
-
"\n",
|
276 |
-
" ",
|
277 |
-
".",
|
278 |
-
",",
|
279 |
-
"!",
|
280 |
-
"?",
|
281 |
-
":",
|
282 |
-
";",
|
283 |
-
"-",
|
284 |
-
"–",
|
285 |
-
"—",
|
286 |
-
]:
|
287 |
-
group_text += " "
|
288 |
-
|
289 |
-
# Добавляем текст чанка
|
290 |
-
group_text += current_text
|
291 |
-
|
292 |
-
# Добавляем правую недостающую часть к последнему чанку группы
|
293 |
-
last_chunk = group[-1]
|
294 |
-
|
295 |
-
# Добавляем правую часть предложения к последнему чанку группы
|
296 |
-
if (
|
297 |
-
hasattr(last_chunk, 'right_sentence_part')
|
298 |
-
and last_chunk.right_sentence_part
|
299 |
-
):
|
300 |
-
right_part = last_chunk.right_sentence_part.strip()
|
301 |
-
if right_part:
|
302 |
-
# Проверяем нужен ли пробел перед правой частью
|
303 |
-
if (
|
304 |
-
group_text
|
305 |
-
and group_text[-1] not in ["\n", " "]
|
306 |
-
and right_part[0]
|
307 |
-
not in ["\n", " ", ".", ",", "!", "?", ":", ";", "-", "–", "—"]
|
308 |
-
):
|
309 |
-
group_text += " "
|
310 |
-
group_text += right_part
|
311 |
-
|
312 |
-
# Добавляем текст группы к результату
|
313 |
-
if (
|
314 |
-
result_text
|
315 |
-
and result_text[-1] not in ["\n", " "]
|
316 |
-
and group_text
|
317 |
-
and group_text[0] not in ["\n", " "]
|
318 |
-
):
|
319 |
-
result_text += " "
|
320 |
-
result_text += group_text
|
321 |
-
|
322 |
-
# Постобработка те��ста: удаляем лишние пробелы и символы переноса строк
|
323 |
-
|
324 |
-
# Заменяем множественные переносы строк на двойные (для разделения абзацев)
|
325 |
-
result_text = re.sub(r'\n{3,}', '\n\n', result_text)
|
326 |
-
|
327 |
-
# Заменяем множественные пробелы на одиночные
|
328 |
-
result_text = re.sub(r' +', ' ', result_text)
|
329 |
-
|
330 |
-
# Убираем пробелы перед знаками препинания
|
331 |
-
result_text = re.sub(r' ([.,!?:;)])', r'\1', result_text)
|
332 |
-
|
333 |
-
# Убираем пробелы перед переносами строк и после переносов строк
|
334 |
-
result_text = re.sub(r' +\n', '\n', result_text)
|
335 |
-
result_text = re.sub(r'\n +', '\n', result_text)
|
336 |
-
|
337 |
-
# Убираем лишние переносы строк и пробелы в начале и конце текста
|
338 |
-
result_text = result_text.strip()
|
339 |
-
|
340 |
-
return result_text
|
341 |
-
|
342 |
-
def _get_sorted_chunks(
|
343 |
-
self, chunks: list[LinkerEntity], links: list[LinkerEntity]
|
344 |
-
) -> list[LinkerEntity]:
|
345 |
-
"""
|
346 |
-
Получает отсортированные чанки на основе связей или поля chunk_index.
|
347 |
-
|
348 |
-
Args:
|
349 |
-
chunks: Список чанков для сортировки
|
350 |
-
links: Список связей для определения порядка
|
351 |
-
|
352 |
-
Returns:
|
353 |
-
Отсортированные чанки
|
354 |
-
"""
|
355 |
-
# Сортируем чанки по порядку в связях
|
356 |
-
if links:
|
357 |
-
# Получаем словарь для быстрого доступа к чанкам по ID
|
358 |
-
chunk_dict = {c.id: c for c in chunks}
|
359 |
-
|
360 |
-
# Сортируем по порядку в связях
|
361 |
-
sorted_chunks = []
|
362 |
-
for link in sorted(links, key=lambda l: l.number_in_relation or 0):
|
363 |
-
if link.target_id in chunk_dict:
|
364 |
-
sorted_chunks.append(chunk_dict[link.target_id])
|
365 |
-
|
366 |
-
return sorted_chunks
|
367 |
-
|
368 |
-
# Если нет связей, сортируем по chunk_index
|
369 |
-
return sorted(chunks, key=lambda c: c.chunk_index or 0)
|
370 |
-
|
371 |
-
def _prepare_document(self, document: ParsedDocument | str) -> ParsedDocument:
|
372 |
-
"""
|
373 |
-
Обрабатывает входные данные и возвращает ParsedDocument.
|
374 |
-
|
375 |
-
Args:
|
376 |
-
document: Документ (ParsedDocument или текст)
|
377 |
-
|
378 |
-
Returns:
|
379 |
-
Обработанный документ типа ParsedDocument
|
380 |
-
"""
|
381 |
-
if isinstance(document, ParsedDocument):
|
382 |
-
return document
|
383 |
-
elif isinstance(document, str):
|
384 |
-
# Простая обработка текстового документа
|
385 |
-
return ParsedDocument(
|
386 |
-
paragraphs=[
|
387 |
-
ParsedTextBlock(text=paragraph)
|
388 |
-
for paragraph in document.split('\n')
|
389 |
-
]
|
390 |
-
)
|
391 |
-
|
392 |
-
def _extract_words(self, doc: ParsedDocument) -> list[str]:
|
393 |
-
"""
|
394 |
-
Извлекает все слова из документа.
|
395 |
-
|
396 |
-
Args:
|
397 |
-
doc: Документ для извлечения слов
|
398 |
-
|
399 |
-
Returns:
|
400 |
-
Список слов документа
|
401 |
-
"""
|
402 |
-
words = []
|
403 |
-
for paragraph in doc.paragraphs:
|
404 |
-
# Добавляем слова из параграфа
|
405 |
-
paragraph_words = paragraph.text.split()
|
406 |
-
words.extend(paragraph_words)
|
407 |
-
# Добавляем маркер конца параграфа как отдельный элемент
|
408 |
-
words.append("\n")
|
409 |
-
return words
|
410 |
-
|
411 |
-
def _ensure_document_entity(
|
412 |
self,
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
|
|
|
|
|
|
418 |
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
metadata={"type": doc.type},
|
432 |
-
type="Document",
|
433 |
-
)
|
434 |
-
return doc_entity
|
435 |
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
439 |
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
|
|
|
|
444 |
|
445 |
-
def
|
446 |
-
self,
|
447 |
-
words: list[str],
|
448 |
-
start_idx: int,
|
449 |
-
length: int,
|
450 |
-
) -> str:
|
451 |
"""
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
words: Список слов документа
|
456 |
-
start_idx: Индекс начала чанка
|
457 |
-
end_idx: Длина текста в словах
|
458 |
-
|
459 |
-
Returns:
|
460 |
-
Итоговый текст
|
461 |
"""
|
462 |
-
|
463 |
-
|
464 |
-
chunk_words = words[start_idx:end_idx]
|
465 |
-
chunk_text = ""
|
466 |
|
467 |
-
|
468 |
-
if word == "\n":
|
469 |
-
# Если это маркер конца параграфа, добавляем перенос строки
|
470 |
-
chunk_text += "\n"
|
471 |
-
else:
|
472 |
-
# Иначе добавляем слово с пробелом
|
473 |
-
if chunk_text and chunk_text[-1] != "\n":
|
474 |
-
chunk_text += " "
|
475 |
-
chunk_text += word
|
476 |
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
|
479 |
-
def
|
480 |
self,
|
|
|
|
|
481 |
chunk_text: str,
|
482 |
in_search_text: str,
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
doc_name: str,
|
489 |
) -> FixedSizeChunk:
|
490 |
-
"""
|
491 |
-
Создает чанк фиксированного размера.
|
492 |
-
|
493 |
-
Args:
|
494 |
-
chunk_text: Текст чанка без нахлеста
|
495 |
-
in_search_text: Текст чанка с нахлестом
|
496 |
-
start_idx: Индекс первого слова в чанке
|
497 |
-
end_idx: Индекс последнего слова в чанке
|
498 |
-
chunk_index: Индекс чанка в документе
|
499 |
-
words: Список всех слов документа
|
500 |
-
total_words: Общее количество слов в документе
|
501 |
-
doc_name: Имя документа
|
502 |
-
|
503 |
-
Returns:
|
504 |
-
FixedSizeChunk: Созданный чанк
|
505 |
-
"""
|
506 |
-
# Определяем нахлесты без учета границ предложений
|
507 |
-
overlap_left = " ".join(
|
508 |
-
words[max(0, start_idx - self.params.overlap_words) : start_idx]
|
509 |
-
)
|
510 |
-
overlap_right = " ".join(
|
511 |
-
words[end_idx : min(total_words, end_idx + self.params.overlap_words)]
|
512 |
-
)
|
513 |
-
|
514 |
-
# Определяем границы предложений
|
515 |
-
left_sentence_part = ""
|
516 |
-
right_sentence_part = ""
|
517 |
-
|
518 |
-
if self.params.respect_sentence_boundaries:
|
519 |
-
# Находим ближайшую границу предложения слева
|
520 |
-
left_text = " ".join(
|
521 |
-
words[max(0, start_idx - self.params.overlap_words) : start_idx]
|
522 |
-
)
|
523 |
-
left_sentence_part = self._find_sentence_boundary(left_text, True)
|
524 |
-
|
525 |
-
# Находим ближайшую границу предложения справа
|
526 |
-
right_text = " ".join(
|
527 |
-
words[end_idx : min(total_words, end_idx + self.params.overlap_words)]
|
528 |
-
)
|
529 |
-
right_sentence_part = self._find_sentence_boundary(right_text, False)
|
530 |
-
|
531 |
-
# Создаем чанк с учетом границ предложений
|
532 |
return FixedSizeChunk(
|
533 |
id=uuid4(),
|
534 |
-
name=f"{
|
535 |
text=chunk_text,
|
536 |
-
chunk_index=chunk_index,
|
537 |
in_search_text=in_search_text,
|
538 |
-
|
|
|
|
|
|
|
|
|
|
|
539 |
left_sentence_part=left_sentence_part,
|
540 |
right_sentence_part=right_sentence_part,
|
541 |
overlap_left=overlap_left,
|
542 |
overlap_right=overlap_right,
|
543 |
-
metadata={},
|
544 |
-
type=FixedSizeChunk.__name__,
|
545 |
)
|
546 |
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
|
|
|
|
550 |
"""
|
551 |
-
|
|
|
|
|
|
|
552 |
|
553 |
Args:
|
554 |
-
|
555 |
-
|
556 |
|
557 |
Returns:
|
558 |
-
|
559 |
"""
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
Стратегия чанкинга фиксированного размера.
|
3 |
"""
|
4 |
|
5 |
+
import logging
|
6 |
import re
|
7 |
+
from io import StringIO
|
8 |
from uuid import uuid4
|
9 |
|
10 |
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
11 |
|
|
|
12 |
from ...models import DocumentAsEntity, LinkerEntity
|
13 |
+
from ...repositories import EntityRepository
|
14 |
+
from ..chunking_strategy import ChunkingStrategy
|
15 |
+
from ..chunking_registry import register_chunking_strategy
|
16 |
+
from ..models import Chunk
|
17 |
from .fixed_size.fixed_size_chunk import FixedSizeChunk
|
18 |
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
FIXED_SIZE = "fixed_size"
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
+
@register_chunking_strategy(FIXED_SIZE)
|
25 |
class FixedSizeChunkingStrategy(ChunkingStrategy):
|
26 |
"""
|
27 |
+
Стратегия чанкинга, разбивающая текст на чанки фиксированного размера словами.
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
Поддерживает перекрытие между чанками и опциональный учет границ предложений
|
30 |
+
для более качественной сборки текста в `dechunk`.
|
|
|
31 |
|
32 |
+
При чанкинге создает экземпляры `FixedSizeChunk`.
|
33 |
+
При сборке (`dechunk`) использует специфичную логику с `left/right_sentence_part`.
|
|
|
34 |
"""
|
35 |
|
36 |
+
DEFAULT_GROUPPER: str = "chunk" # Группа для связывания и сортировки чанков
|
|
|
|
|
|
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
|
|
42 |
respect_sentence_boundaries: bool = True,
|
43 |
):
|
44 |
"""
|
45 |
+
Инициализация стратегии.
|
46 |
|
47 |
Args:
|
48 |
+
words_per_chunk: Целевое количество слов в чанке (включая перекрытие).
|
49 |
+
overlap_words: Количество слов перекрытия между чанками.
|
50 |
+
respect_sentence_boundaries: Учитывать ли границы предложений при
|
51 |
+
формировании `left/right_sentence_part` для улучшения сборки.
|
52 |
"""
|
53 |
+
if overlap_words >= words_per_chunk:
|
54 |
+
raise ValueError("overlap_words должен быть меньше words_per_chunk")
|
55 |
+
if words_per_chunk <= 0 or overlap_words < 0:
|
56 |
+
raise ValueError("words_per_chunk должен быть > 0, overlap_words >= 0")
|
57 |
+
|
58 |
+
self.words_per_chunk = words_per_chunk
|
59 |
+
self.overlap_words = overlap_words
|
60 |
+
self.respect_sentence_boundaries = respect_sentence_boundaries
|
61 |
+
self._step = self.words_per_chunk - self.overlap_words
|
62 |
+
|
63 |
+
# Регулярное выражение для поиска конца предложения (точка, ?, ! перед пробелом или концом строки)
|
64 |
+
self._sentence_end_pattern = re.compile(r'[.!?](?:\s|$)')
|
65 |
+
# Регулярное выражение для очистки текста при сборке
|
66 |
+
self._re_multi_newline = re.compile(r'\n{3,}')
|
67 |
+
self._re_multi_space = re.compile(r' +')
|
68 |
+
self._re_space_punct = re.compile(r' ([.,!?:;)])')
|
69 |
+
self._re_space_newline = re.compile(r' +\n')
|
70 |
+
self._re_newline_space = re.compile(r'\n +')
|
71 |
|
72 |
def chunk(
|
73 |
+
self, document: ParsedDocument, doc_entity: DocumentAsEntity
|
|
|
|
|
74 |
) -> list[LinkerEntity]:
|
75 |
"""
|
76 |
+
Разбивает документ на чанки FixedSizeChunk.
|
77 |
|
78 |
Args:
|
79 |
+
document: Документ для чанкинга.
|
80 |
+
doc_entity: Сущность документа-владельца.
|
81 |
|
82 |
Returns:
|
83 |
+
Список созданных FixedSizeChunk.
|
84 |
"""
|
85 |
+
words = self._extract_words(document)
|
86 |
+
total_words = len(words)
|
87 |
|
88 |
+
if total_words == 0:
|
89 |
+
logger.debug(f"Документ {doc_entity.name} не содержит слов для чанкинга.")
|
90 |
return []
|
91 |
|
92 |
+
result_chunks: list[FixedSizeChunk] = []
|
93 |
+
chunk_index = 0
|
94 |
+
|
95 |
+
# Идем по словам с шагом, равным размеру чанка минус перекрытие
|
96 |
+
for i in range(0, total_words, self._step):
|
97 |
+
start_idx = i
|
98 |
+
# Конец основной части чанка (без правого перекрытия)
|
99 |
+
step_end_idx = min(start_idx + self._step, total_words)
|
100 |
+
# Конец чанка с правым перекрытием (для in_search_text и подсчета токенов)
|
101 |
+
chunk_end_idx = min(start_idx + self.words_per_chunk, total_words)
|
102 |
+
|
103 |
+
# Текст чанка без перекрытия (то, что будет соединяться в dechunk)
|
104 |
+
chunk_text = self._prepare_chunk_text(words, start_idx, step_end_idx)
|
105 |
+
# Текст для поиска (с правым перекрытием)
|
106 |
+
in_search_text = self._prepare_chunk_text(words, start_idx, chunk_end_idx)
|
107 |
+
|
108 |
+
# Границы предложений и нахлесты
|
109 |
+
left_part, right_part, left_overlap, right_overlap = (
|
110 |
+
self._calculate_boundaries(words, start_idx, chunk_end_idx, total_words)
|
111 |
)
|
112 |
|
113 |
+
chunk_instance = self._create_chunk_instance(
|
114 |
+
doc_entity=doc_entity,
|
115 |
+
chunk_index=chunk_index,
|
116 |
+
chunk_text=chunk_text,
|
117 |
+
in_search_text=in_search_text,
|
118 |
+
token_count=(
|
119 |
+
chunk_end_idx - start_idx
|
120 |
+
), # Кол-во слов в чанке с правым нахлестом
|
121 |
+
left_sentence_part=left_part,
|
122 |
+
right_sentence_part=right_part,
|
123 |
+
overlap_left=left_overlap,
|
124 |
+
overlap_right=right_overlap,
|
125 |
)
|
126 |
+
result_chunks.append(chunk_instance)
|
127 |
+
chunk_index += 1
|
128 |
|
129 |
+
logger.info(
|
130 |
+
f"Документ {doc_entity.name} разбит на {len(result_chunks)} FixedSizeChunk."
|
131 |
+
)
|
132 |
+
return result_chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
def _extract_words(self, document: ParsedDocument) -> list[str]:
|
135 |
+
"""Извлекает слова из документа, добавляя '\n' как маркер конца параграфа."""
|
136 |
+
words = []
|
137 |
+
for paragraph in document.paragraphs:
|
138 |
+
if isinstance(paragraph, ParsedTextBlock) and paragraph.text:
|
139 |
+
paragraph_words = paragraph.text.split()
|
140 |
+
# Добавляем только непустые слова
|
141 |
+
words.extend(w for w in paragraph_words if w)
|
142 |
+
# Добавляем маркер конца параграфа, только если были слова
|
143 |
+
if paragraph_words:
|
144 |
+
words.append("\n") # Используем '\n' как специальный "токен"
|
145 |
+
# Удаляем последний '\n', если он есть (не нужен после последнего параграфа)
|
146 |
+
if words and words[-1] == "\n":
|
147 |
+
words.pop()
|
148 |
+
return words
|
149 |
|
150 |
+
def _prepare_chunk_text(
|
151 |
+
self, words: list[str], start_idx: int, end_idx: int
|
|
|
|
|
152 |
) -> str:
|
153 |
+
"""Собирает текст из среза слов, корректно обрабатывая маркеры '\n'."""
|
154 |
+
chunk_words = words[start_idx:end_idx]
|
155 |
+
if not chunk_words:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
return ""
|
157 |
|
158 |
+
with StringIO() as buffer:
|
159 |
+
first_word = True
|
160 |
+
for word in chunk_words:
|
161 |
+
if word == "\n":
|
162 |
+
buffer.write("\n")
|
163 |
+
first_word = True # После переноса строки пробел не нужен
|
164 |
+
else:
|
165 |
+
if not first_word:
|
166 |
+
buffer.write(" ")
|
167 |
+
buffer.write(word)
|
168 |
+
first_word = False
|
169 |
+
return buffer.getvalue()
|
170 |
+
|
171 |
+
def _calculate_boundaries(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
self,
|
173 |
+
words: list[str],
|
174 |
+
chunk_start_idx: int,
|
175 |
+
chunk_end_idx: int,
|
176 |
+
total_words: int,
|
177 |
+
) -> tuple[str, str, str, str]:
|
178 |
+
"""Вычисляет границы предложений и тексты перекрытий."""
|
179 |
+
left_sentence_part = ""
|
180 |
+
right_sentence_part = ""
|
181 |
|
182 |
+
# Границы для перекрытий
|
183 |
+
overlap_left_start = max(0, chunk_start_idx - self.overlap_words)
|
184 |
+
overlap_right_end = min(total_words, chunk_end_idx + self.overlap_words)
|
185 |
|
186 |
+
# Текст левого перекрытия (для поиска границ и как fallback)
|
187 |
+
left_overlap_text = self._prepare_chunk_text(
|
188 |
+
words, overlap_left_start, chunk_start_idx
|
189 |
+
)
|
190 |
+
# Текст правого перекрытия (для поиска границ и как fallback)
|
191 |
+
right_overlap_text = self._prepare_chunk_text(
|
192 |
+
words, chunk_end_idx, overlap_right_end
|
193 |
+
)
|
|
|
|
|
|
|
|
|
194 |
|
195 |
+
if self.respect_sentence_boundaries:
|
196 |
+
# Ищем границу предложения в левом перекрытии
|
197 |
+
left_sentence_part = self._find_sentence_boundary(left_overlap_text, True)
|
198 |
+
# Ищем границу предложения в правом перекрытии
|
199 |
+
right_sentence_part = self._find_sentence_boundary(
|
200 |
+
right_overlap_text, False
|
201 |
+
)
|
202 |
|
203 |
+
return (
|
204 |
+
left_sentence_part,
|
205 |
+
right_sentence_part,
|
206 |
+
left_overlap_text,
|
207 |
+
right_overlap_text,
|
208 |
+
)
|
209 |
|
210 |
+
def _find_sentence_boundary(self, text: str, find_left_part: bool) -> str:
|
|
|
|
|
|
|
|
|
|
|
211 |
"""
|
212 |
+
Находит часть текста на границе предложения.
|
213 |
+
Если find_left_part=True, ищет часть ПОСЛЕ последнего знака препинания.
|
214 |
+
Если find_left_part=False, ищет часть ДО первого знака препинания.
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
"""
|
216 |
+
if not text:
|
217 |
+
return ""
|
|
|
|
|
218 |
|
219 |
+
matches = list(self._sentence_end_pattern.finditer(text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
if not matches:
|
222 |
+
# Если нет знаков конца предложения, то для левой части ничего не берем,
|
223 |
+
# а для правой берем всё (т.к. непонятно, где предложение заканчивается).
|
224 |
+
return "" if find_left_part else text.strip()
|
225 |
+
|
226 |
+
if find_left_part:
|
227 |
+
# Ищем часть после последнего знака
|
228 |
+
last_match_end = matches[-1].end()
|
229 |
+
return text[last_match_end:].strip()
|
230 |
+
else:
|
231 |
+
# Ищем часть до первого знака (включая сам знак)
|
232 |
+
first_match_end = matches[0].end()
|
233 |
+
return text[:first_match_end].strip()
|
234 |
|
235 |
+
def _create_chunk_instance(
|
236 |
self,
|
237 |
+
doc_entity: DocumentAsEntity,
|
238 |
+
chunk_index: int,
|
239 |
chunk_text: str,
|
240 |
in_search_text: str,
|
241 |
+
token_count: int,
|
242 |
+
left_sentence_part: str,
|
243 |
+
right_sentence_part: str,
|
244 |
+
overlap_left: str,
|
245 |
+
overlap_right: str,
|
|
|
246 |
) -> FixedSizeChunk:
|
247 |
+
"""Создает экземпляр FixedSizeChunk с необходимыми атрибутами."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
return FixedSizeChunk(
|
249 |
id=uuid4(),
|
250 |
+
name=f"{doc_entity.name}_chunk_{chunk_index}",
|
251 |
text=chunk_text,
|
|
|
252 |
in_search_text=in_search_text,
|
253 |
+
metadata={}, # Все нужные поля теперь атрибуты
|
254 |
+
source_id=None, # Является компонентом, а не связью
|
255 |
+
target_id=doc_entity.id, # Указывает на владельца (документ)
|
256 |
+
number_in_relation=chunk_index, # Порядковый номер для сортировки
|
257 |
+
groupper=self.DEFAULT_GROUPPER, # Группа для сортировки/соседей
|
258 |
+
# Специфичные поля
|
259 |
left_sentence_part=left_sentence_part,
|
260 |
right_sentence_part=right_sentence_part,
|
261 |
overlap_left=overlap_left,
|
262 |
overlap_right=overlap_right,
|
|
|
|
|
263 |
)
|
264 |
|
265 |
+
@classmethod
|
266 |
+
def _build_sequenced_chunks(
|
267 |
+
cls,
|
268 |
+
repository: EntityRepository,
|
269 |
+
group: list[Chunk],
|
270 |
+
) -> str:
|
271 |
"""
|
272 |
+
Собирает текст для НЕПРЕРЫВНОЙ последовательности FixedSizeChunk.
|
273 |
+
|
274 |
+
Использует `left_sentence_part` первого чанка, `text` всех чанков
|
275 |
+
и `right_sentence_part` последнего чанка. Переопределяет базовый метод.
|
276 |
|
277 |
Args:
|
278 |
+
repository: Репозиторий для получения сущностей.
|
279 |
+
group: Список последовательных FixedSizeChunk. Гарантируется непустым.
|
280 |
|
281 |
Returns:
|
282 |
+
Собранный текст для данной группы.
|
283 |
"""
|
284 |
+
# Важно: Проверяем, что все чанки в группе - это FixedSizeChunk
|
285 |
+
# Это важно, так как мы обращаемся к специфичным атрибутам
|
286 |
+
if not all(isinstance(c, FixedSizeChunk) for c in group):
|
287 |
+
logger.warning(
|
288 |
+
"В _build_sequenced_chunks передан список, содержащий не FixedSizeChunk. Используется базовая сборка."
|
289 |
+
)
|
290 |
+
# Вызываем базовую реализацию, если типы не совпадают
|
291 |
+
return super()._build_sequenced_chunks(repository, group)
|
292 |
+
|
293 |
+
# Гарантированно работаем с FixedSizeChunk
|
294 |
+
typed_group: list[FixedSizeChunk] = group # type: ignore
|
295 |
+
|
296 |
+
parts = []
|
297 |
+
first_chunk = typed_group[0]
|
298 |
+
last_chunk = typed_group[-1]
|
299 |
+
|
300 |
+
# Добавляем левую часть предложения (если есть)
|
301 |
+
if first_chunk.left_sentence_part:
|
302 |
+
parts.append(first_chunk.left_sentence_part.strip())
|
303 |
+
|
304 |
+
# Добавляем текст всех чанков группы
|
305 |
+
for chunk in typed_group:
|
306 |
+
if chunk.text:
|
307 |
+
parts.append(chunk.text.strip())
|
308 |
+
|
309 |
+
# Добавляем правую часть предложения (если есть)
|
310 |
+
if last_chunk.right_sentence_part:
|
311 |
+
parts.append(last_chunk.right_sentence_part.strip())
|
312 |
+
|
313 |
+
# Соединяем все части через пробел, удаляя пустые строки
|
314 |
+
# Очистка _clean_final_text будет вызвана в конце базового dechunk
|
315 |
+
group_text = " ".join(filter(None, parts))
|
316 |
+
|
317 |
+
return group_text
|
lib/extractor/ntr_text_fragmentation/chunking/text_to_text_base.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
from ntr_fileparser import ParsedDocument
|
4 |
+
|
5 |
+
from ..models import LinkerEntity, DocumentAsEntity
|
6 |
+
from .models import CustomChunk
|
7 |
+
from .chunking_strategy import ChunkingStrategy
|
8 |
+
|
9 |
+
|
10 |
+
class TextToTextBaseStrategy(ChunkingStrategy):
|
11 |
+
"""
|
12 |
+
Базовый класс для всех стратегий чанкинга, которые преобразуют текст в текст.
|
13 |
+
Наследуясь от этого класса, не забывайте зарегистрировать стратегию через
|
14 |
+
декоратор @register_chunking_strategy.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def chunk(
|
18 |
+
self, document: ParsedDocument, doc_entity: DocumentAsEntity
|
19 |
+
) -> list[LinkerEntity]:
|
20 |
+
text = self._get_text(document)
|
21 |
+
texts = self._chunk(text, doc_entity)
|
22 |
+
return [
|
23 |
+
CustomChunk(
|
24 |
+
text=chunk_text,
|
25 |
+
in_search_text=chunk_text,
|
26 |
+
doc_entity=doc_entity,
|
27 |
+
number_in_relation=i,
|
28 |
+
groupper=self.__class__.__name__,
|
29 |
+
)
|
30 |
+
for i, chunk_text in enumerate(texts)
|
31 |
+
]
|
32 |
+
|
33 |
+
def _get_text(self, document: ParsedDocument) -> str:
|
34 |
+
return "\n".join(
|
35 |
+
[
|
36 |
+
f"{block.text} {block.number_in_relation}"
|
37 |
+
for block in document.paragraphs
|
38 |
+
]
|
39 |
+
)
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def _chunk(self, text: str, doc_entity: DocumentAsEntity) -> list[LinkerEntity]:
|
43 |
+
raise NotImplementedError(
|
44 |
+
"Метод _chunk должен быть реализован в классе-наследнике"
|
45 |
+
)
|
lib/extractor/ntr_text_fragmentation/core/__init__.py
CHANGED
@@ -2,8 +2,10 @@
|
|
2 |
Основные классы для разбиения и сборки документов.
|
3 |
"""
|
4 |
|
5 |
-
from .
|
6 |
-
from .entity_repository import EntityRepository, InMemoryEntityRepository
|
7 |
from .injection_builder import InjectionBuilder
|
8 |
|
9 |
-
__all__ = [
|
|
|
|
|
|
|
|
2 |
Основные классы для разбиения и сборки документов.
|
3 |
"""
|
4 |
|
5 |
+
from .extractor import EntitiesExtractor
|
|
|
6 |
from .injection_builder import InjectionBuilder
|
7 |
|
8 |
+
__all__ = [
|
9 |
+
"EntitiesExtractor",
|
10 |
+
"InjectionBuilder",
|
11 |
+
]
|
lib/extractor/ntr_text_fragmentation/core/extractor.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Модуль для деструктуризации документа.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from typing import Any, NamedTuple
|
7 |
+
from uuid import uuid4
|
8 |
+
|
9 |
+
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
10 |
+
|
11 |
+
from ..additors import TablesProcessor
|
12 |
+
from ..chunking import ChunkingStrategy, FIXED_SIZE, chunking_registry
|
13 |
+
from ..models import DocumentAsEntity, LinkerEntity
|
14 |
+
|
15 |
+
|
16 |
+
def _check_namedtuple(obj: Any) -> bool:
|
17 |
+
return hasattr(type(obj), '_fields') and isinstance(obj, tuple)
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class EntitiesExtractor:
|
24 |
+
"""
|
25 |
+
Оркестратор процесса извлечения информации из документа.
|
26 |
+
|
27 |
+
Координирует разбиение документа на чанки и обработку
|
28 |
+
дополнительных сущностей (например, таблиц) с использованием
|
29 |
+
зарегистрированных стратегий и процессоров.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
strategy_name: str = FIXED_SIZE,
|
35 |
+
strategy_params: dict[str, Any] | tuple = {},
|
36 |
+
process_tables: bool = True,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Инициализация деструктуризатора.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
strategy_name: Имя стратегии чанкинга для использования
|
43 |
+
strategy_params: Параметры для выбранной стратегии чанкинга
|
44 |
+
process_tables: Флаг обработки таблиц
|
45 |
+
"""
|
46 |
+
self.strategy: ChunkingStrategy | None = None
|
47 |
+
self._strategy_name: str | None = None
|
48 |
+
self.tables_processor: TablesProcessor | None = None
|
49 |
+
|
50 |
+
self.configure(strategy_name, strategy_params, process_tables)
|
51 |
+
|
52 |
+
def configure(
|
53 |
+
self,
|
54 |
+
strategy_name: str | None = None,
|
55 |
+
strategy_params: dict[str, Any] | tuple = {},
|
56 |
+
process_tables: bool | None = None,
|
57 |
+
) -> 'EntitiesExtractor':
|
58 |
+
"""
|
59 |
+
Переконфигурирование деструктуризатора.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
strategy_name: Имя стратегии чанкинга
|
63 |
+
strategy_params: Параметры для выбранной стратегии чанкинга, которыми нужно перезаписать дефолтные
|
64 |
+
process_tables: Обрабатывать ли таблицы
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Destructurer: Возвращает сам себя для удобства использования в цепочке вызовов
|
68 |
+
"""
|
69 |
+
if strategy_name is not None:
|
70 |
+
self.configure_chunking(strategy_name, strategy_params)
|
71 |
+
if process_tables is not None:
|
72 |
+
self.configure_tables_extraction(process_tables)
|
73 |
+
|
74 |
+
return self
|
75 |
+
|
76 |
+
def configure_chunking(
|
77 |
+
self,
|
78 |
+
strategy_name: str = FIXED_SIZE,
|
79 |
+
strategy_params: dict[str, Any] | tuple | None = None,
|
80 |
+
) -> 'EntitiesExtractor':
|
81 |
+
"""
|
82 |
+
Переконфигурирование стратегии чанкинга.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
strategy_name: Имя стратегии чанкинга
|
86 |
+
strategy_params: Параметры для выбранной стратегии чанкинга, которыми нужно перезаписать дефолтные
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Destructurer: Возвращает сам себя
|
90 |
+
"""
|
91 |
+
if strategy_name not in chunking_registry:
|
92 |
+
raise ValueError(
|
93 |
+
f"Неизвестная стратегия: {strategy_name}. "
|
94 |
+
f"Доступные стратегии: {chunking_registry.get_names()}"
|
95 |
+
f"Для регистрации новой стратегии используйте метод `register_chunking_strategy`"
|
96 |
+
)
|
97 |
+
|
98 |
+
strategy_class = chunking_registry[strategy_name]
|
99 |
+
if _check_namedtuple(strategy_params):
|
100 |
+
strategy_params = strategy_params._asdict()
|
101 |
+
elif strategy_params is None:
|
102 |
+
strategy_params = {}
|
103 |
+
try:
|
104 |
+
self.strategy = strategy_class(**strategy_params)
|
105 |
+
self._strategy_name = strategy_name
|
106 |
+
except TypeError as e:
|
107 |
+
raise ValueError(
|
108 |
+
f"Ошибка при попытке инициализировать стратегию {strategy_class.__name__}: {e}. "
|
109 |
+
f"Параметры: {strategy_params}"
|
110 |
+
f"Пожалуйста, проверьте правильность параметров и их соответствие типу стратегии."
|
111 |
+
)
|
112 |
+
|
113 |
+
logger.info(
|
114 |
+
f"Стратегия чанкинга установлена: {strategy_name} с параметрами: {strategy_params}"
|
115 |
+
)
|
116 |
+
|
117 |
+
return self
|
118 |
+
|
119 |
+
def configure_tables_extraction(
|
120 |
+
self,
|
121 |
+
process_tables: bool = True,
|
122 |
+
) -> 'EntitiesExtractor':
|
123 |
+
"""
|
124 |
+
Переконфигурирование процессора таблиц.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
process_tables: Флаг обработки таблиц
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Destructurer: Возвращает сам себя для удобства использования в цепочке вызовов
|
131 |
+
"""
|
132 |
+
self.tables_processor = TablesProcessor()
|
133 |
+
logger.info(f"Процессор таблиц установлен: {process_tables}")
|
134 |
+
return self
|
135 |
+
|
136 |
+
def extract(self, document: ParsedDocument | str) -> list[LinkerEntity]:
|
137 |
+
"""
|
138 |
+
Основной метод извлечения информации из документа.
|
139 |
+
Чанкает и извлекает из документа всё, что можно из него извлечь.
|
140 |
+
Возвращает список сущностей.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
document: Документ для извлечения информации. Если передать строку, она будет \
|
144 |
+
автоматически преобразована в `ParsedDocument`
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
list[LinkerEntity]: список сущностей (документ, чанки, таблицы, связи)
|
148 |
+
|
149 |
+
Raises:
|
150 |
+
RuntimeError: Если стратегия не была сконфигурирована
|
151 |
+
"""
|
152 |
+
if isinstance(document, str):
|
153 |
+
document = ParsedDocument(
|
154 |
+
name='unknown',
|
155 |
+
type='PlainText',
|
156 |
+
paragraphs=[
|
157 |
+
ParsedTextBlock(text=paragraph)
|
158 |
+
for paragraph in document.split('\n')
|
159 |
+
],
|
160 |
+
)
|
161 |
+
|
162 |
+
doc_entity = self._create_document_entity(document)
|
163 |
+
entities: list[LinkerEntity] = [doc_entity]
|
164 |
+
|
165 |
+
if self.strategy is not None:
|
166 |
+
logger.info(
|
167 |
+
f"Чанкирование документа {document.name} с помощью стратегии {self.strategy.__class__.__name__}..."
|
168 |
+
)
|
169 |
+
entities += self._chunk(document, doc_entity)
|
170 |
+
|
171 |
+
if self.tables_processor is not None:
|
172 |
+
logger.info(f"Обработка таблиц в документе {document.name}...")
|
173 |
+
entities += self.tables_processor.extract(document, doc_entity)
|
174 |
+
|
175 |
+
logger.info(f"Извлечение информации из документа {document.name} завершено.")
|
176 |
+
entities = [entity.serialize() for entity in entities]
|
177 |
+
|
178 |
+
return entities
|
179 |
+
|
180 |
+
def _chunk(
|
181 |
+
self,
|
182 |
+
document: ParsedDocument,
|
183 |
+
doc_entity: DocumentAsEntity,
|
184 |
+
) -> list[LinkerEntity]:
|
185 |
+
if self.strategy is None:
|
186 |
+
raise RuntimeError("Стратегия чанкинга не выставлена")
|
187 |
+
|
188 |
+
doc_entity.chunking_strategy_ref = self._strategy_name
|
189 |
+
|
190 |
+
return self.strategy.chunk(document, doc_entity)
|
191 |
+
|
192 |
+
def _create_document_entity(self, document: ParsedDocument) -> DocumentAsEntity:
|
193 |
+
"""
|
194 |
+
Создает сущность документа.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
document: Документ для создания сущности
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
DocumentAsEntity: Сущность документа
|
201 |
+
"""
|
202 |
+
return DocumentAsEntity(
|
203 |
+
id=uuid4(),
|
204 |
+
name=document.name or "Document",
|
205 |
+
text="",
|
206 |
+
metadata={"source_type": document.type},
|
207 |
+
)
|
lib/extractor/ntr_text_fragmentation/core/injection_builder.py
CHANGED
@@ -1,26 +1,30 @@
|
|
1 |
"""
|
2 |
-
Класс для сборки документа из
|
3 |
"""
|
4 |
|
5 |
-
|
6 |
-
from typing import Optional, Type
|
7 |
from uuid import UUID
|
8 |
|
9 |
-
from ..
|
10 |
-
from ..
|
11 |
-
from ..models
|
12 |
-
from
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
class InjectionBuilder:
|
16 |
"""
|
17 |
-
Класс для сборки документов из
|
18 |
|
19 |
Отвечает за:
|
20 |
-
-
|
21 |
-
-
|
22 |
-
-
|
23 |
-
|
|
|
|
|
24 |
"""
|
25 |
|
26 |
def __init__(
|
@@ -32,382 +36,143 @@ class InjectionBuilder:
|
|
32 |
Инициализация сборщика инъекций.
|
33 |
|
34 |
Args:
|
35 |
-
repository: Репозиторий
|
36 |
-
entities: Список
|
37 |
-
|
38 |
-
|
39 |
-
if repository is None and entities is
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
repository = InMemoryEntityRepository(entities)
|
41 |
-
|
42 |
-
self.
|
43 |
-
self.strategy_map: dict[str, Type[ChunkingStrategy]] = {}
|
44 |
-
|
45 |
-
def register_strategy(
|
46 |
-
self,
|
47 |
-
doc_type: str,
|
48 |
-
strategy: Type[ChunkingStrategy],
|
49 |
-
) -> None:
|
50 |
-
"""
|
51 |
-
Регистрирует стратегию для определенного типа документа.
|
52 |
-
|
53 |
-
Args:
|
54 |
-
doc_type: Тип документа
|
55 |
-
strategy: Стратегия чанкинга
|
56 |
-
"""
|
57 |
-
self.strategy_map[doc_type] = strategy
|
58 |
|
59 |
def build(
|
60 |
self,
|
61 |
-
|
62 |
-
|
63 |
include_tables: bool = True,
|
64 |
-
|
65 |
-
|
66 |
-
""
|
67 |
-
Собирает текст из всех документов, связанных с предоставленными чанками.
|
68 |
-
|
69 |
-
Args:
|
70 |
-
filtered_entities: Список чанков или их идентификаторов
|
71 |
-
chunk_scores: Словарь весов чанков {chunk_id: score}
|
72 |
-
include_tables: Флаг для включе��ия таблиц в результат
|
73 |
-
max_documents: Максимальное количество документов (None = все)
|
74 |
-
|
75 |
-
Returns:
|
76 |
-
Собранный текст со всеми документами
|
77 |
-
"""
|
78 |
-
# Преобразуем входные данные в список идентификаторов
|
79 |
-
entity_ids = [
|
80 |
-
entity.id if isinstance(entity, LinkerEntity) else entity
|
81 |
-
for entity in filtered_entities
|
82 |
-
]
|
83 |
-
|
84 |
-
if not entity_ids:
|
85 |
-
return ""
|
86 |
-
|
87 |
-
# Получаем сущности по их идентификаторам
|
88 |
-
entities = self.repository.get_entities_by_ids(entity_ids)
|
89 |
-
|
90 |
-
# Десериализуем сущности в их специализированные типы
|
91 |
-
deserialized_entities = []
|
92 |
-
for entity in entities:
|
93 |
-
# Используем статический метод десериализации
|
94 |
-
deserialized_entity = LinkerEntity.deserialize(entity)
|
95 |
-
deserialized_entities.append(deserialized_entity)
|
96 |
-
|
97 |
-
# Фильтруем сущности на чанки и таблицы
|
98 |
-
chunks = [e for e in deserialized_entities if "Chunk" in e.type]
|
99 |
-
tables = [e for e in deserialized_entities if "Table" in e.type]
|
100 |
-
|
101 |
-
# Группируем таблицы по документам
|
102 |
-
table_ids = {table.id for table in tables}
|
103 |
-
doc_tables = self._group_tables_by_document(table_ids)
|
104 |
-
|
105 |
-
if not chunks and not tables:
|
106 |
-
return ""
|
107 |
-
|
108 |
-
# Получаем идентификаторы чанков
|
109 |
-
chunk_ids = [chunk.id for chunk in chunks]
|
110 |
-
|
111 |
-
# Получаем связи для чанков (чанки являются целями связей)
|
112 |
-
links = self.repository.get_related_entities(
|
113 |
-
chunk_ids,
|
114 |
-
relation_name="document_to_chunk",
|
115 |
-
as_target=True,
|
116 |
-
)
|
117 |
-
|
118 |
-
# Группируем чанки по документам
|
119 |
-
doc_chunks = self._group_chunks_by_document(chunks, links)
|
120 |
-
|
121 |
-
# Получаем все документы для чанков и таблиц
|
122 |
-
doc_ids = set(doc_chunks.keys()) | set(doc_tables.keys())
|
123 |
-
docs = self.repository.get_entities_by_ids(doc_ids)
|
124 |
-
|
125 |
-
# Десериализуем документы
|
126 |
-
deserialized_docs = []
|
127 |
-
for doc in docs:
|
128 |
-
deserialized_doc = LinkerEntity.deserialize(doc)
|
129 |
-
deserialized_docs.append(deserialized_doc)
|
130 |
-
|
131 |
-
# Вычисляем веса документов на основе весов чанков
|
132 |
-
doc_scores = self._calculate_document_scores(doc_chunks, chunk_scores)
|
133 |
-
|
134 |
-
# Сортируем документы по весам (по убыванию)
|
135 |
-
sorted_docs = sorted(
|
136 |
-
deserialized_docs,
|
137 |
-
key=lambda d: doc_scores.get(str(d.id), 0.0),
|
138 |
-
reverse=True
|
139 |
-
)
|
140 |
-
|
141 |
-
# Ограничиваем количество документов, если указано
|
142 |
-
if max_documents:
|
143 |
-
sorted_docs = sorted_docs[:max_documents]
|
144 |
-
|
145 |
-
# Собираем текст для каждого документа
|
146 |
-
result_parts = []
|
147 |
-
for doc in sorted_docs:
|
148 |
-
doc_text = self._build_document_text(
|
149 |
-
doc,
|
150 |
-
doc_chunks.get(doc.id, []),
|
151 |
-
doc_tables.get(doc.id, []),
|
152 |
-
include_tables
|
153 |
-
)
|
154 |
-
if doc_text:
|
155 |
-
result_parts.append(doc_text)
|
156 |
-
|
157 |
-
# Объединяем результаты
|
158 |
-
return "\n\n".join(result_parts)
|
159 |
-
|
160 |
-
def _build_document_text(
|
161 |
-
self,
|
162 |
-
doc: LinkerEntity,
|
163 |
-
chunks: list[LinkerEntity],
|
164 |
-
tables: list[LinkerEntity],
|
165 |
-
include_tables: bool
|
166 |
) -> str:
|
167 |
"""
|
168 |
-
Собирает текст
|
169 |
-
|
170 |
-
Args:
|
171 |
-
doc: Сущность документа
|
172 |
-
chunks: Список чанков документа
|
173 |
-
tables: Список таблиц документа
|
174 |
-
include_tables: Флаг для включения таблиц
|
175 |
-
|
176 |
-
Returns:
|
177 |
-
Собранный текст документа
|
178 |
-
"""
|
179 |
-
# Получаем стратегию чанкинга
|
180 |
-
strategy_name = doc.metadata.get("chunking_strategy", "fixed_size")
|
181 |
-
strategy = self._get_strategy_instance(strategy_name)
|
182 |
-
|
183 |
-
# Собираем текст из чанков
|
184 |
-
chunks_text = strategy.dechunk(chunks, self.repository) if chunks else ""
|
185 |
-
|
186 |
-
# Собираем текст из таблиц, если нужно
|
187 |
-
tables_text = ""
|
188 |
-
if include_tables and tables:
|
189 |
-
# Сортируем таблицы по индексу, если он есть
|
190 |
-
sorted_tables = sorted(
|
191 |
-
tables,
|
192 |
-
key=lambda t: t.metadata.get("table_index", 0) if t.metadata else 0
|
193 |
-
)
|
194 |
-
|
195 |
-
# Собираем текст таблиц
|
196 |
-
tables_text = "\n\n".join(table.text for table in sorted_tables if hasattr(table, 'text'))
|
197 |
-
|
198 |
-
# Формируем результат
|
199 |
-
result = f"[Источник] - {doc.name}\n"
|
200 |
-
if chunks_text:
|
201 |
-
result += chunks_text
|
202 |
-
if tables_text:
|
203 |
-
if chunks_text:
|
204 |
-
result += "\n\n"
|
205 |
-
result += tables_text
|
206 |
-
|
207 |
-
return result
|
208 |
-
|
209 |
-
def _group_chunks_by_document(
|
210 |
-
self,
|
211 |
-
chunks: list[LinkerEntity],
|
212 |
-
links: list[LinkerEntity]
|
213 |
-
) -> dict[UUID, list[LinkerEntity]]:
|
214 |
-
"""
|
215 |
-
Группирует чанки по документам.
|
216 |
-
|
217 |
-
Args:
|
218 |
-
chunks: Список чанков
|
219 |
-
links: Список связей между документами и чанками
|
220 |
-
|
221 |
-
Returns:
|
222 |
-
Словарь {doc_id: [chunks]}
|
223 |
-
"""
|
224 |
-
result = defaultdict(list)
|
225 |
-
|
226 |
-
# Создаем словарь для быстрого доступа к чанкам по ID
|
227 |
-
chunk_dict = {chunk.id: chunk for chunk in chunks}
|
228 |
-
|
229 |
-
# Группируем чанки по документам на основе связей
|
230 |
-
for link in links:
|
231 |
-
if link.target_id in chunk_dict and link.source_id:
|
232 |
-
result[link.source_id].append(chunk_dict[link.target_id])
|
233 |
-
|
234 |
-
return result
|
235 |
-
|
236 |
-
def _group_tables_by_document(
|
237 |
-
self,
|
238 |
-
table_ids: set[UUID]
|
239 |
-
) -> dict[UUID, list[LinkerEntity]]:
|
240 |
-
"""
|
241 |
-
Группирует таблицы по документам.
|
242 |
-
|
243 |
-
Args:
|
244 |
-
table_ids: Множество идентификаторов таблиц
|
245 |
-
|
246 |
-
Returns:
|
247 |
-
Словарь {doc_id: [tables]}
|
248 |
-
"""
|
249 |
-
result = defaultdict(list)
|
250 |
-
|
251 |
-
table_ids = [str(table_id) for table_id in table_ids]
|
252 |
-
|
253 |
-
# Получаем связи для таблиц (таблицы являются целями связей)
|
254 |
-
if not table_ids:
|
255 |
-
return result
|
256 |
-
|
257 |
-
links = self.repository.get_related_entities(
|
258 |
-
table_ids,
|
259 |
-
relation_name="document_to_table",
|
260 |
-
as_target=True,
|
261 |
-
)
|
262 |
-
|
263 |
-
# Получаем сами таблицы
|
264 |
-
tables = self.repository.get_entities_by_ids(table_ids)
|
265 |
-
|
266 |
-
# Десериализуем таблицы
|
267 |
-
deserialized_tables = []
|
268 |
-
for table in tables:
|
269 |
-
deserialized_table = LinkerEntity.deserialize(table)
|
270 |
-
deserialized_tables.append(deserialized_table)
|
271 |
-
|
272 |
-
# Создаем словарь для быстрого доступа к таблицам по ID
|
273 |
-
table_dict = {str(table.id): table for table in deserialized_tables}
|
274 |
-
|
275 |
-
# Группируем таблицы по документам на основе связей
|
276 |
-
for link in links:
|
277 |
-
if link.target_id in table_dict and link.source_id:
|
278 |
-
result[link.source_id].append(table_dict[link.target_id])
|
279 |
-
|
280 |
-
return result
|
281 |
-
|
282 |
-
def _calculate_document_scores(
|
283 |
-
self,
|
284 |
-
doc_chunks: dict[UUID, list[LinkerEntity]],
|
285 |
-
chunk_scores: Optional[dict[str, float]]
|
286 |
-
) -> dict[str, float]:
|
287 |
-
"""
|
288 |
-
Вычисляет веса документов на основе весов чанков.
|
289 |
-
|
290 |
-
Args:
|
291 |
-
doc_chunks: Словарь {doc_id: [chunks]}
|
292 |
-
chunk_scores: Словарь весов чанков {chunk_id: score}
|
293 |
-
|
294 |
-
Returns:
|
295 |
-
Словарь весов документов {doc_id: score}
|
296 |
-
"""
|
297 |
-
if not chunk_scores:
|
298 |
-
return {str(doc_id): 1.0 for doc_id in doc_chunks.keys()}
|
299 |
-
|
300 |
-
result = {}
|
301 |
-
for doc_id, chunks in doc_chunks.items():
|
302 |
-
# Берем максимальный вес среди чанков документа
|
303 |
-
chunk_weights = [chunk_scores.get(str(c.id), 0.0) for c in chunks]
|
304 |
-
result[str(doc_id)] = max(chunk_weights) if chunk_weights else 0.0
|
305 |
-
|
306 |
-
return result
|
307 |
-
|
308 |
-
def add_neighboring_chunks(
|
309 |
-
self, entities: list[LinkerEntity] | list[UUID], max_distance: int = 1
|
310 |
-
) -> list[LinkerEntity]:
|
311 |
-
"""
|
312 |
-
Добавляет соседние чанки к отфильтрованному списку чанков.
|
313 |
|
314 |
Args:
|
315 |
-
entities: Список сущностей
|
316 |
-
|
|
|
|
|
|
|
|
|
317 |
|
318 |
Returns:
|
319 |
-
|
320 |
-
|
321 |
-
# Преобразуем входные данные в список идентификаторов
|
322 |
-
entity_ids = [
|
323 |
-
entity.id if isinstance(entity, LinkerEntity) else entity
|
324 |
-
for entity in entities
|
325 |
-
]
|
326 |
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
329 |
|
330 |
-
|
331 |
-
|
|
|
332 |
|
333 |
-
|
334 |
-
chunk_entities = [e for e in original_entities if isinstance(e, Chunk)]
|
335 |
|
336 |
-
|
337 |
-
return original_entities
|
338 |
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
341 |
|
342 |
-
|
343 |
-
neighboring_chunks = self.repository.get_neighboring_chunks(
|
344 |
-
chunk_ids, max_distance
|
345 |
-
)
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
|
353 |
-
|
354 |
-
all_chunk_ids = [chunk.id for chunk in result if isinstance(chunk, Chunk)]
|
355 |
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
|
359 |
)
|
360 |
|
361 |
-
|
362 |
-
for doc in docs:
|
363 |
-
if doc not in result:
|
364 |
-
result.append(doc)
|
365 |
-
|
366 |
-
for link in links:
|
367 |
-
if link not in result:
|
368 |
-
result.append(link)
|
369 |
-
|
370 |
-
return result
|
371 |
-
|
372 |
-
def _get_strategy_instance(self, strategy_name: str) -> ChunkingStrategy:
|
373 |
-
"""
|
374 |
-
Создает экземпляр стратегии чанкинга по имени.
|
375 |
-
|
376 |
-
Args:
|
377 |
-
strategy_name: Имя стратегии
|
378 |
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
}
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
# Если стратегия известна, импортируем и инициализируем её
|
392 |
-
if strategy_name in strategies:
|
393 |
-
import importlib
|
394 |
-
|
395 |
-
module_path, class_name = strategies[strategy_name].rsplit(".", 1)
|
396 |
-
try:
|
397 |
-
# Конвертируем относительный путь в абсолютный
|
398 |
-
abs_module_path = f"ntr_text_fragmentation{module_path[2:]}"
|
399 |
-
module = importlib.import_module(abs_module_path)
|
400 |
-
strategy_class = getattr(module, class_name)
|
401 |
-
return strategy_class()
|
402 |
-
except (ImportError, AttributeError) as e:
|
403 |
-
# Если импорт не удался, используем стратегию по умолчанию
|
404 |
-
from ..chunking.specific_strategies.fixed_size_chunking import \
|
405 |
-
FixedSizeChunkingStrategy
|
406 |
-
|
407 |
-
return FixedSizeChunkingStrategy()
|
408 |
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
Класс для сборки документа из деструктурированных сущностей (чанков, таблиц).
|
3 |
"""
|
4 |
|
5 |
+
import logging
|
|
|
6 |
from uuid import UUID
|
7 |
|
8 |
+
from ..additors import TablesProcessor
|
9 |
+
from ..chunking import chunking_registry
|
10 |
+
from ..models import DocumentAsEntity, LinkerEntity
|
11 |
+
from ..repositories import EntityRepository, GroupedEntities, InMemoryEntityRepository
|
12 |
+
|
13 |
+
# Настраиваем базовый логгер
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
|
16 |
|
17 |
class InjectionBuilder:
|
18 |
"""
|
19 |
+
Класс для сборки документов из отфильтрованного набора сущностей.
|
20 |
|
21 |
Отвечает за:
|
22 |
+
- Получение десериализованных сущностей по их ID.
|
23 |
+
- Группировку сущностей по документам, к которым они относятся.
|
24 |
+
- Вызов соответствующего метода сборки (`dechunk` или `build`)
|
25 |
+
у стратегии/процессора, передавая им документ, репозиторий и
|
26 |
+
список *всех* отфильтрованных сущностей, относящихся к этому документу.
|
27 |
+
- Агрегацию результатов сборки для нескольких документов.
|
28 |
"""
|
29 |
|
30 |
def __init__(
|
|
|
36 |
Инициализация сборщика инъекций.
|
37 |
|
38 |
Args:
|
39 |
+
repository: Репозиторий для доступа к сущностям.
|
40 |
+
entities: Список сущностей для инициализации дефолтного репозитория, если не указан repository.
|
41 |
+
Использование одновременно repository и entities не допускается.
|
42 |
+
"""
|
43 |
+
if repository is None and entities is None:
|
44 |
+
raise ValueError("Необходимо указать либо repository, либо entities.")
|
45 |
+
if repository is not None and entities is not None:
|
46 |
+
raise ValueError(
|
47 |
+
"Использование одновременно repository и entities не допускается."
|
48 |
+
)
|
49 |
+
if repository is None:
|
50 |
repository = InMemoryEntityRepository(entities)
|
51 |
+
self.repository = repository
|
52 |
+
self.tables_processor = TablesProcessor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
def build(
|
55 |
self,
|
56 |
+
entities: list[UUID] | list[LinkerEntity],
|
57 |
+
scores: list[float] | None = None,
|
58 |
include_tables: bool = True,
|
59 |
+
neighbors_max_distance: int = 1,
|
60 |
+
max_documents: int | None = None,
|
61 |
+
document_prefix: str = "[Источник] - ",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
) -> str:
|
63 |
"""
|
64 |
+
Собирает текст документов на основе *отфильтрованного* списка ID сущностей
|
65 |
+
(чанков, строк таблиц и т.д.).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
Args:
|
68 |
+
entities: Список ID сущностей (UUID), которые были отобраны
|
69 |
+
(например, в результате поиска) и должны войти в контекст.
|
70 |
+
scores: Список оценок для каждой сущности из логики больше - лучше.
|
71 |
+
include_tables: Включать ли таблицы из соответствующих документов.
|
72 |
+
max_documents: Максимальное количество документов для включения в результат
|
73 |
+
(сортировка документов пока не реализована).
|
74 |
|
75 |
Returns:
|
76 |
+
Собранный текст из указанных сущностей (и, возможно, таблиц)
|
77 |
+
сгруппированный по документам.
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
Raises:
|
80 |
+
ValueError: Если entity_ids пуст или содержит невалидные UUID.
|
81 |
+
"""
|
82 |
+
if any(isinstance(eid, UUID) for eid in entities):
|
83 |
+
entities = self.repository.get_entities_by_ids(entities)
|
84 |
|
85 |
+
if not entities:
|
86 |
+
logger.warning("Не удалось получить ни одной сущности по переданным ID.")
|
87 |
+
return ""
|
88 |
|
89 |
+
entities = [e.deserialize() for e in entities]
|
|
|
90 |
|
91 |
+
logger.info(f"Получено {len(entities)} сущностей для сборки.")
|
|
|
92 |
|
93 |
+
if neighbors_max_distance > 0:
|
94 |
+
neighbors = self.repository.get_neighboring_entities(
|
95 |
+
entities, neighbors_max_distance
|
96 |
+
)
|
97 |
+
neighbors = [e.deserialize() for e in neighbors]
|
98 |
+
entities.extend(neighbors)
|
99 |
|
100 |
+
logger.info(f"Получено {len(entities)} сущностей для сборки с соседями.")
|
|
|
|
|
|
|
101 |
|
102 |
+
if scores is None:
|
103 |
+
logger.info(
|
104 |
+
"Оценки не предоставлены, используем порядковые номера в обратном порядке."
|
105 |
+
)
|
106 |
+
scores = [float(i) for i in range(len(entities), 0, -1)]
|
107 |
|
108 |
+
id_to_score = {entity.id: score for entity, score in zip(entities, scores)}
|
|
|
109 |
|
110 |
+
groups: list[GroupedEntities[DocumentAsEntity]] = (
|
111 |
+
self.repository.group_entities_hierarchically(
|
112 |
+
entities=entities,
|
113 |
+
root_type=DocumentAsEntity,
|
114 |
+
)
|
115 |
)
|
116 |
|
117 |
+
logger.info(f"Сгруппировано {len(groups)} документов.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
document_scores = {
|
120 |
+
group.composer.id: max(
|
121 |
+
id_to_score[eid.id] for eid in group.entities if eid.id in id_to_score
|
122 |
+
)
|
123 |
+
for group in groups
|
124 |
+
if any(eid.id in id_to_score for eid in group.entities)
|
125 |
}
|
126 |
|
127 |
+
groups = sorted(
|
128 |
+
groups, key=lambda x: document_scores[x.composer.id], reverse=True
|
129 |
+
)
|
130 |
+
groups = list(groups)[:max_documents]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
builded_documents = [
|
133 |
+
self._build_document(group, include_tables, document_prefix).replace(
|
134 |
+
"\n", "\n\n"
|
135 |
+
)
|
136 |
+
for group in groups
|
137 |
+
]
|
138 |
+
return "\n\n".join(builded_documents)
|
139 |
|
140 |
+
def _build_document(
|
141 |
+
self,
|
142 |
+
group: GroupedEntities,
|
143 |
+
include_tables: bool = True,
|
144 |
+
document_prefix: str = "[Источник] - ",
|
145 |
+
) -> str:
|
146 |
+
document = group.composer
|
147 |
+
entities = group.entities
|
148 |
+
|
149 |
+
name = document.name
|
150 |
+
|
151 |
+
strategy = document.chunking_strategy_ref
|
152 |
+
builded_chunks = None
|
153 |
+
builded_tables = None
|
154 |
+
if strategy is None:
|
155 |
+
logger.warning(f"Стратегия чанкинга не указана для документа {name}")
|
156 |
+
else:
|
157 |
+
strategy_class = chunking_registry.get(strategy)
|
158 |
+
builded_chunks = strategy_class.dechunk(self.repository, entities)
|
159 |
+
|
160 |
+
if include_tables:
|
161 |
+
builded_tables = self.tables_processor.build(self.repository, entities)
|
162 |
+
|
163 |
+
result_text = f"## {document_prefix}{name}\n\n"
|
164 |
+
if builded_chunks:
|
165 |
+
result_text += f'### Текст\n{builded_chunks}\n\n'
|
166 |
+
if builded_tables:
|
167 |
+
result_text += f'### Таблицы\n{builded_tables}\n\n'
|
168 |
+
|
169 |
+
return result_text
|
170 |
+
|
171 |
+
def _deserialize_all(self, groups: list[GroupedEntities]) -> list[GroupedEntities]:
|
172 |
+
return [
|
173 |
+
GroupedEntities(
|
174 |
+
composer=group.composer.deserialize(),
|
175 |
+
entities=[entity.deserialize() for entity in group.entities],
|
176 |
+
)
|
177 |
+
for group in groups
|
178 |
+
]
|
lib/extractor/ntr_text_fragmentation/integrations/__init__.py
CHANGED
@@ -2,8 +2,9 @@
|
|
2 |
Модуль интеграций с внешними хранилищами данных и ORM системами.
|
3 |
"""
|
4 |
|
5 |
-
from .
|
6 |
|
|
|
7 |
__all__ = [
|
8 |
-
"
|
9 |
]
|
|
|
2 |
Модуль интеграций с внешними хранилищами данных и ORM системами.
|
3 |
"""
|
4 |
|
5 |
+
from ..repositories.in_memory_repository import InMemoryEntityRepository
|
6 |
|
7 |
+
# SQLAlchemy не импортируется, чтобы не тащить лишние зависимости в основной код
|
8 |
__all__ = [
|
9 |
+
"InMemoryEntityRepository",
|
10 |
]
|
lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sqlalchemy_repository import SQLAlchemyEntityRepository
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"SQLAlchemyEntityRepository",
|
5 |
+
]
|
6 |
+
|
lib/extractor/ntr_text_fragmentation/integrations/sqlalchemy/sqlalchemy_repository.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Реализация EntityRepository для работы с SQLAlchemy.
|
3 |
+
"""
|
4 |
+
# Добавляем импорт logging и создаем логгер
|
5 |
+
import logging
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Any, Dict, Iterable, List, Type
|
9 |
+
from uuid import UUID
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
from sqlalchemy import Column, and_, or_, select
|
14 |
+
from sqlalchemy.orm import Session, sessionmaker
|
15 |
+
|
16 |
+
from ...models import LinkerEntity
|
17 |
+
from ...repositories.entity_repository import EntityRepository, GroupedEntities
|
18 |
+
|
19 |
+
Base = Any
|
20 |
+
|
21 |
+
|
22 |
+
class SQLAlchemyEntityRepository(EntityRepository, ABC):
|
23 |
+
"""
|
24 |
+
Абстрактная реализация EntityRepository для работы с базой данных через SQLAlchemy.
|
25 |
+
|
26 |
+
Требует определения методов `_entity_model_class` и
|
27 |
+
`_map_db_entity_to_linker_entity` в дочерних классах для работы с конкретной
|
28 |
+
моделью SQLAlchemy и маппинга на LinkerEntity.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, db_session_factory: sessionmaker[Session]):
|
32 |
+
"""
|
33 |
+
Инициализирует репозиторий с фабрикой сессий SQLAlchemy.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
db_session_factory: Фабрика сессий SQLAlchemy (sessionmaker).
|
37 |
+
"""
|
38 |
+
self.db = db_session_factory
|
39 |
+
|
40 |
+
@property
|
41 |
+
@abstractmethod
|
42 |
+
def _entity_model_class(self) -> Type[Base]:
|
43 |
+
"""Возвращает класс модели SQLAlchemy, используемый этим репозиторием."""
|
44 |
+
pass
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def _map_db_entity_to_linker_entity(self, db_entity: Base) -> LinkerEntity:
|
48 |
+
"""Преобразует объект модели SQLAlchemy в объект LinkerEntity."""
|
49 |
+
pass
|
50 |
+
|
51 |
+
def _get_id_column(self) -> Column:
|
52 |
+
"""Возвращает колонку ID (uuid или id) из модели."""
|
53 |
+
entity_model = self._entity_model_class
|
54 |
+
# SQLAlchemy 2.0 style attribute access if Base is DeclarativeBase
|
55 |
+
id_column = getattr(entity_model, 'uuid', getattr(entity_model, 'id', None))
|
56 |
+
if id_column is None:
|
57 |
+
raise AttributeError(f"Модель {entity_model.__name__} не имеет атрибута/колонки 'id' или 'uuid'")
|
58 |
+
# Ensure it's a Column object if using older style mapping
|
59 |
+
# If using 2.0 MappedAsDataclass, this might need adjustment
|
60 |
+
# For now, assuming it returns something comparable
|
61 |
+
return id_column
|
62 |
+
|
63 |
+
def _normalize_entities(
|
64 |
+
self, entities: Iterable[UUID] | Iterable[LinkerEntity]
|
65 |
+
) -> list[UUID]:
|
66 |
+
"""Преобразует входные данные в список UUID."""
|
67 |
+
result = []
|
68 |
+
if entities is None:
|
69 |
+
return result
|
70 |
+
for entity in entities:
|
71 |
+
if isinstance(entity, UUID):
|
72 |
+
result.append(entity)
|
73 |
+
elif isinstance(entity, LinkerEntity):
|
74 |
+
result.append(entity.id)
|
75 |
+
return result
|
76 |
+
|
77 |
+
def get_entities_by_ids(
|
78 |
+
self, entity_ids: Iterable[UUID]
|
79 |
+
) -> List[LinkerEntity]:
|
80 |
+
"""
|
81 |
+
Получить сущности по списку идентификаторов UUID.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
entity_ids: Итерируемый объект с UUID сущностей.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Список найденных сущностей LinkerEntity.
|
88 |
+
"""
|
89 |
+
ids_list = list(entity_ids)
|
90 |
+
if not ids_list:
|
91 |
+
return []
|
92 |
+
|
93 |
+
string_ids = [str(eid) for eid in ids_list]
|
94 |
+
entity_model = self._entity_model_class
|
95 |
+
id_column = self._get_id_column()
|
96 |
+
|
97 |
+
with self.db() as session:
|
98 |
+
db_entities = session.execute(
|
99 |
+
select(entity_model).where(id_column.in_(string_ids))
|
100 |
+
).scalars().all()
|
101 |
+
|
102 |
+
return [self._map_db_entity_to_linker_entity(entity) for entity in db_entities]
|
103 |
+
|
104 |
+
def group_entities_hierarchically(
|
105 |
+
self,
|
106 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
107 |
+
root_type: Type[LinkerEntity],
|
108 |
+
max_levels: int = 10,
|
109 |
+
sort: bool = True,
|
110 |
+
) -> list[GroupedEntities[LinkerEntity]]:
|
111 |
+
"""
|
112 |
+
Группирует сущности по корневым элементам иерархии.
|
113 |
+
|
114 |
+
Ищет родительские связи (где сущность является target_id), поднимаясь
|
115 |
+
вверх по иерархии до `max_levels` или до нахождения `root_type`.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
entities: Список идентификаторов UUID или сущностей LinkerEntity.
|
119 |
+
root_type: Класс корневого типа (например, DocumentAsEntity).
|
120 |
+
max_levels: Максимальная глубина поиска вверх по иерархии.
|
121 |
+
sort: Флаг для сортировк�� сущностей в группах.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Список групп сущностей `GroupedEntities`.
|
125 |
+
"""
|
126 |
+
entity_ids_list = self._normalize_entities(entities)
|
127 |
+
if not entity_ids_list:
|
128 |
+
return []
|
129 |
+
|
130 |
+
entity_model = self._entity_model_class
|
131 |
+
id_column = self._get_id_column()
|
132 |
+
root_type_str = root_type.__name__
|
133 |
+
logger.info(f"[group_hierarchically] Искомый тип корня: '{root_type_str}'")
|
134 |
+
|
135 |
+
entity_type_column = getattr(entity_model, 'entity_type', getattr(entity_model, 'type', None))
|
136 |
+
source_id_column = getattr(entity_model, 'source_id', None)
|
137 |
+
target_id_column = getattr(entity_model, 'target_id', None)
|
138 |
+
|
139 |
+
if not all([entity_type_column, source_id_column, target_id_column]):
|
140 |
+
raise AttributeError(f"Модель {entity_model.__name__} не имеет необходимых колонок: 'entity_type'/'type', 'source_id', 'target_id'")
|
141 |
+
|
142 |
+
entity_to_root_cache: Dict[str, str | None] = {}
|
143 |
+
fetched_entities: Dict[str, Base] = {}
|
144 |
+
|
145 |
+
with self.db() as session:
|
146 |
+
|
147 |
+
def _fetch_entity(entity_id_str: str) -> Base | None:
|
148 |
+
"""Загружает сущность из БД, если её еще нет в fetched_entities."""
|
149 |
+
if entity_id_str not in fetched_entities:
|
150 |
+
entity = session.get(entity_model, entity_id_str)
|
151 |
+
if entity is None:
|
152 |
+
stmt = select(entity_model).where(id_column == entity_id_str)
|
153 |
+
entity = session.execute(stmt).scalar_one_or_none()
|
154 |
+
fetched_entities[entity_id_str] = entity # Store entity or None
|
155 |
+
return fetched_entities[entity_id_str]
|
156 |
+
|
157 |
+
def _find_root(entity_id_str: str, level: int) -> str | None:
|
158 |
+
"""Рекурсивный поиск корневой сущности."""
|
159 |
+
if level > max_levels or not entity_id_str:
|
160 |
+
return None
|
161 |
+
if entity_id_str in entity_to_root_cache:
|
162 |
+
return entity_to_root_cache[entity_id_str]
|
163 |
+
|
164 |
+
db_entity = _fetch_entity(entity_id_str)
|
165 |
+
if not db_entity:
|
166 |
+
logger.warning(f"[_find_root] Не удалось найти сущность с ID {entity_id_str}")
|
167 |
+
entity_to_root_cache[entity_id_str] = None
|
168 |
+
return None
|
169 |
+
|
170 |
+
current_entity_type = getattr(db_entity, entity_type_column.name)
|
171 |
+
|
172 |
+
if current_entity_type == root_type_str:
|
173 |
+
# logger.debug(f"[_find_root] Сущность {entity_id_str} сама является корнем типа '{root_type_str}'")
|
174 |
+
entity_to_root_cache[entity_id_str] = entity_id_str
|
175 |
+
return entity_id_str
|
176 |
+
|
177 |
+
parent_id_str = getattr(db_entity, target_id_column.name, None)
|
178 |
+
|
179 |
+
root_id = None
|
180 |
+
if parent_id_str:
|
181 |
+
# logger.debug(f"[_find_root] Сущность {entity_id_str} указывает на родителя {parent_id_str} через target_id.")
|
182 |
+
root_id = _find_root(str(parent_id_str), level + 1)
|
183 |
+
|
184 |
+
entity_to_root_cache[entity_id_str] = root_id
|
185 |
+
return root_id
|
186 |
+
|
187 |
+
roots_map: Dict[str, str | None] = {}
|
188 |
+
for start_entity_id in entity_ids_list:
|
189 |
+
start_entity_id_str = str(start_entity_id)
|
190 |
+
if start_entity_id_str not in roots_map:
|
191 |
+
found_root = _find_root(start_entity_id_str, 0)
|
192 |
+
roots_map[start_entity_id_str] = found_root
|
193 |
+
# if found_root:
|
194 |
+
# logger.debug(f"[group_hierarchically] Найден корень {found_root} для сущности {start_entity_id_str}")
|
195 |
+
|
196 |
+
|
197 |
+
groups: Dict[str, List[Base]] = defaultdict(list)
|
198 |
+
initial_db_entities = session.execute(
|
199 |
+
select(entity_model).where(id_column.in_([str(eid) for eid in entity_ids_list]))
|
200 |
+
).scalars().all()
|
201 |
+
|
202 |
+
found_roots_count = 0
|
203 |
+
grouped_entities_count = 0
|
204 |
+
for db_entity in initial_db_entities:
|
205 |
+
entity_id_str = str(getattr(db_entity, id_column.name))
|
206 |
+
root_id = roots_map.get(entity_id_str)
|
207 |
+
if root_id:
|
208 |
+
groups[root_id].append(db_entity)
|
209 |
+
grouped_entities_count += 1
|
210 |
+
if len(groups[root_id]) == 1:
|
211 |
+
found_roots_count += 1
|
212 |
+
|
213 |
+
logger.info(f"[group_hierarchically] Сгруппировано {grouped_entities_count} сущностей в {len(groups)} групп (найдено {found_roots_count} уникальных корней).")
|
214 |
+
|
215 |
+
result: list[GroupedEntities[LinkerEntity]] = []
|
216 |
+
for root_id_str, db_entities_list in groups.items():
|
217 |
+
root_db_entity = fetched_entities.get(root_id_str)
|
218 |
+
if root_db_entity:
|
219 |
+
composer = self._map_db_entity_to_linker_entity(root_db_entity)
|
220 |
+
grouped_linker_entities = [
|
221 |
+
self._map_db_entity_to_linker_entity(db_e) for db_e in db_entities_list
|
222 |
+
]
|
223 |
+
|
224 |
+
if sort:
|
225 |
+
grouped_linker_entities.sort(
|
226 |
+
key=lambda entity: (
|
227 |
+
str(getattr(entity, 'groupper', getattr(entity, 'entity_type', getattr(entity, 'type', '')))),
|
228 |
+
int(getattr(entity, 'number_in_relation', getattr(entity, 'chunk_index', float('inf'))))
|
229 |
+
)
|
230 |
+
)
|
231 |
+
# logger.debug(f"[group_hierarchically] Отсортирована группа для корня {root_id_str}")
|
232 |
+
|
233 |
+
result.append(GroupedEntities(composer=composer, entities=grouped_linker_entities))
|
234 |
+
|
235 |
+
logger.info(f"[group_hierarchically] Сформировано {len(result)} объектов GroupedEntities.")
|
236 |
+
return result
|
237 |
+
|
238 |
+
|
239 |
+
def get_neighboring_entities(
|
240 |
+
self,
|
241 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
242 |
+
max_distance: int = 1,
|
243 |
+
) -> list[LinkerEntity]:
|
244 |
+
"""
|
245 |
+
Получить соседние сущности для указанных сущностей.
|
246 |
+
|
247 |
+
Соседство определяется на основе общего "родителя" (target_id сущности,
|
248 |
+
если source_id is None) и близости по `number_in_relation` или
|
249 |
+
`chunk_index` в рамках одной группы (`entity_type` или `type`).
|
250 |
+
|
251 |
+
Args:
|
252 |
+
entities: Список идентификаторов UUID или сущностей LinkerEntity.
|
253 |
+
max_distance: Максимальное расстояние (по порядку) между соседями.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
Список соседних сущностей LinkerEntity.
|
257 |
+
"""
|
258 |
+
entity_ids_list = self._normalize_entities(entities)
|
259 |
+
if not entity_ids_list or max_distance < 1:
|
260 |
+
return []
|
261 |
+
|
262 |
+
string_entity_ids = {str(eid) for eid in entity_ids_list}
|
263 |
+
entity_model = self._entity_model_class
|
264 |
+
id_column = self._get_id_column()
|
265 |
+
source_id_column = getattr(entity_model, 'source_id', None)
|
266 |
+
target_id_column = getattr(entity_model, 'target_id', None)
|
267 |
+
order_column = getattr(entity_model, 'number_in_relation', None)
|
268 |
+
group_column = getattr(entity_model, 'entity_type', getattr(entity_model, 'type', None))
|
269 |
+
|
270 |
+
if not all([source_id_column, target_id_column, order_column, group_column]):
|
271 |
+
raise AttributeError(f"Модель {entity_model.__name__} не имеет необходимых колонок: 'source_id', 'target_id', 'chunk_index'/'number_in_relation', 'entity_type'/'type'")
|
272 |
+
|
273 |
+
neighbor_entities_map: Dict[str, Base] = {}
|
274 |
+
|
275 |
+
with self.db() as session:
|
276 |
+
initial_entities_query = select(entity_model).where(id_column.in_(list(string_entity_ids)))
|
277 |
+
initial_db_entities = session.execute(initial_entities_query).scalars().all()
|
278 |
+
|
279 |
+
valid_initial_entities_info: Dict[tuple[str, str], list[Dict[str, Any]]] = defaultdict(list)
|
280 |
+
parent_group_keys: set[tuple[str, str]] = set()
|
281 |
+
|
282 |
+
for db_entity in initial_db_entities:
|
283 |
+
entity_id_str = str(getattr(db_entity, id_column.name))
|
284 |
+
source_id = getattr(db_entity, source_id_column.name, None)
|
285 |
+
target_id = getattr(db_entity, target_id_column.name, None)
|
286 |
+
group_value = getattr(db_entity, group_column.name, None)
|
287 |
+
order_value = getattr(db_entity, order_column.name, None)
|
288 |
+
|
289 |
+
if source_id is None and target_id is not None and group_value is not None and order_value is not None:
|
290 |
+
parent_id_str = str(target_id)
|
291 |
+
group_key = (parent_id_str, str(group_value))
|
292 |
+
parent_group_keys.add(group_key)
|
293 |
+
valid_initial_entities_info[group_key].append({
|
294 |
+
"id": entity_id_str,
|
295 |
+
"order": order_value
|
296 |
+
})
|
297 |
+
|
298 |
+
if not parent_group_keys:
|
299 |
+
return []
|
300 |
+
|
301 |
+
sibling_conditions = []
|
302 |
+
for parent_id, group_val in parent_group_keys:
|
303 |
+
sibling_conditions.append(
|
304 |
+
and_(
|
305 |
+
target_id_column == parent_id,
|
306 |
+
group_column == group_val,
|
307 |
+
source_id_column.is_(None),
|
308 |
+
order_column.isnot(None)
|
309 |
+
)
|
310 |
+
)
|
311 |
+
|
312 |
+
potential_siblings_query = select(entity_model).where(or_(*sibling_conditions))
|
313 |
+
potential_siblings = session.execute(potential_siblings_query).scalars().all()
|
314 |
+
|
315 |
+
for sibling_entity in potential_siblings:
|
316 |
+
sibling_id_str = str(getattr(sibling_entity, id_column.name))
|
317 |
+
|
318 |
+
if sibling_id_str in string_entity_ids:
|
319 |
+
continue
|
320 |
+
|
321 |
+
sibling_target_id = getattr(sibling_entity, target_id_column.name)
|
322 |
+
sibling_group = getattr(sibling_entity, group_column.name)
|
323 |
+
sibling_order = getattr(sibling_entity, order_column.name)
|
324 |
+
|
325 |
+
if sibling_target_id is None or sibling_group is None or sibling_order is None:
|
326 |
+
logger.warning(f"Потенциальный сиблинг {sibling_id_str} не имеет target_id, группы или порядка, хотя был выбран запросом.")
|
327 |
+
continue
|
328 |
+
|
329 |
+
sibling_parent_id_str = str(sibling_target_id)
|
330 |
+
sibling_group_str = str(sibling_group)
|
331 |
+
group_key = (sibling_parent_id_str, sibling_group_str)
|
332 |
+
|
333 |
+
if group_key in valid_initial_entities_info:
|
334 |
+
for initial_info in valid_initial_entities_info[group_key]:
|
335 |
+
initial_order = initial_info["order"]
|
336 |
+
distance = abs(sibling_order - initial_order)
|
337 |
+
|
338 |
+
if 0 < distance <= max_distance and sibling_id_str not in neighbor_entities_map:
|
339 |
+
neighbor_entities_map[sibling_id_str] = sibling_entity
|
340 |
+
break
|
341 |
+
|
342 |
+
|
343 |
+
return [self._map_db_entity_to_linker_entity(ne) for ne in neighbor_entities_map.values()]
|
344 |
+
|
345 |
+
|
346 |
+
def get_related_entities(
|
347 |
+
self,
|
348 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
349 |
+
relation_type: Type[LinkerEntity] | None = None,
|
350 |
+
as_source: bool = False,
|
351 |
+
as_target: bool = False,
|
352 |
+
as_owner: bool = False, # Добавлено
|
353 |
+
) -> List[LinkerEntity]:
|
354 |
+
"""
|
355 |
+
Получить сущности, связанные с указанными, а также сами связи.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
entities: Список идентификаторов UUID или сущностей LinkerEntity.
|
359 |
+
relation_type: Опциональный класс связи для фильтрации (например, CompositionLink).
|
360 |
+
as_source: Искать связи, где entities - источники (`source_id`).
|
361 |
+
as_target: Искать связи, где entities - цели (`target_id`).
|
362 |
+
as_owner: Искать связи, где entities - владельцы (`source_id`, предполагая связь владения).
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
Список связанных сущностей LinkerEntity и самих связей.
|
366 |
+
"""
|
367 |
+
entity_ids_list = self._normalize_entities(entities)
|
368 |
+
if not entity_ids_list:
|
369 |
+
return []
|
370 |
+
|
371 |
+
if not as_source and not as_target and not as_owner:
|
372 |
+
as_source = True
|
373 |
+
as_target = True
|
374 |
+
|
375 |
+
string_ids = [str(eid) for eid in entity_ids_list]
|
376 |
+
entity_model = self._entity_model_class
|
377 |
+
id_column = self._get_id_column()
|
378 |
+
source_id_column = getattr(entity_model, 'source_id', None)
|
379 |
+
target_id_column = getattr(entity_model, 'target_id', None)
|
380 |
+
entity_type_column = getattr(entity_model, 'entity_type', getattr(entity_model, 'type', None))
|
381 |
+
|
382 |
+
if not all([source_id_column, target_id_column, entity_type_column]):
|
383 |
+
raise AttributeError(f"Модель {entity_model.__name__} не имеет необходимых колонок: 'source_id', 'target_id', 'entity_type'/'type'")
|
384 |
+
|
385 |
+
related_db_objects_map: Dict[str, Base] = {}
|
386 |
+
|
387 |
+
relation_type_str = None
|
388 |
+
if relation_type:
|
389 |
+
relation_type_str = relation_type.__name__
|
390 |
+
|
391 |
+
with self.db() as session:
|
392 |
+
def _add_related(db_objects: Iterable[Base]):
|
393 |
+
"""Helper function to add objects and fetch related source/target entities."""
|
394 |
+
ids_to_fetch = set()
|
395 |
+
for db_obj in db_objects:
|
396 |
+
obj_id = str(getattr(db_obj, id_column.name))
|
397 |
+
if obj_id not in related_db_objects_map:
|
398 |
+
related_db_objects_map[obj_id] = db_obj
|
399 |
+
source_id = getattr(db_obj, source_id_column.name, None)
|
400 |
+
target_id = getattr(db_obj, target_id_column.name, None)
|
401 |
+
if source_id:
|
402 |
+
ids_to_fetch.add(str(source_id))
|
403 |
+
if target_id:
|
404 |
+
ids_to_fetch.add(str(target_id))
|
405 |
+
|
406 |
+
ids_to_fetch.difference_update(related_db_objects_map.keys())
|
407 |
+
if ids_to_fetch:
|
408 |
+
fetched = session.execute(
|
409 |
+
select(entity_model).where(id_column.in_(list(ids_to_fetch)))
|
410 |
+
).scalars().all()
|
411 |
+
for fetched_obj in fetched:
|
412 |
+
fetched_id = str(getattr(fetched_obj, id_column.name))
|
413 |
+
if fetched_id not in related_db_objects_map:
|
414 |
+
related_db_objects_map[fetched_id] = fetched_obj
|
415 |
+
|
416 |
+
if as_source or as_owner:
|
417 |
+
conditions = [source_id_column.in_(string_ids)]
|
418 |
+
if relation_type_str:
|
419 |
+
conditions.append(entity_type_column == relation_type_str)
|
420 |
+
source_links_query = select(entity_model).where(and_(*conditions))
|
421 |
+
source_links = session.execute(source_links_query).scalars().all()
|
422 |
+
_add_related(source_links)
|
423 |
+
|
424 |
+
if as_target:
|
425 |
+
conditions = [target_id_column.in_(string_ids)]
|
426 |
+
if relation_type_str:
|
427 |
+
conditions.append(entity_type_column == relation_type_str)
|
428 |
+
target_links_query = select(entity_model).where(and_(*conditions))
|
429 |
+
target_links = session.execute(target_links_query).scalars().all()
|
430 |
+
_add_related(target_links)
|
431 |
+
|
432 |
+
final_map: Dict[UUID, LinkerEntity] = {}
|
433 |
+
|
434 |
+
for db_obj in related_db_objects_map.values():
|
435 |
+
linker_entity = self._map_db_entity_to_linker_entity(db_obj)
|
436 |
+
if relation_type:
|
437 |
+
is_link = linker_entity.is_link()
|
438 |
+
is_relevant_link = False
|
439 |
+
if is_link:
|
440 |
+
link_source_uuid = linker_entity.source_id
|
441 |
+
link_target_uuid = linker_entity.target_id
|
442 |
+
original_uuids = {UUID(s_id) for s_id in string_ids}
|
443 |
+
|
444 |
+
if (as_source or as_owner) and link_source_uuid in original_uuids:
|
445 |
+
is_relevant_link = True
|
446 |
+
elif as_target and link_target_uuid in original_uuids:
|
447 |
+
is_relevant_link = True
|
448 |
+
|
449 |
+
if is_relevant_link and not isinstance(linker_entity, relation_type):
|
450 |
+
continue
|
451 |
+
|
452 |
+
if linker_entity.id not in final_map:
|
453 |
+
final_map[linker_entity.id] = linker_entity
|
454 |
+
|
455 |
+
return list(final_map.values())
|
lib/extractor/ntr_text_fragmentation/models/__init__.py
CHANGED
@@ -2,12 +2,13 @@
|
|
2 |
Модуль моделей данных.
|
3 |
"""
|
4 |
|
5 |
-
from .chunk import Chunk
|
6 |
from .document import DocumentAsEntity
|
7 |
-
from .linker_entity import LinkerEntity
|
8 |
|
9 |
__all__ = [
|
10 |
-
"LinkerEntity",
|
11 |
-
"DocumentAsEntity",
|
12 |
-
"
|
13 |
-
|
|
|
|
|
|
2 |
Модуль моделей данных.
|
3 |
"""
|
4 |
|
|
|
5 |
from .document import DocumentAsEntity
|
6 |
+
from .linker_entity import LinkerEntity, Entity, Link, register_entity
|
7 |
|
8 |
__all__ = [
|
9 |
+
"LinkerEntity",
|
10 |
+
"DocumentAsEntity",
|
11 |
+
"Entity",
|
12 |
+
"Link",
|
13 |
+
"register_entity",
|
14 |
+
]
|
lib/extractor/ntr_text_fragmentation/models/document.py
CHANGED
@@ -2,39 +2,51 @@
|
|
2 |
Класс для представления документа как сущности.
|
3 |
"""
|
4 |
|
5 |
-
from dataclasses import dataclass
|
6 |
|
7 |
-
from .linker_entity import
|
8 |
|
9 |
|
10 |
@register_entity
|
11 |
@dataclass
|
12 |
-
class DocumentAsEntity(
|
13 |
"""
|
14 |
Класс для представления документа как сущности в системе извлечения и сборки.
|
|
|
|
|
15 |
"""
|
16 |
|
17 |
doc_type: str = "unknown"
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
@classmethod
|
20 |
-
def
|
21 |
"""
|
22 |
Десериализует DocumentAsEntity из объекта LinkerEntity.
|
23 |
-
|
24 |
Args:
|
25 |
data: Объект LinkerEntity для преобразования в DocumentAsEntity
|
26 |
-
|
27 |
Returns:
|
28 |
Десериализованный объект DocumentAsEntity
|
29 |
"""
|
|
|
|
|
|
|
30 |
metadata = data.metadata or {}
|
31 |
-
|
32 |
-
# Получаем
|
33 |
-
doc_type = metadata.get('_doc_type', 'unknown')
|
34 |
-
|
|
|
|
|
|
|
35 |
# Создаем чистые метаданные без служебных полей
|
36 |
clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')}
|
37 |
-
|
38 |
return cls(
|
39 |
id=data.id,
|
40 |
name=data.name,
|
@@ -44,6 +56,8 @@ class DocumentAsEntity(LinkerEntity):
|
|
44 |
source_id=data.source_id,
|
45 |
target_id=data.target_id,
|
46 |
number_in_relation=data.number_in_relation,
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
)
|
|
|
2 |
Класс для представления документа как сущности.
|
3 |
"""
|
4 |
|
5 |
+
from dataclasses import dataclass, field
|
6 |
|
7 |
+
from .linker_entity import Entity, register_entity
|
8 |
|
9 |
|
10 |
@register_entity
|
11 |
@dataclass
|
12 |
+
class DocumentAsEntity(Entity):
|
13 |
"""
|
14 |
Класс для представления документа как сущности в системе извлечения и сборки.
|
15 |
+
Содержит ссылки на классы стратегии чанкинга и обработчика таблиц,
|
16 |
+
использовавшихся при деструктуризации.
|
17 |
"""
|
18 |
|
19 |
doc_type: str = "unknown"
|
20 |
+
|
21 |
+
chunking_strategy_ref: str | None = None
|
22 |
+
|
23 |
+
type: str = field(default="DocumentAsEntity")
|
24 |
+
|
25 |
@classmethod
|
26 |
+
def _deserialize_to_me(cls, data: Entity) -> 'DocumentAsEntity':
|
27 |
"""
|
28 |
Десериализует DocumentAsEntity из объекта LinkerEntity.
|
29 |
+
|
30 |
Args:
|
31 |
data: Объект LinkerEntity для преобразования в DocumentAsEntity
|
32 |
+
|
33 |
Returns:
|
34 |
Десериализованный объект DocumentAsEntity
|
35 |
"""
|
36 |
+
if not isinstance(data, Entity):
|
37 |
+
raise TypeError(f"Ожидался LinkerEntity, получен {type(data)}")
|
38 |
+
|
39 |
metadata = data.metadata or {}
|
40 |
+
|
41 |
+
# Получаем поля из атрибутов или метаданных
|
42 |
+
doc_type = getattr(data, 'doc_type', metadata.get('_doc_type', 'unknown'))
|
43 |
+
strategy_ref = getattr(
|
44 |
+
data, 'chunking_strategy_ref', metadata.get('_chunking_strategy_ref', None)
|
45 |
+
)
|
46 |
+
|
47 |
# Создаем чистые метаданные без служебных полей
|
48 |
clean_metadata = {k: v for k, v in metadata.items() if not k.startswith('_')}
|
49 |
+
|
50 |
return cls(
|
51 |
id=data.id,
|
52 |
name=data.name,
|
|
|
56 |
source_id=data.source_id,
|
57 |
target_id=data.target_id,
|
58 |
number_in_relation=data.number_in_relation,
|
59 |
+
groupper=data.groupper,
|
60 |
+
type=cls.__name__,
|
61 |
+
doc_type=doc_type,
|
62 |
+
chunking_strategy_ref=strategy_ref,
|
63 |
)
|
lib/extractor/ntr_text_fragmentation/models/linker_entity.py
CHANGED
@@ -2,56 +2,75 @@
|
|
2 |
Базовый абстрактный класс для всех сущностей с поддержкой триплетного подхода.
|
3 |
"""
|
4 |
|
|
|
5 |
import uuid
|
6 |
-
from abc import abstractmethod
|
7 |
from dataclasses import dataclass, field, fields
|
8 |
from uuid import UUID
|
9 |
|
|
|
|
|
10 |
|
11 |
@dataclass
|
12 |
class LinkerEntity:
|
13 |
"""
|
14 |
Общий класс для всех сущностей в системе извлечения и сборки.
|
15 |
Поддерживает триплетный подход, где каждая сущность может опционально связывать две другие сущности.
|
16 |
-
|
17 |
Attributes:
|
18 |
id (UUID): Уникальный идентификатор сущности.
|
19 |
name (str): Название сущности.
|
20 |
text (str): Текстое представление сущности.
|
21 |
in_search_text (str | None): Текст для поиска. Если задан, используется в __str__, иначе используется обычное представление.
|
22 |
metadata (dict): Метаданные сущности.
|
23 |
-
source_id (UUID | None): Опциональный идентификатор исходной сущности.
|
24 |
Если указан, эта сущность является связью.
|
25 |
-
target_id (UUID | None): Опциональный идентификатор целевой сущности.
|
26 |
Если указан, эта сущность является связью.
|
27 |
-
number_in_relation (int | None): Используется в случае связей один-ко-многим,
|
28 |
указывает номер целевой сущности в списке.
|
29 |
type (str): Тип сущности.
|
30 |
"""
|
31 |
|
32 |
-
id: UUID
|
33 |
-
name: str
|
34 |
-
text: str
|
35 |
-
metadata: dict
|
36 |
in_search_text: str | None = None
|
37 |
source_id: UUID | None = None
|
38 |
target_id: UUID | None = None
|
39 |
number_in_relation: int | None = None
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def __post_init__(self):
|
43 |
if self.id is None:
|
44 |
self.id = uuid.uuid4()
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
(self.source_id is None and self.target_id is not None):
|
49 |
-
raise ValueError("source_id и target_id должны быть либо оба указаны, либо оба None")
|
50 |
|
51 |
def is_link(self) -> bool:
|
52 |
"""
|
53 |
Проверяет, является ли сущность связью (имеет и source_id, и target_id).
|
54 |
-
|
55 |
Returns:
|
56 |
bool: True, если сущность является связью, иначе False
|
57 |
"""
|
@@ -85,7 +104,7 @@ class LinkerEntity:
|
|
85 |
and self.text == other.text
|
86 |
and self.type == other.type
|
87 |
)
|
88 |
-
|
89 |
# Если мы имеем дело со связями, также проверяем поля связи
|
90 |
if self.is_link() or other.is_link():
|
91 |
return (
|
@@ -93,62 +112,51 @@ class LinkerEntity:
|
|
93 |
and self.source_id == other.source_id
|
94 |
and self.target_id == other.target_id
|
95 |
)
|
96 |
-
|
97 |
return basic_equality
|
98 |
|
99 |
def serialize(self) -> 'LinkerEntity':
|
100 |
"""
|
101 |
-
Сериализует сущность в
|
102 |
"""
|
103 |
-
|
104 |
-
|
|
|
105 |
|
106 |
-
#
|
107 |
-
|
108 |
-
for attr_name in dir(self):
|
109 |
-
# Пропускаем служебные атрибуты, методы и уже известные поля
|
110 |
-
if (
|
111 |
-
attr_name.startswith('_')
|
112 |
-
or attr_name in known_fields
|
113 |
-
or callable(getattr(self, attr_name))
|
114 |
-
):
|
115 |
-
continue
|
116 |
-
|
117 |
-
# Добавляем дополнительные поля в словарь
|
118 |
-
dict_entity[attr_name] = getattr(self, attr_name)
|
119 |
|
120 |
# Преобразуем имена дополнительных полей, добавляя префикс "_"
|
121 |
-
|
|
|
|
|
122 |
|
123 |
-
# Объединяем с существующими метаданными
|
124 |
-
|
125 |
|
126 |
result_type = self.type
|
127 |
if result_type == "Entity":
|
128 |
result_type = self.__class__.__name__
|
129 |
|
130 |
-
# Создаем базовый объект LinkerEntity
|
131 |
return LinkerEntity(
|
132 |
id=self.id,
|
133 |
name=self.name,
|
134 |
text=self.text,
|
135 |
in_search_text=self.in_search_text,
|
136 |
-
metadata=
|
137 |
source_id=self.source_id,
|
138 |
target_id=self.target_id,
|
139 |
number_in_relation=self.number_in_relation,
|
|
|
140 |
type=result_type,
|
141 |
)
|
142 |
|
143 |
-
|
144 |
-
@abstractmethod
|
145 |
-
def deserialize(cls, data: 'LinkerEntity') -> 'Self':
|
146 |
"""
|
147 |
-
Десериализует сущность
|
148 |
"""
|
149 |
-
|
150 |
-
f"Метод deserialize для класса {cls.__class__.__name__} не реализован"
|
151 |
-
)
|
152 |
|
153 |
# Реестр для хранения всех наследников LinkerEntity
|
154 |
_entity_classes = {}
|
@@ -157,61 +165,85 @@ class LinkerEntity:
|
|
157 |
def register_entity_class(cls, entity_class):
|
158 |
"""
|
159 |
Регистрирует класс-наследник в реестре.
|
160 |
-
|
161 |
Args:
|
162 |
entity_class: Класс для регистрации
|
163 |
"""
|
164 |
entity_type = entity_class.__name__
|
165 |
cls._entity_classes[entity_type] = entity_class
|
166 |
-
# Также регистрируем по типу, если он отличается от имени класса
|
167 |
if hasattr(entity_class, 'type') and isinstance(entity_class.type, str):
|
168 |
cls._entity_classes[entity_class.type] = entity_class
|
169 |
-
|
170 |
@classmethod
|
171 |
-
def
|
172 |
"""
|
173 |
Десериализует сущность в нужный тип на основе поля type.
|
174 |
-
|
175 |
Args:
|
176 |
data: Сериализованная сущность типа LinkerEntity
|
177 |
-
|
178 |
Returns:
|
179 |
Десериализованная сущность правильного типа
|
180 |
"""
|
181 |
# Получаем тип сущности
|
182 |
entity_type = data.type
|
183 |
-
|
184 |
# Проверяем реестр классов
|
185 |
if entity_type in cls._entity_classes:
|
186 |
try:
|
187 |
-
return cls._entity_classes[entity_type].
|
188 |
-
except
|
189 |
-
|
190 |
return data
|
191 |
-
|
192 |
-
# Если тип не найден в реестре, просто возвращаем исходную сущность
|
193 |
-
# Больше не используем опасное сканирование sys.modules
|
194 |
return data
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
# Декоратор для регистрации производных классов
|
198 |
def register_entity(cls):
|
199 |
"""
|
200 |
Декоратор для регистрации классов-наследников LinkerEntity.
|
201 |
-
|
202 |
Пример использования:
|
203 |
-
|
204 |
@register_entity
|
205 |
class MyEntity(LinkerEntity):
|
206 |
type = "my_entity"
|
207 |
-
|
208 |
Args:
|
209 |
cls: Класс, который нужно зарегистрировать
|
210 |
-
|
211 |
Returns:
|
212 |
Исходный класс (без изменений)
|
213 |
"""
|
214 |
# Регистрируем класс в реестр, используя его имя или указанный тип
|
215 |
-
entity_type =
|
216 |
LinkerEntity._entity_classes[entity_type] = cls
|
|
|
|
|
|
|
|
|
217 |
return cls
|
|
|
2 |
Базовый абстрактный класс для всех сущностей с поддержкой триплетного подхода.
|
3 |
"""
|
4 |
|
5 |
+
import logging
|
6 |
import uuid
|
|
|
7 |
from dataclasses import dataclass, field, fields
|
8 |
from uuid import UUID
|
9 |
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
|
13 |
@dataclass
|
14 |
class LinkerEntity:
|
15 |
"""
|
16 |
Общий класс для всех сущностей в системе извлечения и сборки.
|
17 |
Поддерживает триплетный подход, где каждая сущность может опционально связывать две другие сущности.
|
18 |
+
|
19 |
Attributes:
|
20 |
id (UUID): Уникальный идентификатор сущности.
|
21 |
name (str): Название сущности.
|
22 |
text (str): Текстое представление сущности.
|
23 |
in_search_text (str | None): Текст для поиска. Если задан, используется в __str__, иначе используется обычное представление.
|
24 |
metadata (dict): Метаданные сущности.
|
25 |
+
source_id (UUID | None): Опциональный идентификатор исходной сущности.
|
26 |
Если указан, эта сущность является связью.
|
27 |
+
target_id (UUID | None): Опциональный идентификатор целевой сущности.
|
28 |
Если указан, эта сущность является связью.
|
29 |
+
number_in_relation (int | None): Используется в случае связей один-ко-многим,
|
30 |
указывает номер целевой сущности в списке.
|
31 |
type (str): Тип сущности.
|
32 |
"""
|
33 |
|
34 |
+
id: UUID = field(default_factory=uuid.uuid4)
|
35 |
+
name: str = field(default="")
|
36 |
+
text: str = field(default="")
|
37 |
+
metadata: dict = field(default_factory=dict)
|
38 |
in_search_text: str | None = None
|
39 |
source_id: UUID | None = None
|
40 |
target_id: UUID | None = None
|
41 |
number_in_relation: int | None = None
|
42 |
+
groupper: str | None = None
|
43 |
+
type: str | None = None
|
44 |
+
|
45 |
+
@property
|
46 |
+
def owner_id(self) -> UUID | None:
|
47 |
+
"""
|
48 |
+
Возвращает идентификатор владельца сущности.
|
49 |
+
"""
|
50 |
+
if self.is_link():
|
51 |
+
return None
|
52 |
+
return self.target_id
|
53 |
+
|
54 |
+
@owner_id.setter
|
55 |
+
def owner_id(self, value: UUID | None):
|
56 |
+
"""
|
57 |
+
Устанавливает идентификатор владельца сущности.
|
58 |
+
"""
|
59 |
+
if self.is_link():
|
60 |
+
raise ValueError("Связь не может иметь владельца")
|
61 |
+
self.target_id = value
|
62 |
|
63 |
def __post_init__(self):
|
64 |
if self.id is None:
|
65 |
self.id = uuid.uuid4()
|
66 |
+
|
67 |
+
if self.type is None:
|
68 |
+
self.type = self.__class__.__name__
|
|
|
|
|
69 |
|
70 |
def is_link(self) -> bool:
|
71 |
"""
|
72 |
Проверяет, является ли сущность связью (имеет и source_id, и target_id).
|
73 |
+
|
74 |
Returns:
|
75 |
bool: True, если сущность является связью, иначе False
|
76 |
"""
|
|
|
104 |
and self.text == other.text
|
105 |
and self.type == other.type
|
106 |
)
|
107 |
+
|
108 |
# Если мы имеем дело со связями, также проверяем поля связи
|
109 |
if self.is_link() or other.is_link():
|
110 |
return (
|
|
|
112 |
and self.source_id == other.source_id
|
113 |
and self.target_id == other.target_id
|
114 |
)
|
115 |
+
|
116 |
return basic_equality
|
117 |
|
118 |
def serialize(self) -> 'LinkerEntity':
|
119 |
"""
|
120 |
+
Сериализует сущность в базовый класс `LinkerEntity`, сохраняя все дополнительные поля в метаданные.
|
121 |
"""
|
122 |
+
base_fields = {f.name for f in fields(LinkerEntity)}
|
123 |
+
current_fields = {f.name for f in fields(self.__class__)}
|
124 |
+
extra_field_names = current_fields - base_fields
|
125 |
|
126 |
+
# Собираем только дополнительные поля, определенные в подклассе
|
127 |
+
extra_fields_dict = {name: getattr(self, name) for name in extra_field_names}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Преобразуем имена дополнительных полей, добавляя префикс "_"
|
130 |
+
prefixed_extra_fields = {
|
131 |
+
f'_{name}': value for name, value in extra_fields_dict.items()
|
132 |
+
}
|
133 |
|
134 |
+
# Объединяем с существующими метаданными (если они были установлены вручную)
|
135 |
+
final_metadata = {**prefixed_extra_fields, **self.metadata}
|
136 |
|
137 |
result_type = self.type
|
138 |
if result_type == "Entity":
|
139 |
result_type = self.__class__.__name__
|
140 |
|
141 |
+
# Создаем базовый объект LinkerEntity
|
142 |
return LinkerEntity(
|
143 |
id=self.id,
|
144 |
name=self.name,
|
145 |
text=self.text,
|
146 |
in_search_text=self.in_search_text,
|
147 |
+
metadata=final_metadata, # Используем собранные метаданные
|
148 |
source_id=self.source_id,
|
149 |
target_id=self.target_id,
|
150 |
number_in_relation=self.number_in_relation,
|
151 |
+
groupper=self.groupper,
|
152 |
type=result_type,
|
153 |
)
|
154 |
|
155 |
+
def deserialize(self) -> 'LinkerEntity':
|
|
|
|
|
156 |
"""
|
157 |
+
Десериализует сущность в нужный тип на основе поля type.
|
158 |
"""
|
159 |
+
return self._deserialize(self)
|
|
|
|
|
160 |
|
161 |
# Реестр для хранения всех наследников LinkerEntity
|
162 |
_entity_classes = {}
|
|
|
165 |
def register_entity_class(cls, entity_class):
|
166 |
"""
|
167 |
Регистрирует класс-наследник в реестре.
|
168 |
+
|
169 |
Args:
|
170 |
entity_class: Класс для регистрации
|
171 |
"""
|
172 |
entity_type = entity_class.__name__
|
173 |
cls._entity_classes[entity_type] = entity_class
|
|
|
174 |
if hasattr(entity_class, 'type') and isinstance(entity_class.type, str):
|
175 |
cls._entity_classes[entity_class.type] = entity_class
|
176 |
+
|
177 |
@classmethod
|
178 |
+
def _deserialize(cls, data: 'LinkerEntity') -> 'LinkerEntity':
|
179 |
"""
|
180 |
Десериализует сущность в нужный тип на основе поля type.
|
181 |
+
|
182 |
Args:
|
183 |
data: Сериализованная сущность типа LinkerEntity
|
184 |
+
|
185 |
Returns:
|
186 |
Десериализованная сущность правильного типа
|
187 |
"""
|
188 |
# Получаем тип сущности
|
189 |
entity_type = data.type
|
190 |
+
|
191 |
# Проверяем реестр классов
|
192 |
if entity_type in cls._entity_classes:
|
193 |
try:
|
194 |
+
return cls._entity_classes[entity_type]._deserialize_to_me(data)
|
195 |
+
except Exception as e:
|
196 |
+
logger.error(f"Ошибка при вызове _deserialize_to_me для {entity_type}: {e}", exc_info=True)
|
197 |
return data
|
198 |
+
|
|
|
|
|
199 |
return data
|
200 |
|
201 |
+
@classmethod
|
202 |
+
def _deserialize_to_me(cls, data: 'LinkerEntity') -> 'LinkerEntity':
|
203 |
+
"""
|
204 |
+
Десериализует сущность в нужный тип на основе поля type.
|
205 |
+
"""
|
206 |
+
return cls(
|
207 |
+
id=data.id,
|
208 |
+
name=data.name,
|
209 |
+
text=data.text,
|
210 |
+
in_search_text=data.in_search_text,
|
211 |
+
metadata=data.metadata,
|
212 |
+
source_id=data.source_id,
|
213 |
+
target_id=data.target_id,
|
214 |
+
number_in_relation=data.number_in_relation,
|
215 |
+
type=data.type,
|
216 |
+
groupper=data.groupper,
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
# Алиасы для удобства
|
221 |
+
Link = LinkerEntity
|
222 |
+
Entity = LinkerEntity
|
223 |
+
|
224 |
|
225 |
# Декоратор для регистрации производных классов
|
226 |
def register_entity(cls):
|
227 |
"""
|
228 |
Декоратор для регистрации классов-наследников LinkerEntity.
|
229 |
+
|
230 |
Пример использования:
|
231 |
+
|
232 |
@register_entity
|
233 |
class MyEntity(LinkerEntity):
|
234 |
type = "my_entity"
|
235 |
+
|
236 |
Args:
|
237 |
cls: Класс, который нужно зарегистрировать
|
238 |
+
|
239 |
Returns:
|
240 |
Исходный класс (без изменений)
|
241 |
"""
|
242 |
# Регистрируем класс в реестр, используя его имя или указанный тип
|
243 |
+
entity_type = cls.__name__
|
244 |
LinkerEntity._entity_classes[entity_type] = cls
|
245 |
+
|
246 |
+
if hasattr(cls, 'type') and isinstance(cls.type, str):
|
247 |
+
LinkerEntity._entity_classes[cls.type] = cls
|
248 |
+
|
249 |
return cls
|
lib/extractor/ntr_text_fragmentation/repositories/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .entity_repository import EntityRepository, GroupedEntities
|
2 |
+
from .in_memory_repository import InMemoryEntityRepository
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"EntityRepository",
|
6 |
+
"GroupedEntities",
|
7 |
+
"InMemoryEntityRepository",
|
8 |
+
]
|
lib/extractor/ntr_text_fragmentation/repositories/entity_repository.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Интерфейс репозитория сущностей.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Generic, Iterable, Type, TypeVar
|
8 |
+
from uuid import UUID
|
9 |
+
|
10 |
+
from ..models import LinkerEntity
|
11 |
+
|
12 |
+
T = TypeVar('T', bound=LinkerEntity)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class GroupedEntities(Generic[T]):
|
17 |
+
composer: T
|
18 |
+
entities: list[LinkerEntity]
|
19 |
+
|
20 |
+
|
21 |
+
class EntityRepository(ABC):
|
22 |
+
"""
|
23 |
+
Абстрактный интерфейс для доступа к хранилищу сущностей.
|
24 |
+
Позволяет InjectionBuilder получать нужные сущности независимо от их хранилища.
|
25 |
+
|
26 |
+
Этот интерфейс определяет только методы для получения сущностей.
|
27 |
+
Логика сохранения и изменения сущностей остается за пределами этого интерфейса
|
28 |
+
и должна быть реализована в конкретных классах, расширяющих данный интерфейс.
|
29 |
+
"""
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def get_entities_by_ids(
|
33 |
+
self,
|
34 |
+
entity_ids: Iterable[UUID],
|
35 |
+
) -> list[LinkerEntity]:
|
36 |
+
"""
|
37 |
+
Получить сущности по списку идентификаторов.
|
38 |
+
Может возвращать экземпляры подклассов LinkerEntity.
|
39 |
+
"""
|
40 |
+
pass
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
def group_entities_hierarchically(
|
44 |
+
self,
|
45 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
46 |
+
root_type: Type[LinkerEntity],
|
47 |
+
max_levels: int = 10,
|
48 |
+
sort: bool = True,
|
49 |
+
) -> list[GroupedEntities[LinkerEntity]]:
|
50 |
+
"""
|
51 |
+
Группирует сущности по корневым элементам иерархии, поддерживая
|
52 |
+
многоуровневые связи (например, строка → подтаблица → таблица → документ).
|
53 |
+
|
54 |
+
Args:
|
55 |
+
entities: Список идентификаторов или сущностей для группировки
|
56 |
+
root_type: Корневой тип сущностей для группировки (например, DocumentAsEntity)
|
57 |
+
max_levels: Максимальная глубина поиска корневого элемента
|
58 |
+
sort: Флаг для сортировки сущностей в группах по их позициям
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Список групп сущностей, объединенных по корневому объекту
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
@abstractmethod
|
66 |
+
def get_neighboring_entities(
|
67 |
+
self,
|
68 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
69 |
+
max_distance: int = 1,
|
70 |
+
) -> list[LinkerEntity]:
|
71 |
+
"""
|
72 |
+
Получить соседние сущности для указанных сущностей.
|
73 |
+
Порядок определяется через CompositionLink и number_in_relation.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
entities: Список идентификаторов сущностей (UUID) или самих сущностей (LinkerEntity).
|
77 |
+
max_distance: Максимальное расстояние между сущностями (по умолчанию 1).
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Список соседних сущностей.
|
81 |
+
"""
|
82 |
+
pass
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def get_related_entities(
|
86 |
+
self,
|
87 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
88 |
+
relation_type: Type[LinkerEntity] | None = None,
|
89 |
+
as_source: bool = False,
|
90 |
+
as_target: bool = False,
|
91 |
+
as_owner: bool = False,
|
92 |
+
) -> list[LinkerEntity]:
|
93 |
+
"""
|
94 |
+
Получить сущности, связанные с указанными. Возвращает как сущности, так и связи к ним ведущие.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
entities: Список идентификаторов сущностей (UUID) или самих сущностей (LinkerEntity).
|
98 |
+
relation_type: Опциональный тип связи для фильтрации (например, CompositionLink)
|
99 |
+
as_source: Искать связи, где entities - источники
|
100 |
+
as_target: Искать связи, где entities - цели
|
101 |
+
as_owner: Искать связи, где entities - владельцы (связи-композиции)
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Список связанных сущностей и самих связей
|
105 |
+
"""
|
106 |
+
pass
|
lib/extractor/ntr_text_fragmentation/repositories/in_memory_repository.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import Iterable, Type
|
4 |
+
from uuid import UUID
|
5 |
+
|
6 |
+
from ..models import LinkerEntity
|
7 |
+
from .entity_repository import EntityRepository, GroupedEntities
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class InMemoryEntityRepository(EntityRepository):
|
13 |
+
"""
|
14 |
+
Реализация EntityRepository, хранящая все сущности в памяти.
|
15 |
+
Обеспечивает обратную совместимость и используется для тестирования.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, entities: list[LinkerEntity] | None = None):
|
19 |
+
"""
|
20 |
+
Инициализация репозитория с начальным списком сущностей.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
entities: Начальный список сущностей
|
24 |
+
"""
|
25 |
+
self.entities = entities or []
|
26 |
+
self.entities_by_id: dict[UUID, LinkerEntity] = {}
|
27 |
+
self.relations_by_source: dict[UUID, list[LinkerEntity]] = defaultdict(list)
|
28 |
+
self.relations_by_target: dict[UUID, list[LinkerEntity]] = defaultdict(list)
|
29 |
+
self.compositions: dict[UUID, list[LinkerEntity]] = defaultdict(list)
|
30 |
+
|
31 |
+
self._build_indices()
|
32 |
+
|
33 |
+
def _build_indices(self) -> None:
|
34 |
+
"""
|
35 |
+
Строит индексы для быстрого доступа.
|
36 |
+
Использует LinkerEntity.deserialize для возможной типизации связей.
|
37 |
+
"""
|
38 |
+
self.entities_by_id.clear()
|
39 |
+
self.relations_by_source.clear()
|
40 |
+
self.relations_by_target.clear()
|
41 |
+
self.compositions.clear()
|
42 |
+
|
43 |
+
for entity in self.entities:
|
44 |
+
try:
|
45 |
+
deserialized_entity = LinkerEntity._deserialize(entity)
|
46 |
+
except Exception as e:
|
47 |
+
logger.warning(f"Error deserializing entity: {e}")
|
48 |
+
deserialized_entity = entity
|
49 |
+
|
50 |
+
self.entities_by_id[deserialized_entity.id] = deserialized_entity
|
51 |
+
|
52 |
+
if deserialized_entity.is_link():
|
53 |
+
self.relations_by_source[deserialized_entity.source_id].append(
|
54 |
+
deserialized_entity
|
55 |
+
)
|
56 |
+
self.relations_by_target[deserialized_entity.target_id].append(
|
57 |
+
deserialized_entity
|
58 |
+
)
|
59 |
+
|
60 |
+
if deserialized_entity.owner_id is not None:
|
61 |
+
self.compositions[deserialized_entity.owner_id].append(
|
62 |
+
deserialized_entity
|
63 |
+
)
|
64 |
+
|
65 |
+
logger.info(f"Построены индексы для {len(self.entities)} сущностей.")
|
66 |
+
logger.info(f"Всего сущностей: {len(self.entities_by_id)}")
|
67 |
+
logger.info(f"Всего связей: {len(self.relations_by_source)}")
|
68 |
+
logger.info(f"Всего композиций: {len(self.compositions)}")
|
69 |
+
|
70 |
+
def _normalize_entities(
|
71 |
+
self, entities: Iterable[UUID] | Iterable[LinkerEntity]
|
72 |
+
) -> list[UUID]:
|
73 |
+
"""
|
74 |
+
Преобразует входные данные в список UUID.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
entities: Итерируемый объект с UUID или LinkerEntity
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
list[UUID]: Список идентификаторов
|
81 |
+
"""
|
82 |
+
result = []
|
83 |
+
for entity in entities:
|
84 |
+
if isinstance(entity, UUID):
|
85 |
+
result.append(entity)
|
86 |
+
elif isinstance(entity, LinkerEntity):
|
87 |
+
result.append(entity.id)
|
88 |
+
return result
|
89 |
+
|
90 |
+
def get_entities_by_ids(
|
91 |
+
self, entities: Iterable[UUID] | Iterable[LinkerEntity]
|
92 |
+
) -> list[LinkerEntity]:
|
93 |
+
"""
|
94 |
+
Получить сущности по списку идентификаторов или сущностей.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
entities: Список идентификаторов или сущностей
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
list[LinkerEntity]: Список найденных сущностей
|
101 |
+
"""
|
102 |
+
entity_ids = self._normalize_entities(entities)
|
103 |
+
return [
|
104 |
+
self.entities_by_id[eid] for eid in entity_ids if eid in self.entities_by_id
|
105 |
+
]
|
106 |
+
|
107 |
+
def group_entities_hierarchically(
|
108 |
+
self,
|
109 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
110 |
+
root_type: Type[LinkerEntity],
|
111 |
+
max_levels: int = 10,
|
112 |
+
sort: bool = True,
|
113 |
+
) -> list[GroupedEntities]:
|
114 |
+
"""
|
115 |
+
Группирует сущности по корневым элементам иерархии, поддерживая
|
116 |
+
многоуровневые связи (например, строка → подтаблица → таблица → документ).
|
117 |
+
|
118 |
+
Args:
|
119 |
+
entities: Список идентификаторов или сущностей для группировки
|
120 |
+
root_type: Корневой тип сущностей для группировки (например, DocumentAsEntity)
|
121 |
+
max_levels: Максимальн��я глубина поиска корневого элемента
|
122 |
+
sort: Флаг для сортировки сущностей в группах по их позициям
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Список групп сущностей, объединенных по корневому объекту
|
126 |
+
"""
|
127 |
+
entity_ids = self._normalize_entities(entities)
|
128 |
+
|
129 |
+
# Словарь для хранения найденных корневых элементов для каждой сущности
|
130 |
+
entity_to_root: dict[UUID, UUID] = {}
|
131 |
+
|
132 |
+
# Если включена сортировка, соберем информацию о позициях сущностей
|
133 |
+
entity_positions: dict[UUID, tuple[str, int]] = {}
|
134 |
+
if sort:
|
135 |
+
for entity_id in entity_ids:
|
136 |
+
entity = self.entities_by_id.get(entity_id)
|
137 |
+
if entity:
|
138 |
+
entity_positions[entity_id] = (
|
139 |
+
entity.groupper,
|
140 |
+
entity.number_in_relation,
|
141 |
+
)
|
142 |
+
|
143 |
+
# Функция для нахождения корневого элемента для сущности
|
144 |
+
def find_root(
|
145 |
+
entity_id: UUID, visited: set | None = None, level: int = 0
|
146 |
+
) -> UUID | None:
|
147 |
+
# Проверка на максимальную глубину поиска
|
148 |
+
if level >= max_levels:
|
149 |
+
return None
|
150 |
+
|
151 |
+
# Инициализация множества посещенных узлов для отслеживания пути
|
152 |
+
if visited is None:
|
153 |
+
visited = set()
|
154 |
+
|
155 |
+
# Проверка, не обрабатывали ли мы уже эту сущность
|
156 |
+
if entity_id in visited:
|
157 |
+
return None
|
158 |
+
|
159 |
+
# Добавляем текущую сущность в посещенные
|
160 |
+
visited.add(entity_id)
|
161 |
+
|
162 |
+
# Проверяем, есть ли уже найденный корень для этой сущности
|
163 |
+
if entity_id in entity_to_root:
|
164 |
+
return entity_to_root[entity_id]
|
165 |
+
|
166 |
+
# Проверяем, является ли сама сущность корневым типом
|
167 |
+
entity = self.entities_by_id.get(entity_id)
|
168 |
+
|
169 |
+
if entity and isinstance(entity, root_type):
|
170 |
+
return entity_id
|
171 |
+
|
172 |
+
# Получаем родительскую сущность через owner_id
|
173 |
+
if entity and entity.owner_id:
|
174 |
+
parent_root = find_root(entity.owner_id, visited, level + 1)
|
175 |
+
if parent_root:
|
176 |
+
return parent_root
|
177 |
+
|
178 |
+
return None
|
179 |
+
|
180 |
+
# Находим корневой элемент для каждой сущности
|
181 |
+
for entity_id in entity_ids:
|
182 |
+
root_id = find_root(entity_id)
|
183 |
+
if root_id:
|
184 |
+
entity_to_root[entity_id] = root_id
|
185 |
+
|
186 |
+
logger.info(f"Найдены корневые элементы для {len(entity_to_root)} сущностей из общего количества {len(entity_ids)}.")
|
187 |
+
|
188 |
+
# Группируем сущности по корневым элементам
|
189 |
+
root_to_entities: dict[UUID, list[LinkerEntity]] = defaultdict(list)
|
190 |
+
|
191 |
+
for entity_id in entity_ids:
|
192 |
+
if entity_id in entity_to_root:
|
193 |
+
root_id = entity_to_root[entity_id]
|
194 |
+
entity = self.entities_by_id.get(entity_id)
|
195 |
+
if entity:
|
196 |
+
root_to_entities[root_id].append(entity)
|
197 |
+
|
198 |
+
# Формируем результат
|
199 |
+
result = []
|
200 |
+
for root_id, entities_list in root_to_entities.items():
|
201 |
+
root_entity = self.entities_by_id.get(root_id)
|
202 |
+
if root_entity:
|
203 |
+
# Сортируем сущности при формировании групп, если нужно
|
204 |
+
if sort:
|
205 |
+
entities_list.sort(
|
206 |
+
key=lambda entity: entity_positions.get(
|
207 |
+
entity.id,
|
208 |
+
("", float('inf')), # Сущности без позиции в конец
|
209 |
+
)
|
210 |
+
)
|
211 |
+
|
212 |
+
result.append(
|
213 |
+
GroupedEntities(composer=root_entity, entities=entities_list)
|
214 |
+
)
|
215 |
+
|
216 |
+
return result
|
217 |
+
|
218 |
+
def get_neighboring_entities(
|
219 |
+
self,
|
220 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
221 |
+
max_distance: int = 1,
|
222 |
+
) -> list[LinkerEntity]:
|
223 |
+
"""
|
224 |
+
Получает соседние сущности в пределах указанного расстояния в рамках одной композиционной группы.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
entities: Список идентификаторов или сущностей
|
228 |
+
max_distance: Максимальное расстояние между сущностями
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
list[LinkerEntity]: Список соседних сущностей
|
232 |
+
"""
|
233 |
+
entity_ids = self._normalize_entities(entities)
|
234 |
+
|
235 |
+
if not entity_ids:
|
236 |
+
return []
|
237 |
+
|
238 |
+
entities = self.get_entities_by_ids(entity_ids)
|
239 |
+
entities = [entity for entity in entities if entity.owner_id is not None]
|
240 |
+
|
241 |
+
# Ищем соседей
|
242 |
+
neighbors = {
|
243 |
+
entity.owner_id: [
|
244 |
+
sibling
|
245 |
+
for sibling in self.compositions.get(entity.owner_id, [])
|
246 |
+
if (
|
247 |
+
sibling.groupper == entity.groupper
|
248 |
+
and abs(sibling.number_in_relation - entity.number_in_relation)
|
249 |
+
<= max_distance
|
250 |
+
)
|
251 |
+
]
|
252 |
+
for entity in entities
|
253 |
+
}
|
254 |
+
|
255 |
+
neighbors = {
|
256 |
+
owner_id: sorted(
|
257 |
+
neighbors, key=lambda x: (x.groupper, x.number_in_relation)
|
258 |
+
)
|
259 |
+
for owner_id, neighbors in neighbors.items()
|
260 |
+
}
|
261 |
+
|
262 |
+
neighbor_ids = {
|
263 |
+
owner_id: [neighbor.id for neighbor in neighbors]
|
264 |
+
for owner_id, neighbors in neighbors.items()
|
265 |
+
}
|
266 |
+
|
267 |
+
# Собираем все ID соседей и исключаем исходные сущности
|
268 |
+
all_neighbor_ids = set(sum(neighbor_ids.values(), [])) - set(entity_ids)
|
269 |
+
|
270 |
+
return self.get_entities_by_ids(all_neighbor_ids)
|
271 |
+
|
272 |
+
def get_related_entities(
|
273 |
+
self,
|
274 |
+
entities: Iterable[UUID] | Iterable[LinkerEntity],
|
275 |
+
relation_type: Type[LinkerEntity] | None = None,
|
276 |
+
as_source: bool = False,
|
277 |
+
as_target: bool = False,
|
278 |
+
as_owner: bool = False,
|
279 |
+
) -> list[LinkerEntity]:
|
280 |
+
"""
|
281 |
+
Получает связанные сущности и их связи.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
entities: Список идентификаторов или сущностей
|
285 |
+
relation_type: Тип связи для фильтрации
|
286 |
+
as_source: Искать связи, где entities являются источниками
|
287 |
+
as_target: Искать связи, где entities являются целями
|
288 |
+
as_owner: Искать связи, где entities являются владельцами (связи-композиции)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
list[LinkerEntity]: Список связанных сущностей и их связей
|
292 |
+
"""
|
293 |
+
entity_ids = self._normalize_entities(entities)
|
294 |
+
result = set()
|
295 |
+
|
296 |
+
# Если не указано направление, ищем в обе стороны
|
297 |
+
if not as_source and not as_target and not as_owner:
|
298 |
+
as_source = True
|
299 |
+
as_target = True
|
300 |
+
as_owner = True
|
301 |
+
|
302 |
+
# Поиск связей, где entities являются источниками
|
303 |
+
if as_source:
|
304 |
+
for entity_id in entity_ids:
|
305 |
+
for relation in self.relations_by_source.get(entity_id, []):
|
306 |
+
if relation_type is None or isinstance(relation, relation_type):
|
307 |
+
result.add(relation)
|
308 |
+
if relation.target_id in self.entities_by_id:
|
309 |
+
result.add(self.entities_by_id[relation.target_id])
|
310 |
+
|
311 |
+
# Поиск связей, где entities являются целями
|
312 |
+
if as_target:
|
313 |
+
for entity_id in entity_ids:
|
314 |
+
for relation in self.relations_by_target.get(entity_id, []):
|
315 |
+
if relation_type is None or isinstance(relation, relation_type):
|
316 |
+
result.add(relation)
|
317 |
+
if relation.source_id in self.entities_by_id:
|
318 |
+
result.add(self.entities_by_id[relation.source_id])
|
319 |
+
|
320 |
+
# Поиск связей, где entities являются владельцами
|
321 |
+
if as_owner:
|
322 |
+
for entity_id in entity_ids:
|
323 |
+
for child in self.compositions.get(entity_id, []):
|
324 |
+
if relation_type is None or isinstance(child, relation_type):
|
325 |
+
result.add(child)
|
326 |
+
|
327 |
+
return list(result)
|
328 |
+
|
329 |
+
def add_entities(self, entities: list[LinkerEntity]) -> None:
|
330 |
+
"""Добавляет сущности в репозиторий и перестраивает индексы."""
|
331 |
+
self.entities.extend(entities)
|
332 |
+
self._build_indices()
|
333 |
+
|
334 |
+
def set_entities(self, entities: list[LinkerEntity]) -> None:
|
335 |
+
"""Устанавливает сущности в репозиторий и перестраивает индексы."""
|
336 |
+
self.entities = entities
|
337 |
+
self._build_indices()
|
lib/extractor/scripts/test_chunking.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Скрипт для визуального тестирования процесса чанкинга и сборки документа.
|
4 |
+
|
5 |
+
Этот скрипт:
|
6 |
+
1. Считывает test_input/test.docx с помощью UniversalParser
|
7 |
+
2. Чанкит документ через Destructurer с fixed_size-стратегией
|
8 |
+
3. Сохраняет результат чанкинга в test_output/test.csv
|
9 |
+
4. Выбирает 20-30 случайных чанков из CSV
|
10 |
+
5. Создает InjectionBuilder с InMemoryEntityRepository
|
11 |
+
6. Собирает текст из выбранных чанков
|
12 |
+
7. Сохраняет результат в test_output/test_builded.txt
|
13 |
+
"""
|
14 |
+
|
15 |
+
import json
|
16 |
+
import logging
|
17 |
+
import os
|
18 |
+
import random
|
19 |
+
from pathlib import Path
|
20 |
+
from typing import List
|
21 |
+
from uuid import UUID
|
22 |
+
|
23 |
+
import pandas as pd
|
24 |
+
from ntr_fileparser import UniversalParser
|
25 |
+
from ntr_text_fragmentation import (DocumentAsEntity, EntitiesExtractor,
|
26 |
+
InjectionBuilder, InMemoryEntityRepository,
|
27 |
+
LinkerEntity)
|
28 |
+
|
29 |
+
|
30 |
+
def setup_logging() -> None:
|
31 |
+
"""Настройка логгирования."""
|
32 |
+
logging.basicConfig(
|
33 |
+
level=logging.INFO,
|
34 |
+
format="%(asctime)s - %(levelname)s - [%(pathname)s:%(lineno)d] - %(message)s",
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def ensure_directories() -> None:
|
39 |
+
"""Проверка наличия необходимых директорий."""
|
40 |
+
for directory in ["test_input", "test_output"]:
|
41 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
|
44 |
+
def save_entities_to_csv(entities: List[LinkerEntity], csv_path: str) -> None:
|
45 |
+
"""
|
46 |
+
Сохраняет сущности в CSV файл.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
entities: Список сущностей
|
50 |
+
csv_path: Путь для сохранения CSV файла
|
51 |
+
"""
|
52 |
+
data = []
|
53 |
+
for entity in entities:
|
54 |
+
# Базовые поля для всех типов сущностей
|
55 |
+
entity_dict = {
|
56 |
+
"id": str(entity.id),
|
57 |
+
"type": entity.type,
|
58 |
+
"name": entity.name,
|
59 |
+
"text": entity.text,
|
60 |
+
"metadata": json.dumps(entity.metadata or {}, ensure_ascii=False),
|
61 |
+
"in_search_text": entity.in_search_text,
|
62 |
+
"source_id": str(entity.source_id) if entity.source_id else None,
|
63 |
+
"target_id": str(entity.target_id) if entity.target_id else None,
|
64 |
+
"number_in_relation": entity.number_in_relation,
|
65 |
+
"groupper": entity.groupper,
|
66 |
+
"type": entity.type,
|
67 |
+
}
|
68 |
+
|
69 |
+
# Дополнительные поля специфичные для подклассов (если они есть в __dict__)
|
70 |
+
# Это не самый надежный способ, но для скрипта визуализации может подойти
|
71 |
+
# Сериализация LinkerEntity теперь должна сама класть доп поля в metadata
|
72 |
+
# for key, value in entity.__dict__.items():
|
73 |
+
# if key not in entity_dict and not key.startswith('_'):
|
74 |
+
# entity_dict[key] = value
|
75 |
+
|
76 |
+
data.append(entity_dict)
|
77 |
+
|
78 |
+
df = pd.DataFrame(data)
|
79 |
+
# Указываем кодировку UTF-8 при записи CSV
|
80 |
+
df.to_csv(csv_path, index=False, encoding='utf-8')
|
81 |
+
logging.info(f"Сохранено {len(entities)} сущностей в {csv_path}")
|
82 |
+
|
83 |
+
|
84 |
+
def load_entities_from_csv(csv_path: str) -> List[LinkerEntity]:
|
85 |
+
"""
|
86 |
+
Загружает сущности из CSV файла.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
csv_path: Путь к CSV файлу
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Список сущностей
|
93 |
+
"""
|
94 |
+
df = pd.read_csv(csv_path)
|
95 |
+
entities = []
|
96 |
+
|
97 |
+
for _, row in df.iterrows():
|
98 |
+
# Обработка метаданных
|
99 |
+
metadata_str = row.get("metadata", "{}")
|
100 |
+
try:
|
101 |
+
# Используем json.loads для парсинга JSON строки
|
102 |
+
metadata = (
|
103 |
+
json.loads(metadata_str)
|
104 |
+
if pd.notna(metadata_str) and metadata_str
|
105 |
+
else {}
|
106 |
+
)
|
107 |
+
except json.JSONDecodeError: # Ловим ошибку JSON
|
108 |
+
logging.warning(
|
109 |
+
f"Не удалось распарсить метаданные JSON: {metadata_str}. Используется пустой словарь."
|
110 |
+
)
|
111 |
+
metadata = {}
|
112 |
+
|
113 |
+
# Общие поля для всех типов сущностей
|
114 |
+
# Преобразуем ID обратно в UUID
|
115 |
+
entity_id = row['id']
|
116 |
+
if isinstance(entity_id, str):
|
117 |
+
try:
|
118 |
+
entity_id = UUID(entity_id)
|
119 |
+
except ValueError:
|
120 |
+
logging.warning(
|
121 |
+
f"Неверный формат UUID для id: {entity_id}. Пропускаем сущность."
|
122 |
+
)
|
123 |
+
continue
|
124 |
+
|
125 |
+
common_args = {
|
126 |
+
"id": entity_id,
|
127 |
+
"name": row["name"] if pd.notna(row.get("name")) else "",
|
128 |
+
"text": row["text"] if pd.notna(row.get("text")) else "",
|
129 |
+
"metadata": metadata,
|
130 |
+
"in_search_text": (
|
131 |
+
row["in_search_text"] if pd.notna(row.get('in_search_text')) else None
|
132 |
+
),
|
133 |
+
"type": (
|
134 |
+
row["type"] if pd.notna(row.get('type')) else LinkerEntity.__name__
|
135 |
+
), # Используем базовый тип, если не указан
|
136 |
+
"groupper": row["groupper"] if pd.notna(row.get("groupper")) else None,
|
137 |
+
}
|
138 |
+
|
139 |
+
# Добавляем поля связи, если они есть, преобразуя в UUID
|
140 |
+
source_id_str = row.get("source_id")
|
141 |
+
target_id_str = row.get("target_id")
|
142 |
+
|
143 |
+
if pd.notna(source_id_str):
|
144 |
+
try:
|
145 |
+
common_args["source_id"] = UUID(source_id_str)
|
146 |
+
except ValueError:
|
147 |
+
logging.warning(
|
148 |
+
f"Неверный формат UUID для source_id: {source_id_str}. Пропускаем поле."
|
149 |
+
)
|
150 |
+
if pd.notna(target_id_str):
|
151 |
+
try:
|
152 |
+
common_args["target_id"] = UUID(target_id_str)
|
153 |
+
except ValueError:
|
154 |
+
logging.warning(
|
155 |
+
f"Неверный формат UUID для target_id: {target_id_str}. Пропускаем поле."
|
156 |
+
)
|
157 |
+
|
158 |
+
if pd.notna(row.get("number_in_relation")):
|
159 |
+
try:
|
160 |
+
common_args["number_in_relation"] = int(row["number_in_relation"])
|
161 |
+
except ValueError:
|
162 |
+
logging.warning(
|
163 |
+
f"Неверный формат для number_in_relation: {row['number_in_relation']}. Пропускаем поле."
|
164 |
+
)
|
165 |
+
|
166 |
+
# Пытаемся десериализовать в конкретный тип, если он известен
|
167 |
+
entity_class = LinkerEntity._entity_classes.get(
|
168 |
+
common_args["type"], LinkerEntity
|
169 |
+
)
|
170 |
+
try:
|
171 |
+
# Создаем экземпляр, передавая только те аргументы, которые ожидает класс
|
172 |
+
# (используя LinkerEntity._deserialize_to_me как пример, но нужно убедиться,
|
173 |
+
# что он принимает все нужные поля или имеет **kwargs)
|
174 |
+
# Пока создаем базовый LinkerEntity, т.к. подклассы могут требовать специфичные поля
|
175 |
+
# которых нет в CSV или в common_args
|
176 |
+
entity = LinkerEntity(**common_args)
|
177 |
+
# Если нужно строгое восстановление типов, потребуется более сложная логика
|
178 |
+
# с проверкой полей каждого подкласса
|
179 |
+
except TypeError as e:
|
180 |
+
logging.warning(
|
181 |
+
f"Ошибка создания экземпляра {entity_class.__name__} для ID {common_args['id']}: {e}. Создан базовый LinkerEntity."
|
182 |
+
)
|
183 |
+
entity = LinkerEntity(**common_args) # Откат к базовому классу
|
184 |
+
|
185 |
+
entities.append(entity)
|
186 |
+
|
187 |
+
logging.info(f"Загружено {len(entities)} сущностей из {csv_path}")
|
188 |
+
return entities
|
189 |
+
|
190 |
+
|
191 |
+
def main() -> None:
|
192 |
+
"""Основная функция скрипта."""
|
193 |
+
setup_logging()
|
194 |
+
ensure_directories()
|
195 |
+
|
196 |
+
# Пути к файлам
|
197 |
+
input_doc_path = "test_input/test2.docx"
|
198 |
+
output_csv_path = "test_output/test2.csv"
|
199 |
+
output_text_path = "test_output/test2.md"
|
200 |
+
|
201 |
+
# Проверка наличия входного файла
|
202 |
+
if not os.path.exists(input_doc_path):
|
203 |
+
logging.error(f"Файл {input_doc_path} не найден!")
|
204 |
+
return
|
205 |
+
|
206 |
+
logging.info(f"Парсинг документа {input_doc_path}")
|
207 |
+
|
208 |
+
try:
|
209 |
+
# Шаг 1: Парсинг документа дважды, как если бы это были два разных документа
|
210 |
+
parser = UniversalParser()
|
211 |
+
document1 = parser.parse_by_path(input_doc_path)
|
212 |
+
document2 = parser.parse_by_path(input_doc_path)
|
213 |
+
|
214 |
+
# Меняем название второго документа, чтобы отличить его
|
215 |
+
document2.name = document2.name + "_copy" if document2.name else "copy_doc"
|
216 |
+
|
217 |
+
# Шаг 2: Чанкинг и извлечение таблиц с использованием EntitiesExtractor
|
218 |
+
all_entities = []
|
219 |
+
|
220 |
+
# Обработка первого документа
|
221 |
+
logging.info("Начало процесса деструктуризации первого документа")
|
222 |
+
# Инициализируем экстрактор без документа (используем дефолтные настройки или настроим позже)
|
223 |
+
extractor1 = EntitiesExtractor()
|
224 |
+
# Настройка чанкинга
|
225 |
+
extractor1.configure_chunking(
|
226 |
+
strategy_name="fixed_size",
|
227 |
+
strategy_params={
|
228 |
+
"words_per_chunk": 50,
|
229 |
+
"overlap_words": 25,
|
230 |
+
"respect_sentence_boundaries": True, # Добавлено по запросу
|
231 |
+
},
|
232 |
+
)
|
233 |
+
# Настройка извлечения таблиц
|
234 |
+
extractor1.configure_tables_extraction(process_tables=True)
|
235 |
+
# Выполнение деструктуризации
|
236 |
+
entities1 = extractor1.extract(document1)
|
237 |
+
|
238 |
+
# Находим ID документа 1
|
239 |
+
doc1_entity = next((e for e in entities1 if e.type == DocumentAsEntity.__name__), None)
|
240 |
+
if not doc1_entity:
|
241 |
+
logging.error("Не удалось найти DocumentAsEntity для первого документа!")
|
242 |
+
return
|
243 |
+
doc1_id = doc1_entity.id
|
244 |
+
logging.info(f"ID первого документа: {doc1_id}")
|
245 |
+
|
246 |
+
logging.info(f"Получено {len(entities1)} сущностей из первого документа")
|
247 |
+
all_entities.extend(entities1)
|
248 |
+
|
249 |
+
# Обработка второго документа
|
250 |
+
logging.info("Начало процесса деструктуризации второго документа")
|
251 |
+
# Инициализируем экстрактор без документа
|
252 |
+
extractor2 = EntitiesExtractor()
|
253 |
+
# Настройка чанкинга (те же параметры)
|
254 |
+
extractor2.configure_chunking(
|
255 |
+
strategy_name="fixed_size",
|
256 |
+
strategy_params={
|
257 |
+
"words_per_chunk": 50,
|
258 |
+
"overlap_words": 25,
|
259 |
+
"respect_sentence_boundaries": True,
|
260 |
+
},
|
261 |
+
)
|
262 |
+
# Настройка извлечения таблиц
|
263 |
+
extractor2.configure_tables_extraction(process_tables=True)
|
264 |
+
# Выполнение деструктуризации
|
265 |
+
entities2 = extractor2.extract(document2)
|
266 |
+
|
267 |
+
# Находим ID документа 2
|
268 |
+
doc2_entity = next((e for e in entities2 if e.type == DocumentAsEntity.__name__), None)
|
269 |
+
if not doc2_entity:
|
270 |
+
logging.error("Не удалось найти DocumentAsEntity для второго документа!")
|
271 |
+
return
|
272 |
+
doc2_id = doc2_entity.id
|
273 |
+
logging.info(f"ID второго документа: {doc2_id}")
|
274 |
+
|
275 |
+
logging.info(f"Получено {len(entities2)} сущностей из второго документа")
|
276 |
+
all_entities.extend(entities2)
|
277 |
+
|
278 |
+
logging.info(
|
279 |
+
f"Всего получено {len(all_entities)} сущностей из обоих документов"
|
280 |
+
)
|
281 |
+
|
282 |
+
# Шаг 3: Сохранение результатов чанкинга в CSV
|
283 |
+
save_entities_to_csv(all_entities, output_csv_path)
|
284 |
+
|
285 |
+
# Шаг 4: Загрузка сущностей из CSV и выбор случайных чанков
|
286 |
+
loaded_entities = load_entities_from_csv(output_csv_path)
|
287 |
+
|
288 |
+
# Шаг 5: Создание InjectionBuilder с InMemoryEntityRepository
|
289 |
+
# Сначала создаем репозиторий со ВСЕМИ загруженными сущностями
|
290 |
+
repository = InMemoryEntityRepository(loaded_entities)
|
291 |
+
builder = InjectionBuilder(repository=repository)
|
292 |
+
|
293 |
+
# Фильтрация только чанков (сущностей с in_search_text)
|
294 |
+
# Убедимся, что работаем с десериализованными сущностями из репозитория
|
295 |
+
# (Репозиторий уже десериализует при инициализации, если нужно)
|
296 |
+
all_entities_from_repo = repository.get_entities_by_ids(
|
297 |
+
[e.id for e in loaded_entities]
|
298 |
+
)
|
299 |
+
# Выбираем все сущности с in_search_text
|
300 |
+
selectable_entities = [
|
301 |
+
e for e in all_entities_from_repo if e.in_search_text is not None
|
302 |
+
]
|
303 |
+
|
304 |
+
# Выбор случайных сущностей (от 20 до 30, но не более доступных)
|
305 |
+
num_entities_to_select = min(random.randint(100, 500), len(selectable_entities))
|
306 |
+
if num_entities_to_select > 0:
|
307 |
+
selected_entities = random.sample(
|
308 |
+
selectable_entities, num_entities_to_select
|
309 |
+
)
|
310 |
+
selected_ids = [entity.id for entity in selected_entities]
|
311 |
+
logging.info(
|
312 |
+
f"Выбрано {len(selected_ids)} случайных ID сущностей (с in_search_text) для сборки"
|
313 |
+
)
|
314 |
+
|
315 |
+
# Дополнительная статистика по документам
|
316 |
+
# Используем репозиторий для получения информации о владельцах
|
317 |
+
selected_entities_details = repository.get_entities_by_ids(selected_ids)
|
318 |
+
# Считаем на основе owner_id
|
319 |
+
doc1_entities_count = sum(1 for e in selected_entities_details if e.owner_id == doc1_id)
|
320 |
+
doc2_entities_count = sum(1 for e in selected_entities_details if e.owner_id == doc2_id)
|
321 |
+
other_owner_count = len(selected_entities_details) - (doc1_entities_count + doc2_entities_count)
|
322 |
+
|
323 |
+
logging.info(
|
324 |
+
f"Из них {doc1_entities_count} принадлежат первому документу (ID: {doc1_id}), "
|
325 |
+
f"{doc2_entities_count} второму (ID: {doc2_id}) (на основе owner_id). "
|
326 |
+
f"{other_owner_count} имеют другого владельца (вероятно, таблицы/строки)."
|
327 |
+
)
|
328 |
+
|
329 |
+
else:
|
330 |
+
logging.warning("Не найдено сущностей с in_search_text для выбора.")
|
331 |
+
selected_ids = []
|
332 |
+
selected_entities = [] # Добавлено для ясности
|
333 |
+
|
334 |
+
# Шаг 6: Сборка текста из выбранных ID
|
335 |
+
logging.info("Начало сборки текста из выбранных ID")
|
336 |
+
# Передаем ID, а не сущности, т.к. builder сам их получит из репозитория
|
337 |
+
assembled_text = builder.build(
|
338 |
+
selected_ids, include_tables=True
|
339 |
+
) # Включаем таблицы
|
340 |
+
|
341 |
+
# Шаг 7: Сохранение результата в файл
|
342 |
+
with open(output_text_path, "w", encoding="utf-8") as f:
|
343 |
+
f.write(assembled_text.replace('\n', '\n\n'))
|
344 |
+
|
345 |
+
logging.info(f"Результат сборки сохранен в {output_text_path}")
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
logging.error(f"Произошла ошибка: {e}", exc_info=True)
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
main()
|
lib/extractor/tests/chunking/test_chunking_registry.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit-тесты для реестра стратегий чанкинга _ChunkingRegistry.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import pytest
|
6 |
+
from ntr_text_fragmentation.chunking import (ChunkingStrategy,
|
7 |
+
_ChunkingRegistry,
|
8 |
+
chunking_registry,
|
9 |
+
register_chunking_strategy)
|
10 |
+
|
11 |
+
|
12 |
+
# Фикстуры
|
13 |
+
class MockStrategy(ChunkingStrategy):
|
14 |
+
"""Мок-стратегия для тестов."""
|
15 |
+
|
16 |
+
def chunk(self, document, doc_entity):
|
17 |
+
pass
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def dechunk(cls, repository, filtered_entities):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
@pytest.fixture
|
25 |
+
def clean_registry() -> _ChunkingRegistry:
|
26 |
+
"""Фикстура для получения чистого экземпляра реестра."""
|
27 |
+
# Создаем новый экземпляр, чтобы не влиять на глобальный chunking_registry
|
28 |
+
return _ChunkingRegistry()
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.fixture
|
32 |
+
def populated_registry(clean_registry: _ChunkingRegistry) -> _ChunkingRegistry:
|
33 |
+
"""Фикстура для реестра с зарегистрированными стратегиями."""
|
34 |
+
clean_registry.register("mock1", MockStrategy)
|
35 |
+
clean_registry.register("mock2", MockStrategy)
|
36 |
+
return clean_registry
|
37 |
+
|
38 |
+
|
39 |
+
# Тесты
|
40 |
+
def test_register(clean_registry: _ChunkingRegistry):
|
41 |
+
"""Тест регистрации стратегии."""
|
42 |
+
assert len(clean_registry) == 0
|
43 |
+
clean_registry.register("test_strategy", MockStrategy)
|
44 |
+
assert len(clean_registry) == 1
|
45 |
+
assert "test_strategy" in clean_registry
|
46 |
+
assert clean_registry.get("test_strategy") is MockStrategy
|
47 |
+
|
48 |
+
def test_get(populated_registry: _ChunkingRegistry):
|
49 |
+
"""Тест получения стратегии по имени."""
|
50 |
+
strategy = populated_registry.get("mock1")
|
51 |
+
assert strategy is MockStrategy
|
52 |
+
|
53 |
+
# Тест получения несуществующей стратегии
|
54 |
+
with pytest.raises(KeyError):
|
55 |
+
populated_registry.get("nonexistent")
|
56 |
+
|
57 |
+
def test_getitem(populated_registry: _ChunkingRegistry):
|
58 |
+
"""Тест получения стратегии через __getitem__."""
|
59 |
+
strategy = populated_registry["mock1"]
|
60 |
+
assert strategy is MockStrategy
|
61 |
+
|
62 |
+
# Тест получения несуществующей стратегии
|
63 |
+
with pytest.raises(KeyError):
|
64 |
+
_ = populated_registry["nonexistent"]
|
65 |
+
|
66 |
+
def test_get_names(populated_registry: _ChunkingRegistry):
|
67 |
+
"""Тест получения списка имен зарегистрированных стратегий."""
|
68 |
+
names = populated_registry.get_names()
|
69 |
+
assert isinstance(names, list)
|
70 |
+
assert len(names) == 2
|
71 |
+
assert "mock1" in names
|
72 |
+
assert "mock2" in names
|
73 |
+
|
74 |
+
def test_len(populated_registry: _ChunkingRegistry):
|
75 |
+
"""Тест получения количества зарегистрированных стратегий."""
|
76 |
+
assert len(populated_registry) == 2
|
77 |
+
|
78 |
+
def test_contains(populated_registry: _ChunkingRegistry):
|
79 |
+
"""Тест проверки наличия стратегии."""
|
80 |
+
assert "mock1" in populated_registry
|
81 |
+
assert "nonexistent" not in populated_registry
|
82 |
+
# Проверка по самому классу стратегии (экземпляры не хранятся)
|
83 |
+
assert MockStrategy in populated_registry
|
84 |
+
class AnotherStrategy(ChunkingStrategy): # type: ignore
|
85 |
+
def chunk(self, document, doc_entity): pass
|
86 |
+
@classmethod
|
87 |
+
def dechunk(cls, repository, filtered_entities): pass
|
88 |
+
assert AnotherStrategy not in populated_registry
|
89 |
+
|
90 |
+
def test_decorator_register():
|
91 |
+
"""Тест декоратора register_chunking_strategy."""
|
92 |
+
# Сохраняем текущее состояние глобального реестра
|
93 |
+
original_registry_state = chunking_registry._chunking_strategies.copy()
|
94 |
+
original_len = len(chunking_registry)
|
95 |
+
|
96 |
+
@register_chunking_strategy("decorated_strategy")
|
97 |
+
class DecoratedStrategy(ChunkingStrategy):
|
98 |
+
def chunk(self, document, doc_entity):
|
99 |
+
pass
|
100 |
+
@classmethod
|
101 |
+
def dechunk(cls, repository, filtered_entities):
|
102 |
+
pass
|
103 |
+
|
104 |
+
assert len(chunking_registry) == original_len + 1
|
105 |
+
assert "decorated_strategy" in chunking_registry
|
106 |
+
assert chunking_registry.get("decorated_strategy") is DecoratedStrategy
|
107 |
+
|
108 |
+
# Тест регистрации с именем по умолчанию (имя класса)
|
109 |
+
@register_chunking_strategy()
|
110 |
+
class DefaultNameStrategy(ChunkingStrategy):
|
111 |
+
def chunk(self, document, doc_entity):
|
112 |
+
pass
|
113 |
+
@classmethod
|
114 |
+
def dechunk(cls, repository, filtered_entities):
|
115 |
+
pass
|
116 |
+
|
117 |
+
assert len(chunking_registry) == original_len + 2
|
118 |
+
assert "DefaultNameStrategy" in chunking_registry
|
119 |
+
assert chunking_registry.get("DefaultNameStrategy") is DefaultNameStrategy
|
120 |
+
|
121 |
+
# Восстанавливаем исходное состояние глоб��льного реестра
|
122 |
+
chunking_registry._chunking_strategies = original_registry_state
|
lib/extractor/tests/chunking/test_fixed_size_chunking.py
CHANGED
@@ -1,334 +1,355 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import pytest
|
4 |
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
]
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
)
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
@pytest.fixture
|
43 |
-
def
|
44 |
-
"""
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
),
|
64 |
-
ParsedTextBlock(
|
65 |
-
text="Седьмой параграф содержит несколько предложений разной длины. Каждое предложение имеет свою структуру. И все они должны корректно обрабатываться."
|
66 |
-
),
|
67 |
-
ParsedTextBlock(
|
68 |
-
text="Восьмой параграф начинается с длинного предложения, которое также должно быть разбито на несколько чанков, так как оно содержит много слов и не помещается в один чанк стандартного размера. А затем идет короткое предложение."
|
69 |
-
),
|
70 |
-
ParsedTextBlock(
|
71 |
-
text="Девятый параграф. Содержит разные предложения. С разной пунктуацией. И разной структурой."
|
72 |
-
),
|
73 |
-
ParsedTextBlock(
|
74 |
-
text="Десятый параграф начинается с короткого предложения. Затем идет длинное предложение, которое должно быть разбито на несколько чанков, потому что оно содержит много слов и не помещается в один чанк стандартного размера. И заканчивается коротким предложением."
|
75 |
-
),
|
76 |
]
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
)
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
# Разбиваем документ
|
128 |
-
entities = strategy.chunk(doc, doc_entity)
|
129 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
130 |
-
|
131 |
-
# Проверяем, что длинное предложение было разбито на несколько чанков
|
132 |
-
assert len(chunks) > 1
|
133 |
-
|
134 |
-
# Собираем документ обратно
|
135 |
-
result_text = strategy.dechunk(chunks)
|
136 |
-
|
137 |
-
# Проверяем корректность сборки
|
138 |
-
original_words = set(text.split())
|
139 |
-
result_words = set(result_text.split())
|
140 |
-
assert original_words.issubset(result_words)
|
141 |
-
|
142 |
-
# Проверяем, что все предложения сохранились
|
143 |
-
original_sentences = set(s.strip() for s in text.split('.'))
|
144 |
-
result_sentences = set(s.strip() for s in result_text.split('.'))
|
145 |
-
assert original_sentences.issubset(result_sentences)
|
146 |
-
|
147 |
-
def test_empty_document(self, doc_entity):
|
148 |
-
"""Тест обработки пустого документа."""
|
149 |
-
doc = ParsedDocument(name="empty.txt", type="text", paragraphs=[])
|
150 |
-
|
151 |
-
strategy = FixedSizeChunkingStrategy()
|
152 |
-
|
153 |
-
# Разбиваем документ
|
154 |
-
entities = strategy.chunk(doc, doc_entity)
|
155 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
156 |
-
|
157 |
-
# Проверяем, что чанков нет
|
158 |
-
assert len(chunks) == 0
|
159 |
-
|
160 |
-
# Проверяем, что сборка пустого документа возвращает пустую строку
|
161 |
-
result_text = strategy.dechunk(chunks)
|
162 |
-
assert result_text == ""
|
163 |
-
|
164 |
-
def test_special_characters_and_punctuation(self, doc_entity):
|
165 |
-
"""Тест обработки текста со специальными символами и пунктуацией."""
|
166 |
-
text = (
|
167 |
-
"Текст с разными символами: !@#$%^&*(). "
|
168 |
-
"Скобки (внутри) и [квадратные]. "
|
169 |
-
"Кавычки «елочки» и \"прямые\". "
|
170 |
-
"Тире — и дефис-. "
|
171 |
-
"Многоточие... и запятые, в разных местах."
|
172 |
-
)
|
173 |
-
doc = ParsedDocument(
|
174 |
-
name="test_document.txt",
|
175 |
-
type="text",
|
176 |
-
paragraphs=[ParsedTextBlock(text=text)],
|
177 |
-
)
|
178 |
-
|
179 |
-
strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2)
|
180 |
-
|
181 |
-
# Разбиваем документ
|
182 |
-
entities = strategy.chunk(doc, doc_entity)
|
183 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
184 |
-
|
185 |
-
# Собираем документ обратно
|
186 |
-
result_text = strategy.dechunk(chunks)
|
187 |
-
|
188 |
-
# Проверяем, что все специальные символы сохранились
|
189 |
-
special_chars = set('!@#$%^&*()[]«»"—...')
|
190 |
-
result_chars = set(result_text)
|
191 |
-
assert special_chars.issubset(result_chars)
|
192 |
-
|
193 |
-
# Проверяем, что текст совпадает с оригиналом
|
194 |
-
assert result_text == text
|
195 |
-
|
196 |
-
def test_large_document_chunking(self, large_document, doc_entity):
|
197 |
-
"""Тест нарезки и сборки большого документа с множеством параграфов."""
|
198 |
-
strategy = FixedSizeChunkingStrategy(words_per_chunk=20, overlap_words=5)
|
199 |
-
|
200 |
-
# Разбиваем документ
|
201 |
-
entities = strategy.chunk(large_document, doc_entity)
|
202 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
203 |
-
|
204 |
-
# Проверяем, что документ был разбит на несколько чанков
|
205 |
-
assert len(chunks) > 1
|
206 |
-
|
207 |
-
# Собираем документ обратно
|
208 |
-
result_text = strategy.dechunk(chunks)
|
209 |
-
|
210 |
-
# Получаем оригинальный текст
|
211 |
-
original_paragraphs = [p.text for p in large_document.paragraphs]
|
212 |
-
|
213 |
-
# Проверяем, что все параграфы сохранились
|
214 |
-
result_paragraphs = result_text.split('\n')
|
215 |
-
assert len(result_paragraphs) == len(original_paragraphs)
|
216 |
-
|
217 |
-
# Проверяем, что каждый параграф совпадает с оригиналом
|
218 |
-
for orig, res in zip(original_paragraphs, result_paragraphs):
|
219 |
-
assert orig.strip() == res.strip()
|
220 |
-
|
221 |
-
def test_exact_text_comparison(self, sample_document, doc_entity):
|
222 |
-
"""Тест точного сравнения текстов после нарезки и сборки."""
|
223 |
-
strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2)
|
224 |
-
|
225 |
-
# Разбиваем документ
|
226 |
-
entities = strategy.chunk(sample_document, doc_entity)
|
227 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
228 |
-
|
229 |
-
# Собираем документ обратно
|
230 |
-
result_text = strategy.dechunk(chunks)
|
231 |
-
|
232 |
-
# Получаем оригинальный текст по параграфам
|
233 |
-
original_paragraphs = [p.text for p in sample_document.paragraphs]
|
234 |
-
|
235 |
-
# Проверяем, что все параграфы сохранились
|
236 |
-
result_paragraphs = result_text.split('\n')
|
237 |
-
assert len(result_paragraphs) == len(original_paragraphs)
|
238 |
-
|
239 |
-
# Проверяем, что каждый параграф совпадает с оригиналом
|
240 |
-
for orig, res in zip(original_paragraphs, result_paragraphs):
|
241 |
-
assert orig.strip() == res.strip()
|
242 |
-
|
243 |
-
def test_non_sequential_chunks(self, large_document, doc_entity):
|
244 |
-
"""Тест обработки непоследовательных чанков с вставкой многоточий."""
|
245 |
-
strategy = FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=2)
|
246 |
-
|
247 |
-
# Разбиваем документ
|
248 |
-
entities = strategy.chunk(large_document, doc_entity)
|
249 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
250 |
-
|
251 |
-
# Проверяем, что получили достаточное количество чанков
|
252 |
-
assert len(chunks) >= 5, "Для теста нужно не менее 5 чанков"
|
253 |
-
|
254 |
-
# Отсортируем чанки по индексу
|
255 |
-
sorted_chunks = sorted(chunks, key=lambda c: c.chunk_index or 0)
|
256 |
-
|
257 |
-
# Выберем несколько несмежных чанков (например, 0, 1, 3, 4, 7)
|
258 |
-
selected_indices = [0, 1, 3, 4, 7]
|
259 |
-
selected_chunks = [sorted_chunks[i] for i in selected_indices if i < len(sorted_chunks)]
|
260 |
-
|
261 |
-
# Перемешаем чанки, чтобы убедиться, что сортировка работает
|
262 |
-
import random
|
263 |
-
random.shuffle(selected_chunks)
|
264 |
-
|
265 |
-
# Собираем документ из несмежных чанков
|
266 |
-
result_text = strategy.dechunk(selected_chunks)
|
267 |
-
|
268 |
-
# Проверяем наличие многоточий между непоследовательными чанками
|
269 |
-
assert "\n\n...\n\n" in result_text, "В тексте должно быть многоточие между непоследовательными чанками"
|
270 |
-
|
271 |
-
# Подсчитываем количество многоточий, должно быть 2 группы разрыва (между 1-3 и 4-7)
|
272 |
-
ellipsis_count = result_text.count("\n\n...\n\n")
|
273 |
-
assert ellipsis_count == 2, f"Ожидалось 2 многоточия, получено {ellipsis_count}"
|
274 |
-
|
275 |
-
# Проверяем, что чанки с индексами 0 и 1 идут без многоточия между ними
|
276 |
-
# Для этого находим текст первого чанка и проверяем, что после него нет многоточия
|
277 |
-
first_chunk_text = sorted_chunks[0].text
|
278 |
-
second_chunk_text = sorted_chunks[1].text
|
279 |
-
|
280 |
-
# Проверяем, что текст первого чанка не заканчивается многоточием
|
281 |
-
first_chunk_position = result_text.find(first_chunk_text)
|
282 |
-
second_chunk_position = result_text.find(second_chunk_text, first_chunk_position)
|
283 |
-
|
284 |
-
# Текст между первым и вторым чанком не должен содержать многоточие
|
285 |
-
text_between = result_text[first_chunk_position + len(first_chunk_text):second_chunk_position]
|
286 |
-
assert "\n\n...\n\n" not in text_between, "Не должно быть многоточия между последовательными чанками"
|
287 |
-
|
288 |
-
def test_overlap_addition_in_dechunk(self, large_document, doc_entity):
|
289 |
-
"""Тест добавления нахлеста при сборке чанков."""
|
290 |
-
strategy = FixedSizeChunkingStrategy(words_per_chunk=15, overlap_words=5)
|
291 |
-
|
292 |
-
# Разбиваем документ
|
293 |
-
entities = strategy.chunk(large_document, doc_entity)
|
294 |
-
chunks = [e for e in entities if e.type == "FixedSizeChunk"]
|
295 |
-
|
296 |
-
# Отбираем несколько чанков с непустыми overlap_left и overlap_right
|
297 |
-
overlapping_chunks = []
|
298 |
-
for chunk in chunks:
|
299 |
-
if hasattr(chunk, 'overlap_left') and hasattr(chunk, 'overlap_right'):
|
300 |
-
if chunk.overlap_left and chunk.overlap_right:
|
301 |
-
overlapping_chunks.append(chunk)
|
302 |
-
if len(overlapping_chunks) >= 3:
|
303 |
-
break
|
304 |
-
|
305 |
-
# Проверяем, что нашли подходящие чанки
|
306 |
-
assert len(overlapping_chunks) > 0, "Не найдены чанки с нахлестом"
|
307 |
-
|
308 |
-
# Собираем чанки
|
309 |
-
result_text = strategy.dechunk(overlapping_chunks)
|
310 |
-
|
311 |
-
# Проверяем, что нахлесты включены в результат
|
312 |
-
for chunk in overlapping_chunks:
|
313 |
-
if hasattr(chunk, 'overlap_left') and chunk.overlap_left:
|
314 |
-
# Хотя бы часть нахлеста должна присутствовать в тексте
|
315 |
-
# Берем первые три слова нахлеста для проверки
|
316 |
-
overlap_words = chunk.overlap_left.split()[:3]
|
317 |
-
if overlap_words:
|
318 |
-
overlap_sample = " ".join(overlap_words)
|
319 |
-
assert overlap_sample in result_text, f"Левый нахлест не найден в результате: {overlap_sample}"
|
320 |
-
|
321 |
-
if hasattr(chunk, 'overlap_right') and chunk.overlap_right:
|
322 |
-
# Аналогично проверяем правый нахлест
|
323 |
-
overlap_words = chunk.overlap_right.split()[:3]
|
324 |
-
if overlap_words:
|
325 |
-
overlap_sample = " ".join(overlap_words)
|
326 |
-
assert overlap_sample in result_text, f"Правый нахлест не найден в результате: {overlap_sample}"
|
327 |
-
|
328 |
-
# Проверяем обработку предложений
|
329 |
-
for chunk in overlapping_chunks:
|
330 |
-
if hasattr(chunk, 'left_sentence_part') and chunk.left_sentence_part:
|
331 |
-
assert chunk.left_sentence_part in result_text, "Левая часть предложения не найдена в результате"
|
332 |
-
|
333 |
-
if hasattr(chunk, 'right_sentence_part') and chunk.right_sentence_part:
|
334 |
-
assert chunk.right_sentence_part in result_text, "Правая часть предложения не найдена в результате"
|
|
|
1 |
+
"""
|
2 |
+
Unit-тесты для стратегии чанкинга FixedSizeChunkingStrategy.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import re
|
6 |
+
from uuid import uuid4
|
7 |
|
8 |
import pytest
|
9 |
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
10 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size.fixed_size_chunk import \
|
11 |
+
FixedSizeChunk
|
12 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import (
|
13 |
+
FIXED_SIZE, FixedSizeChunkingStrategy)
|
14 |
+
from lib.extractor.ntr_text_fragmentation.repositories.in_memory_repository import \
|
15 |
+
InMemoryEntityRepository
|
16 |
+
from ntr_text_fragmentation.models import DocumentAsEntity, LinkerEntity
|
17 |
+
|
18 |
+
|
19 |
+
# --- Фикстуры ---
|
20 |
+
@pytest.fixture
|
21 |
+
def sample_text() -> str:
|
22 |
+
"""Пример текста для тестов."""
|
23 |
+
return (
|
24 |
+
"Это первое предложение. Второе предложение немного длиннее. "
|
25 |
+
"Третье! Четвертое? И пятое.\n"
|
26 |
+
"Новый параграф начинается здесь. Он содержит еще одно предложение. "
|
27 |
+
"И заканчивается тут.\n"
|
28 |
+
"Последний параграф."
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
@pytest.fixture
|
33 |
+
def parsed_document(sample_text: str) -> ParsedDocument:
|
34 |
+
"""Фикстура для ParsedDocument."""
|
35 |
+
paragraphs = [
|
36 |
+
ParsedTextBlock(text=p) for p in sample_text.split('\n') if p
|
37 |
+
]
|
38 |
+
return ParsedDocument(name="test_doc.txt", type="txt", paragraphs=paragraphs)
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.fixture
|
42 |
+
def doc_entity() -> DocumentAsEntity:
|
43 |
+
"""Фикстура для DocumentAsEntity."""
|
44 |
+
return DocumentAsEntity(id=uuid4(), name="test_doc")
|
45 |
+
|
46 |
+
|
47 |
+
@pytest.fixture(scope="module")
|
48 |
+
def default_strategy() -> FixedSizeChunkingStrategy:
|
49 |
+
"""Стратегия с настройками по умолчанию."""
|
50 |
+
return FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=3)
|
51 |
+
|
52 |
+
|
53 |
+
@pytest.fixture(scope="module")
|
54 |
+
def no_sentence_boundary_strategy() -> FixedSizeChunkingStrategy:
|
55 |
+
"""Стратегия без учета границ предложений."""
|
56 |
+
return FixedSizeChunkingStrategy(
|
57 |
+
words_per_chunk=10, overlap_words=3, respect_sentence_boundaries=False
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
@pytest.fixture
|
62 |
+
def extracted_words(parsed_document: ParsedDocument, default_strategy: FixedSizeChunkingStrategy) -> list[str]:
|
63 |
+
"""Извлеченные слова из sample_text."""
|
64 |
+
# Используем приватный метод для консистентности
|
65 |
+
return default_strategy._extract_words(parsed_document)
|
66 |
+
|
67 |
+
|
68 |
+
@pytest.fixture
|
69 |
+
def chunked_entities(
|
70 |
+
default_strategy: FixedSizeChunkingStrategy,
|
71 |
+
parsed_document: ParsedDocument,
|
72 |
+
doc_entity: DocumentAsEntity
|
73 |
+
) -> list[LinkerEntity]:
|
74 |
+
"""Результат чанкинга документа стратегией по умолчанию."""
|
75 |
+
return default_strategy.chunk(parsed_document, doc_entity)
|
76 |
+
|
77 |
+
|
78 |
+
# --- Тесты инициализации и вспомогательных методов ---
|
79 |
+
class TestFixedSizeChunkingStrategyInitAndHelpers:
|
80 |
+
"""Тесты инициализации и приватных методов FixedSizeChunkingStrategy."""
|
81 |
+
|
82 |
+
def test_init_validation(self):
|
83 |
+
"""Тест валидации параметров при инициализации."""
|
84 |
+
with pytest.raises(ValueError, match="overlap_words должен быть меньше words_per_chunk"):
|
85 |
+
FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=10)
|
86 |
+
with pytest.raises(ValueError, match="overlap_words должен быть меньше words_per_chunk"):
|
87 |
+
FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=15)
|
88 |
+
with pytest.raises(ValueError, match="words_per_chunk должен быть > 0"):
|
89 |
+
FixedSizeChunkingStrategy(words_per_chunk=0, overlap_words=0)
|
90 |
+
with pytest.raises(ValueError, match="overlap_words >= 0"):
|
91 |
+
FixedSizeChunkingStrategy(words_per_chunk=10, overlap_words=-1)
|
92 |
+
|
93 |
+
def test_extract_words(self, parsed_document: ParsedDocument, default_strategy: FixedSizeChunkingStrategy):
|
94 |
+
"""Тест метода _extract_words."""
|
95 |
+
words = default_strategy._extract_words(parsed_document)
|
96 |
+
assert words == [
|
97 |
+
"Это", "первое", "предложение.", "Второе", "предложение", "немного", "длиннее.",
|
98 |
+
"Третье!", "Четвертое?", "И", "пятое.", "\n",
|
99 |
+
"Новый", "параграф", "начинается", "здесь.", "Он", "содержит", "еще", "одно", "предложение.",
|
100 |
+
"И", "заканчивается", "тут.", "\n",
|
101 |
+
"Последний", "параграф."
|
102 |
]
|
103 |
+
# Проверяем на пустом документе
|
104 |
+
empty_doc = ParsedDocument()
|
105 |
+
assert default_strategy._extract_words(empty_doc) == []
|
106 |
+
|
107 |
+
def test_prepare_chunk_text(self, extracted_words: list[str], default_strategy: FixedSizeChunkingStrategy):
|
108 |
+
"""Тест метода _prepare_chunk_text."""
|
109 |
+
# Первые 5 слов
|
110 |
+
text = default_strategy._prepare_chunk_text(extracted_words, 0, 5)
|
111 |
+
assert text == "Это первое предложение. Второе предложение"
|
112 |
+
# Слова с переносом строки
|
113 |
+
text = default_strategy._prepare_chunk_text(extracted_words, 10, 15)
|
114 |
+
assert text == "пятое.\nНовый параграф начинается"
|
115 |
+
# Пустой срез
|
116 |
+
text = default_strategy._prepare_chunk_text(extracted_words, 5, 5)
|
117 |
+
assert text == ""
|
118 |
+
|
119 |
+
def test_find_sentence_boundary(self, default_strategy: FixedSizeChunkingStrategy):
|
120 |
+
"""Тест метода _find_sentence_boundary."""
|
121 |
+
# Ищем левую часть (после знака препинания)
|
122 |
+
text1 = "Some text. This part should be found."
|
123 |
+
assert default_strategy._find_sentence_boundary(text1, True) == "This part should be found."
|
124 |
+
text2 = "No punctuation here"
|
125 |
+
assert default_strategy._find_sentence_boundary(text2, True) == ""
|
126 |
+
text3 = "Ends with dot."
|
127 |
+
assert default_strategy._find_sentence_boundary(text3, True) == ""
|
128 |
+
text4 = "Multiple sentences. Second one? Third one!"
|
129 |
+
assert default_strategy._find_sentence_boundary(text4, True) == ""
|
130 |
+
|
131 |
+
# Ищем правую часть (до знака препинания)
|
132 |
+
text5 = "Find this part. Rest of text."
|
133 |
+
assert default_strategy._find_sentence_boundary(text5, False) == "Find this part."
|
134 |
+
text6 = "No punctuation here"
|
135 |
+
assert default_strategy._find_sentence_boundary(text6, False) == "No punctuation here"
|
136 |
+
text7 = "Ends with dot."
|
137 |
+
assert default_strategy._find_sentence_boundary(text7, False) == "Ends with dot."
|
138 |
+
text8 = "Multiple sentences. Second one? Third one!"
|
139 |
+
assert default_strategy._find_sentence_boundary(text8, False) == "Multiple sentences."
|
140 |
+
text9 = ""
|
141 |
+
assert default_strategy._find_sentence_boundary(text9, False) == ""
|
142 |
+
|
143 |
+
def test_calculate_boundaries(self, extracted_words: list[str], default_strategy: FixedSizeChunkingStrategy):
|
144 |
+
"""Тест метода _calculate_boundaries (с respect_sentence_boundaries=True)."""
|
145 |
+
# Пример для первого чанка (индексы 0-10, overlap=3)
|
146 |
+
# overlap_left_start=0, chunk_start=0, chunk_end=10, overlap_right_end=13
|
147 |
+
left_part, right_part, left_overlap, right_overlap = default_strategy._calculate_boundaries(
|
148 |
+
extracted_words, 0, 10, len(extracted_words)
|
149 |
)
|
150 |
+
assert left_overlap == "" # Нет левого оверлапа
|
151 |
+
assert left_part == "" # Т.к. левый оверлап пустой
|
152 |
+
# Правый оверлап: слова с 10 по 13 (исключая 13) -> "пятое. \n Новый"
|
153 |
+
assert right_overlap == "пятое.\nНовый"
|
154 |
+
# Правая часть предложения: ищем до первого знака в right_overlap -> "пятое."
|
155 |
+
assert right_part == "пятое."
|
156 |
+
|
157 |
+
# Пример для второго чанка (индексы 7-17, step=7)
|
158 |
+
# overlap_left_start=4, chunk_start=7, chunk_end=17, overlap_right_end=20
|
159 |
+
left_part, right_part, left_overlap, right_overlap = default_strategy._calculate_boundaries(
|
160 |
+
extracted_words, 7, 17, len(extracted_words)
|
161 |
+
)
|
162 |
+
# Левый оверлап: слова 4-7 -> "предложение немного длиннее."
|
163 |
+
assert left_overlap == "предложение немного длиннее."
|
164 |
+
# Левая часть предложения: ищем после последнего знака в left_overlap -> ""
|
165 |
+
assert left_part == ""
|
166 |
+
# Правый оверлап: слова 17-20 -> "содержит еще одно"
|
167 |
+
assert right_overlap == "содержит еще одно"
|
168 |
+
# Правая часть предложения: ищем до первого знака -> нет знаков, берем всё -> "содержит еще одно"
|
169 |
+
assert right_part == "содержит еще одно"
|
170 |
+
|
171 |
+
def test_calculate_boundaries_no_respect(self, extracted_words: list[str], no_sentence_boundary_strategy: FixedSizeChunkingStrategy):
|
172 |
+
"""Тест _calculate_boundaries с respect_sentence_boundaries=False."""
|
173 |
+
left_part, right_part, left_overlap, right_overlap = no_sentence_boundary_strategy._calculate_boundaries(
|
174 |
+
extracted_words, 7, 17, len(extracted_words)
|
175 |
)
|
176 |
+
assert left_overlap == "предложение немного длиннее."
|
177 |
+
assert right_overlap == "содержит еще одно"
|
178 |
+
# left/right_sentence_part должны быть пустыми, т.к. respect_sentence_boundaries=False
|
179 |
+
assert left_part == ""
|
180 |
+
assert right_part == ""
|
181 |
+
|
182 |
+
def test_clean_final_text(self, default_strategy: FixedSizeChunkingStrategy):
|
183 |
+
"""Тест метода _clean_final_text."""
|
184 |
+
text = " Too many spaces. \n\n\nMultiple newlines. space before punct . ( parentheses ) \n space after newline "
|
185 |
+
cleaned = FixedSizeChunkingStrategy._clean_final_text(text)
|
186 |
+
expected = "Too many spaces.\n\nMultiple newlines. space before punct.(parentheses)\nspace after newline"
|
187 |
+
assert cleaned == expected
|
188 |
+
|
189 |
+
|
190 |
+
# --- Тесты метода chunk ---
|
191 |
+
class TestFixedSizeChunkingStrategyChunk:
|
192 |
+
"""Тесты основного метода chunk."""
|
193 |
+
|
194 |
+
def test_chunk_defaults(self, chunked_entities: list[LinkerEntity], doc_entity: DocumentAsEntity):
|
195 |
+
"""Тест чанкинга с настройками по умолчанию."""
|
196 |
+
# Ожидаемое кол-во слов = 29. chunk=10, overlap=3, step=7.
|
197 |
+
# Чанки: 0-10, 7-17, 14-24, 21-29(обрезано), 28-29(остаток)
|
198 |
+
assert len(chunked_entities) == 5
|
199 |
+
|
200 |
+
# Проверка первого чанка
|
201 |
+
chunk0 = chunked_entities[0]
|
202 |
+
assert isinstance(chunk0, FixedSizeChunk)
|
203 |
+
assert chunk0.owner_id == doc_entity.id
|
204 |
+
assert chunk0.number_in_relation == 0
|
205 |
+
assert chunk0.groupper == "chunk"
|
206 |
+
assert chunk0.text == "Это первое предложение. Второе предложение немного длиннее." # step=7 слов
|
207 |
+
assert chunk0.in_search_text == "Это первое предложение. Второе предложение немного длиннее. Третье! Четвертое? И" # chunk=10 слов
|
208 |
+
assert chunk0.token_count == 10
|
209 |
+
assert chunk0.left_sentence_part == ""
|
210 |
+
assert chunk0.right_sentence_part == "пятое." # Из правого overlap "пятое.\nНовый"
|
211 |
+
assert chunk0.overlap_left == ""
|
212 |
+
assert chunk0.overlap_right == "пятое.\nНовый"
|
213 |
+
|
214 |
+
# Проверка второго чанка
|
215 |
+
chunk1 = chunked_entities[1]
|
216 |
+
assert isinstance(chunk1, FixedSizeChunk)
|
217 |
+
assert chunk1.number_in_relation == 1
|
218 |
+
assert chunk1.text == "Третье! Четвертое? И пятое.\nНовый параграф начинается" # Индексы 7-14 (step=7)
|
219 |
+
assert chunk1.in_search_text == "Третье! Четвертое? И пятое.\nНовый параграф начинается здесь. Он" # Индексы 7-17 (chunk=10)
|
220 |
+
assert chunk1.token_count == 10
|
221 |
+
assert chunk1.left_sentence_part == "" # Из левого overlap "предложение немного длиннее."
|
222 |
+
assert chunk1.right_sentence_part == "содержит еще одно" # Из правого overlap "содержит еще одно предложение."
|
223 |
+
assert chunk1.overlap_left == "предложение немного длиннее."
|
224 |
+
assert chunk1.overlap_right == "содержит еще одно предложение."
|
225 |
+
|
226 |
+
# Проверка последнего чанка (остаток)
|
227 |
+
chunk4 = chunked_entities[4]
|
228 |
+
assert isinstance(chunk4, FixedSizeChunk)
|
229 |
+
assert chunk4.number_in_relation == 4
|
230 |
+
assert chunk4.text == "параграф." # Индекс 28 (step=1)
|
231 |
+
assert chunk4.in_search_text == "параграф." # Индекс 28 (chunk=1, остаток)
|
232 |
+
assert chunk4.token_count == 1
|
233 |
+
assert chunk4.left_sentence_part == "" # Из левого overlap "\nПоследний"
|
234 |
+
assert chunk4.right_sentence_part == "" # Правый overlap пустой
|
235 |
+
assert chunk4.overlap_left == "\nПоследний"
|
236 |
+
assert chunk4.overlap_right == ""
|
237 |
+
|
238 |
+
def test_chunk_no_sentence_boundary(
|
239 |
+
self,
|
240 |
+
no_sentence_boundary_strategy: FixedSizeChunkingStrategy,
|
241 |
+
parsed_document: ParsedDocument,
|
242 |
+
doc_entity: DocumentAsEntity
|
243 |
+
):
|
244 |
+
"""Тест чанкинга без учета границ предложений."""
|
245 |
+
chunks = no_sentence_boundary_strategy.chunk(parsed_document, doc_entity)
|
246 |
+
assert len(chunks) == 5
|
247 |
+
chunk0 = chunks[0]
|
248 |
+
assert isinstance(chunk0, FixedSizeChunk)
|
249 |
+
# left/right_sentence_part должны быть пустыми
|
250 |
+
assert chunk0.left_sentence_part == ""
|
251 |
+
assert chunk0.right_sentence_part == ""
|
252 |
+
assert chunk0.overlap_right == "пятое.\nНовый" # overlap сам по себе остается
|
253 |
+
|
254 |
+
chunk1 = chunks[1]
|
255 |
+
assert isinstance(chunk1, FixedSizeChunk)
|
256 |
+
assert chunk1.left_sentence_part == ""
|
257 |
+
assert chunk1.right_sentence_part == ""
|
258 |
+
assert chunk1.overlap_left == "предложение немного длиннее."
|
259 |
+
assert chunk1.overlap_right == "содержит еще одно предложение."
|
260 |
+
|
261 |
+
def test_chunk_empty_document(self, default_strategy: FixedSizeChunkingStrategy, doc_entity: DocumentAsEntity):
|
262 |
+
"""Тест чанкинга пустого документа."""
|
263 |
+
empty_doc = ParsedDocument()
|
264 |
+
chunks = default_strategy.chunk(empty_doc, doc_entity)
|
265 |
+
assert chunks == []
|
266 |
+
|
267 |
+
def test_chunk_short_document(self, default_strategy: FixedSizeChunkingStrategy, doc_entity: DocumentAsEntity):
|
268 |
+
"""Тест чанкинга очень короткого документа."""
|
269 |
+
short_doc = ParsedDocument(paragraphs=[ParsedTextBlock(text="One two three.")])
|
270 |
+
chunks = default_strategy.chunk(short_doc, doc_entity)
|
271 |
+
assert len(chunks) == 1
|
272 |
+
chunk0 = chunks[0]
|
273 |
+
assert isinstance(chunk0, FixedSizeChunk)
|
274 |
+
assert chunk0.text == "One two three."
|
275 |
+
assert chunk0.in_search_text == "One two three."
|
276 |
+
assert chunk0.token_count == 3
|
277 |
+
assert chunk0.left_sentence_part == ""
|
278 |
+
assert chunk0.right_sentence_part == "" # Нет правого оверлапа
|
279 |
+
|
280 |
+
|
281 |
+
# --- Тесты метода dechunk ---
|
282 |
+
class TestFixedSizeChunkingStrategyDechunk:
|
283 |
+
"""Тесты classmethod dechunk."""
|
284 |
|
285 |
@pytest.fixture
|
286 |
+
def mock_repository(self, chunked_entities: list[LinkerEntity]) -> InMemoryEntityRepository:
|
287 |
+
"""Мок-репозиторий с чанками."""
|
288 |
+
# В dechunk репозиторий пока не используется, но передадим его
|
289 |
+
return InMemoryEntityRepository(chunked_entities)
|
290 |
+
|
291 |
+
def test_dechunk_full_sequence(self, mock_repository: InMemoryEntityRepository, chunked_entities: list[LinkerEntity]):
|
292 |
+
"""Тест сборки полной последовательности чанков."""
|
293 |
+
# Передаем все чанки
|
294 |
+
assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, chunked_entities)
|
295 |
+
|
296 |
+
# Ожидаем, что текст будет собран с использованием left/right_sentence_part
|
297 |
+
# chunk0.left + chunk0.text + chunk1.text + ... + chunkN.text + chunkN.right
|
298 |
+
# chunk0.left_sentence_part = ""
|
299 |
+
# chunk4.right_sentence_part = ""
|
300 |
+
expected_parts = [
|
301 |
+
chunked_entities[0].text, # "Это первое предложение. Второе предложение немного длиннее."
|
302 |
+
chunked_entities[1].text, # "Третье! Четвертое? И пятое.\nНовый параграф начинается"
|
303 |
+
chunked_entities[2].text, # "здесь. Он содержит еще одно предложение."
|
304 |
+
chunked_entities[3].text, # "И заканчивается тут.\nПоследний"
|
305 |
+
chunked_entities[4].text, # "параграф."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
]
|
307 |
+
expected_raw = " ".join(expected_parts)
|
308 |
+
# Применяем очистку, как в методе _build_sequenced_chunks
|
309 |
+
expected_cleaned = FixedSizeChunkingStrategy._clean_final_text(expected_raw)
|
310 |
+
|
311 |
+
# Сравниваем с очищенным результатом
|
312 |
+
assert assembled_text == expected_cleaned
|
313 |
+
# Проверим, что переносы строк сохранились где надо
|
314 |
+
assert "пятое.\nНовый" in assembled_text
|
315 |
+
assert "тут.\nПоследний" in assembled_text
|
316 |
+
|
317 |
+
def test_dechunk_with_gap(self, mock_repository: InMemoryEntityRepository, chunked_entities: list[LinkerEntity]):
|
318 |
+
"""Тест сборки с пропуском чанка."""
|
319 |
+
# Удаляем chunk 1 (индекс 1)
|
320 |
+
filtered_chunks = [chunked_entities[0]] + chunked_entities[2:]
|
321 |
+
assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, filtered_chunks)
|
322 |
+
|
323 |
+
# Группа 1: chunk 0
|
324 |
+
group1_parts = [
|
325 |
+
chunked_entities[0].left_sentence_part, # ""
|
326 |
+
chunked_entities[0].text, # "Это первое предложение. Второе предложение немного длиннее."
|
327 |
+
chunked_entities[0].right_sentence_part # "пятое."
|
328 |
+
]
|
329 |
+
group1_text = FixedSizeChunkingStrategy._clean_final_text(" ".join(filter(None, group1_parts)))
|
330 |
+
|
331 |
+
# Группа 2: chunks 2, 3, 4
|
332 |
+
group2_parts = [
|
333 |
+
chunked_entities[2].left_sentence_part, # "здесь. Он"
|
334 |
+
chunked_entities[2].text, # "здесь. Он содержит еще одно предложение."
|
335 |
+
chunked_entities[3].text, # "И заканчивается тут.\nПоследний"
|
336 |
+
chunked_entities[4].text, # "параграф."
|
337 |
+
chunked_entities[4].right_sentence_part # ""
|
338 |
+
]
|
339 |
+
group2_text = FixedSizeChunkingStrategy._clean_final_text(" ".join(filter(None, group2_parts)))
|
340 |
+
|
341 |
+
expected_text = f"{group1_text}\n...\n{group2_text}"
|
342 |
+
assert assembled_text == expected_text
|
343 |
+
|
344 |
+
def test_dechunk_not_fixed_size_chunk(self, mock_repository: InMemoryEntityRepository, doc_entity: DocumentAsEntity):
|
345 |
+
"""Тест сборки, если передан не FixedSizeChunk."""
|
346 |
+
# Создаем обычный LinkerEntity вместо FixedSizeChunk
|
347 |
+
non_fsc = LinkerEntity(id=uuid4(), name="not a chunk", text="some text", target_id=doc_entity.id, number_in_relation=0)
|
348 |
+
assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, [non_fsc])
|
349 |
+
# Ожидаем просто текст из .text
|
350 |
+
assert assembled_text == "some text"
|
351 |
+
|
352 |
+
def test_dechunk_empty_list(self, mock_repository: InMemoryEntityRepository):
|
353 |
+
"""Тест сборки пустого списка чанков."""
|
354 |
+
assembled_text = FixedSizeChunkingStrategy.dechunk(mock_repository, [])
|
355 |
+
assert assembled_text == ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/extractor/tests/conftest.py
CHANGED
@@ -9,47 +9,3 @@ import pytest
|
|
9 |
from ntr_text_fragmentation.models.linker_entity import LinkerEntity
|
10 |
from tests.custom_entity import CustomEntity # Импортируем наш кастомный класс
|
11 |
|
12 |
-
|
13 |
-
@pytest.fixture
|
14 |
-
def sample_entity():
|
15 |
-
"""
|
16 |
-
Фикстура, возвращающая экземпляр LinkerEntity с предустановленными значениями.
|
17 |
-
"""
|
18 |
-
return LinkerEntity(
|
19 |
-
id=UUID('12345678-1234-5678-1234-567812345678'),
|
20 |
-
name="Тестовая сущность",
|
21 |
-
text="Текст тестовой сущности",
|
22 |
-
metadata={"test_key": "test_value"}
|
23 |
-
)
|
24 |
-
|
25 |
-
|
26 |
-
@pytest.fixture
|
27 |
-
def sample_custom_entity():
|
28 |
-
"""
|
29 |
-
Фикстура, возвращающая экземпляр CustomEntity с предустановленными значениями.
|
30 |
-
"""
|
31 |
-
return CustomEntity(
|
32 |
-
id=UUID('87654321-8765-4321-8765-432187654321'),
|
33 |
-
name="Тестовый кастомный объект",
|
34 |
-
text="Текст кастомного объекта",
|
35 |
-
metadata={"original_key": "original_value"},
|
36 |
-
in_search_text="Текст для поиска кастомного объекта",
|
37 |
-
custom_field1="custom_value",
|
38 |
-
custom_field2=42
|
39 |
-
)
|
40 |
-
|
41 |
-
|
42 |
-
@pytest.fixture
|
43 |
-
def sample_link():
|
44 |
-
"""
|
45 |
-
Фикстура, возвращающая экземпляр LinkerEntity с предустановленными значениями связи.
|
46 |
-
"""
|
47 |
-
return LinkerEntity(
|
48 |
-
id=UUID('98765432-9876-5432-9876-543298765432'),
|
49 |
-
name="Тестовая связь",
|
50 |
-
text="Текст тестовой связи",
|
51 |
-
metadata={"test_key": "test_value"},
|
52 |
-
source_id=UUID('12345678-1234-5678-1234-567812345678'),
|
53 |
-
target_id=UUID('87654321-8765-4321-8765-432187654321'),
|
54 |
-
type="Link"
|
55 |
-
)
|
|
|
9 |
from ntr_text_fragmentation.models.linker_entity import LinkerEntity
|
10 |
from tests.custom_entity import CustomEntity # Импортируем наш кастомный класс
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/extractor/tests/core/test_extractor.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit-тесты для EntitiesExtractor.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from unittest.mock import MagicMock, patch
|
6 |
+
from uuid import UUID, uuid4
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
from ntr_fileparser import ParsedDocument, ParsedTextBlock
|
10 |
+
# Импортируем конкретную стратегию и процессор для мокирования
|
11 |
+
from ntr_text_fragmentation.additors.tables_processor import TablesProcessor
|
12 |
+
from ntr_text_fragmentation.chunking import chunking_registry
|
13 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import (
|
14 |
+
FIXED_SIZE, FixedSizeChunkingStrategy)
|
15 |
+
from ntr_text_fragmentation.core.extractor import EntitiesExtractor
|
16 |
+
from ntr_text_fragmentation.models import DocumentAsEntity, LinkerEntity
|
17 |
+
|
18 |
+
|
19 |
+
# --- Фикстуры ---
|
20 |
+
@pytest.fixture
|
21 |
+
def mock_document() -> ParsedDocument:
|
22 |
+
"""Мок ParsedDocument."""
|
23 |
+
return ParsedDocument(
|
24 |
+
name="mock_doc.pdf",
|
25 |
+
type="pdf",
|
26 |
+
paragraphs=[ParsedTextBlock(text="Paragraph 1."), ParsedTextBlock(text="Paragraph 2.")],
|
27 |
+
# Можно добавить таблицы и т.д., если нужно тестировать их обработку
|
28 |
+
)
|
29 |
+
|
30 |
+
@pytest.fixture
|
31 |
+
def mock_chunk() -> LinkerEntity:
|
32 |
+
"""Мок сущности чанка."""
|
33 |
+
return LinkerEntity(id=uuid4(), name="mock_chunk", text="chunk text", type="Chunk")
|
34 |
+
|
35 |
+
@pytest.fixture
|
36 |
+
def mock_table_entity() -> LinkerEntity:
|
37 |
+
"""Мок сущности таблицы."""
|
38 |
+
return LinkerEntity(id=uuid4(), name="mock_table", text="table text", type="TableEntity")
|
39 |
+
|
40 |
+
@pytest.fixture
|
41 |
+
def mock_strategy_instance(mock_chunk: LinkerEntity) -> MagicMock:
|
42 |
+
"""Мок экземпляра стратегии чанкинга."""
|
43 |
+
instance = MagicMock(spec=FixedSizeChunkingStrategy)
|
44 |
+
# Мокируем метод chunk, чтобы он возвращал предопределенный чанк
|
45 |
+
instance.chunk.return_value = [mock_chunk]
|
46 |
+
return instance
|
47 |
+
|
48 |
+
@pytest.fixture
|
49 |
+
def mock_tables_processor_instance(mock_table_entity: LinkerEntity) -> MagicMock:
|
50 |
+
"""Мок экземпляра процессора таблиц."""
|
51 |
+
instance = MagicMock(spec=TablesProcessor)
|
52 |
+
# Мокируем метод extract, чтобы он возвращал предопределенную сущность таблицы
|
53 |
+
instance.extract.return_value = [mock_table_entity]
|
54 |
+
return instance
|
55 |
+
|
56 |
+
@pytest.fixture(autouse=True)
|
57 |
+
def mock_registry_and_processors(
|
58 |
+
mock_strategy_instance: MagicMock,
|
59 |
+
mock_tables_processor_instance: MagicMock
|
60 |
+
):
|
61 |
+
"""Мокирует реестр стратегий и конструкторы процессоров."""
|
62 |
+
# Мокируем реестр, чтобы он возвращал наш мок-класс стратегии
|
63 |
+
mock_strategy_class = MagicMock(return_value=mock_strategy_instance)
|
64 |
+
with patch.dict(chunking_registry._chunking_strategies, {FIXED_SIZE: mock_strategy_class}, clear=True):
|
65 |
+
# Мокируем конструктор TablesProcessor, чтобы он возвращал наш мок-экземпляр
|
66 |
+
with patch('ntr_text_fragmentation.core.extractor.TablesProcessor', return_value=mock_tables_processor_instance):
|
67 |
+
yield
|
68 |
+
|
69 |
+
|
70 |
+
# --- Тесты --- #
|
71 |
+
class TestEntitiesExtractor:
|
72 |
+
"""Тесты для EntitiesExtractor."""
|
73 |
+
|
74 |
+
def test_init_defaults(self, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock):
|
75 |
+
"""Тест инициализации с настройками по умолчанию."""
|
76 |
+
extractor = EntitiesExtractor()
|
77 |
+
# По умолчанию используется FIXED_SIZE стратегия и process_tables=True
|
78 |
+
assert extractor.strategy is mock_strategy_instance
|
79 |
+
assert extractor._strategy_name == FIXED_SIZE
|
80 |
+
assert extractor.tables_processor is mock_tables_processor_instance
|
81 |
+
|
82 |
+
def test_init_custom_strategy(self, mock_strategy_instance: MagicMock):
|
83 |
+
"""Тест инициализации с указанием стратегии и параметров."""
|
84 |
+
strategy_params = {'words_per_chunk': 100, 'overlap_words': 10}
|
85 |
+
# Ожидаем, что конструктор мок-стратегии будет вызван с этими параметрами
|
86 |
+
extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, strategy_params=strategy_params, process_tables=False)
|
87 |
+
|
88 |
+
# Проверяем, что конструктор мок-стратегии был вызван с правильными параметрами
|
89 |
+
mock_strategy_class = chunking_registry[FIXED_SIZE]
|
90 |
+
mock_strategy_class.assert_called_once_with(**strategy_params)
|
91 |
+
assert extractor.strategy is mock_strategy_instance
|
92 |
+
assert extractor._strategy_name == FIXED_SIZE
|
93 |
+
assert extractor.tables_processor is None # process_tables=False
|
94 |
+
|
95 |
+
def test_init_invalid_strategy_name(self):
|
96 |
+
"""Тест инициализации с невалидным именем стратегии."""
|
97 |
+
with pytest.raises(ValueError, match="Неизвестная стратегия: invalid_strategy"):
|
98 |
+
EntitiesExtractor(strategy_name="invalid_strategy")
|
99 |
+
|
100 |
+
def test_configure_chunking(self, mock_strategy_instance: MagicMock):
|
101 |
+
"""Тест переконфигурации стратегии чанкинга."""
|
102 |
+
extractor = EntitiesExtractor(process_tables=False) # Изначально без стратегии
|
103 |
+
assert extractor.strategy is None
|
104 |
+
|
105 |
+
params = {'words_per_chunk': 20}
|
106 |
+
extractor.configure_chunking(strategy_name=FIXED_SIZE, strategy_params=params)
|
107 |
+
|
108 |
+
mock_strategy_class = chunking_registry[FIXED_SIZE]
|
109 |
+
mock_strategy_class.assert_called_once_with(**params)
|
110 |
+
assert extractor.strategy is mock_strategy_instance
|
111 |
+
assert extractor._strategy_name == FIXED_SIZE
|
112 |
+
|
113 |
+
def test_configure_chunking_invalid_params(self):
|
114 |
+
"""Тест ошибки при неверных параметрах для стратегии."""
|
115 |
+
# Настроим мок-класс стратегии, чтобы он вызывал TypeError при инициализации
|
116 |
+
mock_strategy_class_error = MagicMock(side_effect=TypeError("Invalid param"))
|
117 |
+
with patch.dict(chunking_registry._chunking_strategies, {FIXED_SIZE: mock_strategy_class_error}):
|
118 |
+
extractor = EntitiesExtractor(process_tables=False)
|
119 |
+
with pytest.raises(ValueError, match="Ошибка при попытке инициализировать стратегию"):
|
120 |
+
extractor.configure_chunking(strategy_name=FIXED_SIZE, strategy_params={"invalid": 1})
|
121 |
+
|
122 |
+
def test_configure_tables_extraction(self, mock_tables_processor_instance: MagicMock):
|
123 |
+
"""Тест переконфигурации извлечения таблиц."""
|
124 |
+
extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=False) # Изначально без таблиц
|
125 |
+
assert extractor.tables_processor is None
|
126 |
+
|
127 |
+
extractor.configure_tables_extraction(process_tables=True)
|
128 |
+
assert extractor.tables_processor is mock_tables_processor_instance
|
129 |
+
|
130 |
+
extractor.configure_tables_extraction(process_tables=False)
|
131 |
+
# Экземпляр процессора создается, но не используется, если process_tables=False в destructure
|
132 |
+
# Однако configure_tables_extraction устанавливает его. Проверим это.
|
133 |
+
# Ожидаем, что конструктор TablesProcessor будет вызван при configure_tables_extraction(True)
|
134 |
+
# и при configure_tables_extraction(False) он не обнулится?
|
135 |
+
# Судя по коду configure_tables_extraction, он всегда создает новый TablesProcessor.
|
136 |
+
# Давайте уточним тест, что процессор создается.
|
137 |
+
with patch('ntr_text_fragmentation.core.extractor.TablesProcessor') as mock_constructor:
|
138 |
+
extractor.configure_tables_extraction(process_tables=True)
|
139 |
+
mock_constructor.assert_called_once()
|
140 |
+
assert extractor.tables_processor is not None
|
141 |
+
|
142 |
+
mock_constructor.reset_mock()
|
143 |
+
extractor.configure_tables_extraction(process_tables=False)
|
144 |
+
# Повторный вызов конструктора! Это может быть неэффективно, но тест должен отражать код.
|
145 |
+
mock_constructor.assert_called_once()
|
146 |
+
assert extractor.tables_processor is not None # Экземпляр остается
|
147 |
+
|
148 |
+
def test_configure_chaining(self, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock):
|
149 |
+
"""Тест цепочки вызовов configure."""
|
150 |
+
extractor = EntitiesExtractor(strategy_name=None, process_tables=None) # Полностью пустой
|
151 |
+
assert extractor.strategy is None
|
152 |
+
assert extractor.tables_processor is None
|
153 |
+
|
154 |
+
returned_extractor = extractor.configure(strategy_name=FIXED_SIZE, process_tables=True)
|
155 |
+
|
156 |
+
assert returned_extractor is extractor # Должен возвращать себя
|
157 |
+
assert extractor.strategy is mock_strategy_instance
|
158 |
+
assert extractor.tables_processor is mock_tables_processor_instance
|
159 |
+
|
160 |
+
# Переконфигурируем только стратегию
|
161 |
+
new_params = {"words_per_chunk": 50}
|
162 |
+
extractor.configure(strategy_name=FIXED_SIZE, strategy_params=new_params)
|
163 |
+
mock_strategy_class = chunking_registry[FIXED_SIZE]
|
164 |
+
# Конструктор стратегии вызывается повторно
|
165 |
+
mock_strategy_class.assert_called_with(**new_params)
|
166 |
+
assert extractor.tables_processor is mock_tables_processor_instance # Процессор таблиц не изменился
|
167 |
+
|
168 |
+
# Переконфигурируем только таблицы
|
169 |
+
extractor.configure(process_tables=False)
|
170 |
+
# tables_processor создается, но использоваться не будет
|
171 |
+
assert extractor.tables_processor is not None
|
172 |
+
assert extractor.strategy is mock_strategy_instance # Стратегия не изменилась
|
173 |
+
|
174 |
+
def test_destructure_calls_chunk_and_extract(self, mock_document: ParsedDocument, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock, mock_chunk: LinkerEntity, mock_table_entity: LinkerEntity):
|
175 |
+
"""Тест, что destructure вызывает chunk и extract."""
|
176 |
+
extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=True)
|
177 |
+
entities = extractor.destructure(mock_document)
|
178 |
+
|
179 |
+
# Проверяем вызов chunk стратегии
|
180 |
+
mock_strategy_instance.chunk.assert_called_once()
|
181 |
+
# Проверяем аргументы вызова chunk
|
182 |
+
call_args, _ = mock_strategy_instance.chunk.call_args
|
183 |
+
assert call_args[0] is mock_document
|
184 |
+
assert isinstance(call_args[1], DocumentAsEntity)
|
185 |
+
assert call_args[1].name == "mock_doc.pdf" # Имя документа передано
|
186 |
+
# Проверяем, что стратегия записалась в DocumentAsEntity
|
187 |
+
assert call_args[1].chunking_strategy_ref == FIXED_SIZE
|
188 |
+
|
189 |
+
# Проверяем вызов extract процессора таблиц
|
190 |
+
mock_tables_processor_instance.extract.assert_called_once()
|
191 |
+
# Проверяем аргументы вызова extract
|
192 |
+
call_args_table, _ = mock_tables_processor_instance.extract.call_args
|
193 |
+
assert call_args_table[0] is mock_document
|
194 |
+
assert isinstance(call_args_table[1], DocumentAsEntity)
|
195 |
+
assert call_args_table[1].name == "mock_doc.pdf"
|
196 |
+
|
197 |
+
# Проверяем результат: должен содержать DocumentAsEntity, результат chunk, результат extract
|
198 |
+
assert len(entities) == 3
|
199 |
+
entity_types = {type(e) for e in entities}
|
200 |
+
# Все сущности сериализованы в LinkerEntity
|
201 |
+
assert entity_types == {LinkerEntity}
|
202 |
+
|
203 |
+
entity_ids = {e.id for e in entities}
|
204 |
+
# Проверяем наличие ID моков (после сериализации ID сохраняются)
|
205 |
+
assert mock_chunk.id in entity_ids
|
206 |
+
assert mock_table_entity.id in entity_ids
|
207 |
+
# Проверяем наличие ID документа (он создается внутри)
|
208 |
+
doc_entity_id = next(e.id for e in entities if e.type == "DocumentAsEntity")
|
209 |
+
assert isinstance(doc_entity_id, UUID)
|
210 |
+
|
211 |
+
def test_destructure_only_chunking(self, mock_document: ParsedDocument, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock, mock_chunk: LinkerEntity):
|
212 |
+
"""Тест destructure только с чанкингом."""
|
213 |
+
extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=False)
|
214 |
+
entities = extractor.destructure(mock_document)
|
215 |
+
|
216 |
+
mock_strategy_instance.chunk.assert_called_once()
|
217 |
+
mock_tables_processor_instance.extract.assert_not_called() # extract не должен вызываться
|
218 |
+
|
219 |
+
assert len(entities) == 2 # DocumentAsEntity + chunk
|
220 |
+
entity_types = {e.type for e in entities}
|
221 |
+
assert "DocumentAsEntity" in entity_types
|
222 |
+
assert "Chunk" in entity_types
|
223 |
+
|
224 |
+
def test_destructure_no_strategy_no_tables(self, mock_document: ParsedDocument, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock):
|
225 |
+
"""Тест destructure без стратегии и без таблиц."""
|
226 |
+
# Убираем стратегию из реестра на время теста
|
227 |
+
with patch.dict(chunking_registry._chunking_strategies, {}, clear=True):
|
228 |
+
extractor = EntitiesExtractor(strategy_name=None, process_tables=False)
|
229 |
+
entities = extractor.destructure(mock_document)
|
230 |
+
|
231 |
+
mock_strategy_instance.chunk.assert_not_called()
|
232 |
+
mock_tables_processor_instance.extract.assert_not_called()
|
233 |
+
|
234 |
+
assert len(entities) == 1 # Только DocumentAsEntity
|
235 |
+
assert entities[0].type == "DocumentAsEntity"
|
236 |
+
assert entities[0].name == "mock_doc.pdf"
|
237 |
+
|
238 |
+
def test_destructure_with_string_input(self, mock_strategy_instance: MagicMock, mock_tables_processor_instance: MagicMock):
|
239 |
+
"""Тест destructure с входной строкой."""
|
240 |
+
input_string = "Это тестовая строка.\nВторая строка."
|
241 |
+
extractor = EntitiesExtractor(strategy_name=FIXED_SIZE, process_tables=False)
|
242 |
+
entities = extractor.destructure(input_string)
|
243 |
+
|
244 |
+
# Проверяем, что chunk был вызван с созданным ParsedDocument
|
245 |
+
mock_strategy_instance.chunk.assert_called_once()
|
246 |
+
call_args, _ = mock_strategy_instance.chunk.call_args
|
247 |
+
assert isinstance(call_args[0], ParsedDocument)
|
248 |
+
assert call_args[0].name == "unknown"
|
249 |
+
assert call_args[0].type == "PlainText"
|
250 |
+
assert len(call_args[0].paragraphs) == 2
|
251 |
+
assert call_args[0].paragraphs[0].text == "Это тестовая строка."
|
252 |
+
assert isinstance(call_args[1], DocumentAsEntity)
|
253 |
+
|
254 |
+
mock_tables_processor_instance.extract.assert_not_called()
|
255 |
+
|
256 |
+
assert len(entities) == 2 # Document + chunk
|
257 |
+
|
258 |
+
def test_destructure_runtime_error_no_strategy(self, mock_document: ParsedDocument):
|
259 |
+
"""Тест RuntimeError, если стратегия не сконфигурирована, но вызывается _chunk."""
|
260 |
+
# Этот тест немного искусственный, т.к. destructure не вызовет _chunk, если strategy is None
|
261 |
+
# Но проверим сам метод _chunk на всякий случай
|
262 |
+
extractor = EntitiesExtractor(strategy_name=None, process_tables=False)
|
263 |
+
doc_entity = extractor._create_document_entity(mock_document)
|
264 |
+
with pytest.raises(RuntimeError, match="Стратегия чанкинга не выставлена"):
|
265 |
+
extractor._chunk(mock_document, doc_entity)
|
lib/extractor/tests/core/test_in_memory_repository.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit-тесты для InMemoryEntityRepository.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from uuid import UUID, uuid4
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
# Импортируем все необходимые типы сущностей
|
9 |
+
from ntr_text_fragmentation.additors.tables.models import (SubTableEntity,
|
10 |
+
TableEntity,
|
11 |
+
TableRowEntity)
|
12 |
+
from ntr_text_fragmentation.chunking.models import Chunk
|
13 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size.fixed_size_chunk import \
|
14 |
+
FixedSizeChunk
|
15 |
+
from ntr_text_fragmentation.core.entity_repository import GroupedEntities
|
16 |
+
from lib.extractor.ntr_text_fragmentation.repositories.in_memory_repository import \
|
17 |
+
InMemoryEntityRepository
|
18 |
+
from ntr_text_fragmentation.models import DocumentAsEntity, LinkerEntity
|
19 |
+
|
20 |
+
# --- Фикстуры ---
|
21 |
+
|
22 |
+
@pytest.fixture
|
23 |
+
def doc1_id() -> UUID:
|
24 |
+
return uuid4()
|
25 |
+
|
26 |
+
@pytest.fixture
|
27 |
+
def doc2_id() -> UUID:
|
28 |
+
return uuid4()
|
29 |
+
|
30 |
+
@pytest.fixture
|
31 |
+
def table1_id() -> UUID:
|
32 |
+
return uuid4()
|
33 |
+
|
34 |
+
@pytest.fixture
|
35 |
+
def subtable1_id() -> UUID:
|
36 |
+
return uuid4()
|
37 |
+
|
38 |
+
@pytest.fixture
|
39 |
+
def doc1(doc1_id: UUID) -> DocumentAsEntity:
|
40 |
+
return DocumentAsEntity(id=doc1_id, name="doc1")
|
41 |
+
|
42 |
+
@pytest.fixture
|
43 |
+
def doc2(doc2_id: UUID) -> DocumentAsEntity:
|
44 |
+
return DocumentAsEntity(id=doc2_id, name="doc2")
|
45 |
+
|
46 |
+
@pytest.fixture
|
47 |
+
def chunk1_doc1(doc1_id: UUID) -> Chunk:
|
48 |
+
return Chunk(id=uuid4(), name="chunk1_doc1", text="text1", target_id=doc1_id, number_in_relation=0, groupper="chunk")
|
49 |
+
|
50 |
+
@pytest.fixture
|
51 |
+
def chunk2_doc1(doc1_id: UUID) -> Chunk:
|
52 |
+
return Chunk(id=uuid4(), name="chunk2_doc1", text="text2", target_id=doc1_id, number_in_relation=1, groupper="chunk")
|
53 |
+
|
54 |
+
@pytest.fixture
|
55 |
+
def chunk3_doc1(doc1_id: UUID) -> Chunk:
|
56 |
+
# Пропускаем номер 2 для теста соседей и группировки
|
57 |
+
return Chunk(id=uuid4(), name="chunk3_doc1", text="text3", target_id=doc1_id, number_in_relation=3, groupper="chunk")
|
58 |
+
|
59 |
+
@pytest.fixture
|
60 |
+
def chunk1_doc2(doc2_id: UUID) -> Chunk:
|
61 |
+
return Chunk(id=uuid4(), name="chunk1_doc2", text="text_doc2", target_id=doc2_id, number_in_relation=0, groupper="chunk")
|
62 |
+
|
63 |
+
@pytest.fixture
|
64 |
+
def table1_doc1(doc1_id: UUID, table1_id: UUID) -> TableEntity:
|
65 |
+
return TableEntity(id=table1_id, name="table1", target_id=doc1_id, number_in_relation=0, groupper="table")
|
66 |
+
|
67 |
+
@pytest.fixture
|
68 |
+
def row1_table1(table1_id: UUID) -> TableRowEntity:
|
69 |
+
return TableRowEntity(id=uuid4(), name="row1_table1", cells=["r1c1", "r1c2"], target_id=table1_id, number_in_relation=0, groupper="row")
|
70 |
+
|
71 |
+
@pytest.fixture
|
72 |
+
def row2_table1(table1_id: UUID) -> TableRowEntity:
|
73 |
+
return TableRowEntity(id=uuid4(), name="row2_table1", cells=["r2c1", "r2c2"], target_id=table1_id, number_in_relation=1, groupper="row")
|
74 |
+
|
75 |
+
@pytest.fixture
|
76 |
+
def subtable1_table1(table1_id: UUID, subtable1_id: UUID) -> SubTableEntity:
|
77 |
+
return SubTableEntity(id=subtable1_id, name="subtable1", target_id=table1_id, number_in_relation=2, groupper="subtable")
|
78 |
+
|
79 |
+
@pytest.fixture
|
80 |
+
def row1_subtable1(subtable1_id: UUID) -> TableRowEntity:
|
81 |
+
return TableRowEntity(id=uuid4(), name="row1_subtable1", cells=["sr1c1"], target_id=subtable1_id, number_in_relation=0, groupper="subrow")
|
82 |
+
|
83 |
+
@pytest.fixture
|
84 |
+
def link1(chunk1_doc1: Chunk, chunk2_doc1: Chunk) -> LinkerEntity:
|
85 |
+
# Пример кастомной связи
|
86 |
+
return LinkerEntity(id=uuid4(), name="link1", source_id=chunk1_doc1.id, target_id=chunk2_doc1.id, type="CustomLink")
|
87 |
+
|
88 |
+
@pytest.fixture
|
89 |
+
def all_entities(
|
90 |
+
doc1: DocumentAsEntity,
|
91 |
+
doc2: DocumentAsEntity,
|
92 |
+
chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk,
|
93 |
+
chunk1_doc2: Chunk,
|
94 |
+
table1_doc1: TableEntity,
|
95 |
+
row1_table1: TableRowEntity, row2_table1: TableRowEntity,
|
96 |
+
subtable1_table1: SubTableEntity,
|
97 |
+
row1_subtable1: TableRowEntity,
|
98 |
+
link1: LinkerEntity
|
99 |
+
) -> list[LinkerEntity]:
|
100 |
+
return [
|
101 |
+
doc1, doc2, chunk1_doc1, chunk2_doc1, chunk3_doc1, chunk1_doc2,
|
102 |
+
table1_doc1, row1_table1, row2_table1, subtable1_table1, row1_subtable1, link1
|
103 |
+
]
|
104 |
+
|
105 |
+
@pytest.fixture
|
106 |
+
def empty_repository() -> InMemoryEntityRepository:
|
107 |
+
return InMemoryEntityRepository()
|
108 |
+
|
109 |
+
@pytest.fixture
|
110 |
+
def populated_repository(all_entities: list[LinkerEntity]) -> InMemoryEntityRepository:
|
111 |
+
return InMemoryEntityRepository(all_entities)
|
112 |
+
|
113 |
+
|
114 |
+
# --- Тесты --- #
|
115 |
+
class TestInMemoryEntityRepository:
|
116 |
+
"""Тесты для InMemoryEntityRepository."""
|
117 |
+
|
118 |
+
def test_init_empty(self, empty_repository: InMemoryEntityRepository):
|
119 |
+
"""Тест инициализации пустого репозитория."""
|
120 |
+
assert empty_repository.entities == []
|
121 |
+
assert empty_repository.entities_by_id == {}
|
122 |
+
assert empty_repository.relations_by_source == {}
|
123 |
+
assert empty_repository.relations_by_target == {}
|
124 |
+
assert empty_repository.compositions == {}
|
125 |
+
|
126 |
+
def test_init_with_entities(self, populated_repository: InMemoryEntityRepository, all_entities: list[LinkerEntity], doc1_id: UUID, chunk1_doc1: Chunk, link1: LinkerEntity):
|
127 |
+
"""Тест инициализации с сущностями и построения индексов."""
|
128 |
+
assert len(populated_repository.entities) == len(all_entities)
|
129 |
+
assert len(populated_repository.entities_by_id) == len(all_entities)
|
130 |
+
assert doc1_id in populated_repository.entities_by_id
|
131 |
+
assert chunk1_doc1.id in populated_repository.entities_by_id
|
132 |
+
|
133 |
+
# Проверка индекса compositions (owner_id)
|
134 |
+
assert doc1_id in populated_repository.compositions
|
135 |
+
assert len(populated_repository.compositions[doc1_id]) == 4 # chunk1, chunk2, chunk3, table1
|
136 |
+
doc1_children_ids = {e.id for e in populated_repository.compositions[doc1_id]}
|
137 |
+
assert chunk1_doc1.id in doc1_children_ids
|
138 |
+
|
139 |
+
# Проверка индекса relations
|
140 |
+
assert link1.source_id in populated_repository.relations_by_source
|
141 |
+
assert link1 in populated_repository.relations_by_source[link1.source_id]
|
142 |
+
assert link1.target_id in populated_repository.relations_by_target
|
143 |
+
assert link1 in populated_repository.relations_by_target[link1.target_id]
|
144 |
+
|
145 |
+
def test_add_entities(self, empty_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk):
|
146 |
+
"""Тест добавления сущностей."""
|
147 |
+
empty_repository.add_entities([chunk1_doc1])
|
148 |
+
assert len(empty_repository.entities) == 1
|
149 |
+
assert chunk1_doc1.id in empty_repository.entities_by_id
|
150 |
+
assert chunk1_doc1.owner_id in empty_repository.compositions
|
151 |
+
|
152 |
+
empty_repository.add_entities([chunk2_doc1])
|
153 |
+
assert len(empty_repository.entities) == 2
|
154 |
+
assert chunk2_doc1.id in empty_repository.entities_by_id
|
155 |
+
assert len(empty_repository.compositions[chunk1_doc1.owner_id]) == 2
|
156 |
+
|
157 |
+
def test_set_entities(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk):
|
158 |
+
"""Тест установки (замены) сущностей."""
|
159 |
+
initial_count = len(populated_repository.entities)
|
160 |
+
populated_repository.set_entities([chunk1_doc1, chunk2_doc1])
|
161 |
+
assert len(populated_repository.entities) == 2
|
162 |
+
assert len(populated_repository.entities_by_id) == 2
|
163 |
+
assert chunk1_doc1.id in populated_repository.entities_by_id
|
164 |
+
assert chunk2_doc1.id in populated_repository.entities_by_id
|
165 |
+
assert len(populated_repository.compositions) == 1 # Только один owner_id
|
166 |
+
assert len(populated_repository.compositions[chunk1_doc1.owner_id]) == 2
|
167 |
+
assert len(populated_repository.relations_by_source) == 0 # Старые связи удалены
|
168 |
+
|
169 |
+
def test_get_entities_by_ids(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk1_doc2: Chunk):
|
170 |
+
"""Тест получения сущностей по ID."""
|
171 |
+
ids_to_get = [chunk1_doc1.id, chunk1_doc2.id, uuid4()] # Последний ID не существует
|
172 |
+
result = populated_repository.get_entities_by_ids(ids_to_get)
|
173 |
+
assert len(result) == 2
|
174 |
+
result_ids = {e.id for e in result}
|
175 |
+
assert chunk1_doc1.id in result_ids
|
176 |
+
assert chunk1_doc2.id in result_ids
|
177 |
+
|
178 |
+
# Тест с передачей самих сущностей
|
179 |
+
result_from_entities = populated_repository.get_entities_by_ids([chunk1_doc1, chunk1_doc2])
|
180 |
+
assert len(result_from_entities) == 2
|
181 |
+
assert chunk1_doc1 in result_from_entities
|
182 |
+
assert chunk1_doc2 in result_from_entities
|
183 |
+
|
184 |
+
# Тест с пустым списком ID
|
185 |
+
assert populated_repository.get_entities_by_ids([]) == []
|
186 |
+
|
187 |
+
# --- Тесты group_entities_hierarchically ---
|
188 |
+
def test_group_by_doc_simple(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk, table1_doc1: TableEntity):
|
189 |
+
"""Тест простой группировки по документу."""
|
190 |
+
entities_to_group = [chunk1_doc1.id, chunk2_doc1.id, chunk3_doc1.id, table1_doc1.id]
|
191 |
+
groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity)
|
192 |
+
|
193 |
+
assert len(groups) == 1
|
194 |
+
group = groups[0]
|
195 |
+
assert isinstance(group, GroupedEntities)
|
196 |
+
assert group.composer == doc1
|
197 |
+
assert len(group.entities) == 4
|
198 |
+
grouped_ids = {e.id for e in group.entities}
|
199 |
+
assert chunk1_doc1.id in grouped_ids
|
200 |
+
assert chunk2_doc1.id in grouped_ids
|
201 |
+
assert chunk3_doc1.id in grouped_ids
|
202 |
+
assert table1_doc1.id in grouped_ids
|
203 |
+
# Проверка сортировки по number_in_relation
|
204 |
+
assert [e.number_in_relation for e in group.entities] == [0, 0, 1, 3]
|
205 |
+
|
206 |
+
def test_group_by_doc_multi_level(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, row1_table1: TableRowEntity, row1_subtable1: TableRowEntity):
|
207 |
+
"""Тест иерархической группировки (строка -> подтаблица -> таблица -> документ)."""
|
208 |
+
entities_to_group = [row1_table1.id, row1_subtable1.id]
|
209 |
+
groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity)
|
210 |
+
|
211 |
+
assert len(groups) == 1
|
212 |
+
group = groups[0]
|
213 |
+
assert group.composer == doc1
|
214 |
+
assert len(group.entities) == 2
|
215 |
+
grouped_ids = {e.id for e in group.entities}
|
216 |
+
assert row1_table1.id in grouped_ids
|
217 |
+
assert row1_subtable1.id in grouped_ids
|
218 |
+
# Проверка сортировки (строка подтаблицы идет после строки таблицы)
|
219 |
+
# row1_table1: owner=table1 (num=0), row1_subtable1: owner=subtable1 (num=0) -> owner=table1 (num=2)
|
220 |
+
# Порядок сортировки по умолчанию не определен для разных уровней иерархии, только внутри одного owner
|
221 |
+
# Позиция определяется как (groupper, number_in_relation)
|
222 |
+
assert group.entities[0].id == row1_table1.id
|
223 |
+
assert group.entities[1].id == row1_subtable1.id
|
224 |
+
|
225 |
+
def test_group_by_doc_multiple_docs(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, doc2: DocumentAsEntity, chunk1_doc1: Chunk, chunk1_doc2: Chunk):
|
226 |
+
"""Тест группировки сущностей из разных документов."""
|
227 |
+
entities_to_group = [chunk1_doc1.id, chunk1_doc2.id]
|
228 |
+
groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity)
|
229 |
+
|
230 |
+
assert len(groups) == 2
|
231 |
+
composers = {g.composer for g in groups}
|
232 |
+
assert doc1 in composers
|
233 |
+
assert doc2 in composers
|
234 |
+
|
235 |
+
for group in groups:
|
236 |
+
if group.composer == doc1:
|
237 |
+
assert len(group.entities) == 1
|
238 |
+
assert group.entities[0] == chunk1_doc1
|
239 |
+
elif group.composer == doc2:
|
240 |
+
assert len(group.entities) == 1
|
241 |
+
assert group.entities[0] == chunk1_doc2
|
242 |
+
|
243 |
+
def test_group_no_sort(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, chunk3_doc1: Chunk):
|
244 |
+
"""Тест группировки без сортировки."""
|
245 |
+
# Передаем в обратном порядке
|
246 |
+
entities_to_group = [chunk3_doc1.id, chunk1_doc1.id]
|
247 |
+
groups = populated_repository.group_entities_hierarchically(entities_to_group, DocumentAsEntity, sort=False)
|
248 |
+
assert len(groups) == 1
|
249 |
+
group = groups[0]
|
250 |
+
# Порядок должен сохраниться как в entities_to_group
|
251 |
+
assert group.entities[0].id == chunk3_doc1.id
|
252 |
+
assert group.entities[1].id == chunk1_doc1.id
|
253 |
+
|
254 |
+
def test_group_empty(self, populated_repository: InMemoryEntityRepository):
|
255 |
+
"""Тест группировки пустого списка."""
|
256 |
+
groups = populated_repository.group_entities_hierarchically([], DocumentAsEntity)
|
257 |
+
assert groups == []
|
258 |
+
|
259 |
+
def test_group_max_levels(self, populated_repository: InMemoryEntityRepository, row1_subtable1: TableRowEntity):
|
260 |
+
"""Тест ограничения глубины поиска родителя."""
|
261 |
+
# Уровень 1: SubTable, Уровень 2: Table, Уровень 3: Document
|
262 |
+
# Ищем корень DocumentAsEntity (нужно 3 уровня)
|
263 |
+
groups_ok = populated_repository.group_entities_hierarchically([row1_subtable1.id], DocumentAsEntity, max_levels=3)
|
264 |
+
assert len(groups_ok) == 1
|
265 |
+
assert groups_ok[0].composer.type == "DocumentAsEntity"
|
266 |
+
|
267 |
+
# Ищем с max_levels=2 (должен не найти Document)
|
268 |
+
groups_fail = populated_repository.group_entities_hierarchically([row1_subtable1.id], DocumentAsEntity, max_levels=2)
|
269 |
+
assert len(groups_fail) == 0
|
270 |
+
|
271 |
+
# Ищем корень TableEntity (нужно 2 уровня)
|
272 |
+
groups_table_ok = populated_repository.group_entities_hierarchically([row1_subtable1.id], TableEntity, max_levels=2)
|
273 |
+
assert len(groups_table_ok) == 1
|
274 |
+
assert groups_table_ok[0].composer.type == "TableEntity"
|
275 |
+
|
276 |
+
groups_table_fail = populated_repository.group_entities_hierarchically([row1_subtable1.id], TableEntity, max_levels=1)
|
277 |
+
assert len(groups_table_fail) == 0
|
278 |
+
|
279 |
+
# --- Тесты get_neighboring_entities ---
|
280 |
+
def test_get_neighbors_distance_1(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk):
|
281 |
+
"""Тест получения соседей с distance=1."""
|
282 |
+
# Соседи для chunk2 (индекс 1)
|
283 |
+
neighbors = populated_repository.get_neighboring_entities([chunk2_doc1.id], max_distance=1)
|
284 |
+
neighbor_ids = {e.id for e in neighbors}
|
285 |
+
assert len(neighbors) == 2
|
286 |
+
assert chunk1_doc1.id in neighbor_ids # Сосед слева (индекс 0)
|
287 |
+
assert chunk3_doc1.id not in neighbor_ids # Сосед справа (индекс 3) - далеко
|
288 |
+
|
289 |
+
# Соседи для chunk1 (индекс 0)
|
290 |
+
neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id], max_distance=1)
|
291 |
+
neighbor_ids = {e.id for e in neighbors}
|
292 |
+
assert len(neighbors) == 1
|
293 |
+
assert chunk2_doc1.id in neighbor_ids # Сосед справа (индекс 1)
|
294 |
+
|
295 |
+
# Соседи для chunk3 (индекс 3)
|
296 |
+
neighbors = populated_repository.get_neighboring_entities([chunk3_doc1.id], max_distance=1)
|
297 |
+
neighbor_ids = {e.id for e in neighbors}
|
298 |
+
# Сосед слева chunk2 (индекс 1) слишком далеко (diff = 2)
|
299 |
+
assert len(neighbors) == 0
|
300 |
+
|
301 |
+
def test_get_neighbors_distance_2(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk):
|
302 |
+
"""Тест получения соседей с distance=2."""
|
303 |
+
neighbors = populated_repository.get_neighboring_entities([chunk2_doc1.id], max_distance=2)
|
304 |
+
neighbor_ids = {e.id for e in neighbors}
|
305 |
+
assert len(neighbors) == 2
|
306 |
+
assert chunk1_doc1.id in neighbor_ids # Сосед слева (индекс 0, diff=1)
|
307 |
+
assert chunk3_doc1.id in neighbor_ids # Сосед справа (индекс 3, diff=2)
|
308 |
+
|
309 |
+
def test_get_neighbors_multiple_entities(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, chunk3_doc1: Chunk):
|
310 |
+
"""Тест получения соседей для нескольких сущностей."""
|
311 |
+
neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id, chunk3_doc1.id], max_distance=1)
|
312 |
+
neighbor_ids = {e.id for e in neighbors}
|
313 |
+
# Сосед chunk1 -> chunk2
|
314 |
+
# Соседей у chunk3 нет (с distance=1)
|
315 |
+
assert len(neighbors) == 1
|
316 |
+
assert chunk2_doc1.id in neighbor_ids
|
317 |
+
|
318 |
+
def test_get_neighbors_different_owners(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk1_doc2: Chunk):
|
319 |
+
"""Тест: соседи ищутся только в рамках одного owner."""
|
320 |
+
neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id], max_distance=5) # Большая дистанция
|
321 |
+
# Должен найти только chunk2_doc1 и chunk3_doc1, но не chunk1_doc2
|
322 |
+
neighbor_ids = {e.id for e in neighbors}
|
323 |
+
assert len(neighbors) == 2
|
324 |
+
assert chunk1_doc2.id not in neighbor_ids
|
325 |
+
|
326 |
+
def test_get_neighbors_different_groupers(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, table1_doc1: TableEntity):
|
327 |
+
"""Тест: соседи ищутся только в рамках одного groupper."""
|
328 |
+
# У chunk1_doc1 groupper='chunk', у table1_doc1 groupper='table'
|
329 |
+
# Оба принадлежат doc1, number_in_relation = 0 у обоих
|
330 |
+
neighbors = populated_repository.get_neighboring_entities([chunk1_doc1.id], max_distance=1)
|
331 |
+
neighbor_ids = {e.id for e in neighbors}
|
332 |
+
assert table1_doc1.id not in neighbor_ids # Не должен найти таблицу
|
333 |
+
|
334 |
+
def test_get_neighbors_no_owner(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity):
|
335 |
+
"""Тест: сущности без owner_id не должны иметь соседей."""
|
336 |
+
neighbors = populated_repository.get_neighboring_entities([doc1.id], max_distance=1)
|
337 |
+
assert len(neighbors) == 0
|
338 |
+
|
339 |
+
def test_get_neighbors_empty(self, populated_repository: InMemoryEntityRepository):
|
340 |
+
"""Тест получения соседей для пустого списка."""
|
341 |
+
neighbors = populated_repository.get_neighboring_entities([], max_distance=1)
|
342 |
+
assert neighbors == []
|
343 |
+
|
344 |
+
# --- Тесты get_related_entities ---
|
345 |
+
def test_get_related_as_source(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, link1: LinkerEntity):
|
346 |
+
"""Тест поиска связей, где сущность - источник."""
|
347 |
+
related = populated_repository.get_related_entities([chunk1_doc1.id], as_source=True)
|
348 |
+
related_ids = {e.id for e in related}
|
349 |
+
assert len(related) == 2 # Сама связь + цель связи
|
350 |
+
assert link1.id in related_ids
|
351 |
+
assert link1.target_id in related_ids # chunk2_doc1.id
|
352 |
+
|
353 |
+
def test_get_related_as_target(self, populated_repository: InMemoryEntityRepository, chunk2_doc1: Chunk, link1: LinkerEntity):
|
354 |
+
"""Тест поиска связей, где сущность - цель."""
|
355 |
+
related = populated_repository.get_related_entities([chunk2_doc1.id], as_target=True)
|
356 |
+
related_ids = {e.id for e in related}
|
357 |
+
assert len(related) == 2 # Сама связь + источник связи
|
358 |
+
assert link1.id in related_ids
|
359 |
+
assert link1.source_id in related_ids # chunk1_doc1.id
|
360 |
+
|
361 |
+
def test_get_related_as_owner(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, table1_doc1: TableEntity):
|
362 |
+
"""Тест поиска дочерних сущностей (по owner_id)."""
|
363 |
+
related = populated_repository.get_related_entities([doc1.id], as_owner=True)
|
364 |
+
related_ids = {e.id for e in related}
|
365 |
+
# Ожидаем chunk1, chunk2, chunk3, table1
|
366 |
+
assert len(related) == 4
|
367 |
+
assert chunk1_doc1.id in related_ids
|
368 |
+
assert table1_doc1.id in related_ids
|
369 |
+
|
370 |
+
def test_get_related_all_directions(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, chunk2_doc1: Chunk, link1: LinkerEntity):
|
371 |
+
"""Тест поиска связей во всех направлениях (по умолчанию)."""
|
372 |
+
related_c1 = populated_repository.get_related_entities([chunk1_doc1.id]) # source для link1, child для doc1
|
373 |
+
related_c1_ids = {e.id for e in related_c1}
|
374 |
+
assert len(related_c1) == 2 # link1, chunk2_doc1 (target)
|
375 |
+
assert link1.id in related_c1_ids
|
376 |
+
assert chunk2_doc1.id in related_c1_ids
|
377 |
+
|
378 |
+
related_c2 = populated_repository.get_related_entities([chunk2_doc1.id]) # target для link1, child для doc1
|
379 |
+
related_c2_ids = {e.id for e in related_c2}
|
380 |
+
assert len(related_c2) == 2 # link1, chunk1_doc1 (source)
|
381 |
+
assert link1.id in related_c2_ids
|
382 |
+
assert chunk1_doc1.id in related_c2_ids
|
383 |
+
|
384 |
+
related_doc = populated_repository.get_related_entities([doc1.id]) # owner для chunk1/2/3, table1
|
385 |
+
related_doc_ids = {e.id for e in related_doc}
|
386 |
+
assert len(related_doc) == 4 # chunk1, chunk2, chunk3, table1
|
387 |
+
assert chunk1_doc1.id in related_doc_ids
|
388 |
+
|
389 |
+
def test_get_related_filter_by_type(self, populated_repository: InMemoryEntityRepository, doc1: DocumentAsEntity, chunk1_doc1: Chunk, table1_doc1: TableEntity):
|
390 |
+
"""Тест фильтрации связей по типу."""
|
391 |
+
# Ищем только чанки, принадлежащие doc1
|
392 |
+
related_chunks = populated_repository.get_related_entities([doc1.id], as_owner=True, relation_type=Chunk)
|
393 |
+
related_ids = {e.id for e in related_chunks}
|
394 |
+
assert len(related_chunks) == 3
|
395 |
+
assert chunk1_doc1.id in related_ids
|
396 |
+
assert table1_doc1.id not in related_ids
|
397 |
+
|
398 |
+
# Ищем только таблицы, принадлежащие doc1
|
399 |
+
related_tables = populated_repository.get_related_entities([doc1.id], as_owner=True, relation_type=TableEntity)
|
400 |
+
assert len(related_tables) == 1
|
401 |
+
assert related_tables[0].id == table1_doc1.id
|
402 |
+
|
403 |
+
# Ищем только связи типа CustomLink, где chunk1 - источник
|
404 |
+
related_custom_link = populated_repository.get_related_entities([chunk1_doc1.id], as_source=True, relation_type=LinkerEntity) # Используем базовый тип, т.к. CustomLink не регистрировали
|
405 |
+
related_custom_link_ids = {e.id for e in related_custom_link}
|
406 |
+
assert len(related_custom_link) == 2
|
407 |
+
assert link1.id in related_custom_link_ids
|
408 |
+
|
409 |
+
def test_get_related_multiple_entities_input(self, populated_repository: InMemoryEntityRepository, chunk1_doc1: Chunk, chunk2_doc1: Chunk, link1: LinkerEntity):
|
410 |
+
"""Тест поиска связей для нескольких сущностей одновременно."""
|
411 |
+
related = populated_repository.get_related_entities([chunk1_doc1.id, chunk2_doc1.id], as_source=True)
|
412 |
+
related_ids = {e.id for e in related}
|
413 |
+
# chunk1 -> link1 -> chunk2
|
414 |
+
# chunk2 -> нет связей как source
|
415 |
+
assert len(related) == 2 # link1, chunk2
|
416 |
+
assert link1.id in related_ids
|
417 |
+
assert link1.target_id in related_ids
|
418 |
+
|
419 |
+
def test_get_related_no_relations(self, populated_repository: InMemoryEntityRepository, doc2: DocumentAsEntity):
|
420 |
+
"""Тест поиска связей для сущности без связей."""
|
421 |
+
related = populated_repository.get_related_entities([doc2.id]) # У doc2 есть только дочерний chunk1_doc2
|
422 |
+
related_ids = {e.id for e in related}
|
423 |
+
assert len(related) == 1 # Находит только дочерний chunk1_doc2
|
424 |
+
assert chunk1_doc2.id in related_ids
|
425 |
+
|
426 |
+
# Ищем только source/target связи для doc2
|
427 |
+
related_links = populated_repository.get_related_entities([doc2.id], as_source=True, as_target=True)
|
428 |
+
assert len(related_links) == 0
|
429 |
+
|
430 |
+
def test_get_related_empty_input(self, populated_repository: InMemoryEntityRepository):
|
431 |
+
"""Тест поиска связей для пустого списка."""
|
432 |
+
related = populated_repository.get_related_entities([])
|
433 |
+
assert related == []
|
lib/extractor/tests/core/test_injection_builder.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit-тесты для InjectionBuilder.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from unittest.mock import MagicMock, patch
|
6 |
+
from uuid import UUID, uuid4
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
from ntr_text_fragmentation.additors import TablesProcessor
|
10 |
+
from ntr_text_fragmentation.chunking import (ChunkingStrategy,
|
11 |
+
chunking_registry,
|
12 |
+
register_chunking_strategy)
|
13 |
+
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import (
|
14 |
+
FIXED_SIZE, FixedSizeChunkingStrategy)
|
15 |
+
from ntr_text_fragmentation.core.entity_repository import EntityRepository
|
16 |
+
from lib.extractor.ntr_text_fragmentation.repositories.in_memory_repository import \
|
17 |
+
InMemoryEntityRepository
|
18 |
+
from ntr_text_fragmentation.core.injection_builder import InjectionBuilder
|
19 |
+
from ntr_text_fragmentation.models import (Chunk, DocumentAsEntity,
|
20 |
+
LinkerEntity, TableEntity,
|
21 |
+
TableRowEntity)
|
22 |
+
|
23 |
+
# --- Фикстуры --- #
|
24 |
+
|
25 |
+
# Используем реальные ID для связей в фикстурах
|
26 |
+
DOC1_ID = uuid4()
|
27 |
+
DOC2_ID = uuid4()
|
28 |
+
TABLE1_ID = uuid4()
|
29 |
+
CHUNK1_ID = uuid4()
|
30 |
+
CHUNK2_ID = uuid4()
|
31 |
+
CHUNK3_ID = uuid4()
|
32 |
+
CHUNK4_ID = uuid4()
|
33 |
+
ROW1_ID = uuid4()
|
34 |
+
ROW2_ID = uuid4()
|
35 |
+
|
36 |
+
|
37 |
+
@pytest.fixture
|
38 |
+
def doc1() -> DocumentAsEntity:
|
39 |
+
return DocumentAsEntity(id=DOC1_ID, name="Document 1", chunking_strategy_ref=FIXED_SIZE)
|
40 |
+
|
41 |
+
@pytest.fixture
|
42 |
+
def doc2() -> DocumentAsEntity:
|
43 |
+
return DocumentAsEntity(id=DOC2_ID, name="Document 2", chunking_strategy_ref=FIXED_SIZE)
|
44 |
+
|
45 |
+
# Используем FixedSizeChunk для тестирования dechunk
|
46 |
+
@pytest.fixture
|
47 |
+
def chunk1_doc1(doc1: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk:
|
48 |
+
return FixedSizeChunkingStrategy.FixedSizeChunk(
|
49 |
+
id=CHUNK1_ID, name="c1d1", text="Chunk 1 text.", owner_id=doc1.id,
|
50 |
+
number_in_relation=0, groupper="chunk",
|
51 |
+
right_sentence_part="More from chunk 1."
|
52 |
+
)
|
53 |
+
|
54 |
+
@pytest.fixture
|
55 |
+
def chunk2_doc1(doc1: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk:
|
56 |
+
return FixedSizeChunkingStrategy.FixedSizeChunk(
|
57 |
+
id=CHUNK2_ID, name="c2d1", text="Chunk 2 is here.", owner_id=doc1.id,
|
58 |
+
number_in_relation=1, groupper="chunk",
|
59 |
+
left_sentence_part="Continuation from chunk 1.", right_sentence_part="End of chunk 2."
|
60 |
+
)
|
61 |
+
|
62 |
+
@pytest.fixture
|
63 |
+
def chunk3_doc1(doc1: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk:
|
64 |
+
# Пропуск для теста разрыва
|
65 |
+
return FixedSizeChunkingStrategy.FixedSizeChunk(
|
66 |
+
id=CHUNK3_ID, name="c3d1", text="Chunk 3 after gap.", owner_id=doc1.id,
|
67 |
+
number_in_relation=3, groupper="chunk",
|
68 |
+
left_sentence_part="Start after gap."
|
69 |
+
)
|
70 |
+
|
71 |
+
@pytest.fixture
|
72 |
+
def chunk4_doc2(doc2: DocumentAsEntity) -> FixedSizeChunkingStrategy.FixedSizeChunk:
|
73 |
+
return FixedSizeChunkingStrategy.FixedSizeChunk(
|
74 |
+
id=CHUNK4_ID, name="c4d2", text="Chunk 4 from doc 2.", owner_id=doc2.id,
|
75 |
+
number_in_relation=0, groupper="chunk"
|
76 |
+
)
|
77 |
+
|
78 |
+
@pytest.fixture
|
79 |
+
def table1_doc1(doc1: DocumentAsEntity) -> TableEntity:
|
80 |
+
return TableEntity(id=TABLE1_ID, name="t1d1", text="Table 1 representation", owner_id=doc1.id, number_in_relation=0, groupper="table")
|
81 |
+
|
82 |
+
@pytest.fixture
|
83 |
+
def row1_table1(table1_doc1: TableEntity) -> TableRowEntity:
|
84 |
+
return TableRowEntity(id=ROW1_ID, name="r1t1", cells=["a", "b"], owner_id=table1_doc1.id, number_in_relation=0, groupper="row")
|
85 |
+
|
86 |
+
@pytest.fixture
|
87 |
+
def row2_table1(table1_doc1: TableEntity) -> TableRowEntity:
|
88 |
+
return TableRowEntity(id=ROW2_ID, name="r2t1", cells=["c", "d"], owner_id=table1_doc1.id, number_in_relation=1, groupper="row")
|
89 |
+
|
90 |
+
|
91 |
+
@pytest.fixture
|
92 |
+
def all_test_entities(
|
93 |
+
doc1, doc2, chunk1_doc1, chunk2_doc1, chunk3_doc1, chunk4_doc2,
|
94 |
+
table1_doc1, row1_table1, row2_table1
|
95 |
+
) -> list[LinkerEntity]:
|
96 |
+
# Собираем все созданные сущности
|
97 |
+
return [
|
98 |
+
doc1, doc2, chunk1_doc1, chunk2_doc1, chunk3_doc1, chunk4_doc2,
|
99 |
+
table1_doc1, row1_table1, row2_table1
|
100 |
+
]
|
101 |
+
|
102 |
+
@pytest.fixture
|
103 |
+
def serialized_entities(all_test_entities: list[LinkerEntity]) -> list[LinkerEntity]:
|
104 |
+
# Сериализуем, как они хранятся в репозитории
|
105 |
+
return [e.serialize() for e in all_test_entities]
|
106 |
+
|
107 |
+
@pytest.fixture
|
108 |
+
def test_repository(serialized_entities: list[LinkerEntity]) -> InMemoryEntityRepository:
|
109 |
+
# Репозиторий для тестов InjectionBuilder
|
110 |
+
return InMemoryEntityRepository(serialized_entities)
|
111 |
+
|
112 |
+
# --- Моки --- #
|
113 |
+
|
114 |
+
@pytest.fixture
|
115 |
+
def mock_strategy_class() -> MagicMock:
|
116 |
+
"""Мок класса стратегии чанкинга."""
|
117 |
+
mock_cls = MagicMock(spec=FixedSizeChunkingStrategy)
|
118 |
+
# Мокируем classmethod dechunk
|
119 |
+
mock_cls.dechunk.return_value = "[Dechunked Text]"
|
120 |
+
return mock_cls
|
121 |
+
|
122 |
+
@pytest.fixture
|
123 |
+
def mock_tables_processor_build() -> MagicMock:
|
124 |
+
"""Мок статического метода TablesProcessor.build."""
|
125 |
+
with patch.object(TablesProcessor, 'build', return_value="[Built Tables]") as mock_build:
|
126 |
+
yield mock_build
|
127 |
+
|
128 |
+
@pytest.fixture(autouse=True)
|
129 |
+
def setup_mocks(mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock):
|
130 |
+
"""Автоматически применяет моки для реестра и процессора таблиц."""
|
131 |
+
# Регистрируем мок-стратегию в реестре
|
132 |
+
with patch.dict(chunking_registry._chunking_strategies, {FIXED_SIZE: mock_strategy_class}, clear=True):
|
133 |
+
yield # Позволяем тестам выполниться с моками
|
134 |
+
|
135 |
+
|
136 |
+
# --- Тесты --- #
|
137 |
+
class TestInjectionBuilder:
|
138 |
+
"""Тесты для InjectionBuilder."""
|
139 |
+
|
140 |
+
def test_init_with_repository(self, test_repository: InMemoryEntityRepository):
|
141 |
+
"""Тест инициализации с репозиторием."""
|
142 |
+
builder = InjectionBuilder(repository=test_repository)
|
143 |
+
assert builder.repository is test_repository
|
144 |
+
|
145 |
+
def test_init_with_entities(self, serialized_entities: list[LinkerEntity]):
|
146 |
+
"""Тест инициализации со списком сущностей."""
|
147 |
+
builder = InjectionBuilder(entities=serialized_entities)
|
148 |
+
assert isinstance(builder.repository, InMemoryEntityRepository)
|
149 |
+
assert builder.repository.entities == serialized_entities
|
150 |
+
|
151 |
+
def test_init_errors(self, test_repository: InMemoryEntityRepository, serialized_entities: list[LinkerEntity]):
|
152 |
+
"""Тест ошибок при инициализации."""
|
153 |
+
with pytest.raises(ValueError, match="Необходимо указать либо repository, либо entities"):
|
154 |
+
InjectionBuilder()
|
155 |
+
with pytest.raises(ValueError, match="Использование одновременно repository и entities не допускается"):
|
156 |
+
InjectionBuilder(repository=test_repository, entities=serialized_entities)
|
157 |
+
|
158 |
+
def test_build_simple(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, chunk2_doc1: Chunk):
|
159 |
+
"""Тест простого сценария сборки (только чанки одного документа)."""
|
160 |
+
builder = InjectionBuilder(repository=test_repository)
|
161 |
+
selected_ids = [chunk1_doc1.id, chunk2_doc1.id]
|
162 |
+
|
163 |
+
# Ожидаем, что dechunk будет вызван с правильными аргументами
|
164 |
+
mock_strategy_class.dechunk.return_value = "Chunk 1 text. Chunk 2 is here."
|
165 |
+
|
166 |
+
result = builder.build(selected_ids, include_tables=False)
|
167 |
+
|
168 |
+
# Проверка вызова get_entities_by_ids
|
169 |
+
# (Не можем легко проверить без мокирования самого репозитория, но проверим косвенно через вызовы dechunk/build)
|
170 |
+
|
171 |
+
# Проверка вызова dechunk
|
172 |
+
mock_strategy_class.dechunk.assert_called_once()
|
173 |
+
call_args, _ = mock_strategy_class.dechunk.call_args
|
174 |
+
assert call_args[0] is test_repository
|
175 |
+
# Переданные сущности должны быть десериализованы и отсортированы
|
176 |
+
passed_entities = call_args[1]
|
177 |
+
assert len(passed_entities) == 2
|
178 |
+
assert all(isinstance(e, Chunk) for e in passed_entities)
|
179 |
+
assert passed_entities[0].id == chunk1_doc1.id
|
180 |
+
assert passed_entities[1].id == chunk2_doc1.id
|
181 |
+
|
182 |
+
# Проверка вызова TablesProcessor.build (не должен вызываться)
|
183 |
+
mock_tables_processor_build.assert_not_called()
|
184 |
+
|
185 |
+
# Проверка результата
|
186 |
+
expected = (
|
187 |
+
"## [Источник] - Document 1\n\n"
|
188 |
+
"### Текст\nChunk 1 text. Chunk 2 is here.\n\n"
|
189 |
+
)
|
190 |
+
assert result == expected
|
191 |
+
|
192 |
+
def test_build_with_tables(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, row1_table1: TableRowEntity):
|
193 |
+
"""Тест сборки с чанками и таблицами."""
|
194 |
+
builder = InjectionBuilder(repository=test_repository)
|
195 |
+
selected_ids = [chunk1_doc1.id, row1_table1.id]
|
196 |
+
|
197 |
+
mock_strategy_class.dechunk.return_value = "Chunk 1 text."
|
198 |
+
mock_tables_processor_build.return_value = "Table Row 1: [a, b]"
|
199 |
+
|
200 |
+
result = builder.build(selected_ids, include_tables=True)
|
201 |
+
|
202 |
+
mock_strategy_class.dechunk.assert_called_once()
|
203 |
+
# dechunk вызывается только с чанками
|
204 |
+
dechunk_args, _ = mock_strategy_class.dechunk.call_args
|
205 |
+
assert len(dechunk_args[1]) == 1
|
206 |
+
assert dechunk_args[1][0].id == chunk1_doc1.id
|
207 |
+
|
208 |
+
mock_tables_processor_build.assert_called_once()
|
209 |
+
# build вызывается со всеми сущностями группы
|
210 |
+
build_args, _ = mock_tables_processor_build.call_args
|
211 |
+
assert build_args[0] is test_repository
|
212 |
+
assert len(build_args[1]) == 2
|
213 |
+
build_entity_ids = {e.id for e in build_args[1]}
|
214 |
+
assert chunk1_doc1.id in build_entity_ids
|
215 |
+
assert row1_table1.id in build_entity_ids
|
216 |
+
|
217 |
+
expected = (
|
218 |
+
"## [Источник] - Document 1\n\n"
|
219 |
+
"### Текст\nChunk 1 text.\n\n"
|
220 |
+
"### Таблицы\nTable Row 1: [a, b]\n\n"
|
221 |
+
)
|
222 |
+
assert result == expected
|
223 |
+
|
224 |
+
def test_build_include_tables_false(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, row1_table1: TableRowEntity):
|
225 |
+
"""Тест сборки с include_tables=False."""
|
226 |
+
builder = InjectionBuilder(repository=test_repository)
|
227 |
+
selected_ids = [chunk1_doc1.id, row1_table1.id] # Передаем строку таблицы
|
228 |
+
|
229 |
+
mock_strategy_class.dechunk.return_value = "Chunk 1 text."
|
230 |
+
|
231 |
+
result = builder.build(selected_ids, include_tables=False)
|
232 |
+
|
233 |
+
# dechunk вызывается только с чанком
|
234 |
+
mock_strategy_class.dechunk.assert_called_once()
|
235 |
+
dechunk_args, _ = mock_strategy_class.dechunk.call_args
|
236 |
+
assert len(dechunk_args[1]) == 1
|
237 |
+
assert dechunk_args[1][0].id == chunk1_doc1.id
|
238 |
+
|
239 |
+
# TablesProcessor.build не должен вызываться
|
240 |
+
mock_tables_processor_build.assert_not_called()
|
241 |
+
|
242 |
+
expected = (
|
243 |
+
"## [Источник] - Document 1\n\n"
|
244 |
+
"### Текст\nChunk 1 text.\n\n"
|
245 |
+
# Секции таблиц нет
|
246 |
+
)
|
247 |
+
assert result == expected
|
248 |
+
|
249 |
+
def test_build_with_neighbors(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk, chunk2_doc1: Chunk):
|
250 |
+
"""Тест сборки с добавлением соседей."""
|
251 |
+
builder = InjectionBuilder(repository=test_repository)
|
252 |
+
selected_ids = [chunk1_doc1.id]
|
253 |
+
neighbors_distance = 1
|
254 |
+
|
255 |
+
# Мокируем get_neighboring_entities, чтобы убедиться, что он вызывается
|
256 |
+
with patch.object(test_repository, 'get_neighboring_entities', wraps=test_repository.get_neighboring_entities) as mock_get_neighbors:
|
257 |
+
mock_strategy_class.dechunk.return_value = "Chunk 1 text. Chunk 2 is here."
|
258 |
+
result = builder.build(selected_ids, neighbors_max_distance=neighbors_distance)
|
259 |
+
|
260 |
+
mock_get_neighbors.assert_called_once()
|
261 |
+
call_args, _ = mock_get_neighbors.call_args
|
262 |
+
# Первым аргументом должны быть десериализованные сущности из selected_ids
|
263 |
+
assert len(call_args[0]) == 1
|
264 |
+
assert isinstance(call_args[0][0], Chunk)
|
265 |
+
assert call_args[0][0].id == chunk1_doc1.id
|
266 |
+
# Второй аргумент - max_distance
|
267 |
+
assert call_args[1] == neighbors_distance
|
268 |
+
|
269 |
+
# Проверяем, что dechunk вызван с chunk1 и его соседом chunk2
|
270 |
+
mock_strategy_class.dechunk.assert_called_once()
|
271 |
+
dechunk_args, _ = mock_strategy_class.dechunk.call_args
|
272 |
+
assert len(dechunk_args[1]) == 2
|
273 |
+
dechunk_ids = {e.id for e in dechunk_args[1]}
|
274 |
+
assert chunk1_doc1.id in dechunk_ids
|
275 |
+
assert chunk2_doc1.id in dechunk_ids
|
276 |
+
|
277 |
+
expected = (
|
278 |
+
"## [Источник] - Document 1\n\n"
|
279 |
+
"### Текст\nChunk 1 text. Chunk 2 is here.\n\n"
|
280 |
+
)
|
281 |
+
assert result == expected
|
282 |
+
|
283 |
+
def test_build_multiple_docs_no_limit(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk, chunk4_doc2: Chunk):
|
284 |
+
"""Тест сборки сущностей из разных документов без лимита."""
|
285 |
+
builder = InjectionBuilder(repository=test_repository)
|
286 |
+
selected_ids = [chunk1_doc1.id, chunk4_doc2.id]
|
287 |
+
|
288 |
+
# Настроим возвращаемые значения для dechunk (вызывается дважды)
|
289 |
+
mock_strategy_class.dechunk.side_effect = [
|
290 |
+
"Chunk 1 text.", # Для doc1
|
291 |
+
"Chunk 4 from doc 2." # Для doc2
|
292 |
+
]
|
293 |
+
|
294 |
+
result = builder.build(selected_ids, include_tables=False)
|
295 |
+
|
296 |
+
assert mock_strategy_class.dechunk.call_count == 2
|
297 |
+
# TablesProcessor.build не должен вызываться
|
298 |
+
mock_tables_processor_build.assert_not_called()
|
299 |
+
|
300 |
+
# Порядок документов определяется дефолтными скорами (по убыванию индекса)
|
301 |
+
# chunk4_doc2 (score=2.0) > chunk1_doc1 (score=1.0)
|
302 |
+
expected = (
|
303 |
+
"## [Источник] - Document 2\n\n"
|
304 |
+
"### Текст\nChunk 4 from doc 2.\n\n"
|
305 |
+
"\n\n"
|
306 |
+
"## [Источник] - Document 1\n\n"
|
307 |
+
"### Текст\nChunk 1 text.\n\n"
|
308 |
+
)
|
309 |
+
assert result == expected
|
310 |
+
|
311 |
+
def test_build_multiple_docs_with_scores(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk, chunk4_doc2: Chunk):
|
312 |
+
"""Тест сборки сущностей из разных документов с заданными скорами."""
|
313 |
+
builder = InjectionBuilder(repository=test_repository)
|
314 |
+
selected_entities = [
|
315 |
+
test_repository.entities_by_id[chunk4_doc2.id], # doc2
|
316 |
+
test_repository.entities_by_id[chunk1_doc1.id] # doc1
|
317 |
+
]
|
318 |
+
scores = [0.5, 0.9] # doc1 > doc2
|
319 |
+
|
320 |
+
mock_strategy_class.dechunk.side_effect = [
|
321 |
+
"Chunk 1 text.", # doc1
|
322 |
+
"Chunk 4 from doc 2." # doc2
|
323 |
+
]
|
324 |
+
|
325 |
+
result = builder.build(selected_entities, scores=scores, include_tables=False)
|
326 |
+
|
327 |
+
# Проверяем порядок документов в результате (doc1 должен быть первым)
|
328 |
+
expected = (
|
329 |
+
"## [Источник] - Document 1\n\n"
|
330 |
+
"### Текст\nChunk 1 text.\n\n"
|
331 |
+
"\n\n"
|
332 |
+
"## [Источник] - Document 2\n\n"
|
333 |
+
"### Текст\nChunk 4 from doc 2.\n\n"
|
334 |
+
)
|
335 |
+
assert result == expected
|
336 |
+
|
337 |
+
def test_build_max_documents(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk, chunk4_doc2: Chunk):
|
338 |
+
"""Тест сборки с ограничением max_documents."""
|
339 |
+
builder = InjectionBuilder(repository=test_repository)
|
340 |
+
selected_ids = [chunk1_doc1.id, chunk4_doc2.id]
|
341 |
+
|
342 |
+
# doc2 (score 2.0) > doc1 (score 1.0)
|
343 |
+
mock_strategy_class.dechunk.return_value = "Chunk 4 from doc 2."
|
344 |
+
|
345 |
+
result = builder.build(selected_ids, max_documents=1, include_tables=False)
|
346 |
+
|
347 |
+
# Должен быть вызван dechunk только один раз для документа с наивысшим скором (doc2)
|
348 |
+
mock_strategy_class.dechunk.assert_called_once()
|
349 |
+
|
350 |
+
expected = (
|
351 |
+
"## [Источник] - Document 2\n\n"
|
352 |
+
"### Текст\nChunk 4 from doc 2.\n\n"
|
353 |
+
)
|
354 |
+
assert result == expected
|
355 |
+
|
356 |
+
def test_build_custom_prefix(self, test_repository: InMemoryEntityRepository, mock_strategy_class: MagicMock, chunk1_doc1: Chunk):
|
357 |
+
"""Тест сборки с кастомным префиксом документа."""
|
358 |
+
builder = InjectionBuilder(repository=test_repository)
|
359 |
+
selected_ids = [chunk1_doc1.id]
|
360 |
+
custom_prefix = "Source Doc: "
|
361 |
+
|
362 |
+
mock_strategy_class.dechunk.return_value = "Chunk 1 text."
|
363 |
+
|
364 |
+
result = builder.build(selected_ids, document_prefix=custom_prefix, include_tables=False)
|
365 |
+
|
366 |
+
expected = (
|
367 |
+
f"## {custom_prefix}Document 1\n\n"
|
368 |
+
"### Текст\nChunk 1 text.\n\n"
|
369 |
+
)
|
370 |
+
assert result == expected
|
371 |
+
|
372 |
+
def test_build_empty_entities(self, test_repository: InMemoryEntityRepository):
|
373 |
+
"""Тест сборки с пустым списком сущностей."""
|
374 |
+
builder = InjectionBuilder(repository=test_repository)
|
375 |
+
result = builder.build([])
|
376 |
+
assert result == ""
|
377 |
+
|
378 |
+
def test_build_unknown_ids(self, test_repository: InMemoryEntityRepository):
|
379 |
+
"""Тест сборки с неизвестными ID."""
|
380 |
+
builder = InjectionBuilder(repository=test_repository)
|
381 |
+
result = builder.build([uuid4(), uuid4()]) # Передаем несуществующие ID
|
382 |
+
assert result == ""
|
383 |
+
|
384 |
+
def test_build_no_strategy_for_doc(self, test_repository: InMemoryEntityRepository, mock_tables_processor_build: MagicMock, chunk1_doc1: Chunk):
|
385 |
+
"""Тест сборки, если у документа нет chunking_strategy_ref."""
|
386 |
+
# Убираем ссылку на стратегию у документа
|
387 |
+
doc1_entity = test_repository.entities_by_id[chunk1_doc1.owner_id]
|
388 |
+
original_ref = doc1_entity.chunking_strategy_ref
|
389 |
+
doc1_entity.chunking_strategy_ref = None
|
390 |
+
|
391 |
+
builder = InjectionBuilder(repository=test_repository)
|
392 |
+
selected_ids = [chunk1_doc1.id]
|
393 |
+
|
394 |
+
mock_tables_processor_build.return_value = "[Tables]"
|
395 |
+
|
396 |
+
# dechunk не должен вызываться
|
397 |
+
with patch.object(chunking_registry[FIXED_SIZE], 'dechunk') as mock_dechunk:
|
398 |
+
result = builder.build(selected_ids, include_tables=True)
|
399 |
+
mock_dechunk.assert_not_called()
|
400 |
+
|
401 |
+
# build для таблиц должен вызваться
|
402 |
+
mock_tables_processor_build.assert_called_once()
|
403 |
+
|
404 |
+
expected = (
|
405 |
+
"## [Источник] - Document 1\n\n"
|
406 |
+
# Секции Текст нет
|
407 |
+
"### Таблицы\n[Tables]\n\n"
|
408 |
+
)
|
409 |
+
assert result == expected
|
410 |
+
|
411 |
+
# Восстанавливаем ссылку
|
412 |
+
doc1_entity.chunking_strategy_ref = original_ref
|
lib/extractor/tests/custom_entity.py
CHANGED
@@ -2,105 +2,3 @@ from uuid import UUID
|
|
2 |
|
3 |
from ntr_text_fragmentation.models.linker_entity import (LinkerEntity,
|
4 |
register_entity)
|
5 |
-
|
6 |
-
|
7 |
-
@register_entity
|
8 |
-
class CustomEntity(LinkerEntity):
|
9 |
-
"""Пользовательский класс-наследник LinkerEntity для тестирования сериализации и десериализации."""
|
10 |
-
|
11 |
-
def __init__(
|
12 |
-
self,
|
13 |
-
id: UUID,
|
14 |
-
name: str,
|
15 |
-
text: str,
|
16 |
-
metadata: dict,
|
17 |
-
custom_field1: str,
|
18 |
-
custom_field2: int,
|
19 |
-
in_search_text: str | None = None,
|
20 |
-
source_id: UUID | None = None,
|
21 |
-
target_id: UUID | None = None,
|
22 |
-
number_in_relation: int | None = None,
|
23 |
-
type: str = "CustomEntity"
|
24 |
-
):
|
25 |
-
super().__init__(
|
26 |
-
id=id,
|
27 |
-
name=name,
|
28 |
-
text=text,
|
29 |
-
metadata=metadata,
|
30 |
-
in_search_text=in_search_text,
|
31 |
-
source_id=source_id,
|
32 |
-
target_id=target_id,
|
33 |
-
number_in_relation=number_in_relation,
|
34 |
-
type=type
|
35 |
-
)
|
36 |
-
self.custom_field1 = custom_field1
|
37 |
-
self.custom_field2 = custom_field2
|
38 |
-
|
39 |
-
def deserialize(self, entity: LinkerEntity) -> 'CustomEntity':
|
40 |
-
"""Реализация метода десериализации для кастомного класса."""
|
41 |
-
custom_field1 = entity.metadata.get('_custom_field1', '')
|
42 |
-
custom_field2 = entity.metadata.get('_custom_field2', 0)
|
43 |
-
|
44 |
-
# Создаем чистые метаданные без служебных полей
|
45 |
-
clean_metadata = {k: v for k, v in entity.metadata.items()
|
46 |
-
if not k.startswith('_')}
|
47 |
-
|
48 |
-
return CustomEntity(
|
49 |
-
id=entity.id,
|
50 |
-
name=entity.name,
|
51 |
-
text=entity.text,
|
52 |
-
in_search_text=entity.in_search_text,
|
53 |
-
metadata=clean_metadata,
|
54 |
-
source_id=entity.source_id,
|
55 |
-
target_id=entity.target_id,
|
56 |
-
number_in_relation=entity.number_in_relation,
|
57 |
-
custom_field1=custom_field1,
|
58 |
-
custom_field2=custom_field2
|
59 |
-
)
|
60 |
-
|
61 |
-
@classmethod
|
62 |
-
def deserialize(cls, entity: LinkerEntity) -> 'CustomEntity':
|
63 |
-
"""
|
64 |
-
Классовый метод для десериализации.
|
65 |
-
Необходим для работы с реестром классов.
|
66 |
-
|
67 |
-
Args:
|
68 |
-
entity: Сериализованная сущность
|
69 |
-
|
70 |
-
Returns:
|
71 |
-
Десериализованный экземпляр CustomEntity
|
72 |
-
"""
|
73 |
-
custom_field1 = entity.metadata.get('_custom_field1', '')
|
74 |
-
custom_field2 = entity.metadata.get('_custom_field2', 0)
|
75 |
-
|
76 |
-
# Создаем чистые метаданные без служебных полей
|
77 |
-
clean_metadata = {k: v for k, v in entity.metadata.items()
|
78 |
-
if not k.startswith('_')}
|
79 |
-
|
80 |
-
return CustomEntity(
|
81 |
-
id=entity.id,
|
82 |
-
name=entity.name,
|
83 |
-
text=entity.text,
|
84 |
-
in_search_text=entity.in_search_text,
|
85 |
-
metadata=clean_metadata,
|
86 |
-
source_id=entity.source_id,
|
87 |
-
target_id=entity.target_id,
|
88 |
-
number_in_relation=entity.number_in_relation,
|
89 |
-
custom_field1=custom_field1,
|
90 |
-
custom_field2=custom_field2
|
91 |
-
)
|
92 |
-
|
93 |
-
def __eq__(self, other):
|
94 |
-
"""Переопределяем метод сравнения для проверки равенства объектов."""
|
95 |
-
if not isinstance(other, CustomEntity):
|
96 |
-
return False
|
97 |
-
|
98 |
-
# Используем базовое сравнение из LinkerEntity, которое уже учитывает поля связи
|
99 |
-
base_equality = super().__eq__(other)
|
100 |
-
|
101 |
-
# Дополнительно проверяем кастомные поля
|
102 |
-
return (
|
103 |
-
base_equality
|
104 |
-
and self.custom_field1 == other.custom_field1
|
105 |
-
and self.custom_field2 == other.custom_field2
|
106 |
-
)
|
|
|
2 |
|
3 |
from ntr_text_fragmentation.models.linker_entity import (LinkerEntity,
|
4 |
register_entity)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/extractor/tests/models/test_linker_entity.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unit-тесты для базового класса LinkerEntity и его механизма сериализации/десериализации.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import uuid
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from uuid import UUID, uuid4
|
8 |
+
|
9 |
+
import pytest
|
10 |
+
from ntr_text_fragmentation.models import LinkerEntity, register_entity
|
11 |
+
from tests.custom_entity import \
|
12 |
+
CustomEntity # Используем существующий кастомный класс
|
13 |
+
|
14 |
+
|
15 |
+
# Фикстуры
|
16 |
+
@pytest.fixture
|
17 |
+
def base_entity() -> LinkerEntity:
|
18 |
+
"""Фикстура для базовой сущности."""
|
19 |
+
return LinkerEntity(id=uuid4(), name="Base Name", text="Base Text")
|
20 |
+
|
21 |
+
|
22 |
+
@pytest.fixture
|
23 |
+
def link_entity() -> LinkerEntity:
|
24 |
+
"""Фикстура для сущности-связи."""
|
25 |
+
return LinkerEntity(
|
26 |
+
id=uuid4(),
|
27 |
+
name="Link Name",
|
28 |
+
source_id=uuid4(),
|
29 |
+
target_id=uuid4(),
|
30 |
+
number_in_relation=1,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.fixture
|
35 |
+
def custom_entity_instance() -> CustomEntity:
|
36 |
+
"""Фикстура для кастомной сущности."""
|
37 |
+
return CustomEntity(
|
38 |
+
id=uuid4(),
|
39 |
+
name="Custom Name",
|
40 |
+
text="Custom Text",
|
41 |
+
custom_field="custom_value",
|
42 |
+
metadata={"existing_meta": "meta_value"},
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
@pytest.fixture
|
47 |
+
def serialized_custom_entity(
|
48 |
+
custom_entity_instance: CustomEntity,
|
49 |
+
) -> LinkerEntity:
|
50 |
+
"""Фикстура для сериализованной кастомной сущности."""
|
51 |
+
return custom_entity_instance.serialize()
|
52 |
+
|
53 |
+
|
54 |
+
# Тесты
|
55 |
+
class TestLinkerEntity:
|
56 |
+
"""Тесты для класса LinkerEntity."""
|
57 |
+
|
58 |
+
def test_initialization_defaults(self):
|
59 |
+
"""Тест инициализации с значениями по умолчанию."""
|
60 |
+
entity = LinkerEntity()
|
61 |
+
assert isinstance(entity.id, UUID)
|
62 |
+
assert entity.name == ""
|
63 |
+
assert entity.text == ""
|
64 |
+
assert entity.metadata == {}
|
65 |
+
assert entity.in_search_text is None
|
66 |
+
assert entity.source_id is None
|
67 |
+
assert entity.target_id is None
|
68 |
+
assert entity.number_in_relation is None
|
69 |
+
assert entity.groupper is None
|
70 |
+
assert entity.type == "LinkerEntity" # Имя класса по умолчанию
|
71 |
+
|
72 |
+
def test_initialization_with_values(self, base_entity: LinkerEntity):
|
73 |
+
"""Тест инициализации с заданными значениями."""
|
74 |
+
entity_id = base_entity.id
|
75 |
+
assert base_entity.name == "Base Name"
|
76 |
+
assert base_entity.text == "Base Text"
|
77 |
+
assert base_entity.id == entity_id
|
78 |
+
|
79 |
+
def test_is_link(self, base_entity: LinkerEntity, link_entity: LinkerEntity):
|
80 |
+
"""Тест метода is_link()."""
|
81 |
+
assert not base_entity.is_link()
|
82 |
+
assert link_entity.is_link()
|
83 |
+
|
84 |
+
def test_owner_id_property(self, base_entity: LinkerEntity, link_entity: LinkerEntity):
|
85 |
+
"""Тест свойства owner_id."""
|
86 |
+
# У обычной сущности owner_id это target_id
|
87 |
+
owner_uuid = uuid4()
|
88 |
+
base_entity.target_id = owner_uuid
|
89 |
+
assert base_entity.owner_id == owner_uuid
|
90 |
+
|
91 |
+
# У связи нет owner_id
|
92 |
+
assert link_entity.owner_id is None
|
93 |
+
|
94 |
+
# Попытка установить owner_id для связи должна вызвать ошибку
|
95 |
+
with pytest.raises(ValueError, match="Связь не может иметь владельца"):
|
96 |
+
link_entity.owner_id = uuid4()
|
97 |
+
|
98 |
+
# Установка owner_id для обычной сущности
|
99 |
+
new_owner_id = uuid4()
|
100 |
+
base_entity.owner_id = new_owner_id
|
101 |
+
assert base_entity.target_id == new_owner_id
|
102 |
+
|
103 |
+
def test_str_representation(self, base_entity: LinkerEntity):
|
104 |
+
"""Тест строкового представления __str__."""
|
105 |
+
assert str(base_entity) == "Base Name: Base Text"
|
106 |
+
|
107 |
+
base_entity.in_search_text = "Search text representation"
|
108 |
+
assert str(base_entity) == "Search text representation"
|
109 |
+
|
110 |
+
def test_equality(self, base_entity: LinkerEntity):
|
111 |
+
"""Тест сравнения __eq__."""
|
112 |
+
entity_copy = LinkerEntity(
|
113 |
+
id=base_entity.id, name="Base Name", text="Base Text"
|
114 |
+
)
|
115 |
+
different_entity = LinkerEntity(name="Different Name")
|
116 |
+
|
117 |
+
assert base_entity == entity_copy
|
118 |
+
assert base_entity != different_entity
|
119 |
+
assert base_entity != "not an entity"
|
120 |
+
|
121 |
+
def test_equality_links(self, link_entity: LinkerEntity):
|
122 |
+
"""Тест сравнения связей."""
|
123 |
+
link_copy = LinkerEntity(
|
124 |
+
id=link_entity.id,
|
125 |
+
name="Link Name",
|
126 |
+
source_id=link_entity.source_id,
|
127 |
+
target_id=link_entity.target_id,
|
128 |
+
number_in_relation=1,
|
129 |
+
)
|
130 |
+
different_link = LinkerEntity(
|
131 |
+
id=link_entity.id,
|
132 |
+
name="Link Name",
|
133 |
+
source_id=uuid4(), # Другой source_id
|
134 |
+
target_id=link_entity.target_id,
|
135 |
+
number_in_relation=1,
|
136 |
+
)
|
137 |
+
non_link = LinkerEntity(id=link_entity.id)
|
138 |
+
|
139 |
+
assert link_entity == link_copy
|
140 |
+
assert link_entity != different_link
|
141 |
+
assert link_entity != non_link
|
142 |
+
|
143 |
+
# --- Тесты сериализации/десериализации ---
|
144 |
+
|
145 |
+
def test_serialize_base_entity(self, base_entity: LinkerEntity):
|
146 |
+
"""Тест сериализации базовой сущности."""
|
147 |
+
serialized = base_entity.serialize()
|
148 |
+
assert isinstance(serialized, LinkerEntity)
|
149 |
+
# Проверяем, что это не тот же самый объект, а копия базового типа
|
150 |
+
assert serialized is not base_entity
|
151 |
+
assert type(serialized) is LinkerEntity
|
152 |
+
assert serialized.id == base_entity.id
|
153 |
+
assert serialized.name == base_entity.name
|
154 |
+
assert serialized.text == base_entity.text
|
155 |
+
assert serialized.metadata == {} # Нет доп. полей
|
156 |
+
assert serialized.type == "LinkerEntity" # Сохраняем тип
|
157 |
+
|
158 |
+
def test_serialize_custom_entity(
|
159 |
+
self,
|
160 |
+
custom_entity_instance: CustomEntity,
|
161 |
+
serialized_custom_entity: LinkerEntity,
|
162 |
+
):
|
163 |
+
"""Тест сериализации кастомной сущности."""
|
164 |
+
serialized = serialized_custom_entity # Используем фикстуру
|
165 |
+
|
166 |
+
assert isinstance(serialized, LinkerEntity)
|
167 |
+
assert type(serialized) is LinkerEntity
|
168 |
+
assert serialized.id == custom_entity_instance.id
|
169 |
+
assert serialized.name == custom_entity_instance.name
|
170 |
+
assert serialized.text == custom_entity_instance.text
|
171 |
+
# Проверяем, что кастомное поле и исходные метаданные попали в metadata
|
172 |
+
assert "_custom_field" in serialized.metadata
|
173 |
+
assert serialized.metadata["_custom_field"] == "custom_value"
|
174 |
+
assert "existing_meta" in serialized.metadata
|
175 |
+
assert serialized.metadata["existing_meta"] == "meta_value"
|
176 |
+
# Тип должен быть именем кастомного класса
|
177 |
+
assert serialized.type == "CustomEntity"
|
178 |
+
|
179 |
+
def test_deserialize_custom_entity(
|
180 |
+
self, serialized_custom_entity: LinkerEntity
|
181 |
+
):
|
182 |
+
"""Тест десериализации в кастомный тип."""
|
183 |
+
# Используем класс CustomEntity для десериализации, так как он зарегистрирован
|
184 |
+
deserialized = LinkerEntity._deserialize(serialized_custom_entity)
|
185 |
+
|
186 |
+
assert isinstance(deserialized, CustomEntity)
|
187 |
+
assert deserialized.id == serialized_custom_entity.id
|
188 |
+
assert deserialized.name == serialized_custom_entity.name
|
189 |
+
assert deserialized.text == serialized_custom_entity.text
|
190 |
+
# Проверяем восстановление кастомного поля
|
191 |
+
assert deserialized.custom_field == "custom_value"
|
192 |
+
# Проверяем восстановление исходных метаданных
|
193 |
+
assert "existing_meta" in deserialized.metadata
|
194 |
+
assert deserialized.metadata["existing_meta"] == "meta_value"
|
195 |
+
assert deserialized.type == "CustomEntity" # Тип сохраняется
|
196 |
+
|
197 |
+
def test_deserialize_base_entity(self, base_entity: LinkerEntity):
|
198 |
+
"""Тест десериализации базовой сущности (должна вернуться сама)."""
|
199 |
+
serialized = base_entity.serialize() # Сериализуем базовую
|
200 |
+
deserialized = LinkerEntity._deserialize(serialized)
|
201 |
+
assert deserialized is serialized # Возвращается исходный объект LinkerEntity
|
202 |
+
assert type(deserialized) is LinkerEntity
|
203 |
+
|
204 |
+
def test_deserialize_unregistered_type(self):
|
205 |
+
"""Тест десериализации незарегистрированного типа (должен вернуться исходный объект)."""
|
206 |
+
unregistered_entity = LinkerEntity(id=uuid4(), type="UnregisteredType")
|
207 |
+
deserialized = LinkerEntity._deserialize(unregistered_entity)
|
208 |
+
assert deserialized is unregistered_entity
|
209 |
+
assert deserialized.type == "UnregisteredType"
|
210 |
+
|
211 |
+
def test_deserialize_to_me_on_custom_class(
|
212 |
+
self, serialized_custom_entity: LinkerEntity
|
213 |
+
):
|
214 |
+
"""Тест прямого вызова _deserialize_to_me на кастомном классе."""
|
215 |
+
# Вызываем метод десериализации непосредственно у CustomEntity
|
216 |
+
deserialized = CustomEntity._deserialize_to_me(serialized_custom_entity)
|
217 |
+
|
218 |
+
assert isinstance(deserialized, CustomEntity)
|
219 |
+
assert deserialized.id == serialized_custom_entity.id
|
220 |
+
assert deserialized.custom_field == "custom_value"
|
221 |
+
assert deserialized.metadata["existing_meta"] == "meta_value"
|
222 |
+
|
223 |
+
def test_deserialize_to_me_type_error(self):
|
224 |
+
"""Тест ош��бки TypeError в _deserialize_to_me при неверном типе данных."""
|
225 |
+
with pytest.raises(TypeError):
|
226 |
+
# Пытаемся десериализовать не LinkerEntity
|
227 |
+
CustomEntity._deserialize_to_me("not_an_entity") # type: ignore
|
228 |
+
|
229 |
+
def test_register_entity_decorator(self):
|
230 |
+
"""Тест работы декоратора @register_entity."""
|
231 |
+
|
232 |
+
@register_entity
|
233 |
+
@dataclass
|
234 |
+
class TempEntity(LinkerEntity):
|
235 |
+
temp_field: str = "temp"
|
236 |
+
type: str = "Temporary" # Явно указываем тип для регистрации
|
237 |
+
|
238 |
+
assert "Temporary" in LinkerEntity._entity_classes
|
239 |
+
assert LinkerEntity._entity_classes["Temporary"] is TempEntity
|
240 |
+
|
241 |
+
# Проверяем, что он десериализуется
|
242 |
+
instance = TempEntity(id=uuid4(), name="Temp instance", temp_field="value")
|
243 |
+
serialized = instance.serialize()
|
244 |
+
assert serialized.type == "Temporary"
|
245 |
+
deserialized = LinkerEntity._deserialize(serialized)
|
246 |
+
assert isinstance(deserialized, TempEntity)
|
247 |
+
assert deserialized.temp_field == "value"
|
248 |
+
|
249 |
+
# Удаляем временный класс из реестра, чтобы не влиять на другие тесты
|
250 |
+
del LinkerEntity._entity_classes["Temporary"]
|
251 |
+
assert "Temporary" not in LinkerEntity._entity_classes
|
routes/dataset.py
CHANGED
@@ -54,8 +54,7 @@ def try_create_default_dataset(dataset_service: DatasetService):
|
|
54 |
else:
|
55 |
dataset_service.create_dataset_from_directory(
|
56 |
is_default=True,
|
57 |
-
directory_with_documents=dataset_service.config.db_config.files.
|
58 |
-
directory_with_ready_dataset=dataset_service.config.db_config.files.start_path,
|
59 |
)
|
60 |
|
61 |
@router.get('/try_init_default_dataset')
|
|
|
54 |
else:
|
55 |
dataset_service.create_dataset_from_directory(
|
56 |
is_default=True,
|
57 |
+
directory_with_documents=dataset_service.config.db_config.files.documents_path,
|
|
|
58 |
)
|
59 |
|
60 |
@router.get('/try_init_default_dataset')
|
routes/entity.py
CHANGED
@@ -1,18 +1,23 @@
|
|
1 |
from typing import Annotated
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
from fastapi import APIRouter, Depends, HTTPException
|
5 |
from sqlalchemy.orm import Session
|
6 |
|
7 |
-
from common import auth
|
8 |
import common.dependencies as DI
|
|
|
9 |
from components.dbo.chunk_repository import ChunkRepository
|
10 |
from components.services.entity import EntityService
|
11 |
-
from schemas.entity import (
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
|
17 |
router = APIRouter(prefix="/entity", tags=["Entity"])
|
18 |
|
@@ -21,30 +26,30 @@ router = APIRouter(prefix="/entity", tags=["Entity"])
|
|
21 |
async def search_entities(
|
22 |
request: EntitySearchRequest,
|
23 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
24 |
-
current_user: Annotated[any, Depends(auth.get_current_user)]
|
25 |
) -> EntitySearchResponse:
|
26 |
"""
|
27 |
Поиск похожих сущностей по векторному сходству (только ID).
|
28 |
-
|
29 |
Args:
|
30 |
request: Параметры поиска
|
31 |
entity_service: Сервис для работы с сущностями
|
32 |
-
|
33 |
Returns:
|
34 |
Результаты поиска (ID и оценки), отсортированные по убыванию сходства
|
35 |
"""
|
36 |
try:
|
37 |
-
_, scores, ids = entity_service.
|
38 |
request.query,
|
39 |
request.dataset_id,
|
40 |
)
|
41 |
-
|
42 |
# Проверяем, что scores и ids - корректные numpy массивы
|
43 |
if not isinstance(scores, np.ndarray):
|
44 |
scores = np.array(scores)
|
45 |
if not isinstance(ids, np.ndarray):
|
46 |
ids = np.array(ids)
|
47 |
-
|
48 |
# Сортируем результаты по убыванию оценок
|
49 |
# Проверим, что массивы не пустые
|
50 |
if len(scores) > 0:
|
@@ -56,15 +61,14 @@ async def search_entities(
|
|
56 |
else:
|
57 |
sorted_scores = []
|
58 |
sorted_ids = []
|
59 |
-
|
60 |
return EntitySearchResponse(
|
61 |
scores=sorted_scores,
|
62 |
entity_ids=sorted_ids,
|
63 |
)
|
64 |
except Exception as e:
|
65 |
raise HTTPException(
|
66 |
-
status_code=500,
|
67 |
-
detail=f"Error during entity search: {str(e)}"
|
68 |
)
|
69 |
|
70 |
|
@@ -72,60 +76,60 @@ async def search_entities(
|
|
72 |
async def search_entities_with_text(
|
73 |
request: EntitySearchWithTextRequest,
|
74 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
75 |
-
current_user: Annotated[any, Depends(auth.get_current_user)]
|
76 |
) -> EntitySearchWithTextResponse:
|
77 |
"""
|
78 |
Поиск похожих сущностей по векторному сходству с возвратом текстов.
|
79 |
-
|
80 |
Args:
|
81 |
request: Параметры поиска
|
82 |
entity_service: Сервис для работы с сущностями
|
83 |
-
|
84 |
Returns:
|
85 |
Результаты поиска с текстами чанков, отсортированные по убыванию сходства
|
86 |
"""
|
87 |
try:
|
88 |
# Получаем результаты поиска
|
89 |
-
_, scores, entity_ids = entity_service.
|
90 |
-
request.query,
|
91 |
-
request.dataset_id
|
92 |
)
|
93 |
-
|
94 |
# Проверяем, что scores и entity_ids - корректные numpy массивы
|
95 |
if not isinstance(scores, np.ndarray):
|
96 |
scores = np.array(scores)
|
97 |
if not isinstance(entity_ids, np.ndarray):
|
98 |
entity_ids = np.array(entity_ids)
|
99 |
-
|
100 |
# Сортируем результаты по убыванию оценок
|
101 |
# Проверим, что массивы не пустые
|
102 |
if len(scores) > 0:
|
103 |
# Преобразуем индексы в список, чтобы избежать проблем с индексацией
|
104 |
sorted_indices = scores.argsort()[::-1].tolist()
|
105 |
sorted_scores = [float(scores[i]) for i in sorted_indices]
|
106 |
-
sorted_ids = [
|
107 |
-
|
108 |
# Получаем тексты чанков
|
109 |
-
chunks = entity_service.chunk_repository.
|
110 |
-
|
111 |
# Формируем ответ
|
112 |
return EntitySearchWithTextResponse(
|
113 |
chunks=[
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
119 |
for chunk, score in zip(chunks, sorted_scores)
|
120 |
]
|
121 |
)
|
122 |
else:
|
123 |
return EntitySearchWithTextResponse(chunks=[])
|
124 |
-
|
125 |
except Exception as e:
|
126 |
raise HTTPException(
|
127 |
-
status_code=500,
|
128 |
-
detail=f"Error during entity search with text: {str(e)}"
|
129 |
)
|
130 |
|
131 |
|
@@ -133,84 +137,36 @@ async def search_entities_with_text(
|
|
133 |
async def build_entity_text(
|
134 |
request: EntityTextRequest,
|
135 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
136 |
-
current_user: Annotated[any, Depends(auth.get_current_user)]
|
137 |
) -> EntityTextResponse:
|
138 |
"""
|
139 |
Сборка текста из сущностей.
|
140 |
-
|
141 |
Args:
|
142 |
request: Параметры сборки текста
|
143 |
entity_service: Сервис для работы с сущностями
|
144 |
-
|
145 |
Returns:
|
146 |
Собранный текст
|
147 |
"""
|
148 |
try:
|
149 |
-
|
150 |
-
entities = entity_service.chunk_repository.get_chunks_by_ids(request.entities)
|
151 |
-
|
152 |
-
if not entities:
|
153 |
raise HTTPException(
|
154 |
-
status_code=404,
|
155 |
-
detail="No entities found with provided IDs"
|
156 |
)
|
157 |
-
|
158 |
# Собираем текст
|
159 |
text = entity_service.build_text(
|
160 |
-
entities=entities,
|
161 |
chunk_scores=request.chunk_scores,
|
162 |
include_tables=request.include_tables,
|
163 |
max_documents=request.max_documents,
|
164 |
)
|
165 |
-
|
166 |
-
return EntityTextResponse(text=text)
|
167 |
-
except Exception as e:
|
168 |
-
raise HTTPException(
|
169 |
-
status_code=500,
|
170 |
-
detail=f"Error building entity text: {str(e)}"
|
171 |
-
)
|
172 |
|
173 |
-
|
174 |
-
@router.post("/neighbors", response_model=EntityNeighborsResponse)
|
175 |
-
async def get_neighboring_chunks(
|
176 |
-
request: EntityNeighborsRequest,
|
177 |
-
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
178 |
-
current_user: Annotated[any, Depends(auth.get_current_user)]
|
179 |
-
) -> EntityNeighborsResponse:
|
180 |
-
"""
|
181 |
-
Получение соседних чанков для заданных сущностей.
|
182 |
-
|
183 |
-
Args:
|
184 |
-
request: Параметры запроса соседей
|
185 |
-
entity_service: Сервис для работы с сущностями
|
186 |
-
|
187 |
-
Returns:
|
188 |
-
Список сущностей с соседями
|
189 |
-
"""
|
190 |
-
try:
|
191 |
-
# Получаем объекты LinkerEntity по ID
|
192 |
-
entities = entity_service.chunk_repository.get_chunks_by_ids(request.entities)
|
193 |
-
|
194 |
-
if not entities:
|
195 |
-
raise HTTPException(
|
196 |
-
status_code=404,
|
197 |
-
detail="No entities found with provided IDs"
|
198 |
-
)
|
199 |
-
|
200 |
-
# Получаем соседние чанки
|
201 |
-
entities_with_neighbors = entity_service.add_neighboring_chunks(
|
202 |
-
entities,
|
203 |
-
max_distance=request.max_distance,
|
204 |
-
)
|
205 |
-
|
206 |
-
# Преобразуем LinkerEntity в строки
|
207 |
-
return EntityNeighborsResponse(
|
208 |
-
entities=[str(entity.id) for entity in entities_with_neighbors]
|
209 |
-
)
|
210 |
except Exception as e:
|
211 |
raise HTTPException(
|
212 |
-
status_code=500,
|
213 |
-
detail=f"Error getting neighboring chunks: {str(e)}"
|
214 |
)
|
215 |
|
216 |
|
@@ -218,7 +174,7 @@ async def get_neighboring_chunks(
|
|
218 |
async def get_entity_info(
|
219 |
dataset_id: int,
|
220 |
db: Annotated[Session, Depends(DI.get_db)],
|
221 |
-
current_user: Annotated[any, Depends(auth.get_current_user)]
|
222 |
) -> dict:
|
223 |
"""
|
224 |
Получить информацию о сущностях в датасете.
|
@@ -231,40 +187,65 @@ async def get_entity_info(
|
|
231 |
Returns:
|
232 |
dict: Информация о сущностях
|
233 |
"""
|
|
|
234 |
chunk_repository = ChunkRepository(db)
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
# Собираем статистику
|
241 |
stats = {
|
242 |
-
"total_entities":
|
243 |
-
"
|
244 |
-
|
245 |
-
|
246 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
"entities_per_type": {
|
248 |
-
t: len([e for e in
|
249 |
-
for t in set(e.type for e in
|
250 |
-
}
|
251 |
}
|
252 |
-
|
253 |
-
# Примеры сущностей
|
254 |
examples = [
|
255 |
{
|
256 |
-
"id": str(e.id),
|
257 |
"name": e.name,
|
258 |
"type": e.type,
|
259 |
-
"has_embedding":
|
260 |
-
"embedding_shape":
|
|
|
|
|
|
|
|
|
261 |
"text_length": len(e.text),
|
262 |
-
"in_search_text_length": len(e.in_search_text) if e.in_search_text else 0
|
263 |
}
|
264 |
-
|
|
|
265 |
]
|
266 |
-
|
267 |
-
return {
|
268 |
-
"stats": stats,
|
269 |
-
"examples": examples
|
270 |
-
}
|
|
|
1 |
from typing import Annotated
|
2 |
+
from uuid import UUID
|
3 |
|
4 |
import numpy as np
|
5 |
from fastapi import APIRouter, Depends, HTTPException
|
6 |
from sqlalchemy.orm import Session
|
7 |
|
|
|
8 |
import common.dependencies as DI
|
9 |
+
from common import auth
|
10 |
from components.dbo.chunk_repository import ChunkRepository
|
11 |
from components.services.entity import EntityService
|
12 |
+
from schemas.entity import (
|
13 |
+
ChunkInfo,
|
14 |
+
EntitySearchRequest,
|
15 |
+
EntitySearchResponse,
|
16 |
+
EntitySearchWithTextRequest,
|
17 |
+
EntitySearchWithTextResponse,
|
18 |
+
EntityTextRequest,
|
19 |
+
EntityTextResponse,
|
20 |
+
)
|
21 |
|
22 |
router = APIRouter(prefix="/entity", tags=["Entity"])
|
23 |
|
|
|
26 |
async def search_entities(
|
27 |
request: EntitySearchRequest,
|
28 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
29 |
+
current_user: Annotated[any, Depends(auth.get_current_user)],
|
30 |
) -> EntitySearchResponse:
|
31 |
"""
|
32 |
Поиск похожих сущностей по векторному сходству (только ID).
|
33 |
+
|
34 |
Args:
|
35 |
request: Параметры поиска
|
36 |
entity_service: Сервис для работы с сущностями
|
37 |
+
|
38 |
Returns:
|
39 |
Результаты поиска (ID и оценки), отсортированные по убыванию сходства
|
40 |
"""
|
41 |
try:
|
42 |
+
_, scores, ids = entity_service.search_similar_old(
|
43 |
request.query,
|
44 |
request.dataset_id,
|
45 |
)
|
46 |
+
|
47 |
# Проверяем, что scores и ids - корректные numpy массивы
|
48 |
if not isinstance(scores, np.ndarray):
|
49 |
scores = np.array(scores)
|
50 |
if not isinstance(ids, np.ndarray):
|
51 |
ids = np.array(ids)
|
52 |
+
|
53 |
# Сортируем результаты по убыванию оценок
|
54 |
# Проверим, что массивы не пустые
|
55 |
if len(scores) > 0:
|
|
|
61 |
else:
|
62 |
sorted_scores = []
|
63 |
sorted_ids = []
|
64 |
+
|
65 |
return EntitySearchResponse(
|
66 |
scores=sorted_scores,
|
67 |
entity_ids=sorted_ids,
|
68 |
)
|
69 |
except Exception as e:
|
70 |
raise HTTPException(
|
71 |
+
status_code=500, detail=f"Error during entity search: {str(e)}"
|
|
|
72 |
)
|
73 |
|
74 |
|
|
|
76 |
async def search_entities_with_text(
|
77 |
request: EntitySearchWithTextRequest,
|
78 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
79 |
+
current_user: Annotated[any, Depends(auth.get_current_user)],
|
80 |
) -> EntitySearchWithTextResponse:
|
81 |
"""
|
82 |
Поиск похожих сущностей по векторному сходству с возвратом текстов.
|
83 |
+
|
84 |
Args:
|
85 |
request: Параметры поиска
|
86 |
entity_service: Сервис для работы с сущностями
|
87 |
+
|
88 |
Returns:
|
89 |
Результаты поиска с текстами чанков, отсортированные по убыванию сходства
|
90 |
"""
|
91 |
try:
|
92 |
# Получаем результаты поиска
|
93 |
+
_, scores, entity_ids = entity_service.search_similar_old(
|
94 |
+
request.query, request.dataset_id
|
|
|
95 |
)
|
96 |
+
|
97 |
# Проверяем, что scores и entity_ids - корректные numpy массивы
|
98 |
if not isinstance(scores, np.ndarray):
|
99 |
scores = np.array(scores)
|
100 |
if not isinstance(entity_ids, np.ndarray):
|
101 |
entity_ids = np.array(entity_ids)
|
102 |
+
|
103 |
# Сортируем результаты по убыванию оценок
|
104 |
# Проверим, что массивы не пустые
|
105 |
if len(scores) > 0:
|
106 |
# Преобразуем индексы в список, чтобы избежать проблем с индексацией
|
107 |
sorted_indices = scores.argsort()[::-1].tolist()
|
108 |
sorted_scores = [float(scores[i]) for i in sorted_indices]
|
109 |
+
sorted_ids = [UUID(entity_ids[i]) for i in sorted_indices]
|
110 |
+
|
111 |
# Получаем тексты чанков
|
112 |
+
chunks = entity_service.chunk_repository.get_entities_by_ids(sorted_ids)
|
113 |
+
|
114 |
# Формируем ответ
|
115 |
return EntitySearchWithTextResponse(
|
116 |
chunks=[
|
117 |
+
ChunkInfo(
|
118 |
+
id=str(chunk.id), # Преобразуем UUID в строку
|
119 |
+
text=chunk.text,
|
120 |
+
score=score,
|
121 |
+
type=chunk.type,
|
122 |
+
in_search_text=chunk.in_search_text,
|
123 |
+
)
|
124 |
for chunk, score in zip(chunks, sorted_scores)
|
125 |
]
|
126 |
)
|
127 |
else:
|
128 |
return EntitySearchWithTextResponse(chunks=[])
|
129 |
+
|
130 |
except Exception as e:
|
131 |
raise HTTPException(
|
132 |
+
status_code=500, detail=f"Error during entity search with text: {str(e)}"
|
|
|
133 |
)
|
134 |
|
135 |
|
|
|
137 |
async def build_entity_text(
|
138 |
request: EntityTextRequest,
|
139 |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
|
140 |
+
current_user: Annotated[any, Depends(auth.get_current_user)],
|
141 |
) -> EntityTextResponse:
|
142 |
"""
|
143 |
Сборка текста из сущностей.
|
144 |
+
|
145 |
Args:
|
146 |
request: Параметры сборки текста
|
147 |
entity_service: Сервис для работы с сущностями
|
148 |
+
|
149 |
Returns:
|
150 |
Собранный текст
|
151 |
"""
|
152 |
try:
|
153 |
+
if not request.entities:
|
|
|
|
|
|
|
154 |
raise HTTPException(
|
155 |
+
status_code=404, detail="No entities found with provided IDs"
|
|
|
156 |
)
|
157 |
+
|
158 |
# Собираем текст
|
159 |
text = entity_service.build_text(
|
160 |
+
entities=request.entities,
|
161 |
chunk_scores=request.chunk_scores,
|
162 |
include_tables=request.include_tables,
|
163 |
max_documents=request.max_documents,
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
return EntityTextResponse(text=text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
except Exception as e:
|
168 |
raise HTTPException(
|
169 |
+
status_code=500, detail=f"Error building entity text: {str(e)}"
|
|
|
170 |
)
|
171 |
|
172 |
|
|
|
174 |
async def get_entity_info(
|
175 |
dataset_id: int,
|
176 |
db: Annotated[Session, Depends(DI.get_db)],
|
177 |
+
current_user: Annotated[any, Depends(auth.get_current_user)],
|
178 |
) -> dict:
|
179 |
"""
|
180 |
Получить информацию о сущностях в датасете.
|
|
|
187 |
Returns:
|
188 |
dict: Информация о сущностях
|
189 |
"""
|
190 |
+
# Создаем репозиторий, передавая sessionmaker
|
191 |
chunk_repository = ChunkRepository(db)
|
192 |
+
|
193 |
+
# Получаем общее количество сущностей
|
194 |
+
total_entities_count = chunk_repository.count_entities_by_dataset_id(dataset_id)
|
195 |
+
|
196 |
+
# Получаем сущности, готовые к поиску (с текстом и эмбеддингом)
|
197 |
+
searchable_entities, searchable_embeddings = (
|
198 |
+
chunk_repository.get_searching_entities(dataset_id)
|
199 |
+
)
|
200 |
+
|
201 |
+
# Проверка, найдены ли сущности, готовые к поиску
|
202 |
+
# Можно оставить проверку, чтобы не возвращать пустые примеры, если таких нет,
|
203 |
+
# но основная ошибка 404 должна базироваться на total_entities_count
|
204 |
+
if total_entities_count == 0:
|
205 |
+
raise HTTPException(
|
206 |
+
status_code=404, detail=f"No entities found for dataset {dataset_id}"
|
207 |
+
)
|
208 |
+
|
209 |
# Собираем статистику
|
210 |
stats = {
|
211 |
+
"total_entities": total_entities_count, # Реальное общее число
|
212 |
+
"searchable_entities": len(
|
213 |
+
searchable_entities
|
214 |
+
), # Число сущностей с текстом и эмбеддингом
|
215 |
+
"entities_with_embeddings": len(
|
216 |
+
[e for e in searchable_embeddings if e is not None]
|
217 |
+
),
|
218 |
+
"embedding_shapes": [
|
219 |
+
e.shape if e is not None else None for e in searchable_embeddings
|
220 |
+
],
|
221 |
+
"unique_embedding_shapes": set(
|
222 |
+
str(e.shape) if e is not None else None for e in searchable_embeddings
|
223 |
+
),
|
224 |
+
# Статистику по типам лучше считать на основе searchable_entities, т.к. для них есть объекты
|
225 |
+
"entity_types": set(e.type for e in searchable_entities),
|
226 |
"entities_per_type": {
|
227 |
+
t: len([e for e in searchable_entities if e.type == t])
|
228 |
+
for t in set(e.type for e in searchable_entities)
|
229 |
+
},
|
230 |
}
|
231 |
+
|
232 |
+
# Примеры сущностей берем из searchable_entities
|
233 |
examples = [
|
234 |
{
|
235 |
+
"id": str(e.id),
|
236 |
"name": e.name,
|
237 |
"type": e.type,
|
238 |
+
"has_embedding": searchable_embeddings[i] is not None,
|
239 |
+
"embedding_shape": (
|
240 |
+
str(searchable_embeddings[i].shape)
|
241 |
+
if searchable_embeddings[i] is not None
|
242 |
+
else None
|
243 |
+
),
|
244 |
"text_length": len(e.text),
|
245 |
+
"in_search_text_length": len(e.in_search_text) if e.in_search_text else 0,
|
246 |
}
|
247 |
+
# Берем примеры из сущностей, готовых к поиску
|
248 |
+
for i, e in enumerate(searchable_entities[:5])
|
249 |
]
|
250 |
+
|
251 |
+
return {"stats": stats, "examples": examples}
|
|
|
|
|
|
routes/llm.py
CHANGED
@@ -2,21 +2,20 @@ import json
|
|
2 |
import logging
|
3 |
import os
|
4 |
from typing import Annotated, AsyncGenerator, List, Optional
|
5 |
-
from uuid import UUID
|
6 |
|
7 |
-
from common import auth
|
8 |
-
from components.services.dialogue import DialogueService
|
9 |
-
from fastapi.responses import StreamingResponse
|
10 |
-
|
11 |
-
from components.services.dataset import DatasetService
|
12 |
-
from components.services.entity import EntityService
|
13 |
from fastapi import APIRouter, Depends, HTTPException
|
|
|
14 |
|
15 |
import common.dependencies as DI
|
16 |
-
from common
|
17 |
-
from
|
|
|
|
|
18 |
from components.llm.deepinfra_api import DeepInfraApi
|
19 |
from components.llm.utils import append_llm_response_to_history
|
|
|
|
|
|
|
20 |
from components.services.llm_config import LLMConfigService
|
21 |
from components.services.llm_prompt import LlmPromptService
|
22 |
|
@@ -71,13 +70,16 @@ def insert_search_results_to_message(
|
|
71 |
return False
|
72 |
|
73 |
def try_insert_search_results(
|
74 |
-
chat_request: ChatRequest, search_results: str, entities: List[str]
|
75 |
) -> bool:
|
|
|
76 |
for msg in reversed(chat_request.history):
|
77 |
if msg.role == "user" and not msg.searchResults:
|
78 |
-
msg.searchResults = search_results
|
79 |
-
msg.searchEntities = entities
|
80 |
-
|
|
|
|
|
81 |
return False
|
82 |
|
83 |
def collapse_history_to_first_message(chat_request: ChatRequest) -> ChatRequest:
|
@@ -132,21 +134,25 @@ async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prom
|
|
132 |
dataset = dataset_service.get_current_dataset()
|
133 |
if dataset is None:
|
134 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
search_results_event = {
|
139 |
"event": "search_results",
|
140 |
"data": {
|
141 |
"text": text_chunks,
|
142 |
-
"ids": chunk_ids
|
143 |
}
|
144 |
}
|
145 |
yield f"data: {json.dumps(search_results_event, ensure_ascii=False)}\n\n"
|
146 |
|
147 |
# new_message = f'<search-results>\n{text_chunks}\n</search-results>\n{last_query.content}'
|
148 |
|
149 |
-
try_insert_search_results(request,
|
150 |
except Exception as e:
|
151 |
logger.error(f"Error in SSE chat stream while searching: {str(e)}", stack_info=True)
|
152 |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
|
@@ -245,9 +251,12 @@ async def chat(
|
|
245 |
if dataset is None:
|
246 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
247 |
logger.info(f"qe_result.search_query: {qe_result.search_query}")
|
248 |
-
|
|
|
|
|
|
|
249 |
|
250 |
-
chunks = entity_service.chunk_repository.
|
251 |
|
252 |
logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}")
|
253 |
logger.info(f"scores: {scores[:3]}...{scores[-3:]}")
|
|
|
2 |
import logging
|
3 |
import os
|
4 |
from typing import Annotated, AsyncGenerator, List, Optional
|
|
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from fastapi import APIRouter, Depends, HTTPException
|
7 |
+
from fastapi.responses import StreamingResponse
|
8 |
|
9 |
import common.dependencies as DI
|
10 |
+
from common import auth
|
11 |
+
from common.configuration import Configuration
|
12 |
+
from components.llm.common import (ChatRequest, LlmParams, LlmPredictParams,
|
13 |
+
Message)
|
14 |
from components.llm.deepinfra_api import DeepInfraApi
|
15 |
from components.llm.utils import append_llm_response_to_history
|
16 |
+
from components.services.dataset import DatasetService
|
17 |
+
from components.services.dialogue import DialogueService
|
18 |
+
from components.services.entity import EntityService
|
19 |
from components.services.llm_config import LLMConfigService
|
20 |
from components.services.llm_prompt import LlmPromptService
|
21 |
|
|
|
70 |
return False
|
71 |
|
72 |
def try_insert_search_results(
|
73 |
+
chat_request: ChatRequest, search_results: List[str], entities: List[List[str]]
|
74 |
) -> bool:
|
75 |
+
i = 0
|
76 |
for msg in reversed(chat_request.history):
|
77 |
if msg.role == "user" and not msg.searchResults:
|
78 |
+
msg.searchResults = search_results[i]
|
79 |
+
msg.searchEntities = entities[i]
|
80 |
+
i += 1
|
81 |
+
if i == len(search_results):
|
82 |
+
return True
|
83 |
return False
|
84 |
|
85 |
def collapse_history_to_first_message(chat_request: ChatRequest) -> ChatRequest:
|
|
|
134 |
dataset = dataset_service.get_current_dataset()
|
135 |
if dataset is None:
|
136 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
137 |
+
previous_entities = [msg.searchEntities for msg in request.history if msg.searchEntities is not None]
|
138 |
+
previous_entities, chunk_ids, scores = entity_service.search_similar(qe_result.search_query,
|
139 |
+
dataset.id, previous_entities)
|
140 |
+
text_chunks = entity_service.build_text(chunk_ids, scores)
|
141 |
+
all_text_chunks = [text_chunks] + [entity_service.build_text(entities) for entities in previous_entities]
|
142 |
+
all_entities = [chunk_ids] + previous_entities
|
143 |
+
|
144 |
search_results_event = {
|
145 |
"event": "search_results",
|
146 |
"data": {
|
147 |
"text": text_chunks,
|
148 |
+
"ids": chunk_ids
|
149 |
}
|
150 |
}
|
151 |
yield f"data: {json.dumps(search_results_event, ensure_ascii=False)}\n\n"
|
152 |
|
153 |
# new_message = f'<search-results>\n{text_chunks}\n</search-results>\n{last_query.content}'
|
154 |
|
155 |
+
try_insert_search_results(request, all_text_chunks, all_entities)
|
156 |
except Exception as e:
|
157 |
logger.error(f"Error in SSE chat stream while searching: {str(e)}", stack_info=True)
|
158 |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
|
|
|
251 |
if dataset is None:
|
252 |
raise HTTPException(status_code=400, detail="Dataset not found")
|
253 |
logger.info(f"qe_result.search_query: {qe_result.search_query}")
|
254 |
+
previous_entities = [msg.searchEntities for msg in request.history]
|
255 |
+
previous_entities, chunk_ids, scores = entity_service.search_similar(
|
256 |
+
qe_result.search_query, dataset.id, previous_entities
|
257 |
+
)
|
258 |
|
259 |
+
chunks = entity_service.chunk_repository.get_entities_by_ids(chunk_ids)
|
260 |
|
261 |
logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}")
|
262 |
logger.info(f"scores: {scores[:3]}...{scores[-3:]}")
|