# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Tuple, Union import mmengine import torch from torch import nn as nn from torch.autograd import Function from ..utils import ext_loader ext_module = ext_loader.load_ext( '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward']) class RoIAwarePool3d(nn.Module): """Encode the geometry-specific features of each 3D proposal. Please refer to `PartA2 `_ for more details. Args: out_size (int or tuple): The size of output features. n or [n1, n2, n3]. max_pts_per_voxel (int, optional): The maximum number of points per voxel. Default: 128. mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'. Default: 'max'. """ def __init__(self, out_size: Union[int, tuple], max_pts_per_voxel: int = 128, mode: str = 'max'): super().__init__() self.out_size = out_size self.max_pts_per_voxel = max_pts_per_voxel assert mode in ['max', 'avg'] pool_mapping = {'max': 0, 'avg': 1} self.mode = pool_mapping[mode] def forward(self, rois: torch.Tensor, pts: torch.Tensor, pts_feature: torch.Tensor) -> torch.Tensor: """ Args: rois (torch.Tensor): [N, 7], in LiDAR coordinate, (x, y, z) is the bottom center of rois. pts (torch.Tensor): [npoints, 3], coordinates of input points. pts_feature (torch.Tensor): [npoints, C], features of input points. Returns: torch.Tensor: Pooled features whose shape is [N, out_x, out_y, out_z, C]. """ return RoIAwarePool3dFunction.apply(rois, pts, pts_feature, self.out_size, self.max_pts_per_voxel, self.mode) class RoIAwarePool3dFunction(Function): @staticmethod def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor, pts_feature: torch.Tensor, out_size: Union[int, tuple], max_pts_per_voxel: int, mode: int) -> torch.Tensor: """ Args: rois (torch.Tensor): [N, 7], in LiDAR coordinate, (x, y, z) is the bottom center of rois. pts (torch.Tensor): [npoints, 3], coordinates of input points. pts_feature (torch.Tensor): [npoints, C], features of input points. out_size (int or tuple): The size of output features. n or [n1, n2, n3]. max_pts_per_voxel (int): The maximum number of points per voxel. Default: 128. mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average pool). Returns: torch.Tensor: Pooled features whose shape is [N, out_x, out_y, out_z, C]. """ if isinstance(out_size, int): out_x = out_y = out_z = out_size else: assert len(out_size) == 3 assert mmengine.is_tuple_of(out_size, int) out_x, out_y, out_z = out_size num_rois = rois.shape[0] num_channels = pts_feature.shape[-1] num_pts = pts.shape[0] pooled_features = pts_feature.new_zeros( (num_rois, out_x, out_y, out_z, num_channels)) argmax = pts_feature.new_zeros( (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int) pts_idx_of_voxels = pts_feature.new_zeros( (num_rois, out_x, out_y, out_z, max_pts_per_voxel), dtype=torch.int) ext_module.roiaware_pool3d_forward( rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features, pool_method=mode) ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, num_pts, num_channels) return pooled_features @staticmethod def backward( ctx: Any, grad_out: torch.Tensor ) -> Tuple[None, None, torch.Tensor, None, None, None]: ret = ctx.roiaware_pool3d_for_backward pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret grad_in = grad_out.new_zeros((num_pts, num_channels)) ext_module.roiaware_pool3d_backward( pts_idx_of_voxels, argmax, grad_out.contiguous(), grad_in, pool_method=mode) return None, None, grad_in, None, None, None