# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import warnings from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention from mmengine.logging import MMLogger from mmengine.model import (BaseModule, ModuleList, Sequential, constant_init, normal_init, trunc_normal_init) from mmengine.model.weight_init import trunc_normal_ from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict from torch.nn.modules.utils import _pair as to_2tuple from mmdet.registry import MODELS from ..layers import PatchEmbed, nchw_to_nlc, nlc_to_nchw class MixFFN(BaseModule): """An implementation of MixFFN of PVT. The differences between MixFFN & FFN: 1. Use 1X1 Conv to replace Linear layer. 2. Introduce 3X3 Depth-wise Conv to encode positional information. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. feedforward_channels (int): The hidden dimension of FFNs. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='GELU'). ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. Default: None. use_conv (bool): If True, add 3x3 DWConv between two Linear layers. Defaults: False. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, feedforward_channels, act_cfg=dict(type='GELU'), ffn_drop=0., dropout_layer=None, use_conv=False, init_cfg=None): super(MixFFN, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.act_cfg = act_cfg activate = build_activation_layer(act_cfg) in_channels = embed_dims fc1 = Conv2d( in_channels=in_channels, out_channels=feedforward_channels, kernel_size=1, stride=1, bias=True) if use_conv: # 3x3 depth wise conv to provide positional encode information dw_conv = Conv2d( in_channels=feedforward_channels, out_channels=feedforward_channels, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True, groups=feedforward_channels) fc2 = Conv2d( in_channels=feedforward_channels, out_channels=in_channels, kernel_size=1, stride=1, bias=True) drop = nn.Dropout(ffn_drop) layers = [fc1, activate, drop, fc2, drop] if use_conv: layers.insert(1, dw_conv) self.layers = Sequential(*layers) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() def forward(self, x, hw_shape, identity=None): out = nlc_to_nchw(x, hw_shape) out = self.layers(out) out = nchw_to_nlc(out) if identity is None: identity = x return identity + self.dropout_layer(out) class SpatialReductionAttention(MultiheadAttention): """An implementation of Spatial Reduction Attention of PVT. This module is modified from MultiheadAttention which is a module from mmcv.cnn.bricks.transformer. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default: False. qkv_bias (bool): enable bias for qkv if True. Default: True. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). sr_ratio (int): The ratio of spatial reduction of Spatial Reduction Attention of PVT. Default: 1. init_cfg (obj:`mmengine.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=None, batch_first=True, qkv_bias=True, norm_cfg=dict(type='LN'), sr_ratio=1, init_cfg=None): super().__init__( embed_dims, num_heads, attn_drop, proj_drop, batch_first=batch_first, dropout_layer=dropout_layer, bias=qkv_bias, init_cfg=init_cfg) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=sr_ratio, stride=sr_ratio) # The ret[0] of build_norm_layer is norm name. self.norm = build_norm_layer(norm_cfg, embed_dims)[1] # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa from mmdet import digit_version, mmcv_version if mmcv_version < digit_version('1.3.17'): warnings.warn('The legacy version of forward function in' 'SpatialReductionAttention is deprecated in' 'mmcv>=1.3.17 and will no longer support in the' 'future. Please upgrade your mmcv.') self.forward = self.legacy_forward def forward(self, x, hw_shape, identity=None): x_q = x if self.sr_ratio > 1: x_kv = nlc_to_nchw(x, hw_shape) x_kv = self.sr(x_kv) x_kv = nchw_to_nlc(x_kv) x_kv = self.norm(x_kv) else: x_kv = x if identity is None: identity = x_q # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_queries, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_queries, embed_dims) to num_queries_first # (num_queries ,batch, embed_dims), and recover ``attn_output`` # from num_queries_first to batch_first. if self.batch_first: x_q = x_q.transpose(0, 1) x_kv = x_kv.transpose(0, 1) out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] if self.batch_first: out = out.transpose(0, 1) return identity + self.dropout_layer(self.proj_drop(out)) def legacy_forward(self, x, hw_shape, identity=None): """multi head attention forward in mmcv version < 1.3.17.""" x_q = x if self.sr_ratio > 1: x_kv = nlc_to_nchw(x, hw_shape) x_kv = self.sr(x_kv) x_kv = nchw_to_nlc(x_kv) x_kv = self.norm(x_kv) else: x_kv = x if identity is None: identity = x_q out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] return identity + self.dropout_layer(self.proj_drop(out)) class PVTEncoderLayer(BaseModule): """Implements one encoder layer in PVT. Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. drop_rate (float): Probability of an element to be zeroed. after the feed forward layer. Default: 0.0. attn_drop_rate (float): The drop out rate for attention layer. Default: 0.0. drop_path_rate (float): stochastic depth rate. Default: 0.0. qkv_bias (bool): enable bias for qkv if True. Default: True. act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). sr_ratio (int): The ratio of spatial reduction of Spatial Reduction Attention of PVT. Default: 1. use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. Default: False. init_cfg (dict, optional): Initialization config dict. Default: None. """ def __init__(self, embed_dims, num_heads, feedforward_channels, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., qkv_bias=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), sr_ratio=1, use_conv_ffn=False, init_cfg=None): super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg) # The ret[0] of build_norm_layer is norm name. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = SpatialReductionAttention( embed_dims=embed_dims, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), qkv_bias=qkv_bias, norm_cfg=norm_cfg, sr_ratio=sr_ratio) # The ret[0] of build_norm_layer is norm name. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = MixFFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), use_conv=use_conv_ffn, act_cfg=act_cfg) def forward(self, x, hw_shape): x = self.attn(self.norm1(x), hw_shape, identity=x) x = self.ffn(self.norm2(x), hw_shape, identity=x) return x class AbsolutePositionEmbedding(BaseModule): """An implementation of the absolute position embedding in PVT. Args: pos_shape (int): The shape of the absolute position embedding. pos_dim (int): The dimension of the absolute position embedding. drop_rate (float): Probability of an element to be zeroed. Default: 0.0. """ def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None): super().__init__(init_cfg=init_cfg) if isinstance(pos_shape, int): pos_shape = to_2tuple(pos_shape) elif isinstance(pos_shape, tuple): if len(pos_shape) == 1: pos_shape = to_2tuple(pos_shape[0]) assert len(pos_shape) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(pos_shape)}' self.pos_shape = pos_shape self.pos_dim = pos_dim self.pos_embed = nn.Parameter( torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim)) self.drop = nn.Dropout(p=drop_rate) def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'): """Resize pos_embed weights. Resize pos_embed using bilinear interpolate method. Args: pos_embed (torch.Tensor): Position embedding weights. input_shape (tuple): Tuple for (downsampled input image height, downsampled input image width). mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'bilinear'``. Return: torch.Tensor: The resized pos_embed of shape [B, L_new, C]. """ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' pos_h, pos_w = self.pos_shape pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] pos_embed_weight = pos_embed_weight.reshape( 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous() pos_embed_weight = F.interpolate( pos_embed_weight, size=input_shape, mode=mode) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2).contiguous() pos_embed = pos_embed_weight return pos_embed def forward(self, x, hw_shape, mode='bilinear'): pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode) return self.drop(x + pos_embed) @MODELS.register_module() class PyramidVisionTransformer(BaseModule): """Pyramid Vision Transformer (PVT) Implementation of `Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions `_. Args: pretrain_img_size (int | tuple[int]): The size of input image when pretrain. Defaults: 224. in_channels (int): Number of input channels. Default: 3. embed_dims (int): Embedding dimension. Default: 64. num_stags (int): The num of stages. Default: 4. num_layers (Sequence[int]): The layer number of each transformer encode layer. Default: [3, 4, 6, 3]. num_heads (Sequence[int]): The attention heads of each transformer encode layer. Default: [1, 2, 5, 8]. patch_sizes (Sequence[int]): The patch_size of each patch embedding. Default: [4, 2, 2, 2]. strides (Sequence[int]): The stride of each patch embedding. Default: [4, 2, 2, 2]. paddings (Sequence[int]): The padding of each patch embedding. Default: [0, 0, 0, 0]. sr_ratios (Sequence[int]): The spatial reduction rate of each transformer encode layer. Default: [8, 4, 2, 1]. out_indices (Sequence[int] | int): Output from which stages. Default: (0, 1, 2, 3). mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the embedding dim of each transformer encode layer. Default: [8, 8, 4, 4]. qkv_bias (bool): Enable bias for qkv if True. Default: True. drop_rate (float): Probability of an element to be zeroed. Default 0.0. attn_drop_rate (float): The drop out rate for attention layer. Default 0.0. drop_path_rate (float): stochastic depth rate. Default 0.1. use_abs_pos_embed (bool): If True, add absolute position embedding to the patch embedding. Defaults: True. use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN. Default: False. act_cfg (dict): The activation config for FFNs. Default: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). pretrained (str, optional): model pretrained path. Default: None. convert_weights (bool): The flag indicates whether the pre-trained model is from the original repo. We may need to convert some keys to make it compatible. Default: True. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, pretrain_img_size=224, in_channels=3, embed_dims=64, num_stages=4, num_layers=[3, 4, 6, 3], num_heads=[1, 2, 5, 8], patch_sizes=[4, 2, 2, 2], strides=[4, 2, 2, 2], paddings=[0, 0, 0, 0], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3), mlp_ratios=[8, 8, 4, 4], qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, use_abs_pos_embed=True, norm_after_stage=False, use_conv_ffn=False, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN', eps=1e-6), pretrained=None, convert_weights=True, init_cfg=None): super().__init__(init_cfg=init_cfg) self.convert_weights = convert_weights if isinstance(pretrain_img_size, int): pretrain_img_size = to_2tuple(pretrain_img_size) elif isinstance(pretrain_img_size, tuple): if len(pretrain_img_size) == 1: pretrain_img_size = to_2tuple(pretrain_img_size[0]) assert len(pretrain_img_size) == 2, \ f'The size of image should have length 1 or 2, ' \ f'but got {len(pretrain_img_size)}' assert not (init_cfg and pretrained), \ 'init_cfg and pretrained cannot be setting at the same time' if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) elif pretrained is None: self.init_cfg = init_cfg else: raise TypeError('pretrained must be a str or None') self.embed_dims = embed_dims self.num_stages = num_stages self.num_layers = num_layers self.num_heads = num_heads self.patch_sizes = patch_sizes self.strides = strides self.sr_ratios = sr_ratios assert num_stages == len(num_layers) == len(num_heads) \ == len(patch_sizes) == len(strides) == len(sr_ratios) self.out_indices = out_indices assert max(out_indices) < self.num_stages self.pretrained = pretrained # transformer encoder dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers)) ] # stochastic num_layer decay rule cur = 0 self.layers = ModuleList() for i, num_layer in enumerate(num_layers): embed_dims_i = embed_dims * num_heads[i] patch_embed = PatchEmbed( in_channels=in_channels, embed_dims=embed_dims_i, kernel_size=patch_sizes[i], stride=strides[i], padding=paddings[i], bias=True, norm_cfg=norm_cfg) layers = ModuleList() if use_abs_pos_embed: pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1]) pos_embed = AbsolutePositionEmbedding( pos_shape=pos_shape, pos_dim=embed_dims_i, drop_rate=drop_rate) layers.append(pos_embed) layers.extend([ PVTEncoderLayer( embed_dims=embed_dims_i, num_heads=num_heads[i], feedforward_channels=mlp_ratios[i] * embed_dims_i, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[cur + idx], qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, sr_ratio=sr_ratios[i], use_conv_ffn=use_conv_ffn) for idx in range(num_layer) ]) in_channels = embed_dims_i # The ret[0] of build_norm_layer is norm name. if norm_after_stage: norm = build_norm_layer(norm_cfg, embed_dims_i)[1] else: norm = nn.Identity() self.layers.append(ModuleList([patch_embed, layers, norm])) cur += num_layer def init_weights(self): logger = MMLogger.get_current_instance() if self.init_cfg is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m, 0, math.sqrt(2.0 / fan_out)) elif isinstance(m, AbsolutePositionEmbedding): m.init_weights() else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' checkpoint = CheckpointLoader.load_checkpoint( self.init_cfg.checkpoint, logger=logger, map_location='cpu') logger.warn(f'Load pre-trained model for ' f'{self.__class__.__name__} from original repo') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint if self.convert_weights: # Because pvt backbones are not supported by mmpretrain, # so we need to convert pre-trained weights to match this # implementation. state_dict = pvt_convert(state_dict) load_state_dict(self, state_dict, strict=False, logger=logger) def forward(self, x): outs = [] for i, layer in enumerate(self.layers): x, hw_shape = layer[0](x) for block in layer[1]: x = block(x, hw_shape) x = layer[2](x) x = nlc_to_nchw(x, hw_shape) if i in self.out_indices: outs.append(x) return outs @MODELS.register_module() class PyramidVisionTransformerV2(PyramidVisionTransformer): """Implementation of `PVTv2: Improved Baselines with Pyramid Vision Transformer `_.""" def __init__(self, **kwargs): super(PyramidVisionTransformerV2, self).__init__( patch_sizes=[7, 3, 3, 3], paddings=[3, 1, 1, 1], use_abs_pos_embed=False, norm_after_stage=True, use_conv_ffn=True, **kwargs) def pvt_convert(ckpt): new_ckpt = OrderedDict() # Process the concat between q linear weights and kv linear weights use_abs_pos_embed = False use_conv_ffn = False for k in ckpt.keys(): if k.startswith('pos_embed'): use_abs_pos_embed = True if k.find('dwconv') >= 0: use_conv_ffn = True for k, v in ckpt.items(): if k.startswith('head'): continue if k.startswith('norm.'): continue if k.startswith('cls_token'): continue if k.startswith('pos_embed'): stage_i = int(k.replace('pos_embed', '')) new_k = k.replace(f'pos_embed{stage_i}', f'layers.{stage_i - 1}.1.0.pos_embed') if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7 new_v = v[:, 1:, :] # remove cls token else: new_v = v elif k.startswith('patch_embed'): stage_i = int(k.split('.')[0].replace('patch_embed', '')) new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i - 1}.0') new_v = v if 'proj.' in new_k: new_k = new_k.replace('proj.', 'projection.') elif k.startswith('block'): stage_i = int(k.split('.')[0].replace('block', '')) layer_i = int(k.split('.')[1]) new_layer_i = layer_i + use_abs_pos_embed new_k = k.replace(f'block{stage_i}.{layer_i}', f'layers.{stage_i - 1}.1.{new_layer_i}') new_v = v if 'attn.q.' in new_k: sub_item_k = k.replace('q.', 'kv.') new_k = new_k.replace('q.', 'attn.in_proj_') new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) elif 'attn.kv.' in new_k: continue elif 'attn.proj.' in new_k: new_k = new_k.replace('proj.', 'attn.out_proj.') elif 'attn.sr.' in new_k: new_k = new_k.replace('sr.', 'sr.') elif 'mlp.' in new_k: string = f'{new_k}-' new_k = new_k.replace('mlp.', 'ffn.layers.') if 'fc1.weight' in new_k or 'fc2.weight' in new_k: new_v = v.reshape((*v.shape, 1, 1)) new_k = new_k.replace('fc1.', '0.') new_k = new_k.replace('dwconv.dwconv.', '1.') if use_conv_ffn: new_k = new_k.replace('fc2.', '4.') else: new_k = new_k.replace('fc2.', '3.') string += f'{new_k} {v.shape}-{new_v.shape}' elif k.startswith('norm'): stage_i = int(k[4]) new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2') new_v = v else: new_k = k new_v = v new_ckpt[new_k] = new_v return new_ckpt