File size: 5,233 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional, Tuple, Union

import torch
import torch.nn as nn
from mmengine.utils import is_tuple_of
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['riroi_align_rotated_forward', 'riroi_align_rotated_backward'])


class RiRoIAlignRotatedFunction(Function):

    @staticmethod
    def forward(ctx: Any,
                features: torch.Tensor,
                rois: torch.Tensor,
                out_size: Union[int, tuple],
                spatial_scale: float,
                num_samples: int = 0,
                num_orientations: int = 8,
                clockwise: bool = False) -> torch.Tensor:
        if isinstance(out_size, int):
            out_h = out_size
            out_w = out_size
        elif is_tuple_of(out_size, int):
            assert len(out_size) == 2
            out_h, out_w = out_size
        else:
            raise TypeError(
                f'"out_size" should be an integer or tuple of integers,'
                f' but got {out_size}')
        ctx.spatial_scale = spatial_scale
        ctx.num_samples = num_samples
        ctx.num_orientations = num_orientations
        ctx.clockwise = clockwise
        ctx.save_for_backward(rois)
        ctx.feature_size = features.size()

        batch_size, num_channels, _, _ = features.size()
        num_rois = rois.size(0)

        output = features.new_zeros(num_rois, num_channels, out_h, out_w)

        ext_module.riroi_align_rotated_forward(
            features,
            rois,
            output,
            pooled_height=out_h,
            pooled_width=out_w,
            spatial_scale=spatial_scale,
            num_samples=num_samples,
            num_orientations=num_orientations,
            clockwise=clockwise)
        return output

    @staticmethod
    def backward(
        ctx: Any, grad_output: torch.Tensor
    ) -> Optional[Tuple[torch.Tensor, None, None, None, None, None, None]]:
        feature_size = ctx.feature_size
        spatial_scale = ctx.spatial_scale
        num_orientations = ctx.num_orientations
        clockwise = ctx.clockwise
        num_samples = ctx.num_samples
        rois = ctx.saved_tensors[0]
        assert feature_size is not None
        batch_size, num_channels, feature_h, feature_w = feature_size

        out_w = grad_output.size(3)
        out_h = grad_output.size(2)

        grad_input = None

        if ctx.needs_input_grad[0]:
            grad_input = rois.new_zeros(batch_size, num_channels, feature_h,
                                        feature_w)
            ext_module.riroi_align_rotated_backward(
                grad_output.contiguous(),
                rois,
                grad_input,
                pooled_height=out_h,
                pooled_width=out_w,
                spatial_scale=spatial_scale,
                num_samples=num_samples,
                num_orientations=num_orientations,
                clockwise=clockwise)

            return grad_input, None, None, None, None, None, None
        return None


riroi_align_rotated = RiRoIAlignRotatedFunction.apply


class RiRoIAlignRotated(nn.Module):
    """Rotation-invariant RoI align pooling layer for rotated proposals.

    It accepts a feature map of shape (N, C, H, W) and rois with shape
    (n, 6) with each roi decoded as (batch_index, center_x, center_y,
    w, h, angle). The angle is in radian.

    The details are described in the paper `ReDet: A Rotation-equivariant
    Detector for Aerial Object Detection  <https://arxiv.org/abs/2103.07733>`_.

    Args:
        out_size (tuple): fixed dimensional RoI output with shape (h, w).
        spatial_scale (float): scale the input boxes by this number
        num_samples (int): number of inputs samples to take for each
            output sample. 0 to take samples densely for current models.
        num_orientations (int): number of oriented channels.
        clockwise (bool): If True, the angle in each proposal follows a
            clockwise fashion in image space, otherwise, the angle is
            counterclockwise. Default: False.
    """

    def __init__(self,
                 out_size: tuple,
                 spatial_scale: float,
                 num_samples: int = 0,
                 num_orientations: int = 8,
                 clockwise: bool = False):
        super().__init__()

        self.out_size = out_size
        self.spatial_scale = float(spatial_scale)
        self.num_samples = int(num_samples)
        self.num_orientations = int(num_orientations)
        self.clockwise = clockwise

    def forward(self, features: torch.Tensor,
                rois: torch.Tensor) -> torch.Tensor:
        return RiRoIAlignRotatedFunction.apply(features, rois, self.out_size,
                                               self.spatial_scale,
                                               self.num_samples,
                                               self.num_orientations,
                                               self.clockwise)