File size: 3,289 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple

import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['ball_query_forward', 'stack_ball_query_forward'])


class BallQuery(Function):
    """Find nearby points in spherical space."""

    @staticmethod
    def forward(
            ctx,
            min_radius: float,
            max_radius: float,
            sample_num: int,
            xyz: torch.Tensor,
            center_xyz: torch.Tensor,
            xyz_batch_cnt: Optional[torch.Tensor] = None,
            center_xyz_batch_cnt: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            min_radius (float): minimum radius of the balls.
            max_radius (float): maximum radius of the balls.
            sample_num (int): maximum number of features in the balls.
            xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features,
                or staked input (N1 + N2 ..., 3).
            center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
                query, or staked input (M1 + M2 ..., 3).
            xyz_batch_cnt: (batch_size): Stacked input xyz coordinates nums in
                each batch, just like (N1, N2, ...). Defaults to None.
                New in version 1.7.0.
            center_xyz_batch_cnt: (batch_size): Stacked centers coordinates
                nums in each batch, just line (M1, M2, ...). Defaults to None.
                New in version 1.7.0.

        Returns:
            torch.Tensor: (B, npoint, nsample) tensor with the indices of the
            features that form the query balls.
        """
        assert center_xyz.is_contiguous()
        assert xyz.is_contiguous()
        assert min_radius < max_radius
        if xyz_batch_cnt is not None and center_xyz_batch_cnt is not None:
            assert xyz_batch_cnt.dtype == torch.int
            assert center_xyz_batch_cnt.dtype == torch.int
            idx = center_xyz.new_zeros((center_xyz.shape[0], sample_num),
                                       dtype=torch.int32)
            ext_module.stack_ball_query_forward(
                center_xyz,
                center_xyz_batch_cnt,
                xyz,
                xyz_batch_cnt,
                idx,
                max_radius=max_radius,
                nsample=sample_num,
            )
        else:
            B, N, _ = xyz.size()
            npoint = center_xyz.size(1)
            idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int32)
            ext_module.ball_query_forward(
                center_xyz,
                xyz,
                idx,
                b=B,
                n=N,
                m=npoint,
                min_radius=min_radius,
                max_radius=max_radius,
                nsample=sample_num)
        if torch.__version__ != 'parrots':
            ctx.mark_non_differentiable(idx)
        return idx

    @staticmethod
    def backward(ctx, a=None) -> Tuple[None, None, None, None]:
        return None, None, None, None


ball_query = BallQuery.apply