Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.utils.dl_utils import TORCH_VERSION | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from torch.nn.modules.utils import _pair | |
| from ..utils import ext_loader | |
| ext_module = ext_loader.load_ext( | |
| '_ext', | |
| ['prroi_pool_forward', 'prroi_pool_backward', 'prroi_pool_coor_backward']) | |
| class PrRoIPoolFunction(Function): | |
| def symbolic(g, features, rois, output_size, spatial_scale): | |
| return g.op( | |
| 'mmcv::PrRoIPool', | |
| features, | |
| rois, | |
| pooled_height_i=int(output_size[0]), | |
| pooled_width_i=int(output_size[1]), | |
| spatial_scale_f=float(spatial_scale)) | |
| def forward(ctx, | |
| features: torch.Tensor, | |
| rois: torch.Tensor, | |
| output_size: Tuple, | |
| spatial_scale: float = 1.0) -> torch.Tensor: | |
| if features.dtype != torch.float32 or rois.dtype != torch.float32: | |
| raise ValueError('Precise RoI Pooling only takes float input, got ' | |
| f'{features.dtype()} for features and' | |
| f'{rois.dtype()} for rois.') | |
| pooled_height = int(output_size[0]) | |
| pooled_width = int(output_size[1]) | |
| spatial_scale = float(spatial_scale) | |
| features = features.contiguous() | |
| rois = rois.contiguous() | |
| output_shape = (rois.size(0), features.size(1), pooled_height, | |
| pooled_width) | |
| output = features.new_zeros(output_shape) | |
| params = (pooled_height, pooled_width, spatial_scale) | |
| ext_module.prroi_pool_forward( | |
| features, | |
| rois, | |
| output, | |
| pooled_height=params[0], | |
| pooled_width=params[1], | |
| spatial_scale=params[2]) | |
| ctx.params = params | |
| # everything here is contiguous. | |
| ctx.save_for_backward(features, rois, output) | |
| return output | |
| def backward( | |
| ctx, grad_output: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: | |
| features, rois, output = ctx.saved_tensors | |
| grad_input = grad_output.new_zeros(*features.shape) | |
| grad_coor = grad_output.new_zeros(*rois.shape) | |
| if features.requires_grad or TORCH_VERSION == 'parrots': | |
| grad_output = grad_output.contiguous() | |
| ext_module.prroi_pool_backward( | |
| grad_output, | |
| rois, | |
| grad_input, | |
| pooled_height=ctx.params[0], | |
| pooled_width=ctx.params[1], | |
| spatial_scale=ctx.params[2]) | |
| if rois.requires_grad or TORCH_VERSION == 'parrots': | |
| grad_output = grad_output.contiguous() | |
| ext_module.prroi_pool_coor_backward( | |
| output, | |
| grad_output, | |
| features, | |
| rois, | |
| grad_coor, | |
| pooled_height=ctx.params[0], | |
| pooled_width=ctx.params[1], | |
| spatial_scale=ctx.params[2]) | |
| return grad_input, grad_coor, None, None, None | |
| prroi_pool = PrRoIPoolFunction.apply | |
| class PrRoIPool(nn.Module): | |
| """The operation of precision RoI pooling. The implementation of PrRoIPool | |
| is modified from https://github.com/vacancy/PreciseRoIPooling/ | |
| Precise RoI Pooling (PrRoIPool) is an integration-based (bilinear | |
| interpolation) average pooling method for RoI Pooling. It avoids any | |
| quantization and has a continuous gradient on bounding box coordinates. | |
| It is: | |
| 1. different from the original RoI Pooling proposed in Fast R-CNN. PrRoI | |
| Pooling uses average pooling instead of max pooling for each bin and has a | |
| continuous gradient on bounding box coordinates. That is, one can take the | |
| derivatives of some loss function w.r.t the coordinates of each RoI and | |
| optimize the RoI coordinates. | |
| 2. different from the RoI Align proposed in Mask R-CNN. PrRoI Pooling uses | |
| a full integration-based average pooling instead of sampling a constant | |
| number of points. This makes the gradient w.r.t. the coordinates | |
| continuous. | |
| Args: | |
| output_size (Union[int, tuple]): h, w. | |
| spatial_scale (float, optional): scale the input boxes by this number. | |
| Defaults to 1.0. | |
| """ | |
| def __init__(self, | |
| output_size: Union[int, tuple], | |
| spatial_scale: float = 1.0): | |
| super().__init__() | |
| self.output_size = _pair(output_size) | |
| self.spatial_scale = float(spatial_scale) | |
| def forward(self, features: torch.Tensor, | |
| rois: torch.Tensor) -> torch.Tensor: | |
| """Forward function. | |
| Args: | |
| features (torch.Tensor): The feature map. | |
| rois (torch.Tensor): The RoI bboxes in [tl_x, tl_y, br_x, br_y] | |
| format. | |
| Returns: | |
| torch.Tensor: The pooled results. | |
| """ | |
| return prroi_pool(features, rois, self.output_size, self.spatial_scale) | |
| def __repr__(self): | |
| s = self.__class__.__name__ | |
| s += f'(output_size={self.output_size}, ' | |
| s += f'spatial_scale={self.spatial_scale})' | |
| return s | |