rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import warnings
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmengine.logging import MMLogger
from mmengine.model import (BaseModule, ModuleList, Sequential, constant_init,
normal_init, trunc_normal_init)
from mmengine.model.weight_init import trunc_normal_
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from torch.nn.modules.utils import _pair as to_2tuple
from mmdet.registry import MODELS
from ..layers import PatchEmbed, nchw_to_nlc, nlc_to_nchw
class MixFFN(BaseModule):
"""An implementation of MixFFN of PVT.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
Default: None.
use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
Defaults: False.
init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
dropout_layer=None,
use_conv=False,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
activate = build_activation_layer(act_cfg)
in_channels = embed_dims
fc1 = Conv2d(
in_channels=in_channels,
out_channels=feedforward_channels,
kernel_size=1,
stride=1,
bias=True)
if use_conv:
# 3x3 depth wise conv to provide positional encode information
dw_conv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=(3 - 1) // 2,
bias=True,
groups=feedforward_channels)
fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True)
drop = nn.Dropout(ffn_drop)
layers = [fc1, activate, drop, fc2, drop]
if use_conv:
layers.insert(1, dw_conv)
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
def forward(self, x, hw_shape, identity=None):
out = nlc_to_nchw(x, hw_shape)
out = self.layers(out)
out = nchw_to_nlc(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
class SpatialReductionAttention(MultiheadAttention):
"""An implementation of Spatial Reduction Attention of PVT.
This module is modified from MultiheadAttention which is a module from
mmcv.cnn.bricks.transformer.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut. Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: False.
qkv_bias (bool): enable bias for qkv if True. Default: True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention of PVT. Default: 1.
init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
dropout_layer=None,
batch_first=True,
qkv_bias=True,
norm_cfg=dict(type='LN'),
sr_ratio=1,
init_cfg=None):
super().__init__(
embed_dims,
num_heads,
attn_drop,
proj_drop,
batch_first=batch_first,
dropout_layer=dropout_layer,
bias=qkv_bias,
init_cfg=init_cfg)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=sr_ratio,
stride=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
from mmdet import digit_version, mmcv_version
if mmcv_version < digit_version('1.3.17'):
warnings.warn('The legacy version of forward function in'
'SpatialReductionAttention is deprecated in'
'mmcv>=1.3.17 and will no longer support in the'
'future. Please upgrade your mmcv.')
self.forward = self.legacy_forward
def forward(self, x, hw_shape, identity=None):
x_q = x
if self.sr_ratio > 1:
x_kv = nlc_to_nchw(x, hw_shape)
x_kv = self.sr(x_kv)
x_kv = nchw_to_nlc(x_kv)
x_kv = self.norm(x_kv)
else:
x_kv = x
if identity is None:
identity = x_q
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_queries, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_queries, embed_dims) to num_queries_first
# (num_queries ,batch, embed_dims), and recover ``attn_output``
# from num_queries_first to batch_first.
if self.batch_first:
x_q = x_q.transpose(0, 1)
x_kv = x_kv.transpose(0, 1)
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
def legacy_forward(self, x, hw_shape, identity=None):
"""multi head attention forward in mmcv version < 1.3.17."""
x_q = x
if self.sr_ratio > 1:
x_kv = nlc_to_nchw(x, hw_shape)
x_kv = self.sr(x_kv)
x_kv = nchw_to_nlc(x_kv)
x_kv = self.norm(x_kv)
else:
x_kv = x
if identity is None:
identity = x_q
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
return identity + self.dropout_layer(self.proj_drop(out))
class PVTEncoderLayer(BaseModule):
"""Implements one encoder layer in PVT.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): stochastic depth rate. Default: 0.0.
qkv_bias (bool): enable bias for qkv if True.
Default: True.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
Attention of PVT. Default: 1.
use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
Default: False.
init_cfg (dict, optional): Initialization config dict.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
sr_ratio=1,
use_conv_ffn=False,
init_cfg=None):
super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
# The ret[0] of build_norm_layer is norm name.
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = SpatialReductionAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = MixFFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
use_conv=use_conv_ffn,
act_cfg=act_cfg)
def forward(self, x, hw_shape):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x
class AbsolutePositionEmbedding(BaseModule):
"""An implementation of the absolute position embedding in PVT.
Args:
pos_shape (int): The shape of the absolute position embedding.
pos_dim (int): The dimension of the absolute position embedding.
drop_rate (float): Probability of an element to be zeroed.
Default: 0.0.
"""
def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(pos_shape, int):
pos_shape = to_2tuple(pos_shape)
elif isinstance(pos_shape, tuple):
if len(pos_shape) == 1:
pos_shape = to_2tuple(pos_shape[0])
assert len(pos_shape) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pos_shape)}'
self.pos_shape = pos_shape
self.pos_dim = pos_dim
self.pos_embed = nn.Parameter(
torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
self.drop = nn.Dropout(p=drop_rate)
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
"""Resize pos_embed weights.
Resize pos_embed using bilinear interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shape (tuple): Tuple for (downsampled input image height,
downsampled input image width).
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'bilinear'``.
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C].
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = self.pos_shape
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
pos_embed_weight = F.interpolate(
pos_embed_weight, size=input_shape, mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight,
2).transpose(1, 2).contiguous()
pos_embed = pos_embed_weight
return pos_embed
def forward(self, x, hw_shape, mode='bilinear'):
pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
return self.drop(x + pos_embed)
@MODELS.register_module()
class PyramidVisionTransformer(BaseModule):
"""Pyramid Vision Transformer (PVT)
Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
Dense Prediction without Convolutions
<https://arxiv.org/pdf/2102.12122.pdf>`_.
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): Embedding dimension. Default: 64.
num_stags (int): The num of stages. Default: 4.
num_layers (Sequence[int]): The layer number of each transformer encode
layer. Default: [3, 4, 6, 3].
num_heads (Sequence[int]): The attention heads of each transformer
encode layer. Default: [1, 2, 5, 8].
patch_sizes (Sequence[int]): The patch_size of each patch embedding.
Default: [4, 2, 2, 2].
strides (Sequence[int]): The stride of each patch embedding.
Default: [4, 2, 2, 2].
paddings (Sequence[int]): The padding of each patch embedding.
Default: [0, 0, 0, 0].
sr_ratios (Sequence[int]): The spatial reduction rate of each
transformer encode layer. Default: [8, 4, 2, 1].
out_indices (Sequence[int] | int): Output from which stages.
Default: (0, 1, 2, 3).
mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
embedding dim of each transformer encode layer.
Default: [8, 8, 4, 4].
qkv_bias (bool): Enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: True.
use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
Default: False.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
pretrained (str, optional): model pretrained path. Default: None.
convert_weights (bool): The flag indicates whether the
pre-trained model is from the original repo. We may need
to convert some keys to make it compatible.
Default: True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
pretrain_img_size=224,
in_channels=3,
embed_dims=64,
num_stages=4,
num_layers=[3, 4, 6, 3],
num_heads=[1, 2, 5, 8],
patch_sizes=[4, 2, 2, 2],
strides=[4, 2, 2, 2],
paddings=[0, 0, 0, 0],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=True,
norm_after_stage=False,
use_conv_ffn=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrained=None,
convert_weights=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.convert_weights = convert_weights
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
elif isinstance(pretrain_img_size, tuple):
if len(pretrain_img_size) == 1:
pretrain_img_size = to_2tuple(pretrain_img_size[0])
assert len(pretrain_img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
self.init_cfg = init_cfg
else:
raise TypeError('pretrained must be a str or None')
self.embed_dims = embed_dims
self.num_stages = num_stages
self.num_layers = num_layers
self.num_heads = num_heads
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrained = pretrained
# transformer encoder
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
] # stochastic num_layer decay rule
cur = 0
self.layers = ModuleList()
for i, num_layer in enumerate(num_layers):
embed_dims_i = embed_dims * num_heads[i]
patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims_i,
kernel_size=patch_sizes[i],
stride=strides[i],
padding=paddings[i],
bias=True,
norm_cfg=norm_cfg)
layers = ModuleList()
if use_abs_pos_embed:
pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
pos_embed = AbsolutePositionEmbedding(
pos_shape=pos_shape,
pos_dim=embed_dims_i,
drop_rate=drop_rate)
layers.append(pos_embed)
layers.extend([
PVTEncoderLayer(
embed_dims=embed_dims_i,
num_heads=num_heads[i],
feedforward_channels=mlp_ratios[i] * embed_dims_i,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[cur + idx],
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
sr_ratio=sr_ratios[i],
use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
])
in_channels = embed_dims_i
# The ret[0] of build_norm_layer is norm name.
if norm_after_stage:
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
else:
norm = nn.Identity()
self.layers.append(ModuleList([patch_embed, layers, norm]))
cur += num_layer
def init_weights(self):
logger = MMLogger.get_current_instance()
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(m, 0, math.sqrt(2.0 / fan_out))
elif isinstance(m, AbsolutePositionEmbedding):
m.init_weights()
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
logger.warn(f'Load pre-trained model for '
f'{self.__class__.__name__} from original repo')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if self.convert_weights:
# Because pvt backbones are not supported by mmpretrain,
# so we need to convert pre-trained weights to match this
# implementation.
state_dict = pvt_convert(state_dict)
load_state_dict(self, state_dict, strict=False, logger=logger)
def forward(self, x):
outs = []
for i, layer in enumerate(self.layers):
x, hw_shape = layer[0](x)
for block in layer[1]:
x = block(x, hw_shape)
x = layer[2](x)
x = nlc_to_nchw(x, hw_shape)
if i in self.out_indices:
outs.append(x)
return outs
@MODELS.register_module()
class PyramidVisionTransformerV2(PyramidVisionTransformer):
"""Implementation of `PVTv2: Improved Baselines with Pyramid Vision
Transformer <https://arxiv.org/pdf/2106.13797.pdf>`_."""
def __init__(self, **kwargs):
super(PyramidVisionTransformerV2, self).__init__(
patch_sizes=[7, 3, 3, 3],
paddings=[3, 1, 1, 1],
use_abs_pos_embed=False,
norm_after_stage=True,
use_conv_ffn=True,
**kwargs)
def pvt_convert(ckpt):
new_ckpt = OrderedDict()
# Process the concat between q linear weights and kv linear weights
use_abs_pos_embed = False
use_conv_ffn = False
for k in ckpt.keys():
if k.startswith('pos_embed'):
use_abs_pos_embed = True
if k.find('dwconv') >= 0:
use_conv_ffn = True
for k, v in ckpt.items():
if k.startswith('head'):
continue
if k.startswith('norm.'):
continue
if k.startswith('cls_token'):
continue
if k.startswith('pos_embed'):
stage_i = int(k.replace('pos_embed', ''))
new_k = k.replace(f'pos_embed{stage_i}',
f'layers.{stage_i - 1}.1.0.pos_embed')
if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7
new_v = v[:, 1:, :] # remove cls token
else:
new_v = v
elif k.startswith('patch_embed'):
stage_i = int(k.split('.')[0].replace('patch_embed', ''))
new_k = k.replace(f'patch_embed{stage_i}',
f'layers.{stage_i - 1}.0')
new_v = v
if 'proj.' in new_k:
new_k = new_k.replace('proj.', 'projection.')
elif k.startswith('block'):
stage_i = int(k.split('.')[0].replace('block', ''))
layer_i = int(k.split('.')[1])
new_layer_i = layer_i + use_abs_pos_embed
new_k = k.replace(f'block{stage_i}.{layer_i}',
f'layers.{stage_i - 1}.1.{new_layer_i}')
new_v = v
if 'attn.q.' in new_k:
sub_item_k = k.replace('q.', 'kv.')
new_k = new_k.replace('q.', 'attn.in_proj_')
new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
elif 'attn.kv.' in new_k:
continue
elif 'attn.proj.' in new_k:
new_k = new_k.replace('proj.', 'attn.out_proj.')
elif 'attn.sr.' in new_k:
new_k = new_k.replace('sr.', 'sr.')
elif 'mlp.' in new_k:
string = f'{new_k}-'
new_k = new_k.replace('mlp.', 'ffn.layers.')
if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
new_v = v.reshape((*v.shape, 1, 1))
new_k = new_k.replace('fc1.', '0.')
new_k = new_k.replace('dwconv.dwconv.', '1.')
if use_conv_ffn:
new_k = new_k.replace('fc2.', '4.')
else:
new_k = new_k.replace('fc2.', '3.')
string += f'{new_k} {v.shape}-{new_v.shape}'
elif k.startswith('norm'):
stage_i = int(k[4])
new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2')
new_v = v
else:
new_k = k
new_v = v
new_ckpt[new_k] = new_v
return new_ckpt