File size: 2,208 Bytes
308de05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging

import faiss
import numpy as np

from common.constants import DO_NORMALIZATION
from components.embedding_extraction import EmbeddingExtractor

logger = logging.getLogger(__name__)


class FaissVectorSearch:
    def __init__(
        self,
        model: EmbeddingExtractor,
        ids_to_embeddings: dict[str, np.ndarray],
    ):
        self.model = model
        self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
        self.__create_index(ids_to_embeddings)

    def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
        """Создает индекс для векторного поиска."""
        if len(ids_to_embeddings) == 0:
            self.index = None
            return
        embeddings = np.array(list(ids_to_embeddings.values()))
        dim = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dim)
        self.index.add(embeddings)

    def search_vectors(
        self,
        query: str,
        max_entities: int = 100,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Поиск векторов в индексе.

        Args:
            query: Строка, запрос для поиска.
            max_entities: Максимальное количество найденных сущностей.

        Returns:
            tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
                - np.ndarray: Вектор запроса (1, embedding_size)
                - np.ndarray: Оценки косинусного сходства (чем больше, тем лучше)
                - np.ndarray: Идентификаторы найденных векторов
        """
        logger.info(f"Searching vectors in index for query: {query}")
        if self.index is None:
            return (np.array([]), np.array([]), np.array([]))
        query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
        similarities, indexes = self.index.search(query_embeds, max_entities)
        ids = [self.index_to_id[index] for index in indexes[0]]
        return query_embeds, similarities[0], np.array(ids)