File size: 2,208 Bytes
57cf043
86c402d
57cf043
86c402d
57cf043
86c402d
57cf043
 
 
 
 
 
 
744a170
 
86c402d
57cf043
 
86c402d
 
57cf043
86c402d
744a170
86c402d
57cf043
 
86c402d
57cf043
86c402d
57cf043
 
744a170
 
 
 
 
57cf043
 
744a170
86c402d
 
744a170
86c402d
 
 
 
 
 
57cf043
 
 
 
 
744a170
86c402d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging

import faiss
import numpy as np

from common.constants import DO_NORMALIZATION
from components.embedding_extraction import EmbeddingExtractor

logger = logging.getLogger(__name__)


class FaissVectorSearch:
    def __init__(
        self,
        model: EmbeddingExtractor,
        ids_to_embeddings: dict[str, np.ndarray],
    ):
        self.model = model
        self.index_to_id = {i: id_ for i, id_ in enumerate(ids_to_embeddings.keys())}
        self.__create_index(ids_to_embeddings)

    def __create_index(self, ids_to_embeddings: dict[str, np.ndarray]):
        """Создает индекс для векторного поиска."""
        if len(ids_to_embeddings) == 0:
            self.index = None
            return
        embeddings = np.array(list(ids_to_embeddings.values()))
        dim = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dim)
        self.index.add(embeddings)

    def search_vectors(
        self,
        query: str,
        max_entities: int = 100,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Поиск векторов в индексе.

        Args:
            query: Строка, запрос для поиска.
            max_entities: Максимальное количество найденных сущностей.

        Returns:
            tuple[np.ndarray, np.ndarray, np.ndarray]: Кортеж из трех массивов:
                - np.ndarray: Вектор запроса (1, embedding_size)
                - np.ndarray: Оценки косинусного сходства (чем больше, тем лучше)
                - np.ndarray: Идентификаторы найденных векторов
        """
        logger.info(f"Searching vectors in index for query: {query}")
        if self.index is None:
            return (np.array([]), np.array([]), np.array([]))
        query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
        similarities, indexes = self.index.search(query_embeds, max_entities)
        ids = [self.index_to_id[index] for index in indexes[0]]
        return query_embeds, similarities[0], np.array(ids)