Spaces:
Build error
Build error
File size: 5,002 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
class MaskedConv2dFunction(Function):
@staticmethod
def symbolic(g, features, mask, weight, bias, padding, stride=1):
return g.op(
'mmcv::MMCVMaskedConv2d',
features,
mask,
weight,
bias,
padding_i=padding,
stride_i=stride)
@staticmethod
def forward(ctx,
features: torch.Tensor,
mask: torch.Tensor,
weight: torch.nn.Parameter,
bias: torch.nn.Parameter,
padding: int = 0,
stride: int = 1) -> torch.Tensor:
assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
pad_h, pad_w = _pair(padding)
stride_h, stride_w = _pair(stride)
if stride_h != 1 or stride_w != 1:
raise ValueError(
'Stride could not only be 1 in masked_conv2d currently.')
out_channel, in_channel, kernel_h, kernel_w = weight.size()
if features.device.type == 'npu':
import torch_npu
output = torch_npu.npu_conv2d(
features,
weight,
bias,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(1, 1),
groups=1)
if mask.size()[1:] != output.size()[2:]:
raise ValueError(
'The mask is inconsistent with the shape of output_conv.')
mask = mask > 0
mask = mask.type(output.dtype)
output = output * mask
return output
batch_size = features.size(0)
out_h = int(
math.floor(
torch.true_divide((features.size(2) + 2 * pad_h -
(kernel_h - 1) - 1), stride_h) + 1))
out_w = int(
math.floor(
torch.true_divide((features.size(3) + 2 * pad_w -
(kernel_w - 1) - 1), stride_w) + 1))
mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
output = features.new_zeros(batch_size, out_channel, out_h, out_w)
if mask_inds.numel() > 0:
mask_h_idx = mask_inds[:, 0].contiguous()
mask_w_idx = mask_inds[:, 1].contiguous()
data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
mask_inds.size(0))
ext_module.masked_im2col_forward(
features,
mask_h_idx,
mask_w_idx,
data_col,
kernel_h=kernel_h,
kernel_w=kernel_w,
pad_h=pad_h,
pad_w=pad_w)
masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col)
ext_module.masked_col2im_forward(
masked_output,
mask_h_idx,
mask_w_idx,
output,
height=out_h,
width=out_w,
channels=out_channel)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
return (None, ) * 5
masked_conv2d = MaskedConv2dFunction.apply
class MaskedConv2d(nn.Conv2d):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self,
input: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if mask is None: # fallback to the normal Conv2d
return super().forward(input)
else:
return masked_conv2d(input, mask, self.weight, self.bias,
self.padding)
|