Spaces:
Build error
Build error
File size: 1,248 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from torch import Tensor
def embed_similarity(key_embeds: Tensor,
ref_embeds: Tensor,
method: str = 'dot_product',
temperature: int = -1) -> Tensor:
"""Calculate feature similarity from embeddings.
Args:
key_embeds (Tensor): Shape (N1, C).
ref_embeds (Tensor): Shape (N2, C).
method (str, optional): Method to calculate the similarity,
options are 'dot_product' and 'cosine'. Defaults to
'dot_product'.
temperature (int, optional): Softmax temperature. Defaults to -1.
Returns:
Tensor: Similarity matrix of shape (N1, N2).
"""
assert method in ['dot_product', 'cosine']
if method == 'cosine':
key_embeds = F.normalize(key_embeds, p=2, dim=1)
ref_embeds = F.normalize(ref_embeds, p=2, dim=1)
similarity = torch.mm(key_embeds, ref_embeds.T)
if temperature > 0:
similarity /= float(temperature)
return similarity
|