# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from mmengine.utils import digit_version from torch import Tensor, nn _mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} def _corner_pool(x: Tensor, dim: int, flip: bool) -> Tensor: size = x.size(dim) output = x.clone() ind = 1 while ind < size: if flip: cur_start = 0 cur_len = size - ind next_start = ind next_len = size - ind else: cur_start = ind cur_len = size - ind next_start = 0 next_len = size - ind # max_temp should be cloned for backward computation max_temp = output.narrow(dim, cur_start, cur_len).clone() cur_temp = output.narrow(dim, cur_start, cur_len) next_temp = output.narrow(dim, next_start, next_len) cur_temp[...] = torch.where(max_temp > next_temp, max_temp, next_temp) ind = ind << 1 return output class CornerPool(nn.Module): """Corner Pooling. Corner Pooling is a new type of pooling layer that helps a convolutional network better localize corners of bounding boxes. Please refer to `CornerNet: Detecting Objects as Paired Keypoints `_ for more details. Code is modified from https://github.com/princeton-vl/CornerNet-Lite. Args: mode (str): Pooling orientation for the pooling layer - 'bottom': Bottom Pooling - 'left': Left Pooling - 'right': Right Pooling - 'top': Top Pooling Returns: Feature map after pooling. """ cummax_dim_flip = { 'bottom': (2, False), 'left': (3, True), 'right': (3, False), 'top': (2, True), } def __init__(self, mode: str): super().__init__() assert mode in self.cummax_dim_flip self.mode = mode def forward(self, x: Tensor) -> Tensor: if (torch.__version__ != 'parrots' and digit_version(torch.__version__) >= digit_version('1.5.0')): dim, flip = self.cummax_dim_flip[self.mode] if flip: x = x.flip(dim) pool_tensor, _ = torch.cummax(x, dim=dim) if flip: pool_tensor = pool_tensor.flip(dim) return pool_tensor else: dim, flip = self.cummax_dim_flip[self.mode] return _corner_pool(x, dim, flip)