Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy as cp | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule, MaxPool2d | |
from mmengine.model import BaseModule | |
from mmpose.registry import MODELS | |
from .base_backbone import BaseBackbone | |
class RSB(BaseModule): | |
"""Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate | |
Local Representations for Multi-Person Pose Estimation" (ECCV 2020). | |
Args: | |
in_channels (int): Input channels of this block. | |
out_channels (int): Output channels of this block. | |
num_steps (int): Numbers of steps in RSB | |
stride (int): stride of the block. Default: 1 | |
downsample (nn.Module): downsample operation on identity branch. | |
Default: None. | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
expand_times (int): Times by which the in_channels are expanded. | |
Default:26. | |
res_top_channels (int): Number of channels of feature output by | |
ResNet_top. Default:64. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
expansion = 1 | |
def __init__(self, | |
in_channels, | |
out_channels, | |
num_steps=4, | |
stride=1, | |
downsample=None, | |
with_cp=False, | |
norm_cfg=dict(type='BN'), | |
expand_times=26, | |
res_top_channels=64, | |
init_cfg=None): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
super().__init__(init_cfg=init_cfg) | |
assert num_steps > 1 | |
self.in_channels = in_channels | |
self.branch_channels = self.in_channels * expand_times | |
self.branch_channels //= res_top_channels | |
self.out_channels = out_channels | |
self.stride = stride | |
self.downsample = downsample | |
self.with_cp = with_cp | |
self.norm_cfg = norm_cfg | |
self.num_steps = num_steps | |
self.conv_bn_relu1 = ConvModule( | |
self.in_channels, | |
self.num_steps * self.branch_channels, | |
kernel_size=1, | |
stride=self.stride, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
inplace=False) | |
for i in range(self.num_steps): | |
for j in range(i + 1): | |
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' | |
self.add_module( | |
module_name, | |
ConvModule( | |
self.branch_channels, | |
self.branch_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
inplace=False)) | |
self.conv_bn3 = ConvModule( | |
self.num_steps * self.branch_channels, | |
self.out_channels * self.expansion, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
act_cfg=None, | |
norm_cfg=self.norm_cfg, | |
inplace=False) | |
self.relu = nn.ReLU(inplace=False) | |
def forward(self, x): | |
"""Forward function.""" | |
identity = x | |
x = self.conv_bn_relu1(x) | |
spx = torch.split(x, self.branch_channels, 1) | |
outputs = list() | |
outs = list() | |
for i in range(self.num_steps): | |
outputs_i = list() | |
outputs.append(outputs_i) | |
for j in range(i + 1): | |
if j == 0: | |
inputs = spx[i] | |
else: | |
inputs = outputs[i][j - 1] | |
if i > j: | |
inputs = inputs + outputs[i - 1][j] | |
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' | |
module_i_j = getattr(self, module_name) | |
outputs[i].append(module_i_j(inputs)) | |
outs.append(outputs[i][i]) | |
out = torch.cat(tuple(outs), 1) | |
out = self.conv_bn3(out) | |
if self.downsample is not None: | |
identity = self.downsample(identity) | |
out = out + identity | |
out = self.relu(out) | |
return out | |
class Downsample_module(BaseModule): | |
"""Downsample module for RSN. | |
Args: | |
block (nn.Module): Downsample block. | |
num_blocks (list): Number of blocks in each downsample unit. | |
num_units (int): Numbers of downsample units. Default: 4 | |
has_skip (bool): Have skip connections from prior upsample | |
module or not. Default:False | |
num_steps (int): Number of steps in a block. Default:4 | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
in_channels (int): Number of channels of the input feature to | |
downsample module. Default: 64 | |
expand_times (int): Times by which the in_channels are expanded. | |
Default:26. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
block, | |
num_blocks, | |
num_steps=4, | |
num_units=4, | |
has_skip=False, | |
norm_cfg=dict(type='BN'), | |
in_channels=64, | |
expand_times=26, | |
init_cfg=None): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
super().__init__(init_cfg=init_cfg) | |
self.has_skip = has_skip | |
self.in_channels = in_channels | |
assert len(num_blocks) == num_units | |
self.num_blocks = num_blocks | |
self.num_units = num_units | |
self.num_steps = num_steps | |
self.norm_cfg = norm_cfg | |
self.layer1 = self._make_layer( | |
block, | |
in_channels, | |
num_blocks[0], | |
expand_times=expand_times, | |
res_top_channels=in_channels) | |
for i in range(1, num_units): | |
module_name = f'layer{i + 1}' | |
self.add_module( | |
module_name, | |
self._make_layer( | |
block, | |
in_channels * pow(2, i), | |
num_blocks[i], | |
stride=2, | |
expand_times=expand_times, | |
res_top_channels=in_channels)) | |
def _make_layer(self, | |
block, | |
out_channels, | |
blocks, | |
stride=1, | |
expand_times=26, | |
res_top_channels=64): | |
downsample = None | |
if stride != 1 or self.in_channels != out_channels * block.expansion: | |
downsample = ConvModule( | |
self.in_channels, | |
out_channels * block.expansion, | |
kernel_size=1, | |
stride=stride, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None, | |
inplace=True) | |
units = list() | |
units.append( | |
block( | |
self.in_channels, | |
out_channels, | |
num_steps=self.num_steps, | |
stride=stride, | |
downsample=downsample, | |
norm_cfg=self.norm_cfg, | |
expand_times=expand_times, | |
res_top_channels=res_top_channels)) | |
self.in_channels = out_channels * block.expansion | |
for _ in range(1, blocks): | |
units.append( | |
block( | |
self.in_channels, | |
out_channels, | |
num_steps=self.num_steps, | |
expand_times=expand_times, | |
res_top_channels=res_top_channels)) | |
return nn.Sequential(*units) | |
def forward(self, x, skip1, skip2): | |
out = list() | |
for i in range(self.num_units): | |
module_name = f'layer{i + 1}' | |
module_i = getattr(self, module_name) | |
x = module_i(x) | |
if self.has_skip: | |
x = x + skip1[i] + skip2[i] | |
out.append(x) | |
out.reverse() | |
return tuple(out) | |
class Upsample_unit(BaseModule): | |
"""Upsample unit for upsample module. | |
Args: | |
ind (int): Indicates whether to interpolate (>0) and whether to | |
generate feature map for the next hourglass-like module. | |
num_units (int): Number of units that form a upsample module. Along | |
with ind and gen_cross_conv, nm_units is used to decide whether | |
to generate feature map for the next hourglass-like module. | |
in_channels (int): Channel number of the skip-in feature maps from | |
the corresponding downsample unit. | |
unit_channels (int): Channel number in this unit. Default:256. | |
gen_skip: (bool): Whether or not to generate skips for the posterior | |
downsample module. Default:False | |
gen_cross_conv (bool): Whether to generate feature map for the next | |
hourglass-like module. Default:False | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
out_channels (in): Number of channels of feature output by upsample | |
module. Must equal to in_channels of downsample module. Default:64 | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
ind, | |
num_units, | |
in_channels, | |
unit_channels=256, | |
gen_skip=False, | |
gen_cross_conv=False, | |
norm_cfg=dict(type='BN'), | |
out_channels=64, | |
init_cfg=None): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
super().__init__(init_cfg=init_cfg) | |
self.num_units = num_units | |
self.norm_cfg = norm_cfg | |
self.in_skip = ConvModule( | |
in_channels, | |
unit_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None, | |
inplace=True) | |
self.relu = nn.ReLU(inplace=True) | |
self.ind = ind | |
if self.ind > 0: | |
self.up_conv = ConvModule( | |
unit_channels, | |
unit_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None, | |
inplace=True) | |
self.gen_skip = gen_skip | |
if self.gen_skip: | |
self.out_skip1 = ConvModule( | |
in_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
inplace=True) | |
self.out_skip2 = ConvModule( | |
unit_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
inplace=True) | |
self.gen_cross_conv = gen_cross_conv | |
if self.ind == num_units - 1 and self.gen_cross_conv: | |
self.cross_conv = ConvModule( | |
unit_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
norm_cfg=self.norm_cfg, | |
inplace=True) | |
def forward(self, x, up_x): | |
out = self.in_skip(x) | |
if self.ind > 0: | |
up_x = F.interpolate( | |
up_x, | |
size=(x.size(2), x.size(3)), | |
mode='bilinear', | |
align_corners=True) | |
up_x = self.up_conv(up_x) | |
out = out + up_x | |
out = self.relu(out) | |
skip1 = None | |
skip2 = None | |
if self.gen_skip: | |
skip1 = self.out_skip1(x) | |
skip2 = self.out_skip2(out) | |
cross_conv = None | |
if self.ind == self.num_units - 1 and self.gen_cross_conv: | |
cross_conv = self.cross_conv(out) | |
return out, skip1, skip2, cross_conv | |
class Upsample_module(BaseModule): | |
"""Upsample module for RSN. | |
Args: | |
unit_channels (int): Channel number in the upsample units. | |
Default:256. | |
num_units (int): Numbers of upsample units. Default: 4 | |
gen_skip (bool): Whether to generate skip for posterior downsample | |
module or not. Default:False | |
gen_cross_conv (bool): Whether to generate feature map for the next | |
hourglass-like module. Default:False | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
out_channels (int): Number of channels of feature output by upsample | |
module. Must equal to in_channels of downsample module. Default:64 | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
unit_channels=256, | |
num_units=4, | |
gen_skip=False, | |
gen_cross_conv=False, | |
norm_cfg=dict(type='BN'), | |
out_channels=64, | |
init_cfg=None): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = list() | |
for i in range(num_units): | |
self.in_channels.append(RSB.expansion * out_channels * pow(2, i)) | |
self.in_channels.reverse() | |
self.num_units = num_units | |
self.gen_skip = gen_skip | |
self.gen_cross_conv = gen_cross_conv | |
self.norm_cfg = norm_cfg | |
for i in range(num_units): | |
module_name = f'up{i + 1}' | |
self.add_module( | |
module_name, | |
Upsample_unit( | |
i, | |
self.num_units, | |
self.in_channels[i], | |
unit_channels, | |
self.gen_skip, | |
self.gen_cross_conv, | |
norm_cfg=self.norm_cfg, | |
out_channels=64)) | |
def forward(self, x): | |
out = list() | |
skip1 = list() | |
skip2 = list() | |
cross_conv = None | |
for i in range(self.num_units): | |
module_i = getattr(self, f'up{i + 1}') | |
if i == 0: | |
outi, skip1_i, skip2_i, _ = module_i(x[i], None) | |
elif i == self.num_units - 1: | |
outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1]) | |
else: | |
outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1]) | |
out.append(outi) | |
skip1.append(skip1_i) | |
skip2.append(skip2_i) | |
skip1.reverse() | |
skip2.reverse() | |
return out, skip1, skip2, cross_conv | |
class Single_stage_RSN(BaseModule): | |
"""Single_stage Residual Steps Network. | |
Args: | |
unit_channels (int): Channel number in the upsample units. Default:256. | |
num_units (int): Numbers of downsample/upsample units. Default: 4 | |
gen_skip (bool): Whether to generate skip for posterior downsample | |
module or not. Default:False | |
gen_cross_conv (bool): Whether to generate feature map for the next | |
hourglass-like module. Default:False | |
has_skip (bool): Have skip connections from prior upsample | |
module or not. Default:False | |
num_steps (int): Number of steps in RSB. Default: 4 | |
num_blocks (list): Number of blocks in each downsample unit. | |
Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks) | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
in_channels (int): Number of channels of the feature from ResNet_Top. | |
Default: 64. | |
expand_times (int): Times by which the in_channels are expanded in RSB. | |
Default:26. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
has_skip=False, | |
gen_skip=False, | |
gen_cross_conv=False, | |
unit_channels=256, | |
num_units=4, | |
num_steps=4, | |
num_blocks=[2, 2, 2, 2], | |
norm_cfg=dict(type='BN'), | |
in_channels=64, | |
expand_times=26, | |
init_cfg=None): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
num_blocks = cp.deepcopy(num_blocks) | |
super().__init__(init_cfg=init_cfg) | |
assert len(num_blocks) == num_units | |
self.has_skip = has_skip | |
self.gen_skip = gen_skip | |
self.gen_cross_conv = gen_cross_conv | |
self.num_units = num_units | |
self.num_steps = num_steps | |
self.unit_channels = unit_channels | |
self.num_blocks = num_blocks | |
self.norm_cfg = norm_cfg | |
self.downsample = Downsample_module(RSB, num_blocks, num_steps, | |
num_units, has_skip, norm_cfg, | |
in_channels, expand_times) | |
self.upsample = Upsample_module(unit_channels, num_units, gen_skip, | |
gen_cross_conv, norm_cfg, in_channels) | |
def forward(self, x, skip1, skip2): | |
mid = self.downsample(x, skip1, skip2) | |
out, skip1, skip2, cross_conv = self.upsample(mid) | |
return out, skip1, skip2, cross_conv | |
class ResNet_top(BaseModule): | |
"""ResNet top for RSN. | |
Args: | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
channels (int): Number of channels of the feature output by ResNet_top. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, norm_cfg=dict(type='BN'), channels=64, init_cfg=None): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
super().__init__(init_cfg=init_cfg) | |
self.top = nn.Sequential( | |
ConvModule( | |
3, | |
channels, | |
kernel_size=7, | |
stride=2, | |
padding=3, | |
norm_cfg=norm_cfg, | |
inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1)) | |
def forward(self, img): | |
return self.top(img) | |
class RSN(BaseBackbone): | |
"""Residual Steps Network backbone. Paper ref: Cai et al. "Learning | |
Delicate Local Representations for Multi-Person Pose Estimation" (ECCV | |
2020). | |
Args: | |
unit_channels (int): Number of Channels in an upsample unit. | |
Default: 256 | |
num_stages (int): Number of stages in a multi-stage RSN. Default: 4 | |
num_units (int): NUmber of downsample/upsample units in a single-stage | |
RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks) | |
num_blocks (list): Number of RSBs (Residual Steps Block) in each | |
downsample unit. Default: [2, 2, 2, 2] | |
num_steps (int): Number of steps in a RSB. Default:4 | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
Default: dict(type='BN') | |
res_top_channels (int): Number of channels of feature from ResNet_top. | |
Default: 64. | |
expand_times (int): Times by which the in_channels are expanded in RSB. | |
Default:26. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: | |
``[ | |
dict(type='Kaiming', layer=['Conv2d']), | |
dict( | |
type='Constant', | |
val=1, | |
layer=['_BatchNorm', 'GroupNorm']), | |
dict( | |
type='Normal', | |
std=0.01, | |
layer=['Linear']), | |
]`` | |
Example: | |
>>> from mmpose.models import RSN | |
>>> import torch | |
>>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2]) | |
>>> self.eval() | |
>>> inputs = torch.rand(1, 3, 511, 511) | |
>>> level_outputs = self.forward(inputs) | |
>>> for level_output in level_outputs: | |
... for feature in level_output: | |
... print(tuple(feature.shape)) | |
... | |
(1, 256, 64, 64) | |
(1, 256, 128, 128) | |
(1, 256, 64, 64) | |
(1, 256, 128, 128) | |
""" | |
def __init__(self, | |
unit_channels=256, | |
num_stages=4, | |
num_units=4, | |
num_blocks=[2, 2, 2, 2], | |
num_steps=4, | |
norm_cfg=dict(type='BN'), | |
res_top_channels=64, | |
expand_times=26, | |
init_cfg=[ | |
dict(type='Kaiming', layer=['Conv2d']), | |
dict( | |
type='Constant', | |
val=1, | |
layer=['_BatchNorm', 'GroupNorm']), | |
dict(type='Normal', std=0.01, layer=['Linear']), | |
]): | |
# Protect mutable default arguments | |
norm_cfg = cp.deepcopy(norm_cfg) | |
num_blocks = cp.deepcopy(num_blocks) | |
super().__init__(init_cfg=init_cfg) | |
self.unit_channels = unit_channels | |
self.num_stages = num_stages | |
self.num_units = num_units | |
self.num_blocks = num_blocks | |
self.num_steps = num_steps | |
self.norm_cfg = norm_cfg | |
assert self.num_stages > 0 | |
assert self.num_steps > 1 | |
assert self.num_units > 1 | |
assert self.num_units == len(self.num_blocks) | |
self.top = ResNet_top(norm_cfg=norm_cfg) | |
self.multi_stage_rsn = nn.ModuleList([]) | |
for i in range(self.num_stages): | |
if i == 0: | |
has_skip = False | |
else: | |
has_skip = True | |
if i != self.num_stages - 1: | |
gen_skip = True | |
gen_cross_conv = True | |
else: | |
gen_skip = False | |
gen_cross_conv = False | |
self.multi_stage_rsn.append( | |
Single_stage_RSN(has_skip, gen_skip, gen_cross_conv, | |
unit_channels, num_units, num_steps, | |
num_blocks, norm_cfg, res_top_channels, | |
expand_times)) | |
def forward(self, x): | |
"""Model forward function.""" | |
out_feats = [] | |
skip1 = None | |
skip2 = None | |
x = self.top(x) | |
for i in range(self.num_stages): | |
out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2) | |
out_feats.append(out) | |
return out_feats | |