Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| def embed_similarity(key_embeds: Tensor, | |
| ref_embeds: Tensor, | |
| method: str = 'dot_product', | |
| temperature: int = -1) -> Tensor: | |
| """Calculate feature similarity from embeddings. | |
| Args: | |
| key_embeds (Tensor): Shape (N1, C). | |
| ref_embeds (Tensor): Shape (N2, C). | |
| method (str, optional): Method to calculate the similarity, | |
| options are 'dot_product' and 'cosine'. Defaults to | |
| 'dot_product'. | |
| temperature (int, optional): Softmax temperature. Defaults to -1. | |
| Returns: | |
| Tensor: Similarity matrix of shape (N1, N2). | |
| """ | |
| assert method in ['dot_product', 'cosine'] | |
| if method == 'cosine': | |
| key_embeds = F.normalize(key_embeds, p=2, dim=1) | |
| ref_embeds = F.normalize(ref_embeds, p=2, dim=1) | |
| similarity = torch.mm(key_embeds, ref_embeds.T) | |
| if temperature > 0: | |
| similarity /= float(temperature) | |
| return similarity | |