# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F from torch import Tensor def embed_similarity(key_embeds: Tensor, ref_embeds: Tensor, method: str = 'dot_product', temperature: int = -1) -> Tensor: """Calculate feature similarity from embeddings. Args: key_embeds (Tensor): Shape (N1, C). ref_embeds (Tensor): Shape (N2, C). method (str, optional): Method to calculate the similarity, options are 'dot_product' and 'cosine'. Defaults to 'dot_product'. temperature (int, optional): Softmax temperature. Defaults to -1. Returns: Tensor: Similarity matrix of shape (N1, N2). """ assert method in ['dot_product', 'cosine'] if method == 'cosine': key_embeds = F.normalize(key_embeds, p=2, dim=1) ref_embeds = F.normalize(ref_embeds, p=2, dim=1) similarity = torch.mm(key_embeds, ref_embeds.T) if temperature > 0: similarity /= float(temperature) return similarity