Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Sequence, Tuple | |
import torch | |
from torch import Tensor | |
from torch.autograd import Function | |
from torch.autograd.function import once_differentiable | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext( | |
'_ext', ['chamfer_distance_forward', 'chamfer_distance_backward']) | |
class ChamferDistanceFunction(Function): | |
"""This is an implementation of the 2D Chamfer Distance. | |
It has been used in the paper `Oriented RepPoints for Aerial Object | |
Detection (CVPR 2022) <https://arxiv.org/abs/2105.11111>_`. | |
""" | |
def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]: | |
""" | |
Args: | |
xyz1 (Tensor): Point set with shape (B, N, 2). | |
xyz2 (Tensor): Point set with shape (B, N, 2). | |
Returns: | |
Sequence[Tensor]: | |
- dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with | |
shape (B, N). | |
- dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with | |
shape (B, N). | |
- idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) | |
with shape (B, N), which be used in compute gradient. | |
- idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) | |
with shape (B, N), which be used in compute gradient. | |
""" | |
batch_size, n, _ = xyz1.size() | |
_, m, _ = xyz2.size() | |
device = xyz1.device | |
xyz1 = xyz1.contiguous() | |
xyz2 = xyz2.contiguous() | |
dist1 = torch.zeros(batch_size, n).to(device) | |
dist2 = torch.zeros(batch_size, m).to(device) | |
idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device) | |
idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device) | |
ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, | |
idx2) | |
ctx.save_for_backward(xyz1, xyz2, idx1, idx2) | |
return dist1, dist2, idx1, idx2 | |
def backward(ctx, | |
grad_dist1: Tensor, | |
grad_dist2: Tensor, | |
grad_idx1=None, | |
grad_idx2=None) -> Tuple[Tensor, Tensor]: | |
""" | |
Args: | |
grad_dist1 (Tensor): Gradient of chamfer distance | |
(xyz1 to xyz2) with shape (B, N). | |
grad_dist2 (Tensor): Gradient of chamfer distance | |
(xyz2 to xyz1) with shape (B, N). | |
Returns: | |
Tuple[Tensor, Tensor]: | |
- grad_xyz1 (Tensor): Gradient of the point set with shape \ | |
(B, N, 2). | |
- grad_xyz2 (Tensor):Gradient of the point set with shape \ | |
(B, N, 2). | |
""" | |
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors | |
device = grad_dist1.device | |
grad_dist1 = grad_dist1.contiguous() | |
grad_dist2 = grad_dist2.contiguous() | |
grad_xyz1 = torch.zeros(xyz1.size()).to(device) | |
grad_xyz2 = torch.zeros(xyz2.size()).to(device) | |
ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2, | |
grad_dist1, grad_dist2, grad_xyz1, | |
grad_xyz2) | |
return grad_xyz1, grad_xyz2 | |
chamfer_distance = ChamferDistanceFunction.apply | |