File size: 5,669 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import Tensor

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', [
    'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
    'points_in_boxes_all_forward'
])


def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
    """Find the box in which each point is (CUDA).

    Args:
        points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate.
        boxes (torch.Tensor): [B, T, 7],
            num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
            LiDAR/DEPTH coordinate, (x, y, z) is the bottom center.

    Returns:
        torch.Tensor: Return the box indices of points with the shape of
        (B, M). Default background = -1.
    """
    assert points.shape[0] == boxes.shape[0], \
        'Points and boxes should have the same batch size, ' \
        f'but got {points.shape[0]} and {boxes.shape[0]}'
    assert boxes.shape[2] == 7, \
        'boxes dimension should be 7, ' \
        f'but got unexpected shape {boxes.shape[2]}'
    assert points.shape[2] == 3, \
        'points dimension should be 3, ' \
        f'but got unexpected shape {points.shape[2]}'
    batch_size, num_points, _ = points.shape

    box_idxs_of_pts = points.new_zeros((batch_size, num_points),
                                       dtype=torch.int).fill_(-1)

    # If manually put the tensor 'points' or 'boxes' on a device
    # which is not the current device, some temporary variables
    # will be created on the current device in the cuda op,
    # and the output will be incorrect.
    # Therefore, we force the current device to be the same
    # as the device of the tensors if it was not.
    # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
    # for the incorrect output before the fix.
    points_device = points.get_device()
    assert points_device == boxes.get_device(), \
        'Points and boxes should be put on the same device'
    if torch.cuda.current_device() != points_device:
        torch.cuda.set_device(points_device)

    ext_module.points_in_boxes_part_forward(boxes.contiguous(),
                                            points.contiguous(),
                                            box_idxs_of_pts)

    return box_idxs_of_pts


def points_in_boxes_cpu(points: Tensor, boxes: Tensor) -> Tensor:
    """Find all boxes in which each point is (CPU). The CPU version of
    :meth:`points_in_boxes_all`.

    Args:
        points (torch.Tensor): [B, M, 3], [x, y, z] in
            LiDAR/DEPTH coordinate
        boxes (torch.Tensor): [B, T, 7],
            num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
            (x, y, z) is the bottom center.

    Returns:
        torch.Tensor: Return the box indices of points with the shape of
        (B, M, T). Default background = 0.
    """
    assert points.shape[0] == boxes.shape[0], \
        'Points and boxes should have the same batch size, ' \
        f'but got {points.shape[0]} and {boxes.shape[0]}'
    assert boxes.shape[2] == 7, \
        'boxes dimension should be 7, ' \
        f'but got unexpected shape {boxes.shape[2]}'
    assert points.shape[2] == 3, \
        'points dimension should be 3, ' \
        f'but got unexpected shape {points.shape[2]}'
    batch_size, num_points, _ = points.shape
    num_boxes = boxes.shape[1]

    point_indices = points.new_zeros((batch_size, num_boxes, num_points),
                                     dtype=torch.int)
    for b in range(batch_size):
        ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
                                               points[b].float().contiguous(),
                                               point_indices[b])
    point_indices = point_indices.transpose(1, 2)

    return point_indices


def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor:
    """Find all boxes in which each point is (CUDA).

    Args:
        points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
        boxes (torch.Tensor): [B, T, 7],
            num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
            (x, y, z) is the bottom center.

    Returns:
        torch.Tensor: Return the box indices of points with the shape of
        (B, M, T). Default background = 0.
    """
    assert boxes.shape[0] == points.shape[0], \
        'Points and boxes should have the same batch size, ' \
        f'but got {boxes.shape[0]} and {boxes.shape[0]}'
    assert boxes.shape[2] == 7, \
        'boxes dimension should be 7, ' \
        f'but got unexpected shape {boxes.shape[2]}'
    assert points.shape[2] == 3, \
        'points dimension should be 3, ' \
        f'but got unexpected shape {points.shape[2]}'
    batch_size, num_points, _ = points.shape
    num_boxes = boxes.shape[1]

    box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
                                       dtype=torch.int).fill_(0)

    # Same reason as line 25-32
    points_device = points.get_device()
    assert points_device == boxes.get_device(), \
        'Points and boxes should be put on the same device'
    if torch.cuda.current_device() != points_device:
        torch.cuda.set_device(points_device)

    ext_module.points_in_boxes_all_forward(boxes.contiguous(),
                                           points.contiguous(),
                                           box_idxs_of_pts)

    return box_idxs_of_pts