# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Optional, Tuple, Union import torch import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair from ..utils import ext_loader ext_module = ext_loader.load_ext( '_ext', ['masked_im2col_forward', 'masked_col2im_forward']) class MaskedConv2dFunction(Function): @staticmethod def symbolic(g, features, mask, weight, bias, padding, stride=1): return g.op( 'mmcv::MMCVMaskedConv2d', features, mask, weight, bias, padding_i=padding, stride_i=stride) @staticmethod def forward(ctx, features: torch.Tensor, mask: torch.Tensor, weight: torch.nn.Parameter, bias: torch.nn.Parameter, padding: int = 0, stride: int = 1) -> torch.Tensor: assert mask.dim() == 3 and mask.size(0) == 1 assert features.dim() == 4 and features.size(0) == 1 assert features.size()[2:] == mask.size()[1:] pad_h, pad_w = _pair(padding) stride_h, stride_w = _pair(stride) if stride_h != 1 or stride_w != 1: raise ValueError( 'Stride could not only be 1 in masked_conv2d currently.') out_channel, in_channel, kernel_h, kernel_w = weight.size() if features.device.type == 'npu': import torch_npu output = torch_npu.npu_conv2d( features, weight, bias, stride=(stride_h, stride_w), padding=(pad_h, pad_w), dilation=(1, 1), groups=1) if mask.size()[1:] != output.size()[2:]: raise ValueError( 'The mask is inconsistent with the shape of output_conv.') mask = mask > 0 mask = mask.type(output.dtype) output = output * mask return output batch_size = features.size(0) out_h = int( math.floor( torch.true_divide((features.size(2) + 2 * pad_h - (kernel_h - 1) - 1), stride_h) + 1)) out_w = int( math.floor( torch.true_divide((features.size(3) + 2 * pad_w - (kernel_w - 1) - 1), stride_w) + 1)) mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False) output = features.new_zeros(batch_size, out_channel, out_h, out_w) if mask_inds.numel() > 0: mask_h_idx = mask_inds[:, 0].contiguous() mask_w_idx = mask_inds[:, 1].contiguous() data_col = features.new_zeros(in_channel * kernel_h * kernel_w, mask_inds.size(0)) ext_module.masked_im2col_forward( features, mask_h_idx, mask_w_idx, data_col, kernel_h=kernel_h, kernel_w=kernel_w, pad_h=pad_h, pad_w=pad_w) masked_output = torch.addmm(1, bias[:, None], 1, weight.view(out_channel, -1), data_col) ext_module.masked_col2im_forward( masked_output, mask_h_idx, mask_w_idx, output, height=out_h, width=out_w, channels=out_channel) return output @staticmethod @once_differentiable def backward(ctx, grad_output: torch.Tensor) -> tuple: return (None, ) * 5 masked_conv2d = MaskedConv2dFunction.apply class MaskedConv2d(nn.Conv2d): """A MaskedConv2d which inherits the official Conv2d. The masked forward doesn't implement the backward function and only supports the stride parameter to be 1 currently. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, ...]], stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True): super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: if mask is None: # fallback to the normal Conv2d return super().forward(input) else: return masked_conv2d(input, mask, self.weight, self.bias, self.padding)