File size: 5,002 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])


class MaskedConv2dFunction(Function):

    @staticmethod
    def symbolic(g, features, mask, weight, bias, padding, stride=1):
        return g.op(
            'mmcv::MMCVMaskedConv2d',
            features,
            mask,
            weight,
            bias,
            padding_i=padding,
            stride_i=stride)

    @staticmethod
    def forward(ctx,
                features: torch.Tensor,
                mask: torch.Tensor,
                weight: torch.nn.Parameter,
                bias: torch.nn.Parameter,
                padding: int = 0,
                stride: int = 1) -> torch.Tensor:
        assert mask.dim() == 3 and mask.size(0) == 1
        assert features.dim() == 4 and features.size(0) == 1
        assert features.size()[2:] == mask.size()[1:]
        pad_h, pad_w = _pair(padding)
        stride_h, stride_w = _pair(stride)
        if stride_h != 1 or stride_w != 1:
            raise ValueError(
                'Stride could not only be 1 in masked_conv2d currently.')
        out_channel, in_channel, kernel_h, kernel_w = weight.size()

        if features.device.type == 'npu':
            import torch_npu
            output = torch_npu.npu_conv2d(
                features,
                weight,
                bias,
                stride=(stride_h, stride_w),
                padding=(pad_h, pad_w),
                dilation=(1, 1),
                groups=1)
            if mask.size()[1:] != output.size()[2:]:
                raise ValueError(
                    'The mask is inconsistent with the shape of output_conv.')
            mask = mask > 0
            mask = mask.type(output.dtype)
            output = output * mask
            return output

        batch_size = features.size(0)
        out_h = int(
            math.floor(
                torch.true_divide((features.size(2) + 2 * pad_h -
                                   (kernel_h - 1) - 1), stride_h) + 1))
        out_w = int(
            math.floor(
                torch.true_divide((features.size(3) + 2 * pad_w -
                                   (kernel_w - 1) - 1), stride_w) + 1))
        mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
        output = features.new_zeros(batch_size, out_channel, out_h, out_w)
        if mask_inds.numel() > 0:
            mask_h_idx = mask_inds[:, 0].contiguous()
            mask_w_idx = mask_inds[:, 1].contiguous()
            data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
                                          mask_inds.size(0))
            ext_module.masked_im2col_forward(
                features,
                mask_h_idx,
                mask_w_idx,
                data_col,
                kernel_h=kernel_h,
                kernel_w=kernel_w,
                pad_h=pad_h,
                pad_w=pad_w)
            masked_output = torch.addmm(1, bias[:, None], 1,
                                        weight.view(out_channel, -1), data_col)
            ext_module.masked_col2im_forward(
                masked_output,
                mask_h_idx,
                mask_w_idx,
                output,
                height=out_h,
                width=out_w,
                channels=out_channel)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output: torch.Tensor) -> tuple:
        return (None, ) * 5


masked_conv2d = MaskedConv2dFunction.apply


class MaskedConv2d(nn.Conv2d):
    """A MaskedConv2d which inherits the official Conv2d.

    The masked forward doesn't implement the backward function and only
    supports the stride parameter to be 1 currently.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, ...]],
                 stride: int = 1,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1,
                 bias: bool = True):
        super().__init__(in_channels, out_channels, kernel_size, stride,
                         padding, dilation, groups, bias)

    def forward(self,
                input: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        if mask is None:  # fallback to the normal Conv2d
            return super().forward(input)
        else:
            return masked_conv2d(input, mask, self.weight, self.bias,
                                 self.padding)