import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict def _to_output_strided_layers(convolution_def, output_stride): current_stride = 1 rate = 1 block_id = 0 buff = [] for c in convolution_def: conv_type = c[0] inp = c[1] outp = c[2] stride = c[3] if current_stride == output_stride: layer_stride = 1 layer_rate = rate rate *= stride else: layer_stride = stride layer_rate = 1 current_stride *= stride buff.append({ 'block_id': block_id, 'conv_type': conv_type, 'inp': inp, 'outp': outp, 'stride': layer_stride, 'rate': layer_rate, 'output_stride': current_stride }) block_id += 1 return buff def _get_padding(kernel_size, stride, dilation): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding class InputConv(nn.Module): def __init__(self, inp, outp, k=3, stride=1, dilation=1): super(InputConv, self).__init__() self.conv = nn.Conv2d( inp, outp, k, stride, padding=_get_padding(k, stride, dilation), dilation=dilation) def forward(self, x): return F.relu6(self.conv(x)) class SeperableConv(nn.Module): def __init__(self, inp, outp, k=3, stride=1, dilation=1): super(SeperableConv, self).__init__() self.depthwise = nn.Conv2d( inp, inp, k, stride, padding=_get_padding(k, stride, dilation), dilation=dilation, groups=inp) self.pointwise = nn.Conv2d(inp, outp, 1, 1) def forward(self, x): x = F.relu6(self.depthwise(x)) x = F.relu6(self.pointwise(x)) return x MOBILENET_V1_CHECKPOINTS = { 50: 'mobilenet_v1_050', 75: 'mobilenet_v1_075', 100: 'mobilenet_v1_100', 101: 'mobilenet_v1_101' } MOBILE_NET_V1_100 = [ (InputConv, 3, 32, 2), (SeperableConv, 32, 64, 1), (SeperableConv, 64, 128, 2), (SeperableConv, 128, 128, 1), (SeperableConv, 128, 256, 2), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 512, 2), (SeperableConv, 512, 512, 1), (SeperableConv, 512, 512, 1), (SeperableConv, 512, 512, 1), (SeperableConv, 512, 512, 1), (SeperableConv, 512, 512, 1), (SeperableConv, 512, 1024, 2), (SeperableConv, 1024, 1024, 1) ] MOBILE_NET_V1_75 = [ (InputConv, 3, 24, 2), (SeperableConv, 24, 48, 1), (SeperableConv, 48, 96, 2), (SeperableConv, 96, 96, 1), (SeperableConv, 96, 192, 2), (SeperableConv, 192, 192, 1), (SeperableConv, 192, 384, 2), (SeperableConv, 384, 384, 1), (SeperableConv, 384, 384, 1), (SeperableConv, 384, 384, 1), (SeperableConv, 384, 384, 1), (SeperableConv, 384, 384, 1), (SeperableConv, 384, 384, 1), (SeperableConv, 384, 384, 1) ] MOBILE_NET_V1_50 = [ (InputConv, 3, 16, 2), (SeperableConv, 16, 32, 1), (SeperableConv, 32, 64, 2), (SeperableConv, 64, 64, 1), (SeperableConv, 64, 128, 2), (SeperableConv, 128, 128, 1), (SeperableConv, 128, 256, 2), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 256, 1), (SeperableConv, 256, 256, 1) ] class MobileNetV1(nn.Module): def __init__(self, model_id, output_stride=16): super(MobileNetV1, self).__init__() assert model_id in MOBILENET_V1_CHECKPOINTS.keys() self.output_stride = output_stride if model_id == 50: arch = MOBILE_NET_V1_50 elif model_id == 75: arch = MOBILE_NET_V1_75 else: arch = MOBILE_NET_V1_100 conv_def = _to_output_strided_layers(arch, output_stride) conv_list = [('conv%d' % c['block_id'], c['conv_type']( c['inp'], c['outp'], 3, stride=c['stride'], dilation=c['rate'])) for c in conv_def] last_depth = conv_def[-1]['outp'] self.features = nn.Sequential(OrderedDict(conv_list)) self.heatmap = nn.Conv2d(last_depth, 17, 1, 1) self.offset = nn.Conv2d(last_depth, 34, 1, 1) self.displacement_fwd = nn.Conv2d(last_depth, 32, 1, 1) self.displacement_bwd = nn.Conv2d(last_depth, 32, 1, 1) def forward(self, x): x = self.features(x) heatmap = torch.sigmoid(self.heatmap(x)) offset = self.offset(x) displacement_fwd = self.displacement_fwd(x) displacement_bwd = self.displacement_bwd(x) return heatmap, offset, displacement_fwd, displacement_bwd