# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Sequence, Tuple import torch from torch import Tensor from torch.autograd import Function from torch.autograd.function import once_differentiable from ..utils import ext_loader ext_module = ext_loader.load_ext( '_ext', ['chamfer_distance_forward', 'chamfer_distance_backward']) class ChamferDistanceFunction(Function): """This is an implementation of the 2D Chamfer Distance. It has been used in the paper `Oriented RepPoints for Aerial Object Detection (CVPR 2022) _`. """ @staticmethod def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]: """ Args: xyz1 (Tensor): Point set with shape (B, N, 2). xyz2 (Tensor): Point set with shape (B, N, 2). Returns: Sequence[Tensor]: - dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with shape (B, N). - dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with shape (B, N). - idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) with shape (B, N), which be used in compute gradient. - idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) with shape (B, N), which be used in compute gradient. """ batch_size, n, _ = xyz1.size() _, m, _ = xyz2.size() device = xyz1.device xyz1 = xyz1.contiguous() xyz2 = xyz2.contiguous() dist1 = torch.zeros(batch_size, n).to(device) dist2 = torch.zeros(batch_size, m).to(device) idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device) idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device) ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2) ctx.save_for_backward(xyz1, xyz2, idx1, idx2) return dist1, dist2, idx1, idx2 @staticmethod @once_differentiable def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor, grad_idx1=None, grad_idx2=None) -> Tuple[Tensor, Tensor]: """ Args: grad_dist1 (Tensor): Gradient of chamfer distance (xyz1 to xyz2) with shape (B, N). grad_dist2 (Tensor): Gradient of chamfer distance (xyz2 to xyz1) with shape (B, N). Returns: Tuple[Tensor, Tensor]: - grad_xyz1 (Tensor): Gradient of the point set with shape \ (B, N, 2). - grad_xyz2 (Tensor):Gradient of the point set with shape \ (B, N, 2). """ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors device = grad_dist1.device grad_dist1 = grad_dist1.contiguous() grad_dist2 = grad_dist2.contiguous() grad_xyz1 = torch.zeros(xyz1.size()).to(device) grad_xyz2 = torch.zeros(xyz2.size()).to(device) ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2, grad_xyz1, grad_xyz2) return grad_xyz1, grad_xyz2 chamfer_distance = ChamferDistanceFunction.apply