Spaces:
Sleeping
Sleeping
File size: 2,208 Bytes
57cf043 86c402d 57cf043 86c402d 57cf043 86c402d 57cf043 744a170 86c402d 57cf043 86c402d 57cf043 86c402d 744a170 86c402d 57cf043 86c402d 57cf043 86c402d 57cf043 744a170 57cf043 744a170 86c402d 744a170 86c402d 57cf043 744a170 86c402d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import logging
import faiss
import numpy as np
from common.constants import DO_NORMALIZATION
from components.embedding_extraction import EmbeddingExtractor
logger = logging.getLogger(__name__)
class FaissVectorSearch:
def __init__(
self,
model: EmbeddingExtractor,
ids_to_embeddings: dict[str, np.ndarray],
):
self.model = model
self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
self.__create_index(ids_to_embeddings)
def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
"""Создает индекс для векторного поиска."""
if len(ids_to_embeddings) == 0:
self.index = None
return
embeddings = np.array(list(ids_to_embeddings.values()))
dim = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dim)
self.index.add(embeddings)
def search_vectors(
self,
query: str,
max_entities: int = 100,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Поиск векторов в индексе.
Args:
query: Строка, запрос для поиска.
max_entities: Максимальное количество найденных сущностей.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
- np.ndarray: Вектор запроса (1, embedding_size)
- np.ndarray: Оценки косинусного сходства (чем больше, тем лучше)
- np.ndarray: Идентификаторы найденных векторов
"""
logger.info(f"Searching vectors in index for query: {query}")
if self.index is None:
return (np.array([]), np.array([]), np.array([]))
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
similarities, indexes = self.index.search(query_embeds, max_entities)
ids = [self.index_to_id[index] for index in indexes[0]]
return query_embeds, similarities[0], np.array(ids)
|