Spaces:
Running
Running
class Retriever: | |
def __init__(self, embeddings_path: str): | |
self.embeddings: Dict[str, np.ndarray] = self.load_embeddings(embeddings_path) | |
# Keep track of image names | |
self.image_to_index = {image_name: i for i, image_name in enumerate(self.embeddings.keys())} | |
self.index_to_image = {i: image_name for i, image_name in enumerate(self.embeddings.keys())} | |
# Build Faiss index | |
self.embeddings = np.array(list(self.embeddings.values())) | |
self.dim = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(self.dim) | |
self.index.add(self.embeddings) | |
def load_embeddings(embeddings_path: str) -> Dict[str, np.ndarray]: | |
"""Load embeddings from a file | |
""" | |
raise NotImplementedError | |
def retrieve(self, queries: np.ndarray, n_neighbors: int = 5) -> List[str]: | |
"""Retrieve nearest neighbors indexes from queries | |
""" | |
dist, indexes = self.index.search(queries, n_neighbors) | |
return [[self.index_to_image[i] for i in index] for index in indexes] | |