""" PyTorch EfficientDet model Based on official Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet Paper: https://arxiv.org/abs/1911.09070 Hacked together by Ross Wightman """ import torch import torch.nn as nn import logging import math from collections import OrderedDict from typing import List, Callable from functools import partial from timm import create_model from timm.models.layers import create_conv2d, drop_path, create_pool2d, Swish, get_act_layer from .config import get_fpn_config, set_config_writeable, set_config_readonly _DEBUG = False _ACT_LAYER = Swish class SequentialList(nn.Sequential): """ This module exists to work around torchscript typing issues list -> list""" def __init__(self, *args): super(SequentialList, self).__init__(*args) def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: for module in self: x = module(x) return x class ConvBnAct2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding='', bias=False, norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER): super(ConvBnAct2d, self).__init__() self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias) self.bn = None if norm_layer is None else norm_layer(out_channels) self.act = None if act_layer is None else act_layer(inplace=True) def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.act is not None: x = self.act(x) return x class SeparableConv2d(nn.Module): """ Separable Conv """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER): super(SeparableConv2d, self).__init__() self.conv_dw = create_conv2d( in_channels, int(in_channels * channel_multiplier), kernel_size, stride=stride, dilation=dilation, padding=padding, depthwise=True) self.conv_pw = create_conv2d( int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) self.bn = None if norm_layer is None else norm_layer(out_channels) self.act = None if act_layer is None else act_layer(inplace=True) def forward(self, x): x = self.conv_dw(x) x = self.conv_pw(x) if self.bn is not None: x = self.bn(x) if self.act is not None: x = self.act(x) return x class ResampleFeatureMap(nn.Sequential): def __init__(self, in_channels, out_channels, reduction_ratio=1., pad_type='', pooling_type='max', norm_layer=nn.BatchNorm2d, apply_bn=False, conv_after_downsample=False, redundant_bias=False): super(ResampleFeatureMap, self).__init__() pooling_type = pooling_type or 'max' self.in_channels = in_channels self.out_channels = out_channels self.reduction_ratio = reduction_ratio self.conv_after_downsample = conv_after_downsample conv = None if in_channels != out_channels: conv = ConvBnAct2d( in_channels, out_channels, kernel_size=1, padding=pad_type, norm_layer=norm_layer if apply_bn else None, bias=not apply_bn or redundant_bias, act_layer=None) if reduction_ratio > 1: stride_size = int(reduction_ratio) if conv is not None and not self.conv_after_downsample: self.add_module('conv', conv) self.add_module( 'downsample', create_pool2d( pooling_type, kernel_size=stride_size + 1, stride=stride_size, padding=pad_type)) if conv is not None and self.conv_after_downsample: self.add_module('conv', conv) else: if conv is not None: self.add_module('conv', conv) if reduction_ratio < 1: scale = int(1 // reduction_ratio) self.add_module('upsample', nn.UpsamplingNearest2d(scale_factor=scale)) # def forward(self, x): # # here for debugging only # assert x.shape[1] == self.in_channels # if self.reduction_ratio > 1: # if hasattr(self, 'conv') and not self.conv_after_downsample: # x = self.conv(x) # x = self.downsample(x) # if hasattr(self, 'conv') and self.conv_after_downsample: # x = self.conv(x) # else: # if hasattr(self, 'conv'): # x = self.conv(x) # if self.reduction_ratio < 1: # x = self.upsample(x) # return x class FpnCombine(nn.Module): def __init__(self, feature_info, fpn_config, fpn_channels, inputs_offsets, target_reduction, pad_type='', pooling_type='max', norm_layer=nn.BatchNorm2d, apply_bn_for_resampling=False, conv_after_downsample=False, redundant_bias=False, weight_method='attn'): super(FpnCombine, self).__init__() self.inputs_offsets = inputs_offsets self.weight_method = weight_method self.resample = nn.ModuleDict() for idx, offset in enumerate(inputs_offsets): in_channels = fpn_channels if offset < len(feature_info): in_channels = feature_info[offset]['num_chs'] input_reduction = feature_info[offset]['reduction'] else: node_idx = offset - len(feature_info) input_reduction = fpn_config.nodes[node_idx]['reduction'] reduction_ratio = target_reduction / input_reduction self.resample[str(offset)] = ResampleFeatureMap( in_channels, fpn_channels, reduction_ratio=reduction_ratio, pad_type=pad_type, pooling_type=pooling_type, norm_layer=norm_layer, apply_bn=apply_bn_for_resampling, conv_after_downsample=conv_after_downsample, redundant_bias=redundant_bias) if weight_method == 'attn' or weight_method == 'fastattn': self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) # WSM else: self.edge_weights = None def forward(self, x: List[torch.Tensor]): dtype = x[0].dtype nodes = [] for offset, resample in zip(self.inputs_offsets, self.resample.values()): input_node = x[offset] input_node = resample(input_node) nodes.append(input_node) if self.weight_method == 'attn': normalized_weights = torch.softmax(self.edge_weights.to(dtype=dtype), dim=0) out = torch.stack(nodes, dim=-1) * normalized_weights elif self.weight_method == 'fastattn': edge_weights = nn.functional.relu(self.edge_weights.to(dtype=dtype)) weights_sum = torch.sum(edge_weights) out = torch.stack( [(nodes[i] * edge_weights[i]) / (weights_sum + 0.0001) for i in range(len(nodes))], dim=-1) elif self.weight_method == 'sum': out = torch.stack(nodes, dim=-1) else: raise ValueError('unknown weight_method {}'.format(self.weight_method)) out = torch.sum(out, dim=-1) return out class Fnode(nn.Module): """ A simple wrapper used in place of nn.Sequential for torchscript typing Handles input type List[Tensor] -> output type Tensor """ def __init__(self, combine: nn.Module, after_combine: nn.Module): super(Fnode, self).__init__() self.combine = combine self.after_combine = after_combine def forward(self, x: List[torch.Tensor]) -> torch.Tensor: return self.after_combine(self.combine(x)) class BiFpnLayer(nn.Module): def __init__(self, feature_info, fpn_config, fpn_channels, num_levels=5, pad_type='', pooling_type='max', norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER, apply_bn_for_resampling=False, conv_after_downsample=True, conv_bn_relu_pattern=False, separable_conv=True, redundant_bias=False): super(BiFpnLayer, self).__init__() self.num_levels = num_levels self.conv_bn_relu_pattern = False self.feature_info = [] self.fnode = nn.ModuleList() for i, fnode_cfg in enumerate(fpn_config.nodes): logging.debug('fnode {} : {}'.format(i, fnode_cfg)) reduction = fnode_cfg['reduction'] combine = FpnCombine( feature_info, fpn_config, fpn_channels, tuple(fnode_cfg['inputs_offsets']), target_reduction=reduction, pad_type=pad_type, pooling_type=pooling_type, norm_layer=norm_layer, apply_bn_for_resampling=apply_bn_for_resampling, conv_after_downsample=conv_after_downsample, redundant_bias=redundant_bias, weight_method=fnode_cfg['weight_method']) after_combine = nn.Sequential() conv_kwargs = dict( in_channels=fpn_channels, out_channels=fpn_channels, kernel_size=3, padding=pad_type, bias=False, norm_layer=norm_layer, act_layer=act_layer) if not conv_bn_relu_pattern: conv_kwargs['bias'] = redundant_bias conv_kwargs['act_layer'] = None after_combine.add_module('act', act_layer(inplace=True)) after_combine.add_module( 'conv', SeparableConv2d(**conv_kwargs) if separable_conv else ConvBnAct2d(**conv_kwargs)) self.fnode.append(Fnode(combine=combine, after_combine=after_combine)) self.feature_info.append(dict(num_chs=fpn_channels, reduction=reduction)) self.feature_info = self.feature_info[-num_levels::] def forward(self, x: List[torch.Tensor]): for fn in self.fnode: x.append(fn(x)) return x[-self.num_levels::] class BiFpn(nn.Module): def __init__(self, config, feature_info): super(BiFpn, self).__init__() self.num_levels = config.num_levels norm_layer = config.norm_layer or nn.BatchNorm2d if config.norm_kwargs: norm_layer = partial(norm_layer, **config.norm_kwargs) act_layer = get_act_layer(config.act_type) or _ACT_LAYER fpn_config = config.fpn_config or get_fpn_config( config.fpn_name, min_level=config.min_level, max_level=config.max_level) self.resample = nn.ModuleDict() for level in range(config.num_levels): if level < len(feature_info): in_chs = feature_info[level]['num_chs'] reduction = feature_info[level]['reduction'] else: # Adds a coarser level by downsampling the last feature map reduction_ratio = 2 self.resample[str(level)] = ResampleFeatureMap( in_channels=in_chs, out_channels=config.fpn_channels, pad_type=config.pad_type, pooling_type=config.pooling_type, norm_layer=norm_layer, reduction_ratio=reduction_ratio, apply_bn=config.apply_bn_for_resampling, conv_after_downsample=config.conv_after_downsample, redundant_bias=config.redundant_bias, ) in_chs = config.fpn_channels reduction = int(reduction * reduction_ratio) feature_info.append(dict(num_chs=in_chs, reduction=reduction)) self.cell = SequentialList() for rep in range(config.fpn_cell_repeats): logging.debug('building cell {}'.format(rep)) fpn_layer = BiFpnLayer( feature_info=feature_info, fpn_config=fpn_config, fpn_channels=config.fpn_channels, num_levels=config.num_levels, pad_type=config.pad_type, pooling_type=config.pooling_type, norm_layer=norm_layer, act_layer=act_layer, separable_conv=config.separable_conv, apply_bn_for_resampling=config.apply_bn_for_resampling, conv_after_downsample=config.conv_after_downsample, conv_bn_relu_pattern=config.conv_bn_relu_pattern, redundant_bias=config.redundant_bias, ) self.cell.add_module(str(rep), fpn_layer) feature_info = fpn_layer.feature_info def forward(self, x: List[torch.Tensor]): for resample in self.resample.values(): x.append(resample(x[-1])) x = self.cell(x) return x class HeadNet(nn.Module): def __init__(self, config, num_outputs): super(HeadNet, self).__init__() self.num_levels = config.num_levels self.bn_level_first = getattr(config, 'head_bn_level_first', False) norm_layer = config.norm_layer or nn.BatchNorm2d if config.norm_kwargs: norm_layer = partial(norm_layer, **config.norm_kwargs) act_layer = get_act_layer(config.act_type) or _ACT_LAYER # Build convolution repeats conv_fn = SeparableConv2d if config.separable_conv else ConvBnAct2d conv_kwargs = dict( in_channels=config.fpn_channels, out_channels=config.fpn_channels, kernel_size=3, padding=config.pad_type, bias=config.redundant_bias, act_layer=None, norm_layer=None) self.conv_rep = nn.ModuleList([conv_fn(**conv_kwargs) for _ in range(config.box_class_repeats)]) # Build batchnorm repeats. There is a unique batchnorm per feature level for each repeat. # This can be organized with repeats first or feature levels first in module lists, the original models # and weights were setup with repeats first, levels first is required for efficient torchscript usage. self.bn_rep = nn.ModuleList() if self.bn_level_first: for _ in range(self.num_levels): self.bn_rep.append(nn.ModuleList([ norm_layer(config.fpn_channels) for _ in range(config.box_class_repeats)])) else: for _ in range(config.box_class_repeats): self.bn_rep.append(nn.ModuleList([ nn.Sequential(OrderedDict([('bn', norm_layer(config.fpn_channels))])) for _ in range(self.num_levels)])) self.act = act_layer(inplace=True) # Prediction (output) layer. Has bias with special init reqs, see init fn. num_anchors = len(config.aspect_ratios) * config.num_scales predict_kwargs = dict( in_channels=config.fpn_channels, out_channels=num_outputs * num_anchors, kernel_size=3, padding=config.pad_type, bias=True, norm_layer=None, act_layer=None) self.predict = conv_fn(**predict_kwargs) @torch.jit.ignore() def toggle_bn_level_first(self): """ Toggle the batchnorm layers between feature level first vs repeat first access pattern Limitations in torchscript require feature levels to be iterated over first. This function can be used to allow loading weights in the original order, and then toggle before jit scripting the model. """ with torch.no_grad(): new_bn_rep = nn.ModuleList() for i in range(len(self.bn_rep[0])): bn_first = nn.ModuleList() for r in self.bn_rep.children(): m = r[i] # NOTE original rep first model def has extra Sequential container with 'bn', this was # flattened in the level first definition. bn_first.append(m[0] if isinstance(m, nn.Sequential) else nn.Sequential(OrderedDict([('bn', m)]))) new_bn_rep.append(bn_first) self.bn_level_first = not self.bn_level_first self.bn_rep = new_bn_rep @torch.jit.ignore() def _forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: outputs = [] for level in range(self.num_levels): x_level = x[level] for conv, bn in zip(self.conv_rep, self.bn_rep): x_level = conv(x_level) x_level = bn[level](x_level) # this is not allowed in torchscript x_level = self.act(x_level) outputs.append(self.predict(x_level)) return outputs def _forward_level_first(self, x: List[torch.Tensor]) -> List[torch.Tensor]: outputs = [] for level, bn_rep in enumerate(self.bn_rep): # iterating over first bn dim first makes TS happy x_level = x[level] for conv, bn in zip(self.conv_rep, bn_rep): x_level = conv(x_level) x_level = bn(x_level) x_level = self.act(x_level) outputs.append(self.predict(x_level)) return outputs def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: if self.bn_level_first: return self._forward_level_first(x) else: return self._forward(x) def _init_weight(m, n='', ): """ Weight initialization as per Tensorflow official implementations. """ def _fan_in_out(w, groups=1): dimensions = w.dim() if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") num_input_fmaps = w.size(1) num_output_fmaps = w.size(0) receptive_field_size = 1 if w.dim() > 2: receptive_field_size = w[0][0].numel() fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size fan_out //= groups return fan_in, fan_out def _glorot_uniform(w, gain=1, groups=1): fan_in, fan_out = _fan_in_out(w, groups) gain /= max(1., (fan_in + fan_out) / 2.) # fan avg limit = math.sqrt(3.0 * gain) w.data.uniform_(-limit, limit) def _variance_scaling(w, gain=1, groups=1): fan_in, fan_out = _fan_in_out(w, groups) gain /= max(1., fan_in) # fan in # gain /= max(1., (fan_in + fan_out) / 2.) # fan # should it be normal or trunc normal? using normal for now since no good trunc in PT # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) # std = math.sqrt(gain) / .87962566103423978 # w.data.trunc_normal(std=std) std = math.sqrt(gain) w.data.normal_(std=std) if isinstance(m, SeparableConv2d): if 'box_net' in n or 'class_net' in n: _variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups) _variance_scaling(m.conv_pw.weight) if m.conv_pw.bias is not None: if 'class_net.predict' in n: m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) else: m.conv_pw.bias.data.zero_() else: _glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups) _glorot_uniform(m.conv_pw.weight) if m.conv_pw.bias is not None: m.conv_pw.bias.data.zero_() elif isinstance(m, ConvBnAct2d): if 'box_net' in n or 'class_net' in n: m.conv.weight.data.normal_(std=.01) if m.conv.bias is not None: if 'class_net.predict' in n: m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) else: m.conv.bias.data.zero_() else: _glorot_uniform(m.conv.weight) if m.conv.bias is not None: m.conv.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): # looks like all bn init the same? m.weight.data.fill_(1.0) m.bias.data.zero_() def _init_weight_alt(m, n='', ): """ Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition NOTE: this will likely be removed after some experimentation """ if isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: if 'class_net.predict' in n: m.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) else: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() def get_feature_info(backbone): if isinstance(backbone.feature_info, Callable): # old accessor for timm versions <= 0.1.30, efficientnet and mobilenetv3 and related nets only feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) for i, f in enumerate(backbone.feature_info())] else: # new feature info accessor, timm >= 0.2, all models supported feature_info = backbone.feature_info.get_dicts(keys=['num_chs', 'reduction']) return feature_info class EfficientDet(nn.Module): def __init__(self, config, pretrained_backbone=True, alternate_init=False): super(EfficientDet, self).__init__() self.config = config set_config_readonly(self.config) self.backbone = create_model( config.backbone_name, features_only=True, out_indices=(2, 3, 4), pretrained=pretrained_backbone, **config.backbone_args) feature_info = get_feature_info(self.backbone) self.fpn = BiFpn(self.config, feature_info) self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes) self.box_net = HeadNet(self.config, num_outputs=4) for n, m in self.named_modules(): if 'backbone' not in n: if alternate_init: _init_weight_alt(m, n) else: _init_weight(m, n) @torch.jit.ignore() def reset_head(self, num_classes=None, aspect_ratios=None, num_scales=None, alternate_init=False): reset_class_head = False reset_box_head = False set_config_writeable(self.config) if num_classes is not None: reset_class_head = True self.config.num_classes = num_classes if aspect_ratios is not None: reset_box_head = True self.config.aspect_ratios = aspect_ratios if num_scales is not None: reset_box_head = True self.config.num_scales = num_scales set_config_readonly(self.config) if reset_class_head: self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes) for n, m in self.class_net.named_modules(prefix='class_net'): if alternate_init: _init_weight_alt(m, n) else: _init_weight(m, n) if reset_box_head: self.box_net = HeadNet(self.config, num_outputs=4) for n, m in self.box_net.named_modules(prefix='box_net'): if alternate_init: _init_weight_alt(m, n) else: _init_weight(m, n) @torch.jit.ignore() def toggle_head_bn_level_first(self): """ Toggle the head batchnorm layers between being access with feature_level first vs repeat """ self.class_net.toggle_bn_level_first() self.box_net.toggle_bn_level_first() def forward(self, x): x = self.backbone(x) x = self.fpn(x) x_class = self.class_net(x) x_box = self.box_net(x) return x_class, x_box