Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.model_zoo as model_zoo | |
from torch.nn.parameter import Parameter | |
from collections import OrderedDict | |
class SeparableConv2d(nn.Module): | |
def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False): | |
super(SeparableConv2d, self).__init__() | |
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, | |
groups=inplanes, bias=bias) | |
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.pointwise(x) | |
return x | |
def fixed_padding(inputs, kernel_size, rate): | |
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) | |
pad_total = kernel_size_effective - 1 | |
pad_beg = pad_total // 2 | |
pad_end = pad_total - pad_beg | |
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) | |
return padded_inputs | |
class SeparableConv2d_aspp(nn.Module): | |
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0): | |
super(SeparableConv2d_aspp, self).__init__() | |
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, | |
groups=inplanes, bias=bias) | |
self.depthwise_bn = nn.BatchNorm2d(inplanes) | |
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) | |
self.pointwise_bn = nn.BatchNorm2d(planes) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
# x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]) | |
x = self.depthwise(x) | |
x = self.depthwise_bn(x) | |
x = self.relu(x) | |
x = self.pointwise(x) | |
x = self.pointwise_bn(x) | |
x = self.relu(x) | |
return x | |
class Decoder_module(nn.Module): | |
def __init__(self, inplanes, planes, rate=1): | |
super(Decoder_module, self).__init__() | |
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate,padding=1) | |
def forward(self, x): | |
x = self.atrous_convolution(x) | |
return x | |
class ASPP_module(nn.Module): | |
def __init__(self, inplanes, planes, rate): | |
super(ASPP_module, self).__init__() | |
if rate == 1: | |
raise RuntimeError() | |
else: | |
kernel_size = 3 | |
padding = rate | |
self.atrous_convolution = SeparableConv2d_aspp(inplanes, planes, 3, stride=1, dilation=rate, | |
padding=padding) | |
def forward(self, x): | |
x = self.atrous_convolution(x) | |
return x | |
class ASPP_module_rate0(nn.Module): | |
def __init__(self, inplanes, planes, rate=1): | |
super(ASPP_module_rate0, self).__init__() | |
if rate == 1: | |
kernel_size = 1 | |
padding = 0 | |
self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, | |
stride=1, padding=padding, dilation=rate, bias=False) | |
self.bn = nn.BatchNorm2d(planes, eps=1e-5, affine=True) | |
self.relu = nn.ReLU() | |
else: | |
raise RuntimeError() | |
def forward(self, x): | |
x = self.atrous_convolution(x) | |
x = self.bn(x) | |
return self.relu(x) | |
class SeparableConv2d_same(nn.Module): | |
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, padding=0): | |
super(SeparableConv2d_same, self).__init__() | |
self.depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, | |
groups=inplanes, bias=bias) | |
self.depthwise_bn = nn.BatchNorm2d(inplanes) | |
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) | |
self.pointwise_bn = nn.BatchNorm2d(planes) | |
def forward(self, x): | |
x = fixed_padding(x, self.depthwise.kernel_size[0], rate=self.depthwise.dilation[0]) | |
x = self.depthwise(x) | |
x = self.depthwise_bn(x) | |
x = self.pointwise(x) | |
x = self.pointwise_bn(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): | |
super(Block, self).__init__() | |
if planes != inplanes or stride != 1: | |
self.skip = nn.Conv2d(inplanes, planes, 1, stride=2, bias=False) | |
if is_last: | |
self.skip = nn.Conv2d(inplanes, planes, 1, stride=1, bias=False) | |
self.skipbn = nn.BatchNorm2d(planes) | |
else: | |
self.skip = None | |
self.relu = nn.ReLU(inplace=True) | |
rep = [] | |
filters = inplanes | |
if grow_first: | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) | |
# rep.append(nn.BatchNorm2d(planes)) | |
filters = planes | |
for i in range(reps - 1): | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) | |
# rep.append(nn.BatchNorm2d(filters)) | |
if not grow_first: | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) | |
# rep.append(nn.BatchNorm2d(planes)) | |
if not start_with_relu: | |
rep = rep[1:] | |
if stride != 1: | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)) | |
if is_last: | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1,dilation=dilation)) | |
self.rep = nn.Sequential(*rep) | |
def forward(self, inp): | |
x = self.rep(inp) | |
if self.skip is not None: | |
skip = self.skip(inp) | |
skip = self.skipbn(skip) | |
else: | |
skip = inp | |
# print(x.size(),skip.size()) | |
x += skip | |
return x | |
class Block2(nn.Module): | |
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): | |
super(Block2, self).__init__() | |
if planes != inplanes or stride != 1: | |
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) | |
self.skipbn = nn.BatchNorm2d(planes) | |
else: | |
self.skip = None | |
self.relu = nn.ReLU(inplace=True) | |
rep = [] | |
filters = inplanes | |
if grow_first: | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) | |
# rep.append(nn.BatchNorm2d(planes)) | |
filters = planes | |
for i in range(reps - 1): | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) | |
# rep.append(nn.BatchNorm2d(filters)) | |
if not grow_first: | |
rep.append(self.relu) | |
rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) | |
# rep.append(nn.BatchNorm2d(planes)) | |
if not start_with_relu: | |
rep = rep[1:] | |
if stride != 1: | |
self.block2_lastconv = nn.Sequential(*[self.relu,SeparableConv2d_same(planes, planes, 3, stride=2,dilation=dilation)]) | |
if is_last: | |
rep.append(SeparableConv2d_same(planes, planes, 3, stride=1)) | |
self.rep = nn.Sequential(*rep) | |
def forward(self, inp): | |
x = self.rep(inp) | |
low_middle = x.clone() | |
x1 = x | |
x1 = self.block2_lastconv(x1) | |
if self.skip is not None: | |
skip = self.skip(inp) | |
skip = self.skipbn(skip) | |
else: | |
skip = inp | |
x1 += skip | |
return x1,low_middle | |
class Xception(nn.Module): | |
""" | |
Modified Alighed Xception | |
""" | |
def __init__(self, inplanes=3, os=16, pretrained=False): | |
super(Xception, self).__init__() | |
if os == 16: | |
entry_block3_stride = 2 | |
middle_block_rate = 1 | |
exit_block_rates = (1, 2) | |
elif os == 8: | |
entry_block3_stride = 1 | |
middle_block_rate = 2 | |
exit_block_rates = (2, 4) | |
else: | |
raise NotImplementedError | |
# Entry flow | |
self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(32) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(64) | |
self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) | |
self.block2 = Block2(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True) | |
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True) | |
# Middle flow | |
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_rate, start_with_relu=True, grow_first=True) | |
# Exit flow | |
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_rates[0], | |
start_with_relu=True, grow_first=False, is_last=True) | |
self.conv3 = SeparableConv2d_aspp(1024, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) | |
# self.bn3 = nn.BatchNorm2d(1536) | |
self.conv4 = SeparableConv2d_aspp(1536, 1536, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) | |
# self.bn4 = nn.BatchNorm2d(1536) | |
self.conv5 = SeparableConv2d_aspp(1536, 2048, 3, stride=1, dilation=exit_block_rates[1],padding=exit_block_rates[1]) | |
# self.bn5 = nn.BatchNorm2d(2048) | |
# Init weights | |
# self.__init_weight() | |
# Load pretrained model | |
if pretrained: | |
self.__load_xception_pretrained() | |
def forward(self, x): | |
# Entry flow | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
# print('conv1 ',x.size()) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.relu(x) | |
x = self.block1(x) | |
# print('block1',x.size()) | |
# low_level_feat = x | |
x,low_level_feat = self.block2(x) | |
# print('block2',x.size()) | |
x = self.block3(x) | |
# print('xception block3 ',x.size()) | |
# Middle flow | |
x = self.block4(x) | |
x = self.block5(x) | |
x = self.block6(x) | |
x = self.block7(x) | |
x = self.block8(x) | |
x = self.block9(x) | |
x = self.block10(x) | |
x = self.block11(x) | |
x = self.block12(x) | |
x = self.block13(x) | |
x = self.block14(x) | |
x = self.block15(x) | |
x = self.block16(x) | |
x = self.block17(x) | |
x = self.block18(x) | |
x = self.block19(x) | |
# Exit flow | |
x = self.block20(x) | |
x = self.conv3(x) | |
# x = self.bn3(x) | |
x = self.relu(x) | |
x = self.conv4(x) | |
# x = self.bn4(x) | |
x = self.relu(x) | |
x = self.conv5(x) | |
# x = self.bn5(x) | |
x = self.relu(x) | |
return x, low_level_feat | |
def __init_weight(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
# m.weight.data.normal_(0, math.sqrt(2. / n)) | |
torch.nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def __load_xception_pretrained(self): | |
pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') | |
model_dict = {} | |
state_dict = self.state_dict() | |
for k, v in pretrain_dict.items(): | |
if k in state_dict: | |
if 'pointwise' in k: | |
v = v.unsqueeze(-1).unsqueeze(-1) | |
if k.startswith('block12'): | |
model_dict[k.replace('block12', 'block20')] = v | |
elif k.startswith('block11'): | |
model_dict[k.replace('block11', 'block12')] = v | |
model_dict[k.replace('block11', 'block13')] = v | |
model_dict[k.replace('block11', 'block14')] = v | |
model_dict[k.replace('block11', 'block15')] = v | |
model_dict[k.replace('block11', 'block16')] = v | |
model_dict[k.replace('block11', 'block17')] = v | |
model_dict[k.replace('block11', 'block18')] = v | |
model_dict[k.replace('block11', 'block19')] = v | |
elif k.startswith('conv3'): | |
model_dict[k] = v | |
elif k.startswith('bn3'): | |
model_dict[k] = v | |
model_dict[k.replace('bn3', 'bn4')] = v | |
elif k.startswith('conv4'): | |
model_dict[k.replace('conv4', 'conv5')] = v | |
elif k.startswith('bn4'): | |
model_dict[k.replace('bn4', 'bn5')] = v | |
else: | |
model_dict[k] = v | |
state_dict.update(model_dict) | |
self.load_state_dict(state_dict) | |
class DeepLabv3_plus(nn.Module): | |
def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True): | |
if _print: | |
print("Constructing DeepLabv3+ model...") | |
print("Number of classes: {}".format(n_classes)) | |
print("Output stride: {}".format(os)) | |
print("Number of Input Channels: {}".format(nInputChannels)) | |
super(DeepLabv3_plus, self).__init__() | |
# Atrous Conv | |
self.xception_features = Xception(nInputChannels, os, pretrained) | |
# ASPP | |
if os == 16: | |
rates = [1, 6, 12, 18] | |
elif os == 8: | |
rates = [1, 12, 24, 36] | |
raise NotImplementedError | |
else: | |
raise NotImplementedError | |
self.aspp1 = ASPP_module_rate0(2048, 256, rate=rates[0]) | |
self.aspp2 = ASPP_module(2048, 256, rate=rates[1]) | |
self.aspp3 = ASPP_module(2048, 256, rate=rates[2]) | |
self.aspp4 = ASPP_module(2048, 256, rate=rates[3]) | |
self.relu = nn.ReLU() | |
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), | |
nn.Conv2d(2048, 256, 1, stride=1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU() | |
) | |
self.concat_projection_conv1 = nn.Conv2d(1280, 256, 1, bias=False) | |
self.concat_projection_bn1 = nn.BatchNorm2d(256) | |
# adopt [1x1, 48] for channel reduction. | |
self.feature_projection_conv1 = nn.Conv2d(256, 48, 1, bias=False) | |
self.feature_projection_bn1 = nn.BatchNorm2d(48) | |
self.decoder = nn.Sequential(Decoder_module(304, 256), | |
Decoder_module(256, 256) | |
) | |
self.semantic = nn.Conv2d(256, n_classes, kernel_size=1, stride=1) | |
def forward(self, input): | |
x, low_level_features = self.xception_features(input) | |
# print(x.size()) | |
x1 = self.aspp1(x) | |
x2 = self.aspp2(x) | |
x3 = self.aspp3(x) | |
x4 = self.aspp4(x) | |
x5 = self.global_avg_pool(x) | |
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) | |
x = torch.cat((x1, x2, x3, x4, x5), dim=1) | |
x = self.concat_projection_conv1(x) | |
x = self.concat_projection_bn1(x) | |
x = self.relu(x) | |
# print(x.size()) | |
low_level_features = self.feature_projection_conv1(low_level_features) | |
low_level_features = self.feature_projection_bn1(low_level_features) | |
low_level_features = self.relu(low_level_features) | |
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) | |
# print(low_level_features.size()) | |
# print(x.size()) | |
x = torch.cat((x, low_level_features), dim=1) | |
x = self.decoder(x) | |
x = self.semantic(x) | |
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True) | |
return x | |
def freeze_bn(self): | |
for m in self.xception_features.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
def freeze_totally_bn(self): | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
def freeze_aspp_bn(self): | |
for m in self.aspp1.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
for m in self.aspp2.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
for m in self.aspp3.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
for m in self.aspp4.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
def learnable_parameters(self): | |
layer_features_BN = [] | |
layer_features = [] | |
layer_aspp = [] | |
layer_projection =[] | |
layer_decoder = [] | |
layer_other = [] | |
model_para = list(self.named_parameters()) | |
for name,para in model_para: | |
if 'xception' in name: | |
if 'bn' in name or 'downsample.1.weight' in name or 'downsample.1.bias' in name: | |
layer_features_BN.append(para) | |
else: | |
layer_features.append(para) | |
# print (name) | |
elif 'aspp' in name: | |
layer_aspp.append(para) | |
elif 'projection' in name: | |
layer_projection.append(para) | |
elif 'decode' in name: | |
layer_decoder.append(para) | |
elif 'global' not in name: | |
layer_other.append(para) | |
return layer_features_BN,layer_features,layer_aspp,layer_projection,layer_decoder,layer_other | |
def get_backbone_para(self): | |
layer_features = [] | |
other_features = [] | |
model_para = list(self.named_parameters()) | |
for name, para in model_para: | |
if 'xception' in name: | |
layer_features.append(para) | |
else: | |
other_features.append(para) | |
return layer_features, other_features | |
def train_fixbn(self, mode=True, freeze_bn=True, freeze_bn_affine=False): | |
r"""Sets the module in training mode. | |
This has any effect only on certain modules. See documentations of | |
particular modules for details of their behaviors in training/evaluation | |
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, | |
etc. | |
Returns: | |
Module: self | |
""" | |
super(DeepLabv3_plus, self).train(mode) | |
if freeze_bn: | |
print("Freezing Mean/Var of BatchNorm2D.") | |
if freeze_bn_affine: | |
print("Freezing Weight/Bias of BatchNorm2D.") | |
if freeze_bn: | |
for m in self.xception_features.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
if freeze_bn_affine: | |
m.weight.requires_grad = False | |
m.bias.requires_grad = False | |
# for m in self.aspp1.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
# for m in self.aspp2.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
# for m in self.aspp3.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
# for m in self.aspp4.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
# for m in self.global_avg_pool.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
# for m in self.concat_projection_bn1.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
# for m in self.feature_projection_bn1.modules(): | |
# if isinstance(m, nn.BatchNorm2d): | |
# m.eval() | |
# if freeze_bn_affine: | |
# m.weight.requires_grad = False | |
# m.bias.requires_grad = False | |
def __init_weight(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
# torch.nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def load_state_dict_new(self, state_dict): | |
own_state = self.state_dict() | |
#for name inshop_cos own_state: | |
# print name | |
new_state_dict = OrderedDict() | |
for name, param in state_dict.items(): | |
name = name.replace('module.','') | |
new_state_dict[name] = 0 | |
if name not in own_state: | |
if 'num_batch' in name: | |
continue | |
print ('unexpected key "{}" in state_dict' | |
.format(name)) | |
continue | |
# if isinstance(param, own_state): | |
if isinstance(param, Parameter): | |
# backwards compatibility for serialized parameters | |
param = param.data | |
try: | |
own_state[name].copy_(param) | |
except: | |
print('While copying the parameter named {}, whose dimensions in the model are' | |
' {} and whose dimensions in the checkpoint are {}, ...'.format( | |
name, own_state[name].size(), param.size())) | |
continue # i add inshop_cos 2018/02/01 | |
# raise | |
# print 'copying %s' %name | |
# if isinstance(param, own_state): | |
# backwards compatibility for serialized parameters | |
own_state[name].copy_(param) | |
# print 'copying %s' %name | |
missing = set(own_state.keys()) - set(new_state_dict.keys()) | |
if len(missing) > 0: | |
print('missing keys in state_dict: "{}"'.format(missing)) | |
def get_1x_lr_params(model): | |
""" | |
This generator returns all the parameters of the net except for | |
the last classification layer. Note that for each batchnorm layer, | |
requires_grad is set to False in deeplab_resnet.py, therefore this function does not return | |
any batchnorm parameter | |
""" | |
b = [model.xception_features] | |
for i in range(len(b)): | |
for k in b[i].parameters(): | |
if k.requires_grad: | |
yield k | |
def get_10x_lr_params(model): | |
""" | |
This generator returns all the parameters for the last layer of the net, | |
which does the classification of pixel into classes | |
""" | |
b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] | |
for j in range(len(b)): | |
for k in b[j].parameters(): | |
if k.requires_grad: | |
yield k | |
if __name__ == "__main__": | |
model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=False, _print=True) | |
model.eval() | |
image = torch.randn(1, 3, 512, 512)*255 | |
with torch.no_grad(): | |
output = model.forward(image) | |
print(output.size()) | |
# print(output) | |