Spaces:
Running
Running
# UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition | |
# Github source: https://github.com/AILab-CVC/UniRepLKNet | |
# Licensed under The Apache License 2.0 License [see LICENSE for details] | |
# Based on RepLKNet, ConvNeXt, timm, DINO and DeiT code bases | |
# https://github.com/DingXiaoH/RepLKNet-pytorch | |
# https://github.com/facebookresearch/ConvNeXt | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit/ | |
# https://github.com/facebookresearch/dino | |
# --------------------------------------------------------' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from timm.models.layers import trunc_normal_, DropPath, to_2tuple | |
from timm.models.registry import register_model | |
from functools import partial | |
import torch.utils.checkpoint as checkpoint | |
try: | |
from huggingface_hub import hf_hub_download | |
except: | |
hf_hub_download = None # install huggingface_hub if you would like to download models conveniently from huggingface | |
has_mmdet = False | |
has_mmseg = False | |
# =============== for the ease of directly using this file in MMSegmentation and MMDetection. | |
# =============== ignore the following two segments of code if you do not plan to do so | |
# =============== delete one of the following two segments if you get a confliction | |
try: | |
from mmseg.models.builder import BACKBONES as seg_BACKBONES | |
from mmseg.utils import get_root_logger | |
from mmcv.runner import _load_checkpoint | |
has_mmseg = True | |
except ImportError: | |
get_root_logger = None | |
_load_checkpoint = None | |
# try: | |
# from mmdet.models.builder import BACKBONES as det_BACKBONES | |
# from mmdet.utils import get_root_logger | |
# from mmcv.runner import _load_checkpoint | |
# has_mmdet = True | |
# except ImportError: | |
# get_root_logger = None | |
# _load_checkpoint = None | |
# =========================================================================================== | |
class GRNwithNHWC(nn.Module): | |
""" GRN (Global Response Normalization) layer 全局响应归一化层 作用是对输入数据进行归一化处理 | |
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808) | |
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2) | |
We assume the inputs to this layer are (N, H, W, C) 我们假设该层的输入为(N,H,W,C) | |
""" | |
def __init__(self, dim, use_bias=True): | |
super().__init__() | |
self.use_bias = use_bias | |
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
if self.use_bias: | |
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
def forward(self, x): | |
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) #形状为(N,1,1,C) | |
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) | |
if self.use_bias: | |
return (self.gamma * Nx + 1) * x + self.beta | |
else: | |
return (self.gamma * Nx + 1) * x | |
class NCHWtoNHWC(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.permute(0, 2, 3, 1) | |
class NHWCtoNCHW(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return x.permute(0, 3, 1, 2) | |
#================== This function decides which conv implementation (the native or iGEMM) to use | |
# Note that iGEMM large-kernel conv impl will be used if | |
# - you attempt to do so (attempt_to_use_large_impl=True), and | |
# - it has been installed (follow https://github.com/AILab-CVC/UniRepLKNet), and | |
# - the conv layer is depth-wise, stride = 1, non-dilated, kernel_size > 5, and padding == kernel_size // 2 | |
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, | |
attempt_use_lk_impl=True): | |
kernel_size = to_2tuple(kernel_size) | |
if padding is None: | |
padding = (kernel_size[0] // 2, kernel_size[1] // 2) | |
else: | |
padding = to_2tuple(padding) | |
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2) | |
if attempt_use_lk_impl and need_large_impl: | |
print('---------------- trying to import iGEMM implementation for large-kernel conv') | |
try: | |
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM | |
print('---------------- found iGEMM implementation ') | |
except: | |
DepthWiseConv2dImplicitGEMM = None | |
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.') | |
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \ | |
and out_channels == groups and stride == 1 and dilation == 1: | |
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====') | |
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias) | |
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, | |
padding=padding, dilation=dilation, groups=groups, bias=bias) | |
def get_bn(dim, use_sync_bn=False): | |
if use_sync_bn: | |
return nn.SyncBatchNorm(dim) | |
else: | |
return nn.BatchNorm2d(dim) | |
class SEBlock(nn.Module): | |
""" | |
Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507) | |
We assume the inputs to this layer are (N, C, H, W) | |
""" | |
def __init__(self, input_channels, internal_neurons): | |
super(SEBlock, self).__init__() | |
self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, | |
kernel_size=1, stride=1, bias=True) | |
self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, | |
kernel_size=1, stride=1, bias=True) | |
self.input_channels = input_channels | |
self.nonlinear = nn.ReLU(inplace=True) | |
def forward(self, inputs): | |
x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1)) | |
x = self.down(x) | |
x = self.nonlinear(x) | |
x = self.up(x) | |
x = F.sigmoid(x) | |
return inputs * x.view(-1, self.input_channels, 1, 1) | |
def fuse_bn(conv, bn): | |
conv_bias = 0 if conv.bias is None else conv.bias | |
std = (bn.running_var + bn.eps).sqrt() | |
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std | |
def convert_dilated_to_nondilated(kernel, dilate_rate): | |
'''将膨胀卷积核转换为非膨胀卷积核''' | |
identity_kernel = torch.ones((1, 1, 1, 1)) | |
if kernel.size(1) == 1: | |
# This is a DW kernel | |
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate) | |
return dilated | |
else: | |
# This is a dense or group-wise (but not DW) kernel | |
slices = [] | |
for i in range(kernel.size(1)): | |
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate) | |
slices.append(dilated) | |
return torch.cat(slices, dim=1) | |
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r): | |
large_k = large_kernel.size(2) | |
dilated_k = dilated_kernel.size(2) | |
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1 | |
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r) | |
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2 | |
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4) | |
return merged_kernel | |
class DilatedReparamBlock(nn.Module): | |
""" | |
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet) | |
We assume the inputs to this block are (N, C, H, W) | |
""" | |
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True): | |
super().__init__() | |
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1, | |
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy, | |
attempt_use_lk_impl=attempt_use_lk_impl) | |
self.attempt_use_lk_impl = attempt_use_lk_impl | |
# Default settings. We did not tune them carefully. Different settings may work better. | |
if kernel_size == 17: | |
self.kernel_sizes = [5, 9, 3, 3, 3] | |
self.dilates = [1, 2, 4, 5, 7] | |
elif kernel_size == 15: | |
self.kernel_sizes = [5, 7, 3, 3, 3] | |
self.dilates = [1, 2, 3, 5, 7] | |
elif kernel_size == 13: | |
self.kernel_sizes = [5, 7, 3, 3, 3] | |
self.dilates = [1, 2, 3, 4, 5] | |
elif kernel_size == 11: | |
self.kernel_sizes = [5, 5, 3, 3, 3] | |
self.dilates = [1, 2, 3, 4, 5] | |
elif kernel_size == 9: | |
self.kernel_sizes = [5, 5, 3, 3] | |
self.dilates = [1, 2, 3, 4] | |
elif kernel_size == 7: | |
self.kernel_sizes = [5, 3, 3] | |
self.dilates = [1, 2, 3] | |
elif kernel_size == 5: | |
self.kernel_sizes = [3, 3] | |
self.dilates = [1, 2] | |
else: | |
raise ValueError('Dilated Reparam Block requires kernel_size >= 5') | |
if not deploy: | |
self.origin_bn = get_bn(channels, use_sync_bn) | |
for k, r in zip(self.kernel_sizes, self.dilates): | |
self.__setattr__('dil_conv_k{}_{}'.format(k, r), | |
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1, | |
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels, | |
bias=False)) | |
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn)) | |
def forward(self, x): | |
if not hasattr(self, 'origin_bn'): # deploy mode | |
return self.lk_origin(x) | |
out = self.origin_bn(self.lk_origin(x)) | |
for k, r in zip(self.kernel_sizes, self.dilates): | |
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r)) | |
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r)) | |
out = out + bn(conv(x)) | |
return out | |
def merge_dilated_branches(self): | |
if hasattr(self, 'origin_bn'): | |
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn) | |
for k, r in zip(self.kernel_sizes, self.dilates): | |
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r)) | |
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r)) | |
branch_k, branch_b = fuse_bn(conv, bn) | |
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r) | |
origin_b += branch_b | |
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1, | |
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True, | |
attempt_use_lk_impl=self.attempt_use_lk_impl) | |
merged_conv.weight.data = origin_k | |
merged_conv.bias.data = origin_b | |
self.lk_origin = merged_conv | |
self.__delattr__('origin_bn') | |
for k, r in zip(self.kernel_sizes, self.dilates): | |
self.__delattr__('dil_conv_k{}_{}'.format(k, r)) | |
self.__delattr__('dil_bn_k{}_{}'.format(k, r)) | |
class UniRepLKNetBlock(nn.Module): | |
def __init__(self, | |
dim, | |
kernel_size, | |
drop_path=0., | |
layer_scale_init_value=1e-6, | |
deploy=False, | |
attempt_use_lk_impl=True, | |
with_cp=False, | |
use_sync_bn=False, | |
ffn_factor=4): | |
super().__init__() | |
self.with_cp = with_cp | |
if deploy: | |
print('------------------------------- Note: deploy mode') | |
if self.with_cp: | |
print('****** note with_cp = True, reduce memory consumption but may slow down training ******') | |
if kernel_size == 0: | |
self.dwconv = nn.Identity() | |
elif kernel_size >= 7: | |
self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy, | |
use_sync_bn=use_sync_bn, | |
attempt_use_lk_impl=attempt_use_lk_impl) | |
else: | |
assert kernel_size in [3, 5] | |
self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, | |
dilation=1, groups=dim, bias=deploy, | |
attempt_use_lk_impl=attempt_use_lk_impl) | |
if deploy or kernel_size == 0: | |
self.norm = nn.Identity() | |
else: | |
self.norm = get_bn(dim, use_sync_bn=use_sync_bn) | |
self.se = SEBlock(dim, dim // 4) | |
ffn_dim = int(ffn_factor * dim) | |
self.pwconv1 = nn.Sequential( | |
NCHWtoNHWC(), | |
nn.Linear(dim, ffn_dim)) | |
self.act = nn.Sequential( | |
nn.GELU(), | |
GRNwithNHWC(ffn_dim, use_bias=not deploy)) | |
if deploy: | |
self.pwconv2 = nn.Sequential( | |
nn.Linear(ffn_dim, dim), | |
NHWCtoNCHW()) | |
else: | |
self.pwconv2 = nn.Sequential( | |
nn.Linear(ffn_dim, dim, bias=False), | |
NHWCtoNCHW(), | |
get_bn(dim, use_sync_bn=use_sync_bn)) | |
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), | |
requires_grad=True) if (not deploy) and layer_scale_init_value is not None \ | |
and layer_scale_init_value > 0 else None | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
def compute_residual(self, x): | |
y = self.se(self.norm(self.dwconv(x))) | |
y = self.pwconv2(self.act(self.pwconv1(y))) | |
if self.gamma is not None: | |
y = self.gamma.view(1, -1, 1, 1) * y | |
return self.drop_path(y) | |
def forward(self, inputs): | |
def _f(x): | |
return x + self.compute_residual(x) | |
if self.with_cp and inputs.requires_grad: | |
out = checkpoint.checkpoint(_f, inputs) | |
else: | |
out = _f(inputs) | |
return out | |
def reparameterize(self): | |
if hasattr(self.dwconv, 'merge_dilated_branches'): | |
self.dwconv.merge_dilated_branches() | |
if hasattr(self.norm, 'running_var'): | |
std = (self.norm.running_var + self.norm.eps).sqrt() | |
if hasattr(self.dwconv, 'lk_origin'): | |
self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1) | |
self.dwconv.lk_origin.bias.data = self.norm.bias + ( | |
self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std | |
else: | |
conv = nn.Conv2d(self.dwconv.in_channels, self.dwconv.out_channels, self.dwconv.kernel_size, | |
self.dwconv.padding, self.dwconv.groups, bias=True) | |
conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view(-1, 1, 1, 1) | |
conv.bias.data = self.norm.bias - self.norm.running_mean * self.norm.weight / std | |
self.dwconv = conv | |
self.norm = nn.Identity() | |
if self.gamma is not None: | |
final_scale = self.gamma.data | |
self.gamma = None | |
else: | |
final_scale = 1 | |
if self.act[1].use_bias and len(self.pwconv2) == 3: | |
grn_bias = self.act[1].beta.data | |
self.act[1].__delattr__('beta') | |
self.act[1].use_bias = False | |
linear = self.pwconv2[0] | |
grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze() | |
bn = self.pwconv2[2] | |
std = (bn.running_var + bn.eps).sqrt() | |
new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True) | |
new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1) | |
linear_bias = 0 if linear.bias is None else linear.bias.data | |
linear_bias += grn_bias_projected_bias | |
new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale | |
self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1]) | |
default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3), | |
(13, 13), | |
(13, 13, 13, 13, 13, 13), | |
(13, 13)) | |
default_UniRepLKNet_N_kernel_sizes = ((3, 3), | |
(13, 13), | |
(13, 13, 13, 13, 13, 13, 13, 13), | |
(13, 13)) | |
default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3), | |
(13, 13, 13), | |
(13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3), | |
(13, 13, 13)) | |
default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3), | |
(13, 13, 13), | |
(13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3), | |
(13, 13, 13)) | |
UniRepLKNet_A_F_P_depths = (2, 2, 6, 2) | |
UniRepLKNet_N_depths = (2, 2, 8, 2) | |
UniRepLKNet_T_depths = (3, 3, 18, 3) | |
UniRepLKNet_S_B_L_XL_depths = (3, 3, 27, 3) | |
default_depths_to_kernel_sizes = { | |
UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes, | |
UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes, | |
UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes, | |
UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes | |
} | |
class UniRepLKNet(nn.Module): | |
r""" UniRepLKNet | |
A PyTorch impl of UniRepLKNet | |
Args: | |
in_chans (int): Number of input image channels. Default: 3 | |
num_classes (int): Number of classes for classification head. Default: 1000 | |
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 27, 3) | |
dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768) | |
drop_path_rate (float): Stochastic depth rate. Default: 0. | |
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | |
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
kernel_sizes (tuple(tuple(int))): Kernel size for each block. None means using the default settings. Default: None. | |
deploy (bool): deploy = True means using the inference structure. Default: False | |
with_cp (bool): with_cp = True means using torch.utils.checkpoint to save GPU memory. Default: False | |
init_cfg (dict): weights to load. The easiest way to use UniRepLKNet with for OpenMMLab family. Default: None | |
attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True | |
use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False | |
""" | |
def __init__(self, | |
in_chans=3, | |
num_classes=None, | |
depths=(3, 3, 27, 3), | |
dims=(96, 192, 384, 768), | |
drop_path_rate=0.3, | |
layer_scale_init_value=1e-6, | |
head_init_scale=1., | |
kernel_sizes=None, | |
deploy=False, | |
with_cp=False, | |
init_cfg=dict(type='uniform', a=0, b=1), | |
attempt_use_lk_impl=True, | |
use_sync_bn=False, | |
num_nuclei_classes=int, | |
num_tissue_classes=int, | |
**kwargs | |
): | |
super().__init__() | |
depths = tuple(depths) | |
if kernel_sizes is None: | |
if depths in default_depths_to_kernel_sizes: | |
print('=========== use default kernel size ') | |
kernel_sizes = default_depths_to_kernel_sizes[depths] | |
else: | |
raise ValueError('no default kernel size settings for the given depths, ' | |
'please specify kernel sizes for each block, e.g., ' | |
'((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))') | |
print(kernel_sizes) | |
for i in range(4): | |
assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths' | |
self.with_cp = with_cp | |
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
print('=========== drop path rates: ', dp_rates) | |
self.downsample_layers = nn.ModuleList() | |
self.downsample_layers.append(nn.Sequential( | |
nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1), | |
LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"), | |
nn.GELU(), | |
nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1), | |
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))) | |
for i in range(3): | |
if i ==3: | |
self.downsample_layers.append(nn.Sequential( | |
nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=1, padding=1), | |
LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first") | |
)) | |
else: | |
self.downsample_layers.append(nn.Sequential( | |
nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1), | |
LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first"))) | |
self.stages = nn.ModuleList() | |
cur = 0 | |
for i in range(4): | |
main_stage = nn.Sequential( | |
*[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j], | |
layer_scale_init_value=layer_scale_init_value, deploy=deploy, | |
attempt_use_lk_impl=attempt_use_lk_impl, | |
with_cp=with_cp, use_sync_bn=use_sync_bn) for j in | |
range(depths[i])]) | |
self.stages.append(main_stage) | |
cur += depths[i] | |
last_channels = dims[-1] | |
print(last_channels) | |
self.for_pretrain = init_cfg is None | |
self.for_downstream = not self.for_pretrain | |
if self.for_downstream: | |
assert num_classes is None | |
self.output_mode = 'features' | |
self.norm = nn.LayerNorm(last_channels, eps=1e-6) | |
self.conv = nn.Conv2d(in_chans, dims[0] // 4, kernel_size=3, stride=1, padding=1) | |
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") | |
for i_layer in range(4): | |
layer = norm_layer(dims[i_layer]) | |
layer_name = f'norm{i_layer}' | |
self.add_module(layer_name, layer) | |
self.head = nn.Linear(last_channels,19) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv2d, nn.Linear)): | |
trunc_normal_(m.weight, std=.02) | |
if hasattr(m, 'bias') and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
if self.output_mode == 'logits': | |
for stage_idx in range(4): | |
x = self.downsample_layers[stage_idx](x) | |
x = self.stages[stage_idx](x) | |
x = self.norm(x.mean([-2, -1])) #全局平均池化 | |
x = self.head(x) #全连接层 | |
return x | |
elif self.output_mode == 'features': | |
outs = [] | |
input_feature = [] | |
input_feature.append(self.conv(x)) | |
input_feature.append(self.downsample_layers[0][0](x)) | |
for stage_idx in range(4): | |
x = self.downsample_layers[stage_idx](x) | |
x = self.stages[stage_idx](x) | |
outs.append(self.__getattr__(f'norm{stage_idx}')(x)) | |
logits = self.norm(x.mean([-2, -1])) | |
logits = self.head(logits) | |
return logits, outs, input_feature | |
else: | |
raise ValueError('Defined new output mode?') | |
def reparameterize_unireplknet(self): | |
for m in self.modules(): | |
if hasattr(m, 'reparameterize'): | |
m.reparameterize() | |
class LayerNorm(nn.Module): | |
r""" LayerNorm implementation used in ConvNeXt | |
LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
with shape (batch_size, channels, height, width).LayerNorm | |
""" | |
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.data_format = data_format | |
if self.data_format not in ["channels_last", "channels_first"]: | |
raise NotImplementedError | |
self.normalized_shape = (normalized_shape,) | |
self.reshape_last_to_first = reshape_last_to_first | |
def forward(self, x): | |
if self.data_format == "channels_last": | |
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
elif self.data_format == "channels_first": | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
# For easy use as backbone in MMDetection framework. Ignore these lines if you do not use MMDetection | |
if has_mmdet: | |
class UniRepLKNetBackbone(UniRepLKNet): | |
def __init__(self, | |
depths=(3, 3, 27, 3), | |
dims=(96, 192, 384, 768), | |
drop_path_rate=0., | |
layer_scale_init_value=1e-6, | |
kernel_sizes=None, | |
deploy=False, | |
with_cp=False, | |
init_cfg=None, | |
attempt_use_lk_impl=False): | |
assert init_cfg is not None | |
super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims, | |
drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, | |
kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp, | |
init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True) | |
# For easy use as backbone in MMSegmentation framework. Ignore these lines if you do not use MMSegmentation | |
if has_mmseg: | |
class UniRepLKNetBackbone(UniRepLKNet): | |
def __init__(self, | |
depths=(3, 3, 27, 3), | |
dims=(96, 192, 384, 768), | |
drop_path_rate=0., | |
layer_scale_init_value=1e-6, | |
kernel_sizes=None, | |
deploy=False, | |
with_cp=False, | |
init_cfg=None, | |
attempt_use_lk_impl=False): | |
assert init_cfg is not None | |
super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims, | |
drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, | |
kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp, | |
init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True) | |
model_urls = { | |
#TODO: it seems that google drive does not support direct downloading with url? so where to upload the checkpoints other than huggingface? any suggestions? | |
} | |
huggingface_file_names = { | |
"unireplknet_t_1k": "unireplknet_t_in1k_224_acc83.21.pth", | |
"unireplknet_s_1k": "unireplknet_s_in1k_224_acc83.91.pth", | |
} | |
def load_with_key(model, key): | |
# if huggingface hub is found, download from our huggingface repo | |
if hf_hub_download is not None: | |
repo_id = 'DingXiaoH/UniRepLKNet' | |
cache_file = hf_hub_download(repo_id=repo_id, filename=huggingface_file_names[key]) | |
checkpoint = torch.load(cache_file, map_location='cpu') | |
else: | |
checkpoint = torch.hub.load_state_dict_from_url(url=model_urls[key], map_location="cpu", check_hash=True) | |
if 'model' in checkpoint: | |
checkpoint = checkpoint['model'] | |
model.load_state_dict(checkpoint) | |
def initialize_with_pretrained(model, model_name, in_1k_pretrained): | |
if in_1k_pretrained: | |
key = model_name + '_1k' | |
else: | |
key = None | |
if key: | |
load_with_key(model, key) | |
def unireplknet_p(in_1k_pretrained=False, **kwargs): | |
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs) | |
initialize_with_pretrained(model, 'unireplknet_p', in_1k_pretrained, False, False) | |
return model | |
def unireplknet_n(in_1k_pretrained=False, **kwargs): | |
model = UniRepLKNet(depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs) | |
initialize_with_pretrained(model, 'unireplknet_n', in_1k_pretrained, False, False) | |
return model | |
def unireplknet_t(in_1k_pretrained=False, **kwargs): | |
model = UniRepLKNet(depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs) | |
initialize_with_pretrained(model, 'unireplknet_t', in_1k_pretrained, False, False) | |
return model | |
def unireplknet_s(in_1k_pretrained=False, in_22k_pretrained=False, in_22k_to_1k=False, **kwargs): | |
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs) | |
initialize_with_pretrained(model, 'unireplknet_s', in_1k_pretrained, in_22k_pretrained, in_22k_to_1k) | |
return model | |
if __name__ == '__main__': | |
# Test case showing the equivalency of Structural Re-parameterization | |
x = torch.randn(2, 4, 19, 19) | |
layer = UniRepLKNetBlock(4, kernel_size=13, attempt_use_lk_impl=False) | |
for n, p in layer.named_parameters(): | |
if 'beta' in n: | |
torch.nn.init.ones_(p) | |
else: | |
torch.nn.init.normal_(p) | |
for n, p in layer.named_buffers(): | |
if 'running_var' in n: | |
print('random init var') | |
torch.nn.init.uniform_(p) | |
p.data += 2 | |
elif 'running_mean' in n: | |
print('random init mean') | |
torch.nn.init.uniform_(p) | |
layer.gamma.data += 0.5 | |
layer.eval() | |
origin_y = layer(x) | |
layer.reparameterize() | |
eq_y = layer(x) | |
print(layer) | |
print(eq_y - origin_y) | |
print((eq_y - origin_y).abs().sum() / origin_y.abs().sum()) |