Spaces:
Running
Running
import numpy as np | |
import pickle as pkl | |
import networkx as nx | |
import scipy.sparse as sp | |
import torch | |
pascal_graph = {0:[0], | |
1:[1, 2], | |
2:[1, 2, 3, 5], | |
3:[2, 3, 4], | |
4:[3, 4], | |
5:[2, 5, 6], | |
6:[5, 6]} | |
cihp_graph = {0: [], | |
1: [2, 13], | |
2: [1, 13], | |
3: [14, 15], | |
4: [13], | |
5: [6, 7, 9, 10, 11, 12, 14, 15], | |
6: [5, 7, 10, 11, 14, 15, 16, 17], | |
7: [5, 6, 9, 10, 11, 12, 14, 15], | |
8: [16, 17, 18, 19], | |
9: [5, 7, 10, 16, 17, 18, 19], | |
10:[5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17], | |
11:[5, 6, 7, 10, 13], | |
12:[5, 7, 10, 16, 17], | |
13:[1, 2, 4, 10, 11], | |
14:[3, 5, 6, 7, 10], | |
15:[3, 5, 6, 7, 10], | |
16:[6, 8, 9, 10, 12, 18], | |
17:[6, 8, 9, 10, 12, 19], | |
18:[8, 9, 16], | |
19:[8, 9, 17]} | |
atr_graph = {0: [], | |
1: [2, 11], | |
2: [1, 11], | |
3: [11], | |
4: [5, 6, 7, 11, 14, 15, 17], | |
5: [4, 6, 7, 8, 12, 13], | |
6: [4,5,7,8,9,10,12,13], | |
7: [4,11,12,13,14,15], | |
8: [5,6], | |
9: [6, 12], | |
10:[6, 13], | |
11:[1,2,3,4,7,14,15,17], | |
12:[5,6,7,9], | |
13:[5,6,7,10], | |
14:[4,7,11,16], | |
15:[4,7,11,16], | |
16:[14,15], | |
17:[4,11], | |
} | |
cihp2pascal_adj = np.array([[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], | |
[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], | |
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]) | |
cihp2pascal_nlp_adj = \ | |
np.array([[ 1., 0.35333052, 0.32727194, 0.17418084, 0.18757584, | |
0.40608522, 0.37503981, 0.35448462, 0.22598555, 0.23893579, | |
0.33064262, 0.28923404, 0.27986573, 0.4211553 , 0.36915778, | |
0.41377746, 0.32485771, 0.37248222, 0.36865639, 0.41500332], | |
[ 0.39615879, 0.46201529, 0.52321467, 0.30826114, 0.25669527, | |
0.54747773, 0.3670523 , 0.3901983 , 0.27519473, 0.3433325 , | |
0.52728509, 0.32771333, 0.34819325, 0.63882953, 0.68042925, | |
0.69368576, 0.63395791, 0.65344337, 0.59538781, 0.6071375 ], | |
[ 0.16373166, 0.21663339, 0.3053872 , 0.28377612, 0.1372435 , | |
0.4448808 , 0.29479995, 0.31092595, 0.22703953, 0.33983576, | |
0.75778818, 0.2619818 , 0.37069392, 0.35184867, 0.49877512, | |
0.49979437, 0.51853277, 0.52517541, 0.32517741, 0.32377309], | |
[ 0.32687232, 0.38482461, 0.37693463, 0.41610834, 0.20415749, | |
0.76749079, 0.35139853, 0.3787411 , 0.28411737, 0.35155421, | |
0.58792618, 0.31141718, 0.40585111, 0.51189218, 0.82042737, | |
0.8342413 , 0.70732188, 0.72752501, 0.60327325, 0.61431337], | |
[ 0.34069369, 0.34817292, 0.37525998, 0.36497069, 0.17841617, | |
0.69746208, 0.31731463, 0.34628951, 0.25167277, 0.32072379, | |
0.56711286, 0.24894776, 0.37000453, 0.52600859, 0.82483993, | |
0.84966274, 0.7033991 , 0.73449378, 0.56649608, 0.58888791], | |
[ 0.28477487, 0.35139564, 0.42742352, 0.41664321, 0.20004676, | |
0.78566833, 0.42237487, 0.41048549, 0.37933812, 0.46542516, | |
0.62444759, 0.3274493 , 0.49466009, 0.49314658, 0.71244233, | |
0.71497003, 0.8234787 , 0.83566589, 0.62597135, 0.62626812], | |
[ 0.3011378 , 0.31775977, 0.42922647, 0.36896257, 0.17597556, | |
0.72214655, 0.39162804, 0.38137872, 0.34980296, 0.43818419, | |
0.60879174, 0.26762545, 0.46271161, 0.51150476, 0.72318109, | |
0.73678399, 0.82620388, 0.84942166, 0.5943811 , 0.60607602]]) | |
pascal2atr_nlp_adj = \ | |
np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522, | |
0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639, | |
0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778, | |
0.41377746, 0.32006291, 0.28923404], | |
[ 0.39615879, 0.46201529, 0.52321467, 0.25669527, 0.54747773, | |
0.34819325, 0.3433325 , 0.26603942, 0.45162929, 0.59538781, | |
0.6071375 , 0.63882953, 0.63395791, 0.65344337, 0.68042925, | |
0.69368576, 0.44354613, 0.32771333], | |
[ 0.16373166, 0.21663339, 0.3053872 , 0.1372435 , 0.4448808 , | |
0.37069392, 0.33983576, 0.26563416, 0.35443504, 0.32517741, | |
0.32377309, 0.35184867, 0.51853277, 0.52517541, 0.49877512, | |
0.49979437, 0.21750868, 0.2619818 ], | |
[ 0.32687232, 0.38482461, 0.37693463, 0.20415749, 0.76749079, | |
0.40585111, 0.35155421, 0.28271333, 0.52684576, 0.60327325, | |
0.61431337, 0.51189218, 0.70732188, 0.72752501, 0.82042737, | |
0.8342413 , 0.40137029, 0.31141718], | |
[ 0.34069369, 0.34817292, 0.37525998, 0.17841617, 0.69746208, | |
0.37000453, 0.32072379, 0.27268885, 0.47426719, 0.56649608, | |
0.58888791, 0.52600859, 0.7033991 , 0.73449378, 0.82483993, | |
0.84966274, 0.37830796, 0.24894776], | |
[ 0.28477487, 0.35139564, 0.42742352, 0.20004676, 0.78566833, | |
0.49466009, 0.46542516, 0.32662614, 0.55780359, 0.62597135, | |
0.62626812, 0.49314658, 0.8234787 , 0.83566589, 0.71244233, | |
0.71497003, 0.41223219, 0.3274493 ], | |
[ 0.3011378 , 0.31775977, 0.42922647, 0.17597556, 0.72214655, | |
0.46271161, 0.43818419, 0.3192333 , 0.50979216, 0.5943811 , | |
0.60607602, 0.51150476, 0.82620388, 0.84942166, 0.72318109, | |
0.73678399, 0.39259827, 0.26762545]]) | |
cihp2atr_nlp_adj = np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522, | |
0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639, | |
0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778, | |
0.41377746, 0.32006291, 0.28923404], | |
[ 0.35333052, 1. , 0.39206695, 0.42143438, 0.4736689 , | |
0.47139544, 0.51999208, 0.38354847, 0.45628529, 0.46514124, | |
0.50083501, 0.4310595 , 0.39371443, 0.4319752 , 0.42938598, | |
0.46384034, 0.44833757, 0.6153155 ], | |
[ 0.32727194, 0.39206695, 1. , 0.32836702, 0.52603065, | |
0.39543695, 0.3622627 , 0.43575346, 0.33866223, 0.45202552, | |
0.48421 , 0.53669903, 0.47266611, 0.50925436, 0.42286557, | |
0.45403656, 0.37221304, 0.40999322], | |
[ 0.17418084, 0.46892601, 0.25774838, 0.31816231, 0.39330317, | |
0.34218382, 0.48253904, 0.22084125, 0.41335728, 0.52437572, | |
0.5191713 , 0.33576117, 0.44230914, 0.44250678, 0.44330833, | |
0.43887264, 0.50693611, 0.39278795], | |
[ 0.18757584, 0.42143438, 0.32836702, 1. , 0.35030067, | |
0.30110947, 0.41055555, 0.34338879, 0.34336307, 0.37704433, | |
0.38810141, 0.34702081, 0.24171562, 0.25433078, 0.24696241, | |
0.2570884 , 0.4465962 , 0.45263213], | |
[ 0.40608522, 0.4736689 , 0.52603065, 0.35030067, 1. , | |
0.54372584, 0.58300258, 0.56674191, 0.555266 , 0.66599594, | |
0.68567555, 0.55716359, 0.62997328, 0.65638548, 0.61219615, | |
0.63183318, 0.54464151, 0.44293752], | |
[ 0.37503981, 0.50675565, 0.4761106 , 0.37561813, 0.60419403, | |
0.77912403, 0.64595517, 0.85939662, 0.46037144, 0.52348817, | |
0.55875094, 0.37741886, 0.455671 , 0.49434392, 0.38479954, | |
0.41804074, 0.47285709, 0.57236283], | |
[ 0.35448462, 0.50576632, 0.51030446, 0.35841033, 0.55106903, | |
0.50257274, 0.52591451, 0.4283053 , 0.39991808, 0.42327211, | |
0.42853819, 0.42071825, 0.41240559, 0.42259136, 0.38125352, | |
0.3868255 , 0.47604934, 0.51811717], | |
[ 0.22598555, 0.5053299 , 0.36301185, 0.38002282, 0.49700941, | |
0.45625243, 0.62876479, 0.4112051 , 0.33944371, 0.48322639, | |
0.50318714, 0.29207815, 0.38801966, 0.41119094, 0.29199072, | |
0.31021029, 0.41594871, 0.54961962], | |
[ 0.23893579, 0.51999208, 0.3622627 , 0.41055555, 0.58300258, | |
0.68874251, 1. , 0.56977937, 0.49918447, 0.48484363, | |
0.51615925, 0.41222306, 0.49535971, 0.53134951, 0.3807616 , | |
0.41050298, 0.48675801, 0.51112664], | |
[ 0.33064262, 0.306412 , 0.60679935, 0.25592294, 0.58738706, | |
0.40379627, 0.39679161, 0.33618385, 0.39235148, 0.45474013, | |
0.4648476 , 0.59306762, 0.58976007, 0.60778661, 0.55400397, | |
0.56551297, 0.3698029 , 0.33860535], | |
[ 0.28923404, 0.6153155 , 0.40999322, 0.45263213, 0.44293752, | |
0.60359359, 0.51112664, 0.46578181, 0.45656936, 0.38142307, | |
0.38525582, 0.33327223, 0.35360175, 0.36156453, 0.3384992 , | |
0.34261229, 0.49297863, 1. ], | |
[ 0.27986573, 0.47139544, 0.39543695, 0.30110947, 0.54372584, | |
1. , 0.68874251, 0.67765588, 0.48690078, 0.44010641, | |
0.44921156, 0.32321099, 0.48311542, 0.4982002 , 0.39378102, | |
0.40297733, 0.45309735, 0.60359359], | |
[ 0.4211553 , 0.4310595 , 0.53669903, 0.34702081, 0.55716359, | |
0.32321099, 0.41222306, 0.25721705, 0.36633509, 0.5397475 , | |
0.56429928, 1. , 0.55796926, 0.58842844, 0.57930828, | |
0.60410597, 0.41615326, 0.33327223], | |
[ 0.36915778, 0.42938598, 0.42286557, 0.24696241, 0.61219615, | |
0.39378102, 0.3807616 , 0.28089866, 0.48450394, 0.77400821, | |
0.68813814, 0.57930828, 0.8856886 , 0.81673412, 1. , | |
0.92279623, 0.46969152, 0.3384992 ], | |
[ 0.41377746, 0.46384034, 0.45403656, 0.2570884 , 0.63183318, | |
0.40297733, 0.41050298, 0.332879 , 0.48799542, 0.69231828, | |
0.77015091, 0.60410597, 0.79788484, 0.88232104, 0.92279623, | |
1. , 0.45685017, 0.34261229], | |
[ 0.32485771, 0.39371443, 0.47266611, 0.24171562, 0.62997328, | |
0.48311542, 0.49535971, 0.32477932, 0.51486622, 0.79353556, | |
0.69768738, 0.55796926, 1. , 0.92373745, 0.8856886 , | |
0.79788484, 0.47883134, 0.35360175], | |
[ 0.37248222, 0.4319752 , 0.50925436, 0.25433078, 0.65638548, | |
0.4982002 , 0.53134951, 0.38057074, 0.52403969, 0.72035243, | |
0.78711147, 0.58842844, 0.92373745, 1. , 0.81673412, | |
0.88232104, 0.47109935, 0.36156453], | |
[ 0.36865639, 0.46514124, 0.45202552, 0.37704433, 0.66599594, | |
0.44010641, 0.48484363, 0.39636574, 0.50175258, 1. , | |
0.91320249, 0.5397475 , 0.79353556, 0.72035243, 0.77400821, | |
0.69231828, 0.59087008, 0.38142307], | |
[ 0.41500332, 0.50083501, 0.48421, 0.38810141, 0.68567555, | |
0.44921156, 0.51615925, 0.45156472, 0.50438158, 0.91320249, | |
1., 0.56429928, 0.69768738, 0.78711147, 0.68813814, | |
0.77015091, 0.57698754, 0.38525582]]) | |
def normalize_adj(adj): | |
"""Symmetrically normalize adjacency matrix.""" | |
adj = sp.coo_matrix(adj) | |
rowsum = np.array(adj.sum(1)) | |
d_inv_sqrt = np.power(rowsum, -0.5).flatten() | |
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. | |
d_mat_inv_sqrt = sp.diags(d_inv_sqrt) | |
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() | |
def preprocess_adj(adj): | |
"""Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" | |
adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) # return a adjacency matrix of adj ( type is numpy) | |
adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) # | |
# return sparse_to_tuple(adj_normalized) | |
return adj_normalized.todense() | |
def row_norm(inputs): | |
outputs = [] | |
for x in inputs: | |
xsum = x.sum() | |
x = x / xsum | |
outputs.append(x) | |
return outputs | |
def normalize_adj_torch(adj): | |
# print(adj.size()) | |
if len(adj.size()) == 4: | |
new_r = torch.zeros(adj.size()).type_as(adj) | |
for i in range(adj.size(1)): | |
adj_item = adj[0,i] | |
rowsum = adj_item.sum(1) | |
d_inv_sqrt = rowsum.pow_(-0.5) | |
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0 | |
d_mat_inv_sqrt = torch.diag(d_inv_sqrt) | |
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt) | |
new_r[0,i,...] = r | |
return new_r | |
rowsum = adj.sum(1) | |
d_inv_sqrt = rowsum.pow_(-0.5) | |
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0 | |
d_mat_inv_sqrt = torch.diag(d_inv_sqrt) | |
r = torch.matmul(torch.matmul(d_mat_inv_sqrt,adj),d_mat_inv_sqrt) | |
return r | |
# def row_norm(adj): | |
if __name__ == '__main__': | |
a= row_norm(cihp2pascal_adj) | |
print(a) | |
print(cihp2pascal_adj) | |
# print(a.shape) | |