# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from mmengine.logging import print_log from mmengine.registry import MODELS from mmengine.utils import deprecated_api_warning from torch import Tensor from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single from mmcv.utils import IS_MLU_AVAILABLE from ..utils import ext_loader from .modulated_deform_conv import ModulatedDeformConv2dFunction ext_module = ext_loader.load_ext('_ext', [ 'deform_conv_forward', 'deform_conv_backward_input', 'deform_conv_backward_parameters' ]) class DeformConv2dFunction(Function): @staticmethod def symbolic(g, input, offset, weight, stride, padding, dilation, groups, deform_groups, bias=False, im2col_step=32): return g.op( 'mmcv::MMCVDeformConv2d', input, offset, weight, stride_i=stride, padding_i=padding, dilation_i=dilation, groups_i=groups, deform_groups_i=deform_groups, bias_i=bias, im2col_step_i=im2col_step) @staticmethod def _npu_backward(ctx, grad_output): input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ ctx.saved_tensors grad_input, grad_weight, grad_offset_all, grad_bias = \ torch.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, kernel_size=[weight.shape[3], weight.shape[2]], stride=[1, 1, ctx.stride[0], ctx.stride[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]], dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], groups=ctx.groups, deformable_groups=ctx.deform_groups, modulated=True) grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp) return grad_input, grad_offset, grad_weight, \ None, None, None, None, None, None, None @staticmethod def forward(ctx, input: Tensor, offset: Tensor, weight: Tensor, stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, deform_groups: int = 1, bias: bool = False, im2col_step: int = 32) -> Tensor: if input is not None and input.dim() != 4: raise ValueError( f'Expected 4D tensor as input, got {input.dim()}D tensor \ instead.') assert bias is False, 'Only support bias is False.' ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.groups = groups ctx.deform_groups = deform_groups ctx.im2col_step = im2col_step ctx.device = input.device.type # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; # amp won't cast the type of model (float32), but "offset" is cast # to float16 by nn.Conv2d automatically, leading to the type # mismatch with input (when it is float32) or weight. # The flag for whether to use fp16 or amp is the type of "offset", # we cast weight and input to temporarily support fp16 and amp # whatever the pytorch version is. input = input.type_as(offset) weight = weight.type_as(input) if ctx.device == 'npu': mask_shape, _ = torch.chunk(offset, 2, dim=1) mask = torch.ones_like(mask_shape).to(input.device) bias = input.new_empty(0) output = ModulatedDeformConv2dFunction._npu_forward( ctx, input, offset, mask, weight, bias) return output ctx.save_for_backward(input, offset, weight) output = input.new_empty([ int(i) for i in DeformConv2dFunction._output_size(ctx, input, weight) ]) ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones cur_im2col_step = min(ctx.im2col_step, input.size(0)) assert (input.size(0) % cur_im2col_step ) == 0, 'batch size must be divisible by im2col_step' ext_module.deform_conv_forward( input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1], kW=weight.size(3), kH=weight.size(2), dW=ctx.stride[1], dH=ctx.stride[0], padW=ctx.padding[1], padH=ctx.padding[0], dilationW=ctx.dilation[1], dilationH=ctx.dilation[0], group=ctx.groups, deformable_group=ctx.deform_groups, im2col_step=cur_im2col_step) return output @staticmethod @once_differentiable def backward( ctx, grad_output: Tensor ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None, None, None, None, None, None, None]: if ctx.device == 'npu': return DeformConv2dFunction._npu_backward(ctx, grad_output) input, offset, weight = ctx.saved_tensors grad_input = grad_offset = grad_weight = None cur_im2col_step = min(ctx.im2col_step, input.size(0)) assert (input.size(0) % cur_im2col_step ) == 0, 'batch size must be divisible by im2col_step' grad_output = grad_output.contiguous() if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) ext_module.deform_conv_backward_input( input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], kW=weight.size(3), kH=weight.size(2), dW=ctx.stride[1], dH=ctx.stride[0], padW=ctx.padding[1], padH=ctx.padding[0], dilationW=ctx.dilation[1], dilationH=ctx.dilation[0], group=ctx.groups, deformable_group=ctx.deform_groups, im2col_step=cur_im2col_step) if ctx.needs_input_grad[2]: grad_weight = torch.zeros_like(weight) ext_module.deform_conv_backward_parameters( input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], kW=weight.size(3), kH=weight.size(2), dW=ctx.stride[1], dH=ctx.stride[0], padW=ctx.padding[1], padH=ctx.padding[0], dilationW=ctx.dilation[1], dilationH=ctx.dilation[0], group=ctx.groups, deformable_group=ctx.deform_groups, scale=1, im2col_step=cur_im2col_step) return grad_input, grad_offset, grad_weight, \ None, None, None, None, None, None, None @staticmethod def _output_size(ctx, input, weight): channels = weight.size(0) output_size = (input.size(0), channels) for d in range(input.dim() - 2): in_size = input.size(d + 2) pad = ctx.padding[d] kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 stride_ = ctx.stride[d] output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) if not all(map(lambda s: s > 0, output_size)): raise ValueError( 'convolution input is too small (output would be ' + 'x'.join(map(str, output_size)) + ')') return output_size deform_conv2d = DeformConv2dFunction.apply class DeformConv2d(nn.Module): r"""Deformable 2D convolution. Applies a deformable 2D convolution over an input signal composed of several input planes. DeformConv2d was described in the paper `Deformable Convolutional Networks `_ Note: The argument ``im2col_step`` was added in version 1.3.17, which means number of samples processed by the ``im2col_cuda_kernel`` per call. It enables users to define ``batch_size`` and ``im2col_step`` more flexibly and solved `issue mmcv#1440 `_. Args: in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size(int, tuple): Size of the convolving kernel. stride(int, tuple): Stride of the convolution. Default: 1. padding (int or tuple): Zero-padding added to both sides of the input. Default: 0. dilation (int or tuple): Spacing between kernel elements. Default: 1. groups (int): Number of blocked connections from input. channels to output channels. Default: 1. deform_groups (int): Number of deformable group partitions. bias (bool): If True, adds a learnable bias to the output. Default: False. im2col_step (int): Number of samples processed by im2col_cuda_kernel per call. It will work when ``batch_size`` > ``im2col_step``, but ``batch_size`` must be divisible by ``im2col_step``. Default: 32. `New in version 1.3.17.` """ @deprecated_api_warning({'deformable_groups': 'deform_groups'}, cls_name='DeformConv2d') def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, Tuple[int, ...]] = 0, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, deform_groups: int = 1, bias: bool = False, im2col_step: int = 32) -> None: super().__init__() assert not bias, \ f'bias={bias} is not supported in DeformConv2d.' assert in_channels % groups == 0, \ f'in_channels {in_channels} cannot be divisible by groups {groups}' assert out_channels % groups == 0, \ f'out_channels {out_channels} cannot be divisible by groups \ {groups}' self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups self.deform_groups = deform_groups self.im2col_step = im2col_step # enable compatibility with nn.Conv2d self.transposed = False self.output_padding = _single(0) # only weight, no bias self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) self.reset_parameters() def reset_parameters(self): # switch the initialization of `self.weight` to the standard kaiming # method described in `Delving deep into rectifiers: Surpassing # human-level performance on ImageNet classification` - He, K. et al. # (2015), using a uniform distribution nn.init.kaiming_uniform_(self.weight, nonlinearity='relu') def forward(self, x: Tensor, offset: Tensor) -> Tensor: """Deformable Convolutional forward function. Args: x (Tensor): Input feature, shape (B, C_in, H_in, W_in) offset (Tensor): Offset for deformable convolution, shape (B, deform_groups*kernel_size[0]*kernel_size[1]*2, H_out, W_out), H_out, W_out are equal to the output's. An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. The spatial arrangement is like: .. code:: text (x0, y0) (x1, y1) (x2, y2) (x3, y3) (x4, y4) (x5, y5) (x6, y6) (x7, y7) (x8, y8) Returns: Tensor: Output of the layer. """ # To fix an assert error in deform_conv_cuda.cpp:128 # input image is smaller than kernel input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) < self.kernel_size[1]) if input_pad: pad_h = max(self.kernel_size[0] - x.size(2), 0) pad_w = max(self.kernel_size[1] - x.size(3), 0) x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0) offset = offset.contiguous() out = deform_conv2d(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deform_groups, False, self.im2col_step) if input_pad: out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() return out def __repr__(self): s = self.__class__.__name__ s += f'(in_channels={self.in_channels},\n' s += f'out_channels={self.out_channels},\n' s += f'kernel_size={self.kernel_size},\n' s += f'stride={self.stride},\n' s += f'padding={self.padding},\n' s += f'dilation={self.dilation},\n' s += f'groups={self.groups},\n' s += f'deform_groups={self.deform_groups},\n' # bias is not supported in DeformConv2d. s += 'bias=False)' return s @MODELS.register_module('DCN') class DeformConv2dPack(DeformConv2d): """A Deformable Conv Encapsulation that acts as normal Conv layers. The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`. The spatial arrangement is like: .. code:: text (x0, y0) (x1, y1) (x2, y2) (x3, y3) (x4, y4) (x5, y5) (x6, y6) (x7, y7) (x8, y8) Args: in_channels (int): Same as nn.Conv2d. out_channels (int): Same as nn.Conv2d. kernel_size (int or tuple[int]): Same as nn.Conv2d. stride (int or tuple[int]): Same as nn.Conv2d. padding (int or tuple[int]): Same as nn.Conv2d. dilation (int or tuple[int]): Same as nn.Conv2d. groups (int): Same as nn.Conv2d. bias (bool or str): If specified as `auto`, it will be decided by the norm_cfg. Bias will be set as True if norm_cfg is None, otherwise False. """ _version = 2 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_offset = nn.Conv2d( self.in_channels, self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), dilation=_pair(self.dilation), bias=True) self.init_offset() def init_offset(self): self.conv_offset.weight.data.zero_() self.conv_offset.bias.data.zero_() def forward(self, x: Tensor) -> Tensor: # type: ignore offset = self.conv_offset(x) return deform_conv2d(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deform_groups, False, self.im2col_step) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if version is None or version < 2: # the key is different in early versions # In version < 2, DeformConvPack loads previous benchmark models. if (prefix + 'conv_offset.weight' not in state_dict and prefix[:-1] + '_offset.weight' in state_dict): state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( prefix[:-1] + '_offset.weight') if (prefix + 'conv_offset.bias' not in state_dict and prefix[:-1] + '_offset.bias' in state_dict): state_dict[prefix + 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + '_offset.bias') if version is not None and version > 1: print_log( f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to ' 'version 2.', logger='current') super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) if IS_MLU_AVAILABLE: import torchvision from mmengine.utils import digit_version from torchvision.ops import deform_conv2d as tv_deform_conv2d @MODELS.register_module('DCN', force=True) class DeformConv2dPack_MLU(DeformConv2d): """This class is the DCN implementation of the MLU device. The MLU backend support of the operator has been implemented in torchvision. The mmcv registration mechanism is used for multiplexing here. The torchvision implementation of DCN is called. Args: in_channels (int): Same as nn.Conv2d. out_channels (int): Same as nn.Conv2d. kernel_size (int or tuple[int]): Same as nn.Conv2d. stride (int): Same as nn.Conv2d, while tuple is not supported. padding (int): Same as nn.Conv2d, while tuple is not supported. dilation (int): Same as nn.Conv2d, while tuple is not supported. groups (int): Same as nn.Conv2d. bias (bool or str): If specified as `auto`, it will be decided by the norm_cfg. Bias will be set as True if norm_cfg is None, otherwise False. im2col_step (int): Number of samples processed by im2col_cuda_kernel per call. It will work when ``batch_size`` > ``im2col_step``, but ``batch_size`` must be divisible by ``im2col_step``. Default: 32. `New in version 1.7.2. Currently not supported on MLU devices.` """ def __init__(self, *args, **kwargs): assert digit_version(torchvision.__version__) >= digit_version( '0.10.0a0'), 'the version of torchvision should be >= 0.10.0' super().__init__(*args, **kwargs) self.conv_offset = nn.Conv2d( self.in_channels, self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), dilation=_pair(self.dilation), bias=True) self.init_offset() def init_offset(self): self.conv_offset.weight.data.zero_() self.conv_offset.bias.data.zero_() def forward(self, x: Tensor) -> Tensor: # type: ignore cur_im2col_step = min(self.im2col_step, x.size(0)) assert (x.size(0) % cur_im2col_step ) == 0, 'batch size must be divisible by im2col_step' offset = self.conv_offset(x) x = x.type_as(offset) weight = self.weight.type_as(x) return tv_deform_conv2d(x, offset, weight, None, self.stride, self.padding, self.dilation)