cheng-hust's picture
Upload 91 files
e8861c0 verified
raw
history blame
6.72 kB
'''by lyuwenyu
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .common import get_activation, ConvNormLayer, FrozenBatchNorm2d
from src.core import register
__all__ = ['PResNet']
ResNet_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
# 152: [3, 8, 36, 3],
}
donwload_url = {
18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth',
34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth',
50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth',
101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth',
}
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
super().__init__()
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential(OrderedDict([
('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
('conv', ConvNormLayer(ch_in, ch_out, 1, 1))
]))
else:
self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
out = self.branch2a(x)
out = self.branch2b(out)
if self.shortcut:
short = x
else:
short = self.short(x)
out = out + short
out = self.act(out)
return out
class BottleNeck(nn.Module):
expansion = 4
def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
super().__init__()
if variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
width = ch_out
self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential(OrderedDict([
('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))
]))
else:
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
out = self.branch2a(x)
out = self.branch2b(out)
out = self.branch2c(out)
if self.shortcut:
short = x
else:
short = self.short(x)
out = out + short
out = self.act(out)
return out
class Blocks(nn.Module):
def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(count):
self.blocks.append(
block(
ch_in,
ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
shortcut=False if i == 0 else True,
variant=variant,
act=act)
)
if i == 0:
ch_in = ch_out * block.expansion
def forward(self, x):
out = x
for block in self.blocks:
out = block(out)
return out
@register
class PResNet(nn.Module):
def __init__(
self,
depth,
variant='d',
num_stages=4,
return_idx=[0, 1, 2, 3],
act='relu',
freeze_at=-1,
freeze_norm=True,
pretrained=False):
super().__init__()
block_nums = ResNet_cfg[depth]
ch_in = 64
if variant in ['c', 'd']:
conv_def = [
[3, ch_in // 2, 3, 2, "conv1_1"],
[ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
[ch_in // 2, ch_in, 3, 1, "conv1_3"],
]
else:
conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
self.conv1 = nn.Sequential(OrderedDict([
(_name, ConvNormLayer(c_in, c_out, k, s, act=act)) for c_in, c_out, k, s, _name in conv_def
]))
ch_out_list = [64, 128, 256, 512]
block = BottleNeck if depth >= 50 else BasicBlock
_out_channels = [block.expansion * v for v in ch_out_list]
_out_strides = [4, 8, 16, 32]
self.res_layers = nn.ModuleList()
for i in range(num_stages):
stage_num = i + 2
self.res_layers.append(
Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant)
)
ch_in = _out_channels[i]
self.return_idx = return_idx
self.out_channels = [_out_channels[_i] for _i in return_idx]
self.out_strides = [_out_strides[_i] for _i in return_idx]
if freeze_at >= 0:
self._freeze_parameters(self.conv1)
for i in range(min(freeze_at, num_stages)):
self._freeze_parameters(self.res_layers[i])
if freeze_norm:
self._freeze_norm(self)
if pretrained:
state = torch.hub.load_state_dict_from_url(donwload_url[depth])
self.load_state_dict(state)
print(f'Load PResNet{depth} state_dict')
def _freeze_parameters(self, m: nn.Module):
for p in m.parameters():
p.requires_grad = False
def _freeze_norm(self, m: nn.Module):
if isinstance(m, nn.BatchNorm2d):
m = FrozenBatchNorm2d(m.num_features)
else:
for name, child in m.named_children():
_child = self._freeze_norm(child)
if _child is not child:
setattr(m, name, _child)
return m
def forward(self, x):
conv1 = self.conv1(x)
x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
outs = []
for idx, stage in enumerate(self.res_layers):
x = stage(x)
if idx in self.return_idx:
outs.append(x)
return outs