generic-chatbot-backend / components /nmd /faiss_vector_search.py
muryshev's picture
init
57cf043
raw
history blame
1.74 kB
import logging
from typing import List
import numpy as np
import pandas as pd
import faiss
from common.constants import COLUMN_EMBEDDING
from common.constants import DO_NORMALIZATION
from common.configuration import DataBaseConfiguration
from components.embedding_extraction import EmbeddingExtractor
logger = logging.getLogger(__name__)
class FaissVectorSearch:
def __init__(
self, model: EmbeddingExtractor, df: pd.DataFrame, config: DataBaseConfiguration
):
self.model = model
self.config = config
self.path_to_metadata = config.faiss.path_to_metadata
if self.config.ranker.use_ranging:
self.k_neighbors = config.ranker.k_neighbors
else:
self.k_neighbors = config.search.vector_search.k_neighbors
self.__create_index(df)
def __create_index(self, df: pd.DataFrame):
"""Load the metadata file."""
if len(df) == 0:
self.index = None
return
df = df.where(pd.notna(df), None)
embeddings = np.array(df[COLUMN_EMBEDDING].tolist())
dim = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(embeddings)
def search_vectors(self, query: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Поиск векторов в индексе.
"""
logger.info(f"Searching vectors in index for query: {query}")
if self.index is None:
return (np.array([]), np.array([]), np.array([]))
query_embeds = self.model.query_embed_extraction(query, DO_NORMALIZATION)
scores, indexes = self.index.search(query_embeds, self.k_neighbors)
return query_embeds[0], scores[0], indexes[0]