import torch import torch.nn as nn from torchvision.ops.misc import Conv2dNormActivation from .helpers.utils import make_divisible from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig def initialize_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.zeros_(m.bias) class Block(nn.Module): def __init__(self, in_channels, out_channels, expansion_rate, stride): super().__init__() exp_channels = make_divisible(in_channels * expansion_rate, 8) # create the three factorized convs that make up the inverted bottleneck block exp_conv = Conv2dNormActivation( in_channels, exp_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, inplace=False, ) # depthwise convolution with possible stride depth_conv = Conv2dNormActivation( exp_channels, exp_channels, kernel_size=3, stride=stride, padding=1, groups=exp_channels, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, inplace=False, ) proj_conv = Conv2dNormActivation( exp_channels, out_channels, kernel_size=1, stride=1, norm_layer=nn.BatchNorm2d, activation_layer=None, inplace=False, ) self.after_block_activation = nn.ReLU() if in_channels == out_channels: self.use_shortcut = True if stride == 1 or stride == (1, 1): self.shortcut = nn.Sequential() else: # average pooling required for shortcut self.shortcut = nn.Sequential( nn.AvgPool2d(kernel_size=3, stride=stride, padding=1), nn.Sequential(), ) else: self.use_shortcut = False self.block = nn.Sequential(exp_conv, depth_conv, proj_conv) def forward(self, x): if self.use_shortcut: x = self.block(x) + self.shortcut(x) else: x = self.block(x) x = self.after_block_activation(x) return x class NetworkConfig(PretrainedConfig): def __init__( self, n_classes=10, in_channels=1, base_channels=32, channels_multiplier=2.3, expansion_rate=3.0, n_blocks=(3, 2, 1), strides=dict(b2=(1, 1), b3=(1, 2), b4=(2, 1)), add_feats=False, *args, **kwargs, ): super().__init__(*args, **kwargs) self.n_classes = n_classes self.in_channels = in_channels self.base_channels = base_channels self.channels_multiplier = channels_multiplier self.expansion_rate = expansion_rate self.n_blocks = n_blocks self.strides = strides self.add_feats = add_feats class Network(PreTrainedModel): config_class = NetworkConfig def __init__(self, config): super().__init__(config) n_classes = config.n_classes in_channels = config.in_channels base_channels = config.base_channels channels_multiplier = config.channels_multiplier expansion_rate = config.expansion_rate n_blocks = config.n_blocks strides = config.strides n_stages = len(n_blocks) self.add_feats = config.add_feats base_channels = make_divisible(base_channels, 8) channels_per_stage = [base_channels] + [ make_divisible(base_channels * channels_multiplier**stage_id, 8) for stage_id in range(n_stages) ] self.total_block_count = 0 self.in_c = nn.Sequential( Conv2dNormActivation( in_channels, channels_per_stage[0] // 4, activation_layer=torch.nn.ReLU, kernel_size=3, stride=2, inplace=False, ), Conv2dNormActivation( channels_per_stage[0] // 4, channels_per_stage[0], activation_layer=torch.nn.ReLU, kernel_size=3, stride=2, inplace=False, ), ) self.stages = nn.Sequential() for stage_id in range(n_stages): stage = self._make_stage( channels_per_stage[stage_id], channels_per_stage[stage_id + 1], n_blocks[stage_id], strides=strides, expansion_rate=expansion_rate, ) self.stages.add_module(f"s{stage_id + 1}", stage) ff_list = [] ff_list += [ nn.Conv2d( channels_per_stage[-1], n_classes, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=False, ), nn.BatchNorm2d(n_classes), ] ff_list.append(nn.AdaptiveAvgPool2d((1, 1))) self.feed_forward = nn.Sequential(*ff_list) self.apply(initialize_weights) def _make_stage(self, in_channels, out_channels, n_blocks, strides, expansion_rate): stage = nn.Sequential() for index in range(n_blocks): block_id = self.total_block_count + 1 bname = f"b{block_id}" self.total_block_count = self.total_block_count + 1 if bname in strides: stride = strides[bname] else: stride = (1, 1) block = self._make_block( in_channels, out_channels, stride=stride, expansion_rate=expansion_rate ) stage.add_module(bname, block) in_channels = out_channels return stage def _make_block(self, in_channels, out_channels, stride, expansion_rate): block = Block(in_channels, out_channels, expansion_rate, stride) return block def _forward_conv(self, x): x = self.in_c(x) x = self.stages(x) return x def forward(self, x): y = self._forward_conv(x) x = self.feed_forward(y) logits = x.squeeze(2).squeeze(2) if self.add_feats: return logits, y else: return logits def get_model( n_classes=10, in_channels=1, base_channels=32, channels_multiplier=2.3, expansion_rate=3.0, n_blocks=(3, 2, 1), strides=None, add_feats=False, ): """ @param n_classes: number of the classes to predict @param in_channels: input channels to the network, for audio it is by default 1 @param base_channels: number of channels after in_conv @param channels_multiplier: controls the increase in the width of the network after each stage @param expansion_rate: determines the expansion rate in inverted bottleneck blocks @param n_blocks: number of blocks that should exist in each stage @param strides: default value set below @return: full neural network model based on the specified configs """ if strides is None: strides = dict(b2=(1, 1), b3=(1, 2), b4=(2, 1)) model_config = { "n_classes": n_classes, "in_channels": in_channels, "base_channels": base_channels, "channels_multiplier": channels_multiplier, "expansion_rate": expansion_rate, "n_blocks": n_blocks, "strides": strides, "add_feats": add_feats, } m = Network(NetworkConfig(**model_config)) return m