Spaces:
Build error
Build error
File size: 2,626 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from mmengine.utils import digit_version
from torch import Tensor, nn
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
def _corner_pool(x: Tensor, dim: int, flip: bool) -> Tensor:
size = x.size(dim)
output = x.clone()
ind = 1
while ind < size:
if flip:
cur_start = 0
cur_len = size - ind
next_start = ind
next_len = size - ind
else:
cur_start = ind
cur_len = size - ind
next_start = 0
next_len = size - ind
# max_temp should be cloned for backward computation
max_temp = output.narrow(dim, cur_start, cur_len).clone()
cur_temp = output.narrow(dim, cur_start, cur_len)
next_temp = output.narrow(dim, next_start, next_len)
cur_temp[...] = torch.where(max_temp > next_temp, max_temp, next_temp)
ind = ind << 1
return output
class CornerPool(nn.Module):
"""Corner Pooling.
Corner Pooling is a new type of pooling layer that helps a
convolutional network better localize corners of bounding boxes.
Please refer to `CornerNet: Detecting Objects as Paired Keypoints
<https://arxiv.org/abs/1808.01244>`_ for more details.
Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
Args:
mode (str): Pooling orientation for the pooling layer
- 'bottom': Bottom Pooling
- 'left': Left Pooling
- 'right': Right Pooling
- 'top': Top Pooling
Returns:
Feature map after pooling.
"""
cummax_dim_flip = {
'bottom': (2, False),
'left': (3, True),
'right': (3, False),
'top': (2, True),
}
def __init__(self, mode: str):
super().__init__()
assert mode in self.cummax_dim_flip
self.mode = mode
def forward(self, x: Tensor) -> Tensor:
if (torch.__version__ != 'parrots' and
digit_version(torch.__version__) >= digit_version('1.5.0')):
dim, flip = self.cummax_dim_flip[self.mode]
if flip:
x = x.flip(dim)
pool_tensor, _ = torch.cummax(x, dim=dim)
if flip:
pool_tensor = pool_tensor.flip(dim)
return pool_tensor
else:
dim, flip = self.cummax_dim_flip[self.mode]
return _corner_pool(x, dim, flip)
|