rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['ball_query_forward', 'stack_ball_query_forward'])
class BallQuery(Function):
"""Find nearby points in spherical space."""
@staticmethod
def forward(
ctx,
min_radius: float,
max_radius: float,
sample_num: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
xyz_batch_cnt: Optional[torch.Tensor] = None,
center_xyz_batch_cnt: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
or staked input (N1 + N2 ..., 3).
center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
query, or staked input (M1 + M2 ..., 3).
xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
each batch, just like (N1, N2, ...). Defaults to None.
New in version 1.7.0.
center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
nums in each batch, just line (M1, M2, ...). Defaults to None.
New in version 1.7.0.
Returns:
torch.Tensor: (B, npoint, nsample) tensor with the indices of the
features that form the query balls.
"""
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius
if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
assert xyz_batch_cnt.dtype == torch.int
assert center_xyz_batch_cnt.dtype == torch.int
idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
dtype=torch.int32)
ext_module.stack_ball_query_forward(
center_xyz,
center_xyz_batch_cnt,
xyz,
xyz_batch_cnt,
idx,
max_radius=max_radius,
nsample=sample_num,
)
else:
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
ext_module.ball_query_forward(
center_xyz,
xyz,
idx,
b=B,
n=N,
m=npoint,
min_radius=min_radius,
max_radius=max_radius,
nsample=sample_num)
if torch.__version__ != 'parrots':
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, a=None) -> Tuple[None, None, None, None]:
return None, None, None, None
ball_query = BallQuery.apply