wizzseen's picture
Upload 948 files
8a6df40 verified
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import torch
pascal_graph = {0:[0],
1:[1, 2],
2:[1, 2, 3, 5],
3:[2, 3, 4],
4:[3, 4],
5:[2, 5, 6],
6:[5, 6]}
cihp_graph = {0: [],
1: [2, 13],
2: [1, 13],
3: [14, 15],
4: [13],
5: [6, 7, 9, 10, 11, 12, 14, 15],
6: [5, 7, 10, 11, 14, 15, 16, 17],
7: [5, 6, 9, 10, 11, 12, 14, 15],
8: [16, 17, 18, 19],
9: [5, 7, 10, 16, 17, 18, 19],
10:[5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
11:[5, 6, 7, 10, 13],
12:[5, 7, 10, 16, 17],
13:[1, 2, 4, 10, 11],
14:[3, 5, 6, 7, 10],
15:[3, 5, 6, 7, 10],
16:[6, 8, 9, 10, 12, 18],
17:[6, 8, 9, 10, 12, 19],
18:[8, 9, 16],
19:[8, 9, 17]}
atr_graph = {0: [],
1: [2, 11],
2: [1, 11],
3: [11],
4: [5, 6, 7, 11, 14, 15, 17],
5: [4, 6, 7, 8, 12, 13],
6: [4,5,7,8,9,10,12,13],
7: [4,11,12,13,14,15],
8: [5,6],
9: [6, 12],
10:[6, 13],
11:[1,2,3,4,7,14,15,17],
12:[5,6,7,9],
13:[5,6,7,10],
14:[4,7,11,16],
15:[4,7,11,16],
16:[14,15],
17:[4,11],
}
cihp2pascal_adj = np.array([[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
cihp2pascal_nlp_adj = \
np.array([[ 1., 0.35333052, 0.32727194, 0.17418084, 0.18757584,
0.40608522, 0.37503981, 0.35448462, 0.22598555, 0.23893579,
0.33064262, 0.28923404, 0.27986573, 0.4211553 , 0.36915778,
0.41377746, 0.32485771, 0.37248222, 0.36865639, 0.41500332],
[ 0.39615879, 0.46201529, 0.52321467, 0.30826114, 0.25669527,
0.54747773, 0.3670523 , 0.3901983 , 0.27519473, 0.3433325 ,
0.52728509, 0.32771333, 0.34819325, 0.63882953, 0.68042925,
0.69368576, 0.63395791, 0.65344337, 0.59538781, 0.6071375 ],
[ 0.16373166, 0.21663339, 0.3053872 , 0.28377612, 0.1372435 ,
0.4448808 , 0.29479995, 0.31092595, 0.22703953, 0.33983576,
0.75778818, 0.2619818 , 0.37069392, 0.35184867, 0.49877512,
0.49979437, 0.51853277, 0.52517541, 0.32517741, 0.32377309],
[ 0.32687232, 0.38482461, 0.37693463, 0.41610834, 0.20415749,
0.76749079, 0.35139853, 0.3787411 , 0.28411737, 0.35155421,
0.58792618, 0.31141718, 0.40585111, 0.51189218, 0.82042737,
0.8342413 , 0.70732188, 0.72752501, 0.60327325, 0.61431337],
[ 0.34069369, 0.34817292, 0.37525998, 0.36497069, 0.17841617,
0.69746208, 0.31731463, 0.34628951, 0.25167277, 0.32072379,
0.56711286, 0.24894776, 0.37000453, 0.52600859, 0.82483993,
0.84966274, 0.7033991 , 0.73449378, 0.56649608, 0.58888791],
[ 0.28477487, 0.35139564, 0.42742352, 0.41664321, 0.20004676,
0.78566833, 0.42237487, 0.41048549, 0.37933812, 0.46542516,
0.62444759, 0.3274493 , 0.49466009, 0.49314658, 0.71244233,
0.71497003, 0.8234787 , 0.83566589, 0.62597135, 0.62626812],
[ 0.3011378 , 0.31775977, 0.42922647, 0.36896257, 0.17597556,
0.72214655, 0.39162804, 0.38137872, 0.34980296, 0.43818419,
0.60879174, 0.26762545, 0.46271161, 0.51150476, 0.72318109,
0.73678399, 0.82620388, 0.84942166, 0.5943811 , 0.60607602]])
pascal2atr_nlp_adj = \
np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522,
0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639,
0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778,
0.41377746, 0.32006291, 0.28923404],
[ 0.39615879, 0.46201529, 0.52321467, 0.25669527, 0.54747773,
0.34819325, 0.3433325 , 0.26603942, 0.45162929, 0.59538781,
0.6071375 , 0.63882953, 0.63395791, 0.65344337, 0.68042925,
0.69368576, 0.44354613, 0.32771333],
[ 0.16373166, 0.21663339, 0.3053872 , 0.1372435 , 0.4448808 ,
0.37069392, 0.33983576, 0.26563416, 0.35443504, 0.32517741,
0.32377309, 0.35184867, 0.51853277, 0.52517541, 0.49877512,
0.49979437, 0.21750868, 0.2619818 ],
[ 0.32687232, 0.38482461, 0.37693463, 0.20415749, 0.76749079,
0.40585111, 0.35155421, 0.28271333, 0.52684576, 0.60327325,
0.61431337, 0.51189218, 0.70732188, 0.72752501, 0.82042737,
0.8342413 , 0.40137029, 0.31141718],
[ 0.34069369, 0.34817292, 0.37525998, 0.17841617, 0.69746208,
0.37000453, 0.32072379, 0.27268885, 0.47426719, 0.56649608,
0.58888791, 0.52600859, 0.7033991 , 0.73449378, 0.82483993,
0.84966274, 0.37830796, 0.24894776],
[ 0.28477487, 0.35139564, 0.42742352, 0.20004676, 0.78566833,
0.49466009, 0.46542516, 0.32662614, 0.55780359, 0.62597135,
0.62626812, 0.49314658, 0.8234787 , 0.83566589, 0.71244233,
0.71497003, 0.41223219, 0.3274493 ],
[ 0.3011378 , 0.31775977, 0.42922647, 0.17597556, 0.72214655,
0.46271161, 0.43818419, 0.3192333 , 0.50979216, 0.5943811 ,
0.60607602, 0.51150476, 0.82620388, 0.84942166, 0.72318109,
0.73678399, 0.39259827, 0.26762545]])
cihp2atr_nlp_adj = np.array([[ 1., 0.35333052, 0.32727194, 0.18757584, 0.40608522,
0.27986573, 0.23893579, 0.27600672, 0.30964391, 0.36865639,
0.41500332, 0.4211553 , 0.32485771, 0.37248222, 0.36915778,
0.41377746, 0.32006291, 0.28923404],
[ 0.35333052, 1. , 0.39206695, 0.42143438, 0.4736689 ,
0.47139544, 0.51999208, 0.38354847, 0.45628529, 0.46514124,
0.50083501, 0.4310595 , 0.39371443, 0.4319752 , 0.42938598,
0.46384034, 0.44833757, 0.6153155 ],
[ 0.32727194, 0.39206695, 1. , 0.32836702, 0.52603065,
0.39543695, 0.3622627 , 0.43575346, 0.33866223, 0.45202552,
0.48421 , 0.53669903, 0.47266611, 0.50925436, 0.42286557,
0.45403656, 0.37221304, 0.40999322],
[ 0.17418084, 0.46892601, 0.25774838, 0.31816231, 0.39330317,
0.34218382, 0.48253904, 0.22084125, 0.41335728, 0.52437572,
0.5191713 , 0.33576117, 0.44230914, 0.44250678, 0.44330833,
0.43887264, 0.50693611, 0.39278795],
[ 0.18757584, 0.42143438, 0.32836702, 1. , 0.35030067,
0.30110947, 0.41055555, 0.34338879, 0.34336307, 0.37704433,
0.38810141, 0.34702081, 0.24171562, 0.25433078, 0.24696241,
0.2570884 , 0.4465962 , 0.45263213],
[ 0.40608522, 0.4736689 , 0.52603065, 0.35030067, 1. ,
0.54372584, 0.58300258, 0.56674191, 0.555266 , 0.66599594,
0.68567555, 0.55716359, 0.62997328, 0.65638548, 0.61219615,
0.63183318, 0.54464151, 0.44293752],
[ 0.37503981, 0.50675565, 0.4761106 , 0.37561813, 0.60419403,
0.77912403, 0.64595517, 0.85939662, 0.46037144, 0.52348817,
0.55875094, 0.37741886, 0.455671 , 0.49434392, 0.38479954,
0.41804074, 0.47285709, 0.57236283],
[ 0.35448462, 0.50576632, 0.51030446, 0.35841033, 0.55106903,
0.50257274, 0.52591451, 0.4283053 , 0.39991808, 0.42327211,
0.42853819, 0.42071825, 0.41240559, 0.42259136, 0.38125352,
0.3868255 , 0.47604934, 0.51811717],
[ 0.22598555, 0.5053299 , 0.36301185, 0.38002282, 0.49700941,
0.45625243, 0.62876479, 0.4112051 , 0.33944371, 0.48322639,
0.50318714, 0.29207815, 0.38801966, 0.41119094, 0.29199072,
0.31021029, 0.41594871, 0.54961962],
[ 0.23893579, 0.51999208, 0.3622627 , 0.41055555, 0.58300258,
0.68874251, 1. , 0.56977937, 0.49918447, 0.48484363,
0.51615925, 0.41222306, 0.49535971, 0.53134951, 0.3807616 ,
0.41050298, 0.48675801, 0.51112664],
[ 0.33064262, 0.306412 , 0.60679935, 0.25592294, 0.58738706,
0.40379627, 0.39679161, 0.33618385, 0.39235148, 0.45474013,
0.4648476 , 0.59306762, 0.58976007, 0.60778661, 0.55400397,
0.56551297, 0.3698029 , 0.33860535],
[ 0.28923404, 0.6153155 , 0.40999322, 0.45263213, 0.44293752,
0.60359359, 0.51112664, 0.46578181, 0.45656936, 0.38142307,
0.38525582, 0.33327223, 0.35360175, 0.36156453, 0.3384992 ,
0.34261229, 0.49297863, 1. ],
[ 0.27986573, 0.47139544, 0.39543695, 0.30110947, 0.54372584,
1. , 0.68874251, 0.67765588, 0.48690078, 0.44010641,
0.44921156, 0.32321099, 0.48311542, 0.4982002 , 0.39378102,
0.40297733, 0.45309735, 0.60359359],
[ 0.4211553 , 0.4310595 , 0.53669903, 0.34702081, 0.55716359,
0.32321099, 0.41222306, 0.25721705, 0.36633509, 0.5397475 ,
0.56429928, 1. , 0.55796926, 0.58842844, 0.57930828,
0.60410597, 0.41615326, 0.33327223],
[ 0.36915778, 0.42938598, 0.42286557, 0.24696241, 0.61219615,
0.39378102, 0.3807616 , 0.28089866, 0.48450394, 0.77400821,
0.68813814, 0.57930828, 0.8856886 , 0.81673412, 1. ,
0.92279623, 0.46969152, 0.3384992 ],
[ 0.41377746, 0.46384034, 0.45403656, 0.2570884 , 0.63183318,
0.40297733, 0.41050298, 0.332879 , 0.48799542, 0.69231828,
0.77015091, 0.60410597, 0.79788484, 0.88232104, 0.92279623,
1. , 0.45685017, 0.34261229],
[ 0.32485771, 0.39371443, 0.47266611, 0.24171562, 0.62997328,
0.48311542, 0.49535971, 0.32477932, 0.51486622, 0.79353556,
0.69768738, 0.55796926, 1. , 0.92373745, 0.8856886 ,
0.79788484, 0.47883134, 0.35360175],
[ 0.37248222, 0.4319752 , 0.50925436, 0.25433078, 0.65638548,
0.4982002 , 0.53134951, 0.38057074, 0.52403969, 0.72035243,
0.78711147, 0.58842844, 0.92373745, 1. , 0.81673412,
0.88232104, 0.47109935, 0.36156453],
[ 0.36865639, 0.46514124, 0.45202552, 0.37704433, 0.66599594,
0.44010641, 0.48484363, 0.39636574, 0.50175258, 1. ,
0.91320249, 0.5397475 , 0.79353556, 0.72035243, 0.77400821,
0.69231828, 0.59087008, 0.38142307],
[ 0.41500332, 0.50083501, 0.48421, 0.38810141, 0.68567555,
0.44921156, 0.51615925, 0.45156472, 0.50438158, 0.91320249,
1., 0.56429928, 0.69768738, 0.78711147, 0.68813814,
0.77015091, 0.57698754, 0.38525582]])
def normalize_adj(adj):
"""Symmetrically normalize adjacency matrix."""
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
def preprocess_adj(adj):
"""Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) # return a adjacency matrix of adj ( type is numpy)
adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) #
# return sparse_to_tuple(adj_normalized)
return adj_normalized.todense()
def row_norm(inputs):
outputs = []
for x in inputs:
xsum = x.sum()
x = x / xsum
outputs.append(x)
return outputs
def normalize_adj_torch(adj):
# print(adj.size())
if len(adj.size()) == 4:
new_r = torch.zeros(adj.size()).type_as(adj)
for i in range(adj.size(1)):
adj_item = adj[0,i]
rowsum = adj_item.sum(1)
d_inv_sqrt = rowsum.pow_(-0.5)
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
new_r[0,i,...] = r
return new_r
rowsum = adj.sum(1)
d_inv_sqrt = rowsum.pow_(-0.5)
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
r = torch.matmul(torch.matmul(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)
return r
# def row_norm(adj):
if __name__ == '__main__':
a= row_norm(cihp2pascal_adj)
print(a)
print(cihp2pascal_adj)
# print(a.shape)