Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from collections import OrderedDict | |
def _to_output_strided_layers(convolution_def, output_stride): | |
current_stride = 1 | |
rate = 1 | |
block_id = 0 | |
buff = [] | |
for c in convolution_def: | |
conv_type = c[0] | |
inp = c[1] | |
outp = c[2] | |
stride = c[3] | |
if current_stride == output_stride: | |
layer_stride = 1 | |
layer_rate = rate | |
rate *= stride | |
else: | |
layer_stride = stride | |
layer_rate = 1 | |
current_stride *= stride | |
buff.append({ | |
'block_id': block_id, | |
'conv_type': conv_type, | |
'inp': inp, | |
'outp': outp, | |
'stride': layer_stride, | |
'rate': layer_rate, | |
'output_stride': current_stride | |
}) | |
block_id += 1 | |
return buff | |
def _get_padding(kernel_size, stride, dilation): | |
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 | |
return padding | |
class InputConv(nn.Module): | |
def __init__(self, inp, outp, k=3, stride=1, dilation=1): | |
super(InputConv, self).__init__() | |
self.conv = nn.Conv2d( | |
inp, outp, k, stride, padding=_get_padding(k, stride, dilation), dilation=dilation) | |
def forward(self, x): | |
return F.relu6(self.conv(x)) | |
class SeperableConv(nn.Module): | |
def __init__(self, inp, outp, k=3, stride=1, dilation=1): | |
super(SeperableConv, self).__init__() | |
self.depthwise = nn.Conv2d( | |
inp, inp, k, stride, | |
padding=_get_padding(k, stride, dilation), dilation=dilation, groups=inp) | |
self.pointwise = nn.Conv2d(inp, outp, 1, 1) | |
def forward(self, x): | |
x = F.relu6(self.depthwise(x)) | |
x = F.relu6(self.pointwise(x)) | |
return x | |
MOBILENET_V1_CHECKPOINTS = { | |
50: 'mobilenet_v1_050', | |
75: 'mobilenet_v1_075', | |
100: 'mobilenet_v1_100', | |
101: 'mobilenet_v1_101' | |
} | |
MOBILE_NET_V1_100 = [ | |
(InputConv, 3, 32, 2), | |
(SeperableConv, 32, 64, 1), | |
(SeperableConv, 64, 128, 2), | |
(SeperableConv, 128, 128, 1), | |
(SeperableConv, 128, 256, 2), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 512, 2), | |
(SeperableConv, 512, 512, 1), | |
(SeperableConv, 512, 512, 1), | |
(SeperableConv, 512, 512, 1), | |
(SeperableConv, 512, 512, 1), | |
(SeperableConv, 512, 512, 1), | |
(SeperableConv, 512, 1024, 2), | |
(SeperableConv, 1024, 1024, 1) | |
] | |
MOBILE_NET_V1_75 = [ | |
(InputConv, 3, 24, 2), | |
(SeperableConv, 24, 48, 1), | |
(SeperableConv, 48, 96, 2), | |
(SeperableConv, 96, 96, 1), | |
(SeperableConv, 96, 192, 2), | |
(SeperableConv, 192, 192, 1), | |
(SeperableConv, 192, 384, 2), | |
(SeperableConv, 384, 384, 1), | |
(SeperableConv, 384, 384, 1), | |
(SeperableConv, 384, 384, 1), | |
(SeperableConv, 384, 384, 1), | |
(SeperableConv, 384, 384, 1), | |
(SeperableConv, 384, 384, 1), | |
(SeperableConv, 384, 384, 1) | |
] | |
MOBILE_NET_V1_50 = [ | |
(InputConv, 3, 16, 2), | |
(SeperableConv, 16, 32, 1), | |
(SeperableConv, 32, 64, 2), | |
(SeperableConv, 64, 64, 1), | |
(SeperableConv, 64, 128, 2), | |
(SeperableConv, 128, 128, 1), | |
(SeperableConv, 128, 256, 2), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 256, 1), | |
(SeperableConv, 256, 256, 1) | |
] | |
class MobileNetV1(nn.Module): | |
def __init__(self, model_id, output_stride=16): | |
super(MobileNetV1, self).__init__() | |
assert model_id in MOBILENET_V1_CHECKPOINTS.keys() | |
self.output_stride = output_stride | |
if model_id == 50: | |
arch = MOBILE_NET_V1_50 | |
elif model_id == 75: | |
arch = MOBILE_NET_V1_75 | |
else: | |
arch = MOBILE_NET_V1_100 | |
conv_def = _to_output_strided_layers(arch, output_stride) | |
conv_list = [('conv%d' % c['block_id'], c['conv_type']( | |
c['inp'], c['outp'], 3, stride=c['stride'], dilation=c['rate'])) | |
for c in conv_def] | |
last_depth = conv_def[-1]['outp'] | |
self.features = nn.Sequential(OrderedDict(conv_list)) | |
self.heatmap = nn.Conv2d(last_depth, 17, 1, 1) | |
self.offset = nn.Conv2d(last_depth, 34, 1, 1) | |
self.displacement_fwd = nn.Conv2d(last_depth, 32, 1, 1) | |
self.displacement_bwd = nn.Conv2d(last_depth, 32, 1, 1) | |
def forward(self, x): | |
x = self.features(x) | |
heatmap = torch.sigmoid(self.heatmap(x)) | |
offset = self.offset(x) | |
displacement_fwd = self.displacement_fwd(x) | |
displacement_bwd = self.displacement_bwd(x) | |
return heatmap, offset, displacement_fwd, displacement_bwd | |