Spaces:
Sleeping
Sleeping
""" PyTorch EfficientDet model | |
Based on official Tensorflow version at: https://github.com/google/automl/tree/master/efficientdet | |
Paper: https://arxiv.org/abs/1911.09070 | |
Hacked together by Ross Wightman | |
""" | |
import torch | |
import torch.nn as nn | |
import logging | |
import math | |
from collections import OrderedDict | |
from typing import List, Callable | |
from functools import partial | |
from timm import create_model | |
from timm.models.layers import create_conv2d, drop_path, create_pool2d, Swish, get_act_layer | |
from .config import get_fpn_config, set_config_writeable, set_config_readonly | |
_DEBUG = False | |
_ACT_LAYER = Swish | |
class SequentialList(nn.Sequential): | |
""" This module exists to work around torchscript typing issues list -> list""" | |
def __init__(self, *args): | |
super(SequentialList, self).__init__(*args) | |
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |
for module in self: | |
x = module(x) | |
return x | |
class ConvBnAct2d(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding='', bias=False, | |
norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER): | |
super(ConvBnAct2d, self).__init__() | |
self.conv = create_conv2d( | |
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias) | |
self.bn = None if norm_layer is None else norm_layer(out_channels) | |
self.act = None if act_layer is None else act_layer(inplace=True) | |
def forward(self, x): | |
x = self.conv(x) | |
if self.bn is not None: | |
x = self.bn(x) | |
if self.act is not None: | |
x = self.act(x) | |
return x | |
class SeparableConv2d(nn.Module): | |
""" Separable Conv | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, | |
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER): | |
super(SeparableConv2d, self).__init__() | |
self.conv_dw = create_conv2d( | |
in_channels, int(in_channels * channel_multiplier), kernel_size, | |
stride=stride, dilation=dilation, padding=padding, depthwise=True) | |
self.conv_pw = create_conv2d( | |
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) | |
self.bn = None if norm_layer is None else norm_layer(out_channels) | |
self.act = None if act_layer is None else act_layer(inplace=True) | |
def forward(self, x): | |
x = self.conv_dw(x) | |
x = self.conv_pw(x) | |
if self.bn is not None: | |
x = self.bn(x) | |
if self.act is not None: | |
x = self.act(x) | |
return x | |
class ResampleFeatureMap(nn.Sequential): | |
def __init__(self, in_channels, out_channels, reduction_ratio=1., pad_type='', pooling_type='max', | |
norm_layer=nn.BatchNorm2d, apply_bn=False, conv_after_downsample=False, redundant_bias=False): | |
super(ResampleFeatureMap, self).__init__() | |
pooling_type = pooling_type or 'max' | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.reduction_ratio = reduction_ratio | |
self.conv_after_downsample = conv_after_downsample | |
conv = None | |
if in_channels != out_channels: | |
conv = ConvBnAct2d( | |
in_channels, out_channels, kernel_size=1, padding=pad_type, | |
norm_layer=norm_layer if apply_bn else None, | |
bias=not apply_bn or redundant_bias, act_layer=None) | |
if reduction_ratio > 1: | |
stride_size = int(reduction_ratio) | |
if conv is not None and not self.conv_after_downsample: | |
self.add_module('conv', conv) | |
self.add_module( | |
'downsample', | |
create_pool2d( | |
pooling_type, kernel_size=stride_size + 1, stride=stride_size, padding=pad_type)) | |
if conv is not None and self.conv_after_downsample: | |
self.add_module('conv', conv) | |
else: | |
if conv is not None: | |
self.add_module('conv', conv) | |
if reduction_ratio < 1: | |
scale = int(1 // reduction_ratio) | |
self.add_module('upsample', nn.UpsamplingNearest2d(scale_factor=scale)) | |
# def forward(self, x): | |
# # here for debugging only | |
# assert x.shape[1] == self.in_channels | |
# if self.reduction_ratio > 1: | |
# if hasattr(self, 'conv') and not self.conv_after_downsample: | |
# x = self.conv(x) | |
# x = self.downsample(x) | |
# if hasattr(self, 'conv') and self.conv_after_downsample: | |
# x = self.conv(x) | |
# else: | |
# if hasattr(self, 'conv'): | |
# x = self.conv(x) | |
# if self.reduction_ratio < 1: | |
# x = self.upsample(x) | |
# return x | |
class FpnCombine(nn.Module): | |
def __init__(self, feature_info, fpn_config, fpn_channels, inputs_offsets, target_reduction, pad_type='', | |
pooling_type='max', norm_layer=nn.BatchNorm2d, apply_bn_for_resampling=False, | |
conv_after_downsample=False, redundant_bias=False, weight_method='attn'): | |
super(FpnCombine, self).__init__() | |
self.inputs_offsets = inputs_offsets | |
self.weight_method = weight_method | |
self.resample = nn.ModuleDict() | |
for idx, offset in enumerate(inputs_offsets): | |
in_channels = fpn_channels | |
if offset < len(feature_info): | |
in_channels = feature_info[offset]['num_chs'] | |
input_reduction = feature_info[offset]['reduction'] | |
else: | |
node_idx = offset - len(feature_info) | |
input_reduction = fpn_config.nodes[node_idx]['reduction'] | |
reduction_ratio = target_reduction / input_reduction | |
self.resample[str(offset)] = ResampleFeatureMap( | |
in_channels, fpn_channels, reduction_ratio=reduction_ratio, pad_type=pad_type, | |
pooling_type=pooling_type, norm_layer=norm_layer, apply_bn=apply_bn_for_resampling, | |
conv_after_downsample=conv_after_downsample, redundant_bias=redundant_bias) | |
if weight_method == 'attn' or weight_method == 'fastattn': | |
self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) # WSM | |
else: | |
self.edge_weights = None | |
def forward(self, x: List[torch.Tensor]): | |
dtype = x[0].dtype | |
nodes = [] | |
for offset, resample in zip(self.inputs_offsets, self.resample.values()): | |
input_node = x[offset] | |
input_node = resample(input_node) | |
nodes.append(input_node) | |
if self.weight_method == 'attn': | |
normalized_weights = torch.softmax(self.edge_weights.to(dtype=dtype), dim=0) | |
out = torch.stack(nodes, dim=-1) * normalized_weights | |
elif self.weight_method == 'fastattn': | |
edge_weights = nn.functional.relu(self.edge_weights.to(dtype=dtype)) | |
weights_sum = torch.sum(edge_weights) | |
out = torch.stack( | |
[(nodes[i] * edge_weights[i]) / (weights_sum + 0.0001) for i in range(len(nodes))], dim=-1) | |
elif self.weight_method == 'sum': | |
out = torch.stack(nodes, dim=-1) | |
else: | |
raise ValueError('unknown weight_method {}'.format(self.weight_method)) | |
out = torch.sum(out, dim=-1) | |
return out | |
class Fnode(nn.Module): | |
""" A simple wrapper used in place of nn.Sequential for torchscript typing | |
Handles input type List[Tensor] -> output type Tensor | |
""" | |
def __init__(self, combine: nn.Module, after_combine: nn.Module): | |
super(Fnode, self).__init__() | |
self.combine = combine | |
self.after_combine = after_combine | |
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: | |
return self.after_combine(self.combine(x)) | |
class BiFpnLayer(nn.Module): | |
def __init__(self, feature_info, fpn_config, fpn_channels, num_levels=5, pad_type='', | |
pooling_type='max', norm_layer=nn.BatchNorm2d, act_layer=_ACT_LAYER, | |
apply_bn_for_resampling=False, conv_after_downsample=True, conv_bn_relu_pattern=False, | |
separable_conv=True, redundant_bias=False): | |
super(BiFpnLayer, self).__init__() | |
self.num_levels = num_levels | |
self.conv_bn_relu_pattern = False | |
self.feature_info = [] | |
self.fnode = nn.ModuleList() | |
for i, fnode_cfg in enumerate(fpn_config.nodes): | |
logging.debug('fnode {} : {}'.format(i, fnode_cfg)) | |
reduction = fnode_cfg['reduction'] | |
combine = FpnCombine( | |
feature_info, fpn_config, fpn_channels, tuple(fnode_cfg['inputs_offsets']), | |
target_reduction=reduction, pad_type=pad_type, pooling_type=pooling_type, norm_layer=norm_layer, | |
apply_bn_for_resampling=apply_bn_for_resampling, conv_after_downsample=conv_after_downsample, | |
redundant_bias=redundant_bias, weight_method=fnode_cfg['weight_method']) | |
after_combine = nn.Sequential() | |
conv_kwargs = dict( | |
in_channels=fpn_channels, out_channels=fpn_channels, kernel_size=3, padding=pad_type, | |
bias=False, norm_layer=norm_layer, act_layer=act_layer) | |
if not conv_bn_relu_pattern: | |
conv_kwargs['bias'] = redundant_bias | |
conv_kwargs['act_layer'] = None | |
after_combine.add_module('act', act_layer(inplace=True)) | |
after_combine.add_module( | |
'conv', SeparableConv2d(**conv_kwargs) if separable_conv else ConvBnAct2d(**conv_kwargs)) | |
self.fnode.append(Fnode(combine=combine, after_combine=after_combine)) | |
self.feature_info.append(dict(num_chs=fpn_channels, reduction=reduction)) | |
self.feature_info = self.feature_info[-num_levels::] | |
def forward(self, x: List[torch.Tensor]): | |
for fn in self.fnode: | |
x.append(fn(x)) | |
return x[-self.num_levels::] | |
class BiFpn(nn.Module): | |
def __init__(self, config, feature_info): | |
super(BiFpn, self).__init__() | |
self.num_levels = config.num_levels | |
norm_layer = config.norm_layer or nn.BatchNorm2d | |
if config.norm_kwargs: | |
norm_layer = partial(norm_layer, **config.norm_kwargs) | |
act_layer = get_act_layer(config.act_type) or _ACT_LAYER | |
fpn_config = config.fpn_config or get_fpn_config( | |
config.fpn_name, min_level=config.min_level, max_level=config.max_level) | |
self.resample = nn.ModuleDict() | |
for level in range(config.num_levels): | |
if level < len(feature_info): | |
in_chs = feature_info[level]['num_chs'] | |
reduction = feature_info[level]['reduction'] | |
else: | |
# Adds a coarser level by downsampling the last feature map | |
reduction_ratio = 2 | |
self.resample[str(level)] = ResampleFeatureMap( | |
in_channels=in_chs, | |
out_channels=config.fpn_channels, | |
pad_type=config.pad_type, | |
pooling_type=config.pooling_type, | |
norm_layer=norm_layer, | |
reduction_ratio=reduction_ratio, | |
apply_bn=config.apply_bn_for_resampling, | |
conv_after_downsample=config.conv_after_downsample, | |
redundant_bias=config.redundant_bias, | |
) | |
in_chs = config.fpn_channels | |
reduction = int(reduction * reduction_ratio) | |
feature_info.append(dict(num_chs=in_chs, reduction=reduction)) | |
self.cell = SequentialList() | |
for rep in range(config.fpn_cell_repeats): | |
logging.debug('building cell {}'.format(rep)) | |
fpn_layer = BiFpnLayer( | |
feature_info=feature_info, | |
fpn_config=fpn_config, | |
fpn_channels=config.fpn_channels, | |
num_levels=config.num_levels, | |
pad_type=config.pad_type, | |
pooling_type=config.pooling_type, | |
norm_layer=norm_layer, | |
act_layer=act_layer, | |
separable_conv=config.separable_conv, | |
apply_bn_for_resampling=config.apply_bn_for_resampling, | |
conv_after_downsample=config.conv_after_downsample, | |
conv_bn_relu_pattern=config.conv_bn_relu_pattern, | |
redundant_bias=config.redundant_bias, | |
) | |
self.cell.add_module(str(rep), fpn_layer) | |
feature_info = fpn_layer.feature_info | |
def forward(self, x: List[torch.Tensor]): | |
for resample in self.resample.values(): | |
x.append(resample(x[-1])) | |
x = self.cell(x) | |
return x | |
class HeadNet(nn.Module): | |
def __init__(self, config, num_outputs): | |
super(HeadNet, self).__init__() | |
self.num_levels = config.num_levels | |
self.bn_level_first = getattr(config, 'head_bn_level_first', False) | |
norm_layer = config.norm_layer or nn.BatchNorm2d | |
if config.norm_kwargs: | |
norm_layer = partial(norm_layer, **config.norm_kwargs) | |
act_layer = get_act_layer(config.act_type) or _ACT_LAYER | |
# Build convolution repeats | |
conv_fn = SeparableConv2d if config.separable_conv else ConvBnAct2d | |
conv_kwargs = dict( | |
in_channels=config.fpn_channels, out_channels=config.fpn_channels, kernel_size=3, | |
padding=config.pad_type, bias=config.redundant_bias, act_layer=None, norm_layer=None) | |
self.conv_rep = nn.ModuleList([conv_fn(**conv_kwargs) for _ in range(config.box_class_repeats)]) | |
# Build batchnorm repeats. There is a unique batchnorm per feature level for each repeat. | |
# This can be organized with repeats first or feature levels first in module lists, the original models | |
# and weights were setup with repeats first, levels first is required for efficient torchscript usage. | |
self.bn_rep = nn.ModuleList() | |
if self.bn_level_first: | |
for _ in range(self.num_levels): | |
self.bn_rep.append(nn.ModuleList([ | |
norm_layer(config.fpn_channels) for _ in range(config.box_class_repeats)])) | |
else: | |
for _ in range(config.box_class_repeats): | |
self.bn_rep.append(nn.ModuleList([ | |
nn.Sequential(OrderedDict([('bn', norm_layer(config.fpn_channels))])) | |
for _ in range(self.num_levels)])) | |
self.act = act_layer(inplace=True) | |
# Prediction (output) layer. Has bias with special init reqs, see init fn. | |
num_anchors = len(config.aspect_ratios) * config.num_scales | |
predict_kwargs = dict( | |
in_channels=config.fpn_channels, out_channels=num_outputs * num_anchors, kernel_size=3, | |
padding=config.pad_type, bias=True, norm_layer=None, act_layer=None) | |
self.predict = conv_fn(**predict_kwargs) | |
def toggle_bn_level_first(self): | |
""" Toggle the batchnorm layers between feature level first vs repeat first access pattern | |
Limitations in torchscript require feature levels to be iterated over first. | |
This function can be used to allow loading weights in the original order, and then toggle before | |
jit scripting the model. | |
""" | |
with torch.no_grad(): | |
new_bn_rep = nn.ModuleList() | |
for i in range(len(self.bn_rep[0])): | |
bn_first = nn.ModuleList() | |
for r in self.bn_rep.children(): | |
m = r[i] | |
# NOTE original rep first model def has extra Sequential container with 'bn', this was | |
# flattened in the level first definition. | |
bn_first.append(m[0] if isinstance(m, nn.Sequential) else nn.Sequential(OrderedDict([('bn', m)]))) | |
new_bn_rep.append(bn_first) | |
self.bn_level_first = not self.bn_level_first | |
self.bn_rep = new_bn_rep | |
def _forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |
outputs = [] | |
for level in range(self.num_levels): | |
x_level = x[level] | |
for conv, bn in zip(self.conv_rep, self.bn_rep): | |
x_level = conv(x_level) | |
x_level = bn[level](x_level) # this is not allowed in torchscript | |
x_level = self.act(x_level) | |
outputs.append(self.predict(x_level)) | |
return outputs | |
def _forward_level_first(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |
outputs = [] | |
for level, bn_rep in enumerate(self.bn_rep): # iterating over first bn dim first makes TS happy | |
x_level = x[level] | |
for conv, bn in zip(self.conv_rep, bn_rep): | |
x_level = conv(x_level) | |
x_level = bn(x_level) | |
x_level = self.act(x_level) | |
outputs.append(self.predict(x_level)) | |
return outputs | |
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |
if self.bn_level_first: | |
return self._forward_level_first(x) | |
else: | |
return self._forward(x) | |
def _init_weight(m, n='', ): | |
""" Weight initialization as per Tensorflow official implementations. | |
""" | |
def _fan_in_out(w, groups=1): | |
dimensions = w.dim() | |
if dimensions < 2: | |
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | |
num_input_fmaps = w.size(1) | |
num_output_fmaps = w.size(0) | |
receptive_field_size = 1 | |
if w.dim() > 2: | |
receptive_field_size = w[0][0].numel() | |
fan_in = num_input_fmaps * receptive_field_size | |
fan_out = num_output_fmaps * receptive_field_size | |
fan_out //= groups | |
return fan_in, fan_out | |
def _glorot_uniform(w, gain=1, groups=1): | |
fan_in, fan_out = _fan_in_out(w, groups) | |
gain /= max(1., (fan_in + fan_out) / 2.) # fan avg | |
limit = math.sqrt(3.0 * gain) | |
w.data.uniform_(-limit, limit) | |
def _variance_scaling(w, gain=1, groups=1): | |
fan_in, fan_out = _fan_in_out(w, groups) | |
gain /= max(1., fan_in) # fan in | |
# gain /= max(1., (fan_in + fan_out) / 2.) # fan | |
# should it be normal or trunc normal? using normal for now since no good trunc in PT | |
# constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) | |
# std = math.sqrt(gain) / .87962566103423978 | |
# w.data.trunc_normal(std=std) | |
std = math.sqrt(gain) | |
w.data.normal_(std=std) | |
if isinstance(m, SeparableConv2d): | |
if 'box_net' in n or 'class_net' in n: | |
_variance_scaling(m.conv_dw.weight, groups=m.conv_dw.groups) | |
_variance_scaling(m.conv_pw.weight) | |
if m.conv_pw.bias is not None: | |
if 'class_net.predict' in n: | |
m.conv_pw.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) | |
else: | |
m.conv_pw.bias.data.zero_() | |
else: | |
_glorot_uniform(m.conv_dw.weight, groups=m.conv_dw.groups) | |
_glorot_uniform(m.conv_pw.weight) | |
if m.conv_pw.bias is not None: | |
m.conv_pw.bias.data.zero_() | |
elif isinstance(m, ConvBnAct2d): | |
if 'box_net' in n or 'class_net' in n: | |
m.conv.weight.data.normal_(std=.01) | |
if m.conv.bias is not None: | |
if 'class_net.predict' in n: | |
m.conv.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) | |
else: | |
m.conv.bias.data.zero_() | |
else: | |
_glorot_uniform(m.conv.weight) | |
if m.conv.bias is not None: | |
m.conv.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
# looks like all bn init the same? | |
m.weight.data.fill_(1.0) | |
m.bias.data.zero_() | |
def _init_weight_alt(m, n='', ): | |
""" Weight initialization alternative, based on EfficientNet bacbkone init w/ class bias addition | |
NOTE: this will likely be removed after some experimentation | |
""" | |
if isinstance(m, nn.Conv2d): | |
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
fan_out //= m.groups | |
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |
if m.bias is not None: | |
if 'class_net.predict' in n: | |
m.bias.data.fill_(-math.log((1 - 0.01) / 0.01)) | |
else: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1.0) | |
m.bias.data.zero_() | |
def get_feature_info(backbone): | |
if isinstance(backbone.feature_info, Callable): | |
# old accessor for timm versions <= 0.1.30, efficientnet and mobilenetv3 and related nets only | |
feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) | |
for i, f in enumerate(backbone.feature_info())] | |
else: | |
# new feature info accessor, timm >= 0.2, all models supported | |
feature_info = backbone.feature_info.get_dicts(keys=['num_chs', 'reduction']) | |
return feature_info | |
class EfficientDet(nn.Module): | |
def __init__(self, config, pretrained_backbone=True, alternate_init=False): | |
super(EfficientDet, self).__init__() | |
self.config = config | |
set_config_readonly(self.config) | |
self.backbone = create_model( | |
config.backbone_name, features_only=True, out_indices=(2, 3, 4), | |
pretrained=pretrained_backbone, **config.backbone_args) | |
feature_info = get_feature_info(self.backbone) | |
self.fpn = BiFpn(self.config, feature_info) | |
self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes) | |
self.box_net = HeadNet(self.config, num_outputs=4) | |
for n, m in self.named_modules(): | |
if 'backbone' not in n: | |
if alternate_init: | |
_init_weight_alt(m, n) | |
else: | |
_init_weight(m, n) | |
def reset_head(self, num_classes=None, aspect_ratios=None, num_scales=None, alternate_init=False): | |
reset_class_head = False | |
reset_box_head = False | |
set_config_writeable(self.config) | |
if num_classes is not None: | |
reset_class_head = True | |
self.config.num_classes = num_classes | |
if aspect_ratios is not None: | |
reset_box_head = True | |
self.config.aspect_ratios = aspect_ratios | |
if num_scales is not None: | |
reset_box_head = True | |
self.config.num_scales = num_scales | |
set_config_readonly(self.config) | |
if reset_class_head: | |
self.class_net = HeadNet(self.config, num_outputs=self.config.num_classes) | |
for n, m in self.class_net.named_modules(prefix='class_net'): | |
if alternate_init: | |
_init_weight_alt(m, n) | |
else: | |
_init_weight(m, n) | |
if reset_box_head: | |
self.box_net = HeadNet(self.config, num_outputs=4) | |
for n, m in self.box_net.named_modules(prefix='box_net'): | |
if alternate_init: | |
_init_weight_alt(m, n) | |
else: | |
_init_weight(m, n) | |
def toggle_head_bn_level_first(self): | |
""" Toggle the head batchnorm layers between being access with feature_level first vs repeat | |
""" | |
self.class_net.toggle_bn_level_first() | |
self.box_net.toggle_bn_level_first() | |
def forward(self, x): | |
x = self.backbone(x) | |
x = self.fpn(x) | |
x_class = self.class_net(x) | |
x_box = self.box_net(x) | |
return x_class, x_box | |