# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple import torch from torch.autograd import Function from ..utils import ext_loader ext_module = ext_loader.load_ext( '_ext', ['ball_query_forward', 'stack_ball_query_forward']) class BallQuery(Function): """Find nearby points in spherical space.""" @staticmethod def forward( ctx, min_radius: float, max_radius: float, sample_num: int, xyz: torch.Tensor, center_xyz: torch.Tensor, xyz_batch_cnt: Optional[torch.Tensor] = None, center_xyz_batch_cnt: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: min_radius (float): minimum radius of the balls. max_radius (float): maximum radius of the balls. sample_num (int): maximum number of features in the balls. xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features, or staked input (N1 + N2 ..., 3). center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball query, or staked input (M1 + M2 ..., 3). xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in each batch, just like (N1, N2, ...). Defaults to None. New in version 1.7.0. center_xyz_batch_cnt: (batch_size): Stacked centers coordinates nums in each batch, just line (M1, M2, ...). Defaults to None. New in version 1.7.0. Returns: torch.Tensor: (B, npoint, nsample) tensor with the indices of the features that form the query balls. """ assert center_xyz.is_contiguous() assert xyz.is_contiguous() assert min_radius < max_radius if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None: assert xyz_batch_cnt.dtype == torch.int assert center_xyz_batch_cnt.dtype == torch.int idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num), dtype=torch.int32) ext_module.stack_ball_query_forward( center_xyz, center_xyz_batch_cnt, xyz, xyz_batch_cnt, idx, max_radius=max_radius, nsample=sample_num, ) else: B, N, _ = xyz.size() npoint = center_xyz.size(1) idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32) ext_module.ball_query_forward( center_xyz, xyz, idx, b=B, n=N, m=npoint, min_radius=min_radius, max_radius=max_radius, nsample=sample_num) if torch.__version__ != 'parrots': ctx.mark_non_differentiable(idx) return idx @staticmethod def backward(ctx, a=None) -> Tuple[None, None, None, None]: return None, None, None, None ball_query = BallQuery.apply