generic-chatbot-backend / components /nmd /faiss_vector_search.py
muryshev's picture
update
86c402d
raw
history blame
2.38 kB
import logging
import faiss
import numpy as np
from common.configuration import DataBaseConfiguration
from common.constants import DO_NORMALIZATION
from components.embedding_extraction import EmbeddingExtractor
logger = logging.getLogger(__name__)
class FaissVectorSearch:
def __init__(
self,
model: EmbeddingExtractor,
ids_to_embeddings: dict[str, np.ndarray],
config: DataBaseConfiguration,
):
self.model = model
self.config = config
self.path_to_metadata = config.faiss.path_to_metadata
if self.config.ranker.use_ranging:
self.k_neighbors = config.ranker.k_neighbors
else:
self.k_neighbors = config.search.vector_search.k_neighbors
self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
self.__create_index(ids_to_embeddings)
def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
"""Load the metadata file."""
if len(ids_to_embeddings) == 0:
self.index = None
return
embeddings = np.array(list(ids_to_embeddings.values()))
dim = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dim)
self.index.add(embeddings)
def search_vectors(self, query: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Поиск векторов в индексе.
Args:
query: Строка, запрос для поиска.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
- np.ndarray: Вектор запроса (1, embedding_size)
- np.ndarray: Оценки косинусного сходства (чем больше, тем лучше)
- np.ndarray: Идентификаторы найденных векторов
"""
logger.info(f"Searching vectors in index for query: {query}")
if self.index is None:
return (np.array([]), np.array([]), np.array([]))
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
similarities, indexes = self.index.search(query_embeds, self.k_neighbors)
ids = [self.index_to_id[index] for index in indexes[0]]
return query_embeds, similarities[0], np.array(ids)