wizzseen's picture
Upload 948 files
8a6df40 verified
import math
import torch
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch.nn.functional as F
from networks import graph
# import pdb
class GraphConvolution(nn.Module):
def __init__(self,in_features,out_features,bias=False):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features,out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias',None)
self.reset_parameters()
def reset_parameters(self):
# stdv = 1./math.sqrt(self.weight(1))
# self.weight.data.uniform_(-stdv,stdv)
torch.nn.init.xavier_uniform_(self.weight)
# if self.bias is not None:
# self.bias.data.uniform_(-stdv,stdv)
def forward(self, input,adj=None,relu=False):
support = torch.matmul(input, self.weight)
# print(support.size(),adj.size())
if adj is not None:
output = torch.matmul(adj, support)
else:
output = support
# print(output.size())
if self.bias is not None:
return output + self.bias
else:
if relu:
return F.relu(output)
else:
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
class Featuremaps_to_Graph(nn.Module):
def __init__(self,input_channels,hidden_layers,nodes=7):
super(Featuremaps_to_Graph, self).__init__()
self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes))
self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers))
self.reset_parameters()
def forward(self, input):
n,c,h,w = input.size()
# print('fea input',input.size())
input1 = input.view(n,c,h*w)
input1 = input1.transpose(1,2) # n x hw x c
# print('fea input1', input1.size())
############## Feature maps to node ################
fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes
weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer
# softmax fea_node
fea_node = F.softmax(fea_node,dim=-1)
# print(fea_node.size(),weight_node.size())
graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node))
return graph_node # n x n_class x hidden_layer
def reset_parameters(self):
for ww in self.parameters():
torch.nn.init.xavier_uniform_(ww)
# if self.bias is not None:
# self.bias.data.uniform_(-stdv,stdv)
class Featuremaps_to_Graph_transfer(nn.Module):
def __init__(self,input_channels,hidden_layers,nodes=7, source_nodes=20):
super(Featuremaps_to_Graph_transfer, self).__init__()
self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes))
self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers))
self.pre_fea_transfer = nn.Sequential(*[nn.Linear(source_nodes, source_nodes),nn.LeakyReLU(True),
nn.Linear(source_nodes, nodes), nn.LeakyReLU(True)])
self.reset_parameters()
def forward(self, input, source_pre_fea):
self.pre_fea.data = self.pre_fea_learn(source_pre_fea)
n,c,h,w = input.size()
# print('fea input',input.size())
input1 = input.view(n,c,h*w)
input1 = input1.transpose(1,2) # n x hw x c
# print('fea input1', input1.size())
############## Feature maps to node ################
fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes
weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer
# softmax fea_node
fea_node = F.softmax(fea_node,dim=1)
# print(fea_node.size(),weight_node.size())
graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node))
return graph_node # n x n_class x hidden_layer
def pre_fea_learn(self, input):
pre_fea = self.pre_fea_transfer.forward(input.unsqueeze(0)).squeeze(0)
return self.pre_fea.data + pre_fea
class Graph_to_Featuremaps(nn.Module):
# this is a special version
def __init__(self,input_channels,output_channels,hidden_layers,nodes=7):
super(Graph_to_Featuremaps, self).__init__()
self.node_fea = Parameter(torch.FloatTensor(input_channels+hidden_layers,1))
self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels))
self.reset_parameters()
def reset_parameters(self):
for ww in self.parameters():
torch.nn.init.xavier_uniform_(ww)
def forward(self, input, res_feature):
'''
:param input: 1 x batch x nodes x hidden_layer
:param res_feature: batch x channels x h x w
:return:
'''
batchi,channeli,hi,wi = res_feature.size()
# print(res_feature.size())
# print(input.size())
try:
_,batch,nodes,hidden = input.size()
except:
# print(input.size())
input = input.unsqueeze(0)
_,batch, nodes, hidden = input.size()
assert batch == batchi
input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden)
res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2)
res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli)
new_fea = torch.cat((res_feature_after_view1,input1),dim=3)
# print(self.node_fea.size(),new_fea.size())
new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
new_weight = torch.matmul(input, self.weight) # batch x node x channel
new_node = new_node.view(batch, hi*wi, nodes)
# 0721
new_node = F.softmax(new_node, dim=-1)
#
feature_out = torch.matmul(new_node,new_weight)
# print(feature_out.size())
feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size())
return F.relu(feature_out)
class Graph_to_Featuremaps_savemem(nn.Module):
# this is a special version for saving gpu memory. The process is same as Graph_to_Featuremaps.
def __init__(self, input_channels, output_channels, hidden_layers, nodes=7):
super(Graph_to_Featuremaps_savemem, self).__init__()
self.node_fea_for_res = Parameter(torch.FloatTensor(input_channels, 1))
self.node_fea_for_hidden = Parameter(torch.FloatTensor(hidden_layers, 1))
self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels))
self.reset_parameters()
def reset_parameters(self):
for ww in self.parameters():
torch.nn.init.xavier_uniform_(ww)
def forward(self, input, res_feature):
'''
:param input: 1 x batch x nodes x hidden_layer
:param res_feature: batch x channels x h x w
:return:
'''
batchi,channeli,hi,wi = res_feature.size()
# print(res_feature.size())
# print(input.size())
try:
_,batch,nodes,hidden = input.size()
except:
# print(input.size())
input = input.unsqueeze(0)
_,batch, nodes, hidden = input.size()
assert batch == batchi
input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden)
res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2)
res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli)
# new_fea = torch.cat((res_feature_after_view1,input1),dim=3)
## sim
new_node1 = torch.matmul(res_feature_after_view1, self.node_fea_for_res)
new_node2 = torch.matmul(input1, self.node_fea_for_hidden)
new_node = new_node1 + new_node2
## sim end
# print(self.node_fea.size(),new_fea.size())
# new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
new_weight = torch.matmul(input, self.weight) # batch x node x channel
new_node = new_node.view(batch, hi*wi, nodes)
# 0721
new_node = F.softmax(new_node, dim=-1)
#
feature_out = torch.matmul(new_node,new_weight)
# print(feature_out.size())
feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size())
return F.relu(feature_out)
class Graph_trans(nn.Module):
def __init__(self,in_features,out_features,begin_nodes=7,end_nodes=2,bias=False,adj=None):
super(Graph_trans, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features,out_features))
if adj is not None:
h,w = adj.size()
assert (h == end_nodes) and (w == begin_nodes)
self.adj = torch.autograd.Variable(adj,requires_grad=False)
else:
self.adj = Parameter(torch.FloatTensor(end_nodes,begin_nodes))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias',None)
# self.reset_parameters()
def reset_parameters(self):
# stdv = 1./math.sqrt(self.weight(1))
# self.weight.data.uniform_(-stdv,stdv)
torch.nn.init.xavier_uniform_(self.weight)
# if self.bias is not None:
# self.bias.data.uniform_(-stdv,stdv)
def forward(self, input, relu=False, adj_return=False, adj=None):
support = torch.matmul(input,self.weight)
# print(support.size(),self.adj.size())
if adj is None:
adj = self.adj
adj1 = self.norm_trans_adj(adj)
output = torch.matmul(adj1,support)
if adj_return:
output1 = F.normalize(output,p=2,dim=-1)
self.adj_mat = torch.matmul(output1,output1.transpose(-2,-1))
if self.bias is not None:
return output + self.bias
else:
if relu:
return F.relu(output)
else:
return output
def get_adj_mat(self):
adj = graph.normalize_adj_torch(F.relu(self.adj_mat))
return adj
def get_encode_adj(self):
return self.adj
def norm_trans_adj(self,adj): # maybe can use softmax
adj = F.relu(adj)
r = F.softmax(adj,dim=-1)
# print(adj.size())
# row_sum = adj.sum(-1).unsqueeze(-1)
# d_mat = row_sum.expand(adj.size())
# r = torch.div(row_sum,d_mat)
# r[torch.isnan(r)] = 0
return r
if __name__ == '__main__':
graph = torch.randn((7,128))
en = GraphConvolution(128,128)
a = en.forward(graph)
print(a)
# a = en.forward(graph,pred)
# print(a.size())