wizzseen's picture
Upload 948 files
8a6df40 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
def _to_output_strided_layers(convolution_def, output_stride):
current_stride = 1
rate = 1
block_id = 0
buff = []
for c in convolution_def:
conv_type = c[0]
inp = c[1]
outp = c[2]
stride = c[3]
if current_stride == output_stride:
layer_stride = 1
layer_rate = rate
rate *= stride
else:
layer_stride = stride
layer_rate = 1
current_stride *= stride
buff.append({
'block_id': block_id,
'conv_type': conv_type,
'inp': inp,
'outp': outp,
'stride': layer_stride,
'rate': layer_rate,
'output_stride': current_stride
})
block_id += 1
return buff
def _get_padding(kernel_size, stride, dilation):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
class InputConv(nn.Module):
def __init__(self, inp, outp, k=3, stride=1, dilation=1):
super(InputConv, self).__init__()
self.conv = nn.Conv2d(
inp, outp, k, stride, padding=_get_padding(k, stride, dilation), dilation=dilation)
def forward(self, x):
return F.relu6(self.conv(x))
class SeperableConv(nn.Module):
def __init__(self, inp, outp, k=3, stride=1, dilation=1):
super(SeperableConv, self).__init__()
self.depthwise = nn.Conv2d(
inp, inp, k, stride,
padding=_get_padding(k, stride, dilation), dilation=dilation, groups=inp)
self.pointwise = nn.Conv2d(inp, outp, 1, 1)
def forward(self, x):
x = F.relu6(self.depthwise(x))
x = F.relu6(self.pointwise(x))
return x
MOBILENET_V1_CHECKPOINTS = {
50: 'mobilenet_v1_050',
75: 'mobilenet_v1_075',
100: 'mobilenet_v1_100',
101: 'mobilenet_v1_101'
}
MOBILE_NET_V1_100 = [
(InputConv, 3, 32, 2),
(SeperableConv, 32, 64, 1),
(SeperableConv, 64, 128, 2),
(SeperableConv, 128, 128, 1),
(SeperableConv, 128, 256, 2),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 512, 2),
(SeperableConv, 512, 512, 1),
(SeperableConv, 512, 512, 1),
(SeperableConv, 512, 512, 1),
(SeperableConv, 512, 512, 1),
(SeperableConv, 512, 512, 1),
(SeperableConv, 512, 1024, 2),
(SeperableConv, 1024, 1024, 1)
]
MOBILE_NET_V1_75 = [
(InputConv, 3, 24, 2),
(SeperableConv, 24, 48, 1),
(SeperableConv, 48, 96, 2),
(SeperableConv, 96, 96, 1),
(SeperableConv, 96, 192, 2),
(SeperableConv, 192, 192, 1),
(SeperableConv, 192, 384, 2),
(SeperableConv, 384, 384, 1),
(SeperableConv, 384, 384, 1),
(SeperableConv, 384, 384, 1),
(SeperableConv, 384, 384, 1),
(SeperableConv, 384, 384, 1),
(SeperableConv, 384, 384, 1),
(SeperableConv, 384, 384, 1)
]
MOBILE_NET_V1_50 = [
(InputConv, 3, 16, 2),
(SeperableConv, 16, 32, 1),
(SeperableConv, 32, 64, 2),
(SeperableConv, 64, 64, 1),
(SeperableConv, 64, 128, 2),
(SeperableConv, 128, 128, 1),
(SeperableConv, 128, 256, 2),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 256, 1),
(SeperableConv, 256, 256, 1)
]
class MobileNetV1(nn.Module):
def __init__(self, model_id, output_stride=16):
super(MobileNetV1, self).__init__()
assert model_id in MOBILENET_V1_CHECKPOINTS.keys()
self.output_stride = output_stride
if model_id == 50:
arch = MOBILE_NET_V1_50
elif model_id == 75:
arch = MOBILE_NET_V1_75
else:
arch = MOBILE_NET_V1_100
conv_def = _to_output_strided_layers(arch, output_stride)
conv_list = [('conv%d' % c['block_id'], c['conv_type'](
c['inp'], c['outp'], 3, stride=c['stride'], dilation=c['rate']))
for c in conv_def]
last_depth = conv_def[-1]['outp']
self.features = nn.Sequential(OrderedDict(conv_list))
self.heatmap = nn.Conv2d(last_depth, 17, 1, 1)
self.offset = nn.Conv2d(last_depth, 34, 1, 1)
self.displacement_fwd = nn.Conv2d(last_depth, 32, 1, 1)
self.displacement_bwd = nn.Conv2d(last_depth, 32, 1, 1)
def forward(self, x):
x = self.features(x)
heatmap = torch.sigmoid(self.heatmap(x))
offset = self.offset(x)
displacement_fwd = self.displacement_fwd(x)
displacement_bwd = self.displacement_bwd(x)
return heatmap, offset, displacement_fwd, displacement_bwd