wizzseen's picture
Upload 948 files
8a6df40 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn.parameter import Parameter
from collections import OrderedDict
class SeparableConv2d(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
groups=inplanes, bias=bias)
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.conv1(x)
x = self.pointwise(x)
return x
def fixed_padding(inputs, kernel_size, rate):
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
return padded_inputs
class SeparableConv2d_aspp(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
super(SeparableConv2d_aspp, self).__init__()
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
groups=inplanes, bias=bias)
self.depthwise_bn = nn.BatchNorm2d(inplanes)
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
self.pointwise_bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
def forward(self, x):
# x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
x = self.depthwise(x)
x = self.depthwise_bn(x)
x = self.relu(x)
x = self.pointwise(x)
x = self.pointwise_bn(x)
x = self.relu(x)
return x
class Decoder_module(nn.Module):
def __init__(self, inplanes, planes, rate=1):
super(Decoder_module, self).__init__()
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1)
def forward(self, x):
x = self.atrous_convolution(x)
return x
class ASPP_module(nn.Module):
def __init__(self, inplanes, planes, rate):
super(ASPP_module, self).__init__()
if rate == 1:
raise RuntimeError()
else:
kernel_size = 3
padding = rate
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,
padding=padding)
def forward(self, x):
x = self.atrous_convolution(x)
return x
class ASPP_module_rate0(nn.Module):
def __init__(self, inplanes, planes, rate=1):
super(ASPP_module_rate0, self).__init__()
if rate == 1:
kernel_size = 1
padding = 0
self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
stride=1, padding=padding, dilation=rate, bias=False)
self.bn = nn.BatchNorm2d(planes, eps=1e-5, affine=True)
self.relu = nn.ReLU()
else:
raise RuntimeError()
def forward(self, x):
x = self.atrous_convolution(x)
x = self.bn(x)
return self.relu(x)
class SeparableConv2d_same(nn.Module):
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0):
super(SeparableConv2d_same, self).__init__()
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
groups=inplanes, bias=bias)
self.depthwise_bn = nn.BatchNorm2d(inplanes)
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
self.pointwise_bn = nn.BatchNorm2d(planes)
def forward(self, x):
x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0])
x = self.depthwise(x)
x = self.depthwise_bn(x)
x = self.pointwise(x)
x = self.pointwise_bn(x)
return x
class Block(nn.Module):
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
super(Block, self).__init__()
if planes != inplanes or stride != 1:
self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False)
if is_last:
self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False)
self.skipbn = nn.BatchNorm2d(planes)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = inplanes
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
# rep.append(nn.BatchNorm2d(planes))
filters = planes
for i in range(reps - 1):
rep.append(self.relu)
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
# rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
# rep.append(nn.BatchNorm2d(planes))
if not start_with_relu:
rep = rep[1:]
if stride != 1:
rep.append(self.relu)
rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation))
if is_last:
rep.append(self.relu)
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation))
self.rep = nn.Sequential(*rep)
def forward(self, inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
# print(x.size(),skip.size())
x += skip
return x
class Block2(nn.Module):
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False):
super(Block2, self).__init__()
if planes != inplanes or stride != 1:
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
self.skipbn = nn.BatchNorm2d(planes)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = inplanes
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
# rep.append(nn.BatchNorm2d(planes))
filters = planes
for i in range(reps - 1):
rep.append(self.relu)
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
# rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
# rep.append(nn.BatchNorm2d(planes))
if not start_with_relu:
rep = rep[1:]
if stride != 1:
self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)])
if is_last:
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1))
self.rep = nn.Sequential(*rep)
def forward(self, inp):
x = self.rep(inp)
low_middle = x.clone()
x1 = x
x1 = self.block2_lastconv(x1)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x1 += skip
return x1,low_middle
class Xception(nn.Module):
"""
Modified Alighed Xception
"""
def __init__(self, inplanes=3, os=16, pretrained=False):
super(Xception, self).__init__()
if os == 16:
entry_block3_stride = 2
middle_block_rate = 1
exit_block_rates = (1, 2)
elif os == 8:
entry_block3_stride = 1
middle_block_rate = 2
exit_block_rates = (2, 4)
else:
raise NotImplementedError
# Entry flow
self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True)
# Middle flow
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True)
# Exit flow
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0],
start_with_relu=True, grow_first=False, is_last=True)
self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
# self.bn3 = nn.BatchNorm2d(1536)
self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
# self.bn4 = nn.BatchNorm2d(1536)
self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1])
# self.bn5 = nn.BatchNorm2d(2048)
# Init weights
# self.__init_weight()
# Load pretrained model
if pretrained:
self.__load_xception_pretrained()
def forward(self, x):
# Entry flow
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# print('conv1 ',x.size())
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.block1(x)
# print('block1',x.size())
# low_level_feat = x
x,low_level_feat = self.block2(x)
# print('block2',x.size())
x = self.block3(x)
# print('xception block3 ',x.size())
# Middle flow
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.block13(x)
x = self.block14(x)
x = self.block15(x)
x = self.block16(x)
x = self.block17(x)
x = self.block18(x)
x = self.block19(x)
# Exit flow
x = self.block20(x)
x = self.conv3(x)
# x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
# x = self.bn4(x)
x = self.relu(x)
x = self.conv5(x)
# x = self.bn5(x)
x = self.relu(x)
return x, low_level_feat
def __init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def __load_xception_pretrained(self):
pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
if 'pointwise' in k:
v = v.unsqueeze(-1).unsqueeze(-1)
if k.startswith('block12'):
model_dict[k.replace('block12', 'block20')] = v
elif k.startswith('block11'):
model_dict[k.replace('block11', 'block12')] = v
model_dict[k.replace('block11', 'block13')] = v
model_dict[k.replace('block11', 'block14')] = v
model_dict[k.replace('block11', 'block15')] = v
model_dict[k.replace('block11', 'block16')] = v
model_dict[k.replace('block11', 'block17')] = v
model_dict[k.replace('block11', 'block18')] = v
model_dict[k.replace('block11', 'block19')] = v
elif k.startswith('conv3'):
model_dict[k] = v
elif k.startswith('bn3'):
model_dict[k] = v
model_dict[k.replace('bn3', 'bn4')] = v
elif k.startswith('conv4'):
model_dict[k.replace('conv4', 'conv5')] = v
elif k.startswith('bn4'):
model_dict[k.replace('bn4', 'bn5')] = v
else:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
class DeepLabv3_plus(nn.Module):
def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True):
if _print:
print("Constructing DeepLabv3+ model...")
print("Number of classes: {}".format(n_classes))
print("Output stride: {}".format(os))
print("Number of Input Channels: {}".format(nInputChannels))
super(DeepLabv3_plus, self).__init__()
# Atrous Conv
self.xception_features = Xception(nInputChannels, os, pretrained)
# ASPP
if os == 16:
rates = [1, 6, 12, 18]
elif os == 8:
rates = [1, 12, 24, 36]
raise NotImplementedError
else:
raise NotImplementedError
self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0])
self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
self.relu = nn.ReLU()
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(2048, 256, 1, stride=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False)
self.concat_projection_bn1 = nn.BatchNorm2d(256)
# adopt [1x1, 48] for channel reduction.
self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False)
self.feature_projection_bn1 = nn.BatchNorm2d(48)
self.decoder = nn.Sequential(Decoder_module(304, 256),
Decoder_module(256, 256)
)
self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1)
def forward(self, input):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
def freeze_bn(self):
for m in self.xception_features.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def freeze_totally_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def freeze_aspp_bn(self):
for m in self.aspp1.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
for m in self.aspp2.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
for m in self.aspp3.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
for m in self.aspp4.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def learnable_parameters(self):
layer_features_BN = []
layer_features = []
layer_aspp = []
layer_projection =[]
layer_decoder = []
layer_other = []
model_para = list(self.named_parameters())
for name,para in model_para:
if 'xception' in name:
if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name:
layer_features_BN.append(para)
else:
layer_features.append(para)
# print (name)
elif 'aspp' in name:
layer_aspp.append(para)
elif 'projection' in name:
layer_projection.append(para)
elif 'decode' in name:
layer_decoder.append(para)
elif 'global' not in name:
layer_other.append(para)
return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other
def get_backbone_para(self):
layer_features = []
other_features = []
model_para = list(self.named_parameters())
for name, para in model_para:
if 'xception' in name:
layer_features.append(para)
else:
other_features.append(para)
return layer_features, other_features
def train_fixbn(self, mode=True, freeze_bn=True, freeze_bn_affine=False):
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Returns:
Module: self
"""
super(DeepLabv3_plus, self).train(mode)
if freeze_bn:
print("Freezing Mean/Var of BatchNorm2D.")
if freeze_bn_affine:
print("Freezing Weight/Bias of BatchNorm2D.")
if freeze_bn:
for m in self.xception_features.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if freeze_bn_affine:
m.weight.requires_grad = False
m.bias.requires_grad = False
# for m in self.aspp1.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
# for m in self.aspp2.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
# for m in self.aspp3.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
# for m in self.aspp4.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
# for m in self.global_avg_pool.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
# for m in self.concat_projection_bn1.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
# for m in self.feature_projection_bn1.modules():
# if isinstance(m, nn.BatchNorm2d):
# m.eval()
# if freeze_bn_affine:
# m.weight.requires_grad = False
# m.bias.requires_grad = False
def __init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
# torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def load_state_dict_new(self, state_dict):
own_state = self.state_dict()
#for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.','')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print ('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
# raise
# print 'copying %s' %name
# if isinstance(param, own_state):
# backwards compatibility for serialized parameters
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
def get_1x_lr_params(model):
"""
This generator returns all the parameters of the net except for
the last classification layer. Note that for each batchnorm layer,
requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
any batchnorm parameter
"""
b = [model.xception_features]
for i in range(len(b)):
for k in b[i].parameters():
if k.requires_grad:
yield k
def get_10x_lr_params(model):
"""
This generator returns all the parameters for the last layer of the net,
which does the classification of pixel into classes
"""
b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
for j in range(len(b)):
for k in b[j].parameters():
if k.requires_grad:
yield k
if __name__ == "__main__":
model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True)
model.eval()
image = torch.randn(1, 3, 512, 512)*255
with torch.no_grad():
output = model.forward(image)
print(output.size())
# print(output)