generic-chatbot-backend / components /search /faiss_vector_search.py
muryshev's picture
update
308de05
import logging
import faiss
import numpy as np
from common.constants import DO_NORMALIZATION
from components.embedding_extraction import EmbeddingExtractor
logger = logging.getLogger(__name__)
class FaissVectorSearch:
def __init__(
self,
model: EmbeddingExtractor,
ids_to_embeddings: dict[str, np.ndarray],
):
self.model = model
self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
self.__create_index(ids_to_embeddings)
def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
"""Создает индекс для векторного поиска."""
if len(ids_to_embeddings) == 0:
self.index = None
return
embeddings = np.array(list(ids_to_embeddings.values()))
dim = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dim)
self.index.add(embeddings)
def search_vectors(
self,
query: str,
max_entities: int = 100,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Поиск векторов в индексе.
Args:
query: Строка, запрос для поиска.
max_entities: Максимальное количество найденных сущностей.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
- np.ndarray: Вектор запроса (1, embedding_size)
- np.ndarray: Оценки косинусного сходства (чем больше, тем лучше)
- np.ndarray: Идентификаторы найденных векторов
"""
logger.info(f"Searching vectors in index for query: {query}")
if self.index is None:
return (np.array([]), np.array([]), np.array([]))
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
similarities, indexes = self.index.search(query_embeds, max_entities)
ids = [self.index_to_id[index] for index in indexes[0]]
return query_embeds, similarities[0], np.array(ids)