sapiens-pose / external /cv /mmcv /ops /modulated_deform_conv.py
rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import IS_MLU_AVAILABLE
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext',
['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
class ModulatedDeformConv2dFunction(Function):
@staticmethod
def symbolic(g, input, offset, mask, weight, bias, stride, padding,
dilation, groups, deform_groups):
input_tensors = [input, offset, mask, weight]
if bias is not None:
input_tensors.append(bias)
return g.op(
'mmcv::MMCVModulatedDeformConv2d',
*input_tensors,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
groups_i=groups,
deform_groups_i=deform_groups)
@staticmethod
def _calculate_sort_index(kernel_h, kernel_w, deformable_group):
split_num = deformable_group * 2 * kernel_h * kernel_w
sort_index = list(range(split_num))
sort_index_fp = (sort_index[1::2] + sort_index[::2])
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)}
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
sort_index_fp = torch.IntTensor(sort_index_fp)
sort_index_bp = torch.IntTensor(sort_index_bp)
sort_index_fp = sort_index_fp.npu()
sort_index_bp = sort_index_bp.npu()
return sort_index_fp, sort_index_bp
@staticmethod
def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
_, _, kernel_h, kernel_w = weight.shape
conv2d_bias = bias if len(bias) > 0 else None
sort_index_fp, sort_index_bp = \
ModulatedDeformConv2dFunction._calculate_sort_index(
kernel_w, kernel_h, ctx.deform_groups)
select_offset = offset.index_select(1, sort_index_fp)
offset_all = torch.cat([select_offset, mask], dim=1)
output, offset_out = torch.npu_deformable_conv2d(
input_tensor,
weight,
offset_all,
conv2d_bias,
kernel_size=[kernel_w, kernel_h],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[
ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]
],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups,
deformable_groups=ctx.deform_groups,
modulated=True)
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input_tensor.requires_grad:
ctx.save_for_backward(input_tensor, weight, offset_out, offset_all,
sort_index_bp)
return output
@staticmethod
def _npu_backward(ctx, grad_output):
input_tensor, weight, offset_out, offset_all, sort_index_bp = \
ctx.saved_tensors
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True)
grad_offset = grad_offset_all.index_select(1, sort_index_bp)
grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :]
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None, None, None, None)
@staticmethod
def forward(ctx,
input: torch.Tensor,
offset: torch.Tensor,
mask: torch.Tensor,
weight: nn.Parameter,
bias: Optional[nn.Parameter] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1) -> torch.Tensor:
if input is not None and input.dim() != 4:
raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \
instead.')
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deform_groups = deform_groups
ctx.with_bias = bias is not None
ctx.device = input.device.type
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
# to float16 by nn.Conv2d automatically, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "offset",
# we cast weight and input to temporarily support fp16 and amp
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
bias = bias.type_as(input) # type: ignore
mask = mask.type_as(input)
if ctx.device == 'npu':
output = ModulatedDeformConv2dFunction._npu_forward(
ctx, input, offset, mask, weight, bias)
return output
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty([
int(i) for i in ModulatedDeformConv2dFunction._output_size(
ctx, input, weight)
])
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
ext_module.modulated_deform_conv_forward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
output,
ctx._bufs[1],
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=ctx.stride[0],
stride_w=ctx.stride[1],
pad_h=ctx.padding[0],
pad_w=ctx.padding[1],
dilation_h=ctx.dilation[0],
dilation_w=ctx.dilation[1],
group=ctx.groups,
deformable_group=ctx.deform_groups,
with_bias=ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
if ctx.device == 'npu':
return ModulatedDeformConv2dFunction._npu_backward(
ctx, grad_output)
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
grad_output = grad_output.contiguous()
ext_module.modulated_deform_conv_backward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
ctx._bufs[1],
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=ctx.stride[0],
stride_w=ctx.stride[1],
pad_h=ctx.padding[0],
pad_w=ctx.padding[1],
dilation_h=ctx.dilation[0],
dilation_w=ctx.dilation[1],
group=ctx.groups,
deformable_group=ctx.deform_groups,
with_bias=ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _output_size(ctx, input, weight):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = ctx.padding[d]
kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = ctx.stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
'convolution input is too small (output would be ' +
'x'.join(map(str, output_size)) + ')')
return output_size
modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
class ModulatedDeformConv2d(nn.Module):
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
cls_name='ModulatedDeformConv2d')
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.init_weights()
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x: torch.Tensor, offset: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
@MODELS.register_module('DCNv2')
class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv
layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True)
self.init_weights()
def init_weights(self) -> None:
super().init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
print_log(
f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
'version 2.',
logger='current')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
if IS_MLU_AVAILABLE:
import torchvision
from mmengine.utils import digit_version
from torchvision.ops import deform_conv2d as tv_deform_conv2d
@MODELS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
"""This class is the DCNv2 implementation of the MLU device.
The MLU backend support of the operator has been implemented
in torchvision. The mmcv registration mechanism is used for
multiplexing here. The torchvision implementation of DCNv2 is called.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by
the norm_cfg. Bias will be set as True if norm_cfg is None,
otherwise False.
"""
def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True)
self.init_weights()
def init_weights(self):
super().init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
x = x.type_as(offset)
weight = self.weight.type_as(x)
mask = mask.type_as(x)
return tv_deform_conv2d(
x,
offset,
weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask)