Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
""" | |
Скрипт для поиска по векторизованным документам через API. | |
Этот скрипт: | |
1. Считывает все документы из заданной папки с помощью UniversalParser | |
2. Чанкит каждый документ через Destructurer с fixed_size-стратегией | |
3. Векторизует поле in_search_text через BGE-модель | |
4. Поднимает FastAPI с двумя эндпоинтами: | |
- /search/entities - возвращает найденные сущности списком словарей | |
- /search/text - возвращает полноценный собранный текст | |
""" | |
import logging | |
import os | |
from pathlib import Path | |
from typing import Dict, List, Optional | |
import numpy as np | |
import pandas as pd | |
import torch | |
import uvicorn | |
from fastapi import FastAPI, Query | |
from ntr_fileparser import UniversalParser | |
from pydantic import BaseModel | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import AutoModel, AutoTokenizer | |
from ntr_text_fragmentation.chunking.specific_strategies.fixed_size_chunking import \ | |
FixedSizeChunkingStrategy | |
from ntr_text_fragmentation.core.destructurer import Destructurer | |
from ntr_text_fragmentation.core.entity_repository import \ | |
InMemoryEntityRepository | |
from ntr_text_fragmentation.core.injection_builder import InjectionBuilder | |
from ntr_text_fragmentation.models.linker_entity import LinkerEntity | |
# Константы | |
DOCS_FOLDER = "../data/docs" # Путь к папке с документами | |
MODEL_NAME = "BAAI/bge-m3" # Название модели для векторизации | |
BATCH_SIZE = 16 # Размер батча для векторизации | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # Устройство для вычислений | |
MAX_ENTITIES = 100 # Максимальное количество возвращаемых сущностей | |
WORDS_PER_CHUNK = 50 # Количество слов в чанке для fixed_size стратегии | |
OVERLAP_WORDS = 25 # Количество слов перекрытия для fixed_size стратегии | |
# Пути к кэшированным файлам | |
CACHE_DIR = "../data/cache" # Путь к папке с кэшированными данными | |
ENTITIES_CSV = os.path.join(CACHE_DIR, "entities.csv") # Путь к CSV с сущностями | |
EMBEDDINGS_NPY = os.path.join(CACHE_DIR, "embeddings.npy") # Путь к массиву эмбеддингов | |
# Инициализация FastAPI | |
app = FastAPI(title="Документный поиск API", | |
description="API для поиска по векторизованным документам") | |
# Глобальные переменные для хранения данных | |
entities_df = None | |
entity_embeddings = None | |
model = None | |
tokenizer = None | |
entity_repository = None | |
injection_builder = None | |
class EntityResponse(BaseModel): | |
"""Модель ответа для сущностей.""" | |
id: str | |
name: str | |
text: str | |
type: str | |
score: float | |
doc_name: Optional[str] = None | |
metadata: Optional[Dict] = None | |
class TextResponse(BaseModel): | |
"""Модель ответа для собранного текста.""" | |
text: str | |
entities_count: int | |
class TextsResponse(BaseModel): | |
"""Модель ответа для списка текстов.""" | |
texts: List[str] | |
entities_count: int | |
def setup_logging() -> None: | |
"""Настройка логгирования.""" | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
def load_documents(folder_path: str) -> Dict: | |
""" | |
Загружает все документы из указанной папки. | |
Args: | |
folder_path: Путь к папке с документами | |
Returns: | |
Словарь {имя_файла: parsed_document} | |
""" | |
logging.info(f"Чтение документов из {folder_path}...") | |
parser = UniversalParser() | |
documents = {} | |
# Проверка существования папки | |
if not os.path.exists(folder_path): | |
logging.error(f"Папка {folder_path} не существует!") | |
return {} | |
for file_path in Path(folder_path).glob("**/*.docx"): | |
try: | |
doc_name = file_path.stem | |
logging.info(f"Обработка документа: {doc_name}") | |
documents[doc_name] = parser.parse_by_path(str(file_path)) | |
except Exception as e: | |
logging.error(f"Ошибка при чтении файла {file_path}: {e}") | |
logging.info(f"Загружено {len(documents)} документов.") | |
return documents | |
def process_documents(documents: Dict) -> List[LinkerEntity]: | |
""" | |
Обрабатывает документы, применяя fixed_size стратегию чанкинга. | |
Args: | |
documents: Словарь с распарсенными документами | |
Returns: | |
Список сущностей из всех документов | |
""" | |
logging.info("Применение fixed_size стратегии чанкинга ко всем документам...") | |
all_entities = [] | |
for doc_name, document in documents.items(): | |
try: | |
# Создаем Destructurer с fixed_size стратегией | |
destructurer = Destructurer( | |
document, | |
strategy_name="fixed_size", | |
words_per_chunk=WORDS_PER_CHUNK, | |
overlap_words=OVERLAP_WORDS | |
) | |
# Получаем сущности | |
doc_entities = destructurer.destructure() | |
# Добавляем имя документа в метаданные всех сущностей | |
for entity in doc_entities: | |
if not hasattr(entity, 'metadata') or entity.metadata is None: | |
entity.metadata = {} | |
entity.metadata['doc_name'] = doc_name | |
all_entities.extend(doc_entities) | |
logging.info(f"Документ {doc_name}: получено {len(doc_entities)} сущностей") | |
except Exception as e: | |
logging.error(f"Ошибка при обработке документа {doc_name}: {e}") | |
logging.info(f"Всего получено {len(all_entities)} сущностей из всех документов") | |
return all_entities | |
def entities_to_dataframe(entities: List[LinkerEntity]) -> pd.DataFrame: | |
""" | |
Преобразует список сущностей в DataFrame для удобной работы. | |
Args: | |
entities: Список сущностей | |
Returns: | |
DataFrame с данными сущностей | |
""" | |
data = [] | |
for entity in entities: | |
# Получаем имя документа из метаданных | |
doc_name = entity.metadata.get('doc_name', '') if hasattr(entity, 'metadata') and entity.metadata else '' | |
# Базовые поля для всех типов сущностей | |
entity_dict = { | |
"id": str(entity.id), | |
"type": entity.type, | |
"name": entity.name, | |
"text": entity.text, | |
"in_search_text": entity.in_search_text, | |
"doc_name": doc_name, | |
"source_id": entity.source_id if hasattr(entity, 'source_id') else None, | |
"target_id": entity.target_id if hasattr(entity, 'target_id') else None, | |
"metadata": entity.metadata if hasattr(entity, 'metadata') else {}, | |
} | |
data.append(entity_dict) | |
df = pd.DataFrame(data) | |
return df | |
def setup_model_and_tokenizer(): | |
""" | |
Инициализирует модель и токенизатор для векторизации. | |
Returns: | |
Кортеж (модель, токенизатор) | |
""" | |
global model, tokenizer | |
logging.info(f"Загрузка модели {MODEL_NAME} на устройство {DEVICE}...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) | |
model.eval() | |
return model, tokenizer | |
def _average_pool( | |
last_hidden_states: torch.Tensor, | |
attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Расчёт усредненного эмбеддинга по всем токенам | |
Args: | |
last_hidden_states: Матрица эмбеддингов отдельных токенов | |
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены | |
Returns: | |
Усредненный эмбеддинг | |
""" | |
last_hidden = last_hidden_states.masked_fill( | |
~attention_mask[..., None].bool(), 0.0 | |
) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def get_embeddings(texts: List[str]) -> np.ndarray: | |
""" | |
Получает эмбеддинги для списка текстов. | |
Args: | |
texts: Список текстов для векторизации | |
Returns: | |
Массив эмбеддингов | |
""" | |
global model, tokenizer | |
# Проверяем, что модель и токенизатор инициализированы | |
if model is None or tokenizer is None: | |
model, tokenizer = setup_model_and_tokenizer() | |
all_embeddings = [] | |
for i in range(0, len(texts), BATCH_SIZE): | |
batch_texts = texts[i:i+BATCH_SIZE] | |
# Фильтруем None и пустые строки | |
batch_texts = [text for text in batch_texts if text] | |
if not batch_texts: | |
continue | |
# Токенизация с обрезкой и padding | |
encoding = tokenizer( | |
batch_texts, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
).to(DEVICE) | |
# Получаем эмбеддинги с average pooling | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"]) | |
all_embeddings.append(embeddings.cpu().numpy()) | |
if not all_embeddings: | |
return np.array([]) | |
return np.vstack(all_embeddings) | |
def init_entity_repository_and_builder(entities: List[LinkerEntity]): | |
""" | |
Инициализирует хранилище сущностей и сборщик инъекций. | |
Args: | |
entities: Список сущностей | |
""" | |
global entity_repository, injection_builder | |
# Создаем хранилище сущностей | |
entity_repository = InMemoryEntityRepository(entities) | |
# Добавляем метод get_entity_by_id в InMemoryEntityRepository | |
# Это временное решение, в идеале нужно добавить этот метод в сам класс | |
def get_entity_by_id(self, entity_id): | |
"""Получает сущность по ID""" | |
for entity in self.entities: | |
if str(entity.id) == entity_id: | |
return entity | |
return None | |
# Добавляем метод в класс | |
InMemoryEntityRepository.get_entity_by_id = get_entity_by_id | |
# Создаем сборщик инъекций | |
injection_builder = InjectionBuilder(repository=entity_repository) | |
# Регистрируем стратегию | |
injection_builder.register_strategy("fixed_size", FixedSizeChunkingStrategy) | |
def search_entities(query: str, top_n: int = MAX_ENTITIES) -> List[Dict]: | |
""" | |
Ищет сущности по запросу на основе косинусной близости. | |
Args: | |
query: Поисковый запрос | |
top_n: Максимальное количество возвращаемых сущностей | |
Returns: | |
Список найденных сущностей с их скорами | |
""" | |
global entities_df, entity_embeddings | |
# Проверяем наличие данных | |
if entities_df is None or entity_embeddings is None: | |
logging.error("Данные не инициализированы. Запустите сначала prepare_data().") | |
return [] | |
# Векторизуем запрос | |
query_embedding = get_embeddings([query]) | |
if query_embedding.size == 0: | |
return [] | |
# Считаем косинусную близость | |
similarities = cosine_similarity(query_embedding, entity_embeddings)[0] | |
# Получаем индексы топ-N сущностей | |
top_indices = np.argsort(similarities)[-top_n:][::-1] | |
# Фильтруем сущности, которые используются для поиска | |
search_df = entities_df.copy() | |
search_df = search_df[search_df['in_search_text'].notna()] | |
# Если после фильтрации нет данных, возвращаем пустой список | |
if search_df.empty: | |
return [] | |
# Получаем топ-N сущностей | |
results = [] | |
for idx in top_indices: | |
if idx >= len(search_df): | |
continue | |
entity = search_df.iloc[idx] | |
similarity = similarities[idx] | |
# Создаем результат | |
result = { | |
"id": entity["id"], | |
"name": entity["name"], | |
"text": entity["text"], | |
"type": entity["type"], | |
"score": float(similarity), | |
"doc_name": entity["doc_name"], | |
"metadata": entity["metadata"] | |
} | |
results.append(result) | |
return results | |
async def api_search_entities( | |
query: str = Query(..., description="Поисковый запрос"), | |
limit: int = Query(MAX_ENTITIES, description="Максимальное количество результатов") | |
): | |
""" | |
Эндпоинт для поиска сущностей по запросу. | |
Args: | |
query: Поисковый запрос | |
limit: Максимальное количество результатов | |
Returns: | |
Список найденных сущностей | |
""" | |
results = search_entities(query, limit) | |
return results | |
async def api_search_text( | |
query: str = Query(..., description="Поисковый запрос"), | |
limit: int = Query(MAX_ENTITIES, description="Максимальное количество учитываемых сущностей") | |
): | |
""" | |
Эндпоинт для поиска и сборки полного текста по запросу. | |
Args: | |
query: Поисковый запрос | |
limit: Максимальное количество учитываемых сущностей | |
Returns: | |
Собранный текст и количество использованных сущностей | |
""" | |
global injection_builder | |
# Проверяем наличие сборщика инъекций | |
if injection_builder is None: | |
logging.error("Сборщик инъекций не инициализирован.") | |
return {"text": "", "entities_count": 0} | |
# Получаем найденные сущности | |
entity_results = search_entities(query, limit) | |
if not entity_results: | |
return {"text": "", "entities_count": 0} | |
# Получаем список ID сущностей | |
entity_ids = [str(result["id"]) for result in entity_results] | |
# Собираем текст, используя напрямую ID | |
try: | |
assembled_text = injection_builder.build(entity_ids) | |
print('Всё ок прошло вроде бы') | |
return {"text": assembled_text, "entities_count": len(entity_ids)} | |
except ImportError as e: | |
# Обработка ошибки импорта модулей для работы с изображениями | |
logging.error(f"Ошибка импорта при сборке текста: {e}") | |
# Альтернативная сборка текста без использования injection_builder | |
simple_text = "\n\n".join([result["text"] for result in entity_results if result.get("text")]) | |
return {"text": simple_text, "entities_count": len(entity_ids)} | |
except Exception as e: | |
logging.error(f"Ошибка при сборке текста: {e}") | |
return {"text": "", "entities_count": 0} | |
async def api_search_texts( | |
query: str = Query(..., description="Поисковый запрос"), | |
limit: int = Query(MAX_ENTITIES, description="Максимальное количество результатов") | |
): | |
""" | |
Эндпоинт для поиска списка текстов сущностей по запросу. | |
Args: | |
query: Поисковый запрос | |
limit: Максимальное количество результатов | |
Returns: | |
Список текстов найденных сущностей и их количество | |
""" | |
# Получаем найденные сущности | |
entity_results = search_entities(query, limit) | |
if not entity_results: | |
return {"texts": [], "entities_count": 0} | |
# Извлекаем тексты из результатов | |
texts = [result["text"] for result in entity_results if result.get("text")] | |
return {"texts": texts, "entities_count": len(texts)} | |
async def api_search_text_test( | |
query: str = Query(..., description="Поисковый запрос"), | |
limit: int = Query(MAX_ENTITIES, description="Максимальное количество учитываемых сущностей") | |
): | |
""" | |
Тестовый эндпоинт для поиска и сборки текста с использованием подхода из test_chunking_visualization.py. | |
Args: | |
query: Поисковый запрос | |
limit: Максимальное количество учитываемых сущностей | |
Returns: | |
Собранный текст и количество использованных сущностей | |
""" | |
global entity_repository, injection_builder | |
# Проверяем наличие репозитория и сборщика инъекций | |
if entity_repository is None or injection_builder is None: | |
logging.error("Репозиторий или сборщик инъекций не инициализированы.") | |
return {"text": "", "entities_count": 0} | |
# Получаем найденные сущности | |
entity_results = search_entities(query, limit) | |
if not entity_results: | |
return {"text": "", "entities_count": 0} | |
try: | |
# Получаем объекты сущностей из репозитория по ID | |
entity_ids = [result["id"] for result in entity_results] | |
entities = [] | |
for entity_id in entity_ids: | |
entity = entity_repository.get_entity_by_id(entity_id) | |
if entity: | |
entities.append(entity) | |
logging.info(f"Найдено {len(entities)} объектов сущностей по ID") | |
if not entities: | |
logging.error("Не удалось найти сущности в репозитории") | |
# Собираем простой текст из результатов поиска | |
simple_text = "\n\n".join([result["text"] for result in entity_results if result.get("text")]) | |
return {"text": simple_text, "entities_count": len(entity_results)} | |
# Собираем текст, как в test_chunking_visualization.py | |
assembled_text = injection_builder.build(entities) # Передаем сами объекты | |
return {"text": assembled_text, "entities_count": len(entities)} | |
except Exception as e: | |
logging.error(f"Ошибка при сборке текста: {e}", exc_info=True) | |
# Запасной вариант - просто соединяем тексты | |
fallback_text = "\n\n".join([result["text"] for result in entity_results if result.get("text")]) | |
return {"text": fallback_text, "entities_count": len(entity_results)} | |
def save_entities_to_csv(entities: List[LinkerEntity], csv_path: str) -> None: | |
""" | |
Сохраняет сущности в CSV файл. | |
Args: | |
entities: Список сущностей | |
csv_path: Путь для сохранения CSV файла | |
""" | |
logging.info(f"Сохранение {len(entities)} сущностей в {csv_path}") | |
# Создаем директорию, если она не существует | |
os.makedirs(os.path.dirname(csv_path), exist_ok=True) | |
# Преобразуем сущности в DataFrame и сохраняем | |
df = entities_to_dataframe(entities) | |
df.to_csv(csv_path, index=False) | |
logging.info(f"Сохранено {len(entities)} сущностей в {csv_path}") | |
def load_entities_from_csv(csv_path: str) -> List[LinkerEntity]: | |
""" | |
Загружает сущности из CSV файла. | |
Args: | |
csv_path: Путь к CSV файлу | |
Returns: | |
Список сущностей | |
""" | |
logging.info(f"Загрузка сущностей из {csv_path}") | |
if not os.path.exists(csv_path): | |
logging.error(f"Файл {csv_path} не найден") | |
return [] | |
df = pd.read_csv(csv_path) | |
entities = [] | |
for _, row in df.iterrows(): | |
# Обработка метаданных | |
metadata = row.get("metadata", {}) | |
if isinstance(metadata, str): | |
try: | |
metadata = eval(metadata) if metadata and not pd.isna(metadata) else {} | |
except: | |
metadata = {} | |
# Общие поля для всех типов сущностей | |
common_args = { | |
"id": row["id"], | |
"name": row["name"] if not pd.isna(row.get("name", "")) else "", | |
"text": row["text"] if not pd.isna(row.get("text", "")) else "", | |
"metadata": metadata, | |
"type": row["type"], | |
} | |
# Добавляем in_search_text, если он есть | |
if "in_search_text" in row and not pd.isna(row["in_search_text"]): | |
common_args["in_search_text"] = row["in_search_text"] | |
# Добавляем поля связи, если они есть | |
if "source_id" in row and not pd.isna(row["source_id"]): | |
common_args["source_id"] = row["source_id"] | |
common_args["target_id"] = row["target_id"] | |
if "number_in_relation" in row and not pd.isna(row["number_in_relation"]): | |
common_args["number_in_relation"] = int(row["number_in_relation"]) | |
entity = LinkerEntity(**common_args) | |
entities.append(entity) | |
logging.info(f"Загружено {len(entities)} сущностей из {csv_path}") | |
return entities | |
def save_embeddings(embeddings: np.ndarray, file_path: str) -> None: | |
""" | |
Сохраняет эмбеддинги в numpy файл. | |
Args: | |
embeddings: Массив эмбеддингов | |
file_path: Путь для сохранения файла | |
""" | |
logging.info(f"Сохранение эмбеддингов размером {embeddings.shape} в {file_path}") | |
# Создаем директорию, если она не существует | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
# Сохраняем эмбеддинги | |
np.save(file_path, embeddings) | |
logging.info(f"Эмбеддинги сохранены в {file_path}") | |
def load_embeddings(file_path: str) -> np.ndarray: | |
""" | |
Загружает эмбеддинги из numpy файла. | |
Args: | |
file_path: Путь к файлу | |
Returns: | |
Массив эмбеддингов | |
""" | |
logging.info(f"Загрузка эмбеддингов из {file_path}") | |
if not os.path.exists(file_path): | |
logging.error(f"Файл {file_path} не найден") | |
return np.array([]) | |
embeddings = np.load(file_path) | |
logging.info(f"Загружены эмбеддинги размером {embeddings.shape}") | |
return embeddings | |
def prepare_data(): | |
""" | |
Подготавливает все необходимые данные для API. | |
""" | |
global entities_df, entity_embeddings, entity_repository, injection_builder | |
# Проверяем наличие кэшированных данных | |
cache_exists = os.path.exists(ENTITIES_CSV) and os.path.exists(EMBEDDINGS_NPY) | |
if cache_exists: | |
logging.info("Найдены кэшированные данные, загружаем их") | |
# Загружаем сущности из CSV | |
entities = load_entities_from_csv(ENTITIES_CSV) | |
if not entities: | |
logging.error("Не удалось загрузить сущности из кэша, генерируем заново") | |
cache_exists = False | |
else: | |
# Преобразуем сущности в DataFrame | |
entities_df = entities_to_dataframe(entities) | |
# Загружаем эмбеддинги | |
entity_embeddings = load_embeddings(EMBEDDINGS_NPY) | |
if entity_embeddings.size == 0: | |
logging.error("Не удалось загрузить эмбеддинги из кэша, генерируем заново") | |
cache_exists = False | |
else: | |
# Инициализируем хранилище и сборщик | |
init_entity_repository_and_builder(entities) | |
logging.info("Данные успешно загружены из кэша") | |
# Если кэшированных данных нет или их не удалось загрузить, генерируем заново | |
if not cache_exists: | |
logging.info("Кэшированные данные не найдены или не могут быть загружены, обрабатываем документы") | |
# Загружаем и обрабатываем документы | |
documents = load_documents(DOCS_FOLDER) | |
if not documents: | |
logging.error(f"Не найдено документов в папке {DOCS_FOLDER}") | |
return | |
# Получаем сущности из всех документов | |
all_entities = process_documents(documents) | |
if not all_entities: | |
logging.error("Не получено сущностей из документов") | |
return | |
# Преобразуем сущности в DataFrame | |
entities_df = entities_to_dataframe(all_entities) | |
# Инициализируем хранилище и сборщик | |
init_entity_repository_and_builder(all_entities) | |
# Фильтруем только сущности для поиска | |
search_df = entities_df[entities_df['in_search_text'].notna()] | |
if search_df.empty: | |
logging.error("Нет сущностей для поиска с in_search_text") | |
return | |
# Векторизуем тексты сущностей | |
search_texts = search_df['in_search_text'].tolist() | |
entity_embeddings = get_embeddings(search_texts) | |
logging.info(f"Подготовлено {len(search_df)} сущностей для поиска") | |
logging.info(f"Размер эмбеддингов: {entity_embeddings.shape}") | |
# Сохраняем данные в кэш для последующего использования | |
save_entities_to_csv(all_entities, ENTITIES_CSV) | |
save_embeddings(entity_embeddings, EMBEDDINGS_NPY) | |
logging.info("Данные сохранены в кэш для последующего использования") | |
# Вывод итоговой информации (независимо от источника данных) | |
logging.info(f"Подготовка данных завершена. Готово к использованию {entity_embeddings.shape[0]} сущностей") | |
async def startup_event(): | |
"""Запускается при старте приложения.""" | |
setup_logging() | |
prepare_data() | |
def main(): | |
"""Основная функция для запуска скрипта вручную.""" | |
setup_logging() | |
prepare_data() | |
# Запуск Uvicorn сервера | |
uvicorn.run(app, host="0.0.0.0", port=8017) | |
if __name__ == "__main__": | |
main() |