Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from mmengine.utils import digit_version | |
from torch import Tensor, nn | |
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} | |
def _corner_pool(x: Tensor, dim: int, flip: bool) -> Tensor: | |
size = x.size(dim) | |
output = x.clone() | |
ind = 1 | |
while ind < size: | |
if flip: | |
cur_start = 0 | |
cur_len = size - ind | |
next_start = ind | |
next_len = size - ind | |
else: | |
cur_start = ind | |
cur_len = size - ind | |
next_start = 0 | |
next_len = size - ind | |
# max_temp should be cloned for backward computation | |
max_temp = output.narrow(dim, cur_start, cur_len).clone() | |
cur_temp = output.narrow(dim, cur_start, cur_len) | |
next_temp = output.narrow(dim, next_start, next_len) | |
cur_temp[...] = torch.where(max_temp > next_temp, max_temp, next_temp) | |
ind = ind << 1 | |
return output | |
class CornerPool(nn.Module): | |
"""Corner Pooling. | |
Corner Pooling is a new type of pooling layer that helps a | |
convolutional network better localize corners of bounding boxes. | |
Please refer to `CornerNet: Detecting Objects as Paired Keypoints | |
<https://arxiv.org/abs/1808.01244>`_ for more details. | |
Code is modified from https://github.com/princeton-vl/CornerNet-Lite. | |
Args: | |
mode (str): Pooling orientation for the pooling layer | |
- 'bottom': Bottom Pooling | |
- 'left': Left Pooling | |
- 'right': Right Pooling | |
- 'top': Top Pooling | |
Returns: | |
Feature map after pooling. | |
""" | |
cummax_dim_flip = { | |
'bottom': (2, False), | |
'left': (3, True), | |
'right': (3, False), | |
'top': (2, True), | |
} | |
def __init__(self, mode: str): | |
super().__init__() | |
assert mode in self.cummax_dim_flip | |
self.mode = mode | |
def forward(self, x: Tensor) -> Tensor: | |
if (torch.__version__ != 'parrots' and | |
digit_version(torch.__version__) >= digit_version('1.5.0')): | |
dim, flip = self.cummax_dim_flip[self.mode] | |
if flip: | |
x = x.flip(dim) | |
pool_tensor, _ = torch.cummax(x, dim=dim) | |
if flip: | |
pool_tensor = pool_tensor.flip(dim) | |
return pool_tensor | |
else: | |
dim, flip = self.cummax_dim_flip[self.mode] | |
return _corner_pool(x, dim, flip) | |