wizzseen's picture
Upload 948 files
8a6df40 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn.parameter import Parameter
import numpy as np
from collections import OrderedDict
from torch.nn import Parameter
from networks import deeplab_xception,gcn, deeplab_xception_synBN
import pdb
#######################
# base model
#######################
class deeplab_xception_transfer_basemodel(deeplab_xception.DeepLabv3_plus):
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
super(deeplab_xception_transfer_basemodel, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os,)
### source graph
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
# nodes=n_classes)
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
#
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
# hidden_layers=hidden_layers, nodes=n_classes
# )
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
# nn.ReLU(True)])
### target graph
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=n_classes)
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
hidden_layers=hidden_layers, nodes=n_classes
)
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
nn.ReLU(True)])
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
def get_target_parameter(self):
l = []
other = []
for name, k in self.named_parameters():
if 'target' in name or 'semantic' in name:
l.append(k)
else:
other.append(k)
return l, other
def get_semantic_parameter(self):
l = []
for name, k in self.named_parameters():
if 'semantic' in name:
l.append(k)
return l
def get_source_parameter(self):
l = []
for name, k in self.named_parameters():
if 'source' in name:
l.append(k)
return l
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
# graph combine
# print(graph.size(),source_2_target_graph.size())
# graph = self.fc_graph.forward(graph,relu=True)
# print(graph.size())
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
# graph = self.gcn_encode.forward(graph,relu=True)
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
# graph = self.gcn_decode.forward(graph,relu=True)
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
super(deeplab_xception_transfer_basemodel_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os,)
### source graph
### target graph
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=n_classes)
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels,
hidden_layers=hidden_layers, nodes=n_classes
)
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
nn.ReLU(True)])
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_graph' not in name and 'transpose_graph' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
def get_target_parameter(self):
l = []
other = []
for name, k in self.named_parameters():
if 'target' in name or 'semantic' in name:
l.append(k)
else:
other.append(k)
return l, other
def get_semantic_parameter(self):
l = []
for name, k in self.named_parameters():
if 'semantic' in name:
l.append(k)
return l
def get_source_parameter(self):
l = []
for name, k in self.named_parameters():
if 'source' in name:
l.append(k)
return l
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
# graph combine
# print(graph.size(),source_2_target_graph.size())
# graph = self.fc_graph.forward(graph,relu=True)
# print(graph.size())
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
# graph = self.gcn_encode.forward(graph,relu=True)
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
# graph = self.gcn_decode.forward(graph,relu=True)
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
class deeplab_xception_transfer_basemodel_synBN(deeplab_xception_synBN.DeepLabv3_plus):
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
super(deeplab_xception_transfer_basemodel_synBN, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os,)
### source graph
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
# nodes=n_classes)
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
#
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
# hidden_layers=hidden_layers, nodes=n_classes
# )
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
# nn.ReLU(True)])
### target graph
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=n_classes)
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
hidden_layers=hidden_layers, nodes=n_classes
)
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
nn.ReLU(True)])
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
def get_target_parameter(self):
l = []
other = []
for name, k in self.named_parameters():
if 'target' in name or 'semantic' in name:
l.append(k)
else:
other.append(k)
return l, other
def get_semantic_parameter(self):
l = []
for name, k in self.named_parameters():
if 'semantic' in name:
l.append(k)
return l
def get_source_parameter(self):
l = []
for name, k in self.named_parameters():
if 'source' in name:
l.append(k)
return l
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
# graph combine
# print(graph.size(),source_2_target_graph.size())
# graph = self.fc_graph.forward(graph,relu=True)
# print(graph.size())
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
# graph = self.gcn_encode.forward(graph,relu=True)
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
# graph = self.gcn_decode.forward(graph,relu=True)
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
class deeplab_xception_transfer_basemodel_synBN_savememory(deeplab_xception_synBN.DeepLabv3_plus):
def __init__(self,nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256):
super(deeplab_xception_transfer_basemodel_synBN_savememory, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os, )
### source graph
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
# nodes=n_classes)
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
#
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
# hidden_layers=hidden_layers, nodes=n_classes
# )
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
# nn.ReLU(True)])
### target graph
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=n_classes)
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(input_channels=input_channels, output_channels=out_channels,
hidden_layers=hidden_layers, nodes=n_classes
)
self.target_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
nn.BatchNorm2d(input_channels),
nn.ReLU(True)])
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
def get_target_parameter(self):
l = []
other = []
for name, k in self.named_parameters():
if 'target' in name or 'semantic' in name:
l.append(k)
else:
other.append(k)
return l, other
def get_semantic_parameter(self):
l = []
for name, k in self.named_parameters():
if 'semantic' in name:
l.append(k)
return l
def get_source_parameter(self):
l = []
for name, k in self.named_parameters():
if 'source' in name:
l.append(k)
return l
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
# graph combine
# print(graph.size(),source_2_target_graph.size())
# graph = self.fc_graph.forward(graph,relu=True)
# print(graph.size())
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
# graph = self.gcn_encode.forward(graph,relu=True)
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
# graph = self.gcn_decode.forward(graph,relu=True)
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
#######################
# transfer model
#######################
class deeplab_xception_transfer_projection(deeplab_xception_transfer_basemodel):
def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
transfer_graph=None, source_classes=20):
super(deeplab_xception_transfer_projection, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os, input_channels=input_channels,
hidden_layers=hidden_layers, out_channels=out_channels, )
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=source_classes)
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
begin_nodes=source_classes,end_nodes=n_classes)
self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers)
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# source graph
source_graph = self.source_featuremap_2_graph(x)
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
# graph combine 1
# print(graph.size())
# print(source_2_target_graph1.size())
# print(source_2_target_graph1_v5.size())
graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
graph = self.fc_graph.forward(graph,relu=True)
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
# graph combine 2
graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
graph = self.fc_graph.forward(graph, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
# graph combine 3
graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
graph = self.fc_graph.forward(graph, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
def similarity_trans(self,source,target):
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
sim = F.softmax(sim, dim=-1)
return torch.matmul(sim, source)
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
class deeplab_xception_transfer_projection_savemem(deeplab_xception_transfer_basemodel_savememory):
def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
transfer_graph=None, source_classes=20):
super(deeplab_xception_transfer_projection_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os, input_channels=input_channels,
hidden_layers=hidden_layers, out_channels=out_channels, )
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=source_classes)
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
begin_nodes=source_classes,end_nodes=n_classes)
self.fc_graph = gcn.GraphConvolution(hidden_layers*3, hidden_layers)
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# source graph
source_graph = self.source_featuremap_2_graph(x)
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
# graph combine 1
graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
graph = self.fc_graph.forward(graph,relu=True)
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
# graph combine 2
graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
graph = self.fc_graph.forward(graph, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
# graph combine 3
graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
graph = self.fc_graph.forward(graph, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
def similarity_trans(self,source,target):
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
sim = F.softmax(sim, dim=-1)
return torch.matmul(sim, source)
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
class deeplab_xception_transfer_projection_synBN_savemem(deeplab_xception_transfer_basemodel_synBN_savememory):
def __init__(self, nInputChannels=3, n_classes=7, os=16,input_channels=256,hidden_layers=128,out_channels=256,
transfer_graph=None, source_classes=20):
super(deeplab_xception_transfer_projection_synBN_savemem, self).__init__(nInputChannels=nInputChannels, n_classes=n_classes,
os=os, input_channels=input_channels,
hidden_layers=hidden_layers, out_channels=out_channels, )
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
nodes=source_classes)
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
self.transpose_graph = gcn.Graph_trans(in_features=hidden_layers,out_features=hidden_layers,adj=transfer_graph,
begin_nodes=source_classes,end_nodes=n_classes)
self.fc_graph = gcn.GraphConvolution(hidden_layers*3 ,hidden_layers)
def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
x, low_level_features = self.xception_features(input)
# print(x.size())
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.concat_projection_conv1(x)
x = self.concat_projection_bn1(x)
x = self.relu(x)
# print(x.size())
x = F.upsample(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
low_level_features = self.feature_projection_conv1(low_level_features)
low_level_features = self.feature_projection_bn1(low_level_features)
low_level_features = self.relu(low_level_features)
# print(low_level_features.size())
# print(x.size())
x = torch.cat((x, low_level_features), dim=1)
x = self.decoder(x)
### add graph
# source graph
source_graph = self.source_featuremap_2_graph(x)
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)
source_2_target_graph3_v5 = self.transpose_graph.forward(source_graph3, adj=adj3_transfer, relu=True)
# target graph
# print('x size',x.size(),adj1.size())
graph = self.target_featuremap_2_graph(x)
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
# graph combine 1
graph = torch.cat((graph,source_2_target_graph1.squeeze(0), source_2_target_graph1_v5.squeeze(0)),dim=-1)
graph = self.fc_graph.forward(graph,relu=True)
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
# graph combine 2
graph = torch.cat((graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1)
graph = self.fc_graph.forward(graph, relu=True)
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
# graph combine 3
graph = torch.cat((graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1)
graph = self.fc_graph.forward(graph, relu=True)
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
# print(graph.size(),x.size())
graph = self.target_graph_2_fea.forward(graph, x)
x = self.target_skip_conv(x)
x = x + graph
###
x = self.semantic(x)
x = F.upsample(x, size=input.size()[2:], mode='bilinear', align_corners=True)
return x
def similarity_trans(self,source,target):
sim = torch.matmul(F.normalize(target, p=2, dim=-1), F.normalize(source, p=2, dim=-1).transpose(-1, -2))
sim = F.softmax(sim, dim=-1)
return torch.matmul(sim, source)
def load_source_model(self,state_dict):
own_state = self.state_dict()
# for name inshop_cos own_state:
# print name
new_state_dict = OrderedDict()
for name, param in state_dict.items():
name = name.replace('module.', '')
if 'graph' in name and 'source' not in name and 'target' not in name and 'fc_' not in name and 'transpose_graph' not in name:
if 'featuremap_2_graph' in name:
name = name.replace('featuremap_2_graph','source_featuremap_2_graph')
else:
name = name.replace('graph','source_graph')
new_state_dict[name] = 0
if name not in own_state:
if 'num_batch' in name:
continue
print('unexpected key "{}" in state_dict'
.format(name))
continue
# if isinstance(param, own_state):
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
own_state[name].copy_(param)
except:
print('While copying the parameter named {}, whose dimensions in the model are'
' {} and whose dimensions in the checkpoint are {}, ...'.format(
name, own_state[name].size(), param.size()))
continue # i add inshop_cos 2018/02/01
own_state[name].copy_(param)
# print 'copying %s' %name
missing = set(own_state.keys()) - set(new_state_dict.keys())
if len(missing) > 0:
print('missing keys in state_dict: "{}"'.format(missing))
# if __name__ == '__main__':
# net = deeplab_xception_transfer_projection_v3v5_more_savemem()
# img = torch.rand((2,3,128,128))
# net.eval()
# a = torch.rand((1,1,7,7))
# net.forward(img, adj1_target=a)