sapiens-pose / external /cv /mmcv /ops /roiaware_pool3d.py
rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple, Union
import mmengine
import torch
from torch import nn as nn
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
class RoIAwarePool3d(nn.Module):
"""Encode the geometry-specific features of each 3D proposal.
Please refer to `PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ for more
details.
Args:
out_size (int or tuple): The size of output features. n or
[n1, n2, n3].
max_pts_per_voxel (int, optional): The maximum number of points per
voxel. Default: 128.
mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
Default: 'max'.
"""
def __init__(self,
out_size: Union[int, tuple],
max_pts_per_voxel: int = 128,
mode: str = 'max'):
super().__init__()
self.out_size = out_size
self.max_pts_per_voxel = max_pts_per_voxel
assert mode in ['max', 'avg']
pool_mapping = {'max': 0, 'avg': 1}
self.mode = pool_mapping[mode]
def forward(self, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor) -> torch.Tensor:
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois.
pts (torch.Tensor): [npoints, 3], coordinates of input points.
pts_feature (torch.Tensor): [npoints, C], features of input points.
Returns:
torch.Tensor: Pooled features whose shape is
[N, out_x, out_y, out_z, C].
"""
return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
self.out_size,
self.max_pts_per_voxel, self.mode)
class RoIAwarePool3dFunction(Function):
@staticmethod
def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor, out_size: Union[int, tuple],
max_pts_per_voxel: int, mode: int) -> torch.Tensor:
"""
Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate,
(x, y, z) is the bottom center of rois.
pts (torch.Tensor): [npoints, 3], coordinates of input points.
pts_feature (torch.Tensor): [npoints, C], features of input points.
out_size (int or tuple): The size of output features. n or
[n1, n2, n3].
max_pts_per_voxel (int): The maximum number of points per voxel.
Default: 128.
mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
pool).
Returns:
torch.Tensor: Pooled features whose shape is
[N, out_x, out_y, out_z, C].
"""
if isinstance(out_size, int):
out_x = out_y = out_z = out_size
else:
assert len(out_size) == 3
assert mmengine.is_tuple_of(out_size, int)
out_x, out_y, out_z = out_size
num_rois = rois.shape[0]
num_channels = pts_feature.shape[-1]
num_pts = pts.shape[0]
pooled_features = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, num_channels))
argmax = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
pts_idx_of_voxels = pts_feature.new_zeros(
(num_rois, out_x, out_y, out_z, max_pts_per_voxel),
dtype=torch.int)
ext_module.roiaware_pool3d_forward(
rois,
pts,
pts_feature,
argmax,
pts_idx_of_voxels,
pooled_features,
pool_method=mode)
ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
num_pts, num_channels)
return pooled_features
@staticmethod
def backward(
ctx: Any, grad_out: torch.Tensor
) -> Tuple[None, None, torch.Tensor, None, None, None]:
ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
grad_in = grad_out.new_zeros((num_pts, num_channels))
ext_module.roiaware_pool3d_backward(
pts_idx_of_voxels,
argmax,
grad_out.contiguous(),
grad_in,
pool_method=mode)
return None, None, grad_in, None, None, None