Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from mmcv.cnn import ConvModule, build_plugin_layer | |
| from mmcv.runner import BaseModule, Sequential | |
| import mmocr.utils as utils | |
| from mmocr.models.builder import BACKBONES | |
| from mmocr.models.textrecog.layers import BasicBlock | |
| class ResNet(BaseModule): | |
| """ | |
| Args: | |
| in_channels (int): Number of channels of input image tensor. | |
| stem_channels (list[int]): List of channels in each stem layer. E.g., | |
| [64, 128] stands for 64 and 128 channels in the first and second | |
| stem layers. | |
| block_cfgs (dict): Configs of block | |
| arch_layers (list[int]): List of Block number for each stage. | |
| arch_channels (list[int]): List of channels for each stage. | |
| strides (Sequence[int] | Sequence[tuple]): Strides of the first block | |
| of each stage. | |
| out_indices (None | Sequence[int]): Indices of output stages. If not | |
| specified, only the last stage will be returned. | |
| stage_plugins (dict): Configs of stage plugins | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| stem_channels, | |
| block_cfgs, | |
| arch_layers, | |
| arch_channels, | |
| strides, | |
| out_indices=None, | |
| plugins=None, | |
| init_cfg=[ | |
| dict(type='Xavier', layer='Conv2d'), | |
| dict(type='Constant', val=1, layer='BatchNorm2d'), | |
| ]): | |
| super().__init__(init_cfg=init_cfg) | |
| assert isinstance(in_channels, int) | |
| assert isinstance(stem_channels, int) or utils.is_type_list( | |
| stem_channels, int) | |
| assert utils.is_type_list(arch_layers, int) | |
| assert utils.is_type_list(arch_channels, int) | |
| assert utils.is_type_list(strides, tuple) or utils.is_type_list( | |
| strides, int) | |
| assert len(arch_layers) == len(arch_channels) == len(strides) | |
| assert out_indices is None or isinstance(out_indices, (list, tuple)) | |
| self.out_indices = out_indices | |
| self._make_stem_layer(in_channels, stem_channels) | |
| self.num_stages = len(arch_layers) | |
| self.use_plugins = False | |
| self.arch_channels = arch_channels | |
| self.res_layers = [] | |
| if plugins is not None: | |
| self.plugin_ahead_names = [] | |
| self.plugin_after_names = [] | |
| self.use_plugins = True | |
| for i, num_blocks in enumerate(arch_layers): | |
| stride = strides[i] | |
| channel = arch_channels[i] | |
| if self.use_plugins: | |
| self._make_stage_plugins(plugins, stage_idx=i) | |
| res_layer = self._make_layer( | |
| block_cfgs=block_cfgs, | |
| inplanes=self.inplanes, | |
| planes=channel, | |
| blocks=num_blocks, | |
| stride=stride, | |
| ) | |
| self.inplanes = channel | |
| layer_name = f'layer{i + 1}' | |
| self.add_module(layer_name, res_layer) | |
| self.res_layers.append(layer_name) | |
| def _make_layer(self, block_cfgs, inplanes, planes, blocks, stride): | |
| layers = [] | |
| downsample = None | |
| block_cfgs_ = block_cfgs.copy() | |
| if isinstance(stride, int): | |
| stride = (stride, stride) | |
| if stride[0] != 1 or stride[1] != 1 or inplanes != planes: | |
| downsample = ConvModule( | |
| inplanes, | |
| planes, | |
| 1, | |
| stride, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=None) | |
| if block_cfgs_['type'] == 'BasicBlock': | |
| block = BasicBlock | |
| block_cfgs_.pop('type') | |
| else: | |
| raise ValueError('{} not implement yet'.format(block['type'])) | |
| layers.append( | |
| block( | |
| inplanes, | |
| planes, | |
| stride=stride, | |
| downsample=downsample, | |
| **block_cfgs_)) | |
| inplanes = planes | |
| for _ in range(1, blocks): | |
| layers.append(block(inplanes, planes, **block_cfgs_)) | |
| return Sequential(*layers) | |
| def _make_stem_layer(self, in_channels, stem_channels): | |
| if isinstance(stem_channels, int): | |
| stem_channels = [stem_channels] | |
| stem_layers = [] | |
| for _, channels in enumerate(stem_channels): | |
| stem_layer = ConvModule( | |
| in_channels, | |
| channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=False, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU')) | |
| in_channels = channels | |
| stem_layers.append(stem_layer) | |
| self.stem_layers = Sequential(*stem_layers) | |
| self.inplanes = stem_channels[-1] | |
| def _make_stage_plugins(self, plugins, stage_idx): | |
| """Make plugins for ResNet ``stage_idx`` th stage. | |
| Currently we support inserting ``nn.Maxpooling``, | |
| ``mmcv.cnn.Convmodule``into the backbone. Originally designed | |
| for ResNet31-like architectures. | |
| Examples: | |
| >>> plugins=[ | |
| ... dict(cfg=dict(type="Maxpooling", arg=(2,2)), | |
| ... stages=(True, True, False, False), | |
| ... position='before_stage'), | |
| ... dict(cfg=dict(type="Maxpooling", arg=(2,1)), | |
| ... stages=(False, False, True, Flase), | |
| ... position='before_stage'), | |
| ... dict(cfg=dict( | |
| ... type='ConvModule', | |
| ... kernel_size=3, | |
| ... stride=1, | |
| ... padding=1, | |
| ... norm_cfg=dict(type='BN'), | |
| ... act_cfg=dict(type='ReLU')), | |
| ... stages=(True, True, True, True), | |
| ... position='after_stage')] | |
| Suppose ``stage_idx=1``, the structure of stage would be: | |
| .. code-block:: none | |
| Maxpooling -> A set of Basicblocks -> ConvModule | |
| Args: | |
| plugins (list[dict]): List of plugins cfg to build. | |
| stage_idx (int): Index of stage to build | |
| Returns: | |
| list[dict]: Plugins for current stage | |
| """ | |
| in_channels = self.arch_channels[stage_idx] | |
| self.plugin_ahead_names.append([]) | |
| self.plugin_after_names.append([]) | |
| for plugin in plugins: | |
| plugin = plugin.copy() | |
| stages = plugin.pop('stages', None) | |
| position = plugin.pop('position', None) | |
| assert stages is None or len(stages) == self.num_stages | |
| if stages[stage_idx]: | |
| if position == 'before_stage': | |
| name, layer = build_plugin_layer( | |
| plugin['cfg'], | |
| f'_before_stage_{stage_idx+1}', | |
| in_channels=in_channels, | |
| out_channels=in_channels) | |
| self.plugin_ahead_names[stage_idx].append(name) | |
| self.add_module(name, layer) | |
| elif position == 'after_stage': | |
| name, layer = build_plugin_layer( | |
| plugin['cfg'], | |
| f'_after_stage_{stage_idx+1}', | |
| in_channels=in_channels, | |
| out_channels=in_channels) | |
| self.plugin_after_names[stage_idx].append(name) | |
| self.add_module(name, layer) | |
| else: | |
| raise ValueError('uncorrect plugin position') | |
| def forward_plugin(self, x, plugin_name): | |
| out = x | |
| for name in plugin_name: | |
| out = getattr(self, name)(x) | |
| return out | |
| def forward(self, x): | |
| """ | |
| Args: x (Tensor): Image tensor of shape :math:`(N, 3, H, W)`. | |
| Returns: | |
| Tensor or list[Tensor]: Feature tensor. It can be a list of | |
| feature outputs at specific layers if ``out_indices`` is specified. | |
| """ | |
| x = self.stem_layers(x) | |
| outs = [] | |
| for i, layer_name in enumerate(self.res_layers): | |
| res_layer = getattr(self, layer_name) | |
| if not self.use_plugins: | |
| x = res_layer(x) | |
| if self.out_indices and i in self.out_indices: | |
| outs.append(x) | |
| else: | |
| x = self.forward_plugin(x, self.plugin_ahead_names[i]) | |
| x = res_layer(x) | |
| x = self.forward_plugin(x, self.plugin_after_names[i]) | |
| if self.out_indices and i in self.out_indices: | |
| outs.append(x) | |
| return tuple(outs) if self.out_indices else x | |