ngaggion commited on
Commit
e87a462
·
1 Parent(s): 681b570

First push

Browse files
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import cv2
4
+
5
+ from models.HybridGNet2IGSC import Hybrid
6
+ from utils.utils import scipy_to_torch_sparse, genMatrixesLungsHeart
7
+ import scipy.sparse as sp
8
+ import torch
9
+
10
+
11
+ def getDenseMask(landmarks):
12
+ RL = landmarks[0:44]
13
+ LL = landmarks[44:94]
14
+ H = landmarks[94:]
15
+
16
+ img = np.zeros([1024,1024], dtype = 'uint8')
17
+
18
+ RL = RL.reshape(-1, 1, 2).astype('int')
19
+ LL = LL.reshape(-1, 1, 2).astype('int')
20
+ H = H.reshape(-1, 1, 2).astype('int')
21
+
22
+ img = cv2.drawContours(img, [RL], -1, 1, -1)
23
+ img = cv2.drawContours(img, [LL], -1, 1, -1)
24
+ img = cv2.drawContours(img, [H], -1, 2, -1)
25
+
26
+ return img
27
+
28
+
29
+ def drawOnTop(img, landmarks):
30
+ output = getDenseMask(landmarks)
31
+
32
+ image = np.zeros([1024, 1024, 3])
33
+ image[:,:,0] = img + 0.3 * (output == 1).astype('float') - 0.1 * (output == 2).astype('float')
34
+ image[:,:,1] = img + 0.3 * (output == 2).astype('float') - 0.1 * (output == 1).astype('float')
35
+ image[:,:,2] = img - 0.1 * (output == 1).astype('float') - 0.2 * (output == 2).astype('float')
36
+
37
+ image = np.clip(image, 0, 1)
38
+
39
+ return image
40
+
41
+
42
+ def loadModel(device):
43
+ A, AD, D, U = genMatrixesLungsHeart()
44
+ N1 = A.shape[0]
45
+ N2 = AD.shape[0]
46
+
47
+ A = sp.csc_matrix(A).tocoo()
48
+ AD = sp.csc_matrix(AD).tocoo()
49
+ D = sp.csc_matrix(D).tocoo()
50
+ U = sp.csc_matrix(U).tocoo()
51
+
52
+ D_ = [D.copy()]
53
+ U_ = [U.copy()]
54
+
55
+ config = {}
56
+
57
+ config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
58
+ A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
59
+
60
+ A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_, D_, U_))
61
+
62
+ config['latents'] = 64
63
+ config['inputsize'] = 1024
64
+
65
+ f = 32
66
+ config['filters'] = [2, f, f, f, f//2, f//2, f//2]
67
+ config['skip_features'] = f
68
+
69
+ hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
70
+ hybrid.load_state_dict(torch.load("weights/bestMSE.pt"))
71
+ hybrid.eval()
72
+
73
+ return hybrid
74
+
75
+
76
+ def pad_to_square(img):
77
+ h, w = img.shape[:2]
78
+
79
+ if h > w:
80
+ padw = (h - w)
81
+ auxw = padw % 2
82
+ img = np.pad(img, ((0, 0), (padw//2, padw//2 + auxw)), 'constant')
83
+
84
+ padh = 0
85
+ auxh = 0
86
+
87
+ else:
88
+ padh = (w - h)
89
+ auxh = padh % 2
90
+ img = np.pad(img, ((padh//2, padh//2 + auxh), (0, 0)), 'constant')
91
+
92
+ padw = 0
93
+ auxw = 0
94
+
95
+ return img, (padh, padw, auxh, auxw)
96
+
97
+
98
+ def preprocess(input_img):
99
+ img, padding = pad_to_square(input_img)
100
+
101
+ h, w = img.shape[:2]
102
+ if h != 1024 or w != 1024:
103
+ img = cv2.resize(img, (1024, 1024), interpolation = cv2.INTER_CUBIC)
104
+
105
+ return img, (h, w, padding)
106
+
107
+
108
+ def segment(input_img):
109
+ input_img = cv2.imread(input_img, 0) / 255.0
110
+
111
+ img, (h, w, padding) = preprocess(input_img)
112
+
113
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
114
+ hybrid = loadModel(device)
115
+
116
+ data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
117
+
118
+ with torch.no_grad():
119
+ output = hybrid(data)[0].cpu().numpy().reshape(-1, 2) * 1024
120
+
121
+ return drawOnTop(img, output)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ demo = gr.Interface(segment, gr.Image(type="filepath"), "image")
126
+ demo.launch()
models/HybridGNet2IGSC.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.modelUtils import ChebConv, Pool, residualBlock
5
+ import torchvision.ops.roi_align as roi_align
6
+
7
+ import numpy as np
8
+
9
+ class EncoderConv(nn.Module):
10
+ def __init__(self, latents = 64, hw = 32):
11
+ super(EncoderConv, self).__init__()
12
+
13
+ self.latents = latents
14
+ self.c = 4
15
+
16
+ self.size = self.c * np.array([2,4,8,16,32], dtype = np.intc)
17
+
18
+ self.maxpool = nn.MaxPool2d(2)
19
+
20
+ self.dconv_down1 = residualBlock(1, self.size[0])
21
+ self.dconv_down2 = residualBlock(self.size[0], self.size[1])
22
+ self.dconv_down3 = residualBlock(self.size[1], self.size[2])
23
+ self.dconv_down4 = residualBlock(self.size[2], self.size[3])
24
+ self.dconv_down5 = residualBlock(self.size[3], self.size[4])
25
+ self.dconv_down6 = residualBlock(self.size[4], self.size[4])
26
+
27
+ self.fc_mu = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
28
+ self.fc_logvar = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
29
+
30
+ def forward(self, x):
31
+ x = self.dconv_down1(x)
32
+ x = self.maxpool(x)
33
+
34
+ x = self.dconv_down2(x)
35
+ x = self.maxpool(x)
36
+
37
+ conv3 = self.dconv_down3(x)
38
+ x = self.maxpool(conv3)
39
+
40
+ conv4 = self.dconv_down4(x)
41
+ x = self.maxpool(conv4)
42
+
43
+ conv5 = self.dconv_down5(x)
44
+ x = self.maxpool(conv5)
45
+
46
+ conv6 = self.dconv_down6(x)
47
+
48
+ x = conv6.view(conv6.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
49
+
50
+ x_mu = self.fc_mu(x)
51
+ x_logvar = self.fc_logvar(x)
52
+
53
+ return x_mu, x_logvar, conv6, conv5
54
+
55
+
56
+ class SkipBlock(nn.Module):
57
+ def __init__(self, in_filters, window):
58
+ super(SkipBlock, self).__init__()
59
+
60
+ self.window = window
61
+ self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False)
62
+
63
+ def lookup(self, pos, layer, salida = (1,1)):
64
+ B = pos.shape[0]
65
+ N = pos.shape[1]
66
+ F = layer.shape[1]
67
+ h = layer.shape[-1]
68
+
69
+ ## Scale from [0,1] to [0, h]
70
+ pos = pos * h
71
+
72
+ _x1 = (self.window[0] // 2) * 1.0
73
+ _x2 = (self.window[0] // 2 + 1) * 1.0
74
+ _y1 = (self.window[1] // 2) * 1.0
75
+ _y2 = (self.window[1] // 2 + 1) * 1.0
76
+
77
+ boxes = []
78
+ for batch in range(0, B):
79
+ x1 = pos[batch,:,0].reshape(-1, 1) - _x1
80
+ x2 = pos[batch,:,0].reshape(-1, 1) + _x2
81
+ y1 = pos[batch,:,1].reshape(-1, 1) - _y1
82
+ y2 = pos[batch,:,1].reshape(-1, 1) + _y2
83
+
84
+ aux = torch.cat([x1, y1, x2, y2], axis = 1)
85
+ boxes.append(aux)
86
+
87
+ skip = roi_align(layer, boxes, output_size = salida, aligned=True)
88
+ vista = skip.view([B, N, -1])
89
+
90
+ return vista
91
+
92
+ def forward(self, x, adj, conv_layer):
93
+ pos = self.graphConv_pre(x, adj)
94
+ skip = self.lookup(pos, conv_layer)
95
+
96
+ return torch.cat((x, skip, pos), axis = 2), pos
97
+
98
+
99
+ class Hybrid(nn.Module):
100
+ def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices):
101
+ super(Hybrid, self).__init__()
102
+
103
+ self.config = config
104
+ hw = config['inputsize'] // 32
105
+ self.z = config['latents']
106
+ self.encoder = EncoderConv(latents = self.z, hw = hw)
107
+
108
+ self.downsample_matrices = downsample_matrices
109
+ self.upsample_matrices = upsample_matrices
110
+ self.adjacency_matrices = adjacency_matrices
111
+ self.kld_weight = 1e-5
112
+
113
+ n_nodes = config['n_nodes']
114
+ self.filters = config['filters']
115
+ self.K = 6
116
+ self.window = (3,3)
117
+
118
+ # Genero la capa fully connected del decoder
119
+ outshape = self.filters[-1] * n_nodes[-1]
120
+ self.dec_lin = torch.nn.Linear(self.z, outshape)
121
+
122
+ self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1])
123
+ self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2])
124
+ self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3])
125
+ self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4])
126
+ self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5])
127
+
128
+ outsize1 = self.encoder.size[4]
129
+ outsize2 = self.encoder.size[4]
130
+
131
+ # Guardo las capas de convoluciones en grafo
132
+ self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K)
133
+ self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K)
134
+
135
+ self.SC_1 = SkipBlock(self.filters[4], self.window)
136
+
137
+ self.graphConv_up4 = ChebConv(self.filters[4] + outsize1 + 2, self.filters[3], self.K)
138
+ self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K)
139
+
140
+ self.SC_2 = SkipBlock(self.filters[2], self.window)
141
+
142
+ self.graphConv_up2 = ChebConv(self.filters[2] + outsize2 + 2, self.filters[1], self.K)
143
+ self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False)
144
+
145
+ self.pool = Pool()
146
+
147
+ self.reset_parameters()
148
+
149
+ def reset_parameters(self):
150
+ torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1)
151
+
152
+
153
+ def sampling(self, mu, log_var):
154
+ std = torch.exp(0.5*log_var)
155
+ eps = torch.randn_like(std)
156
+ return eps.mul(std).add_(mu)
157
+
158
+
159
+ def forward(self, x):
160
+ self.mu, self.log_var, conv6, conv5 = self.encoder(x)
161
+
162
+ if self.training:
163
+ z = self.sampling(self.mu, self.log_var)
164
+ else:
165
+ z = self.mu
166
+
167
+ x = self.dec_lin(z)
168
+ x = F.relu(x)
169
+
170
+ x = x.reshape(x.shape[0], -1, self.filters[-1])
171
+
172
+ x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices())
173
+ x = self.normalization6u(x)
174
+ x = F.relu(x)
175
+
176
+ x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices())
177
+ x = self.normalization5u(x)
178
+ x = F.relu(x)
179
+
180
+ x, pos1 = self.SC_1(x, self.adjacency_matrices[3]._indices(), conv6)
181
+
182
+ x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices())
183
+ x = self.normalization4u(x)
184
+ x = F.relu(x)
185
+
186
+ x = self.pool(x, self.upsample_matrices[0])
187
+
188
+ x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices())
189
+ x = self.normalization3u(x)
190
+ x = F.relu(x)
191
+
192
+ x, pos2 = self.SC_2(x, self.adjacency_matrices[1]._indices(), conv5)
193
+
194
+ x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices())
195
+ x = self.normalization2u(x)
196
+ x = F.relu(x)
197
+
198
+ x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # Sin relu y sin bias
199
+
200
+ return x, pos1, pos2
models/modelUtils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch_geometric.nn.conv import MessagePassing
2
+ from torch_geometric.nn.conv.cheb_conv import ChebConv
3
+ from torch_geometric.nn.inits import zeros, normal
4
+
5
+ # We change the default initialization from zeros to a normal distribution
6
+ class ChebConv(ChebConv):
7
+ def reset_parameters(self):
8
+ for lin in self.lins:
9
+ normal(lin, mean = 0, std = 0.1)
10
+ #lin.reset_parameters()
11
+ normal(self.bias, mean = 0, std = 0.1)
12
+ #zeros(self.bias)
13
+
14
+ # Pooling from COMA: https://github.com/pixelite1201/pytorch_coma/blob/master/layers.py
15
+ class Pool(MessagePassing):
16
+ def __init__(self):
17
+ # source_to_target is the default value for flow, but is specified here for explicitness
18
+ super(Pool, self).__init__(flow='source_to_target')
19
+
20
+ def forward(self, x, pool_mat, dtype=None):
21
+ pool_mat = pool_mat.transpose(0, 1)
22
+ out = self.propagate(edge_index=pool_mat._indices(), x=x, norm=pool_mat._values(), size=pool_mat.size())
23
+ return out
24
+
25
+ def message(self, x_j, norm):
26
+ return norm.view(1, -1, 1) * x_j
27
+
28
+
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ class residualBlock(nn.Module):
33
+ def __init__(self, in_channels, out_channels, stride=1):
34
+ """
35
+ Args:
36
+ in_channels (int): Number of input channels.
37
+ out_channels (int): Number of output channels.
38
+ stride (int): Controls the stride.
39
+ """
40
+ super(residualBlock, self).__init__()
41
+
42
+ self.skip = nn.Sequential()
43
+
44
+ if stride != 1 or in_channels != out_channels:
45
+ self.skip = nn.Sequential(
46
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
47
+ nn.BatchNorm2d(out_channels, track_running_stats=False))
48
+ else:
49
+ self.skip = None
50
+
51
+ self.block = nn.Sequential(nn.BatchNorm2d(in_channels, track_running_stats=False),
52
+ nn.ReLU(inplace=True),
53
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
54
+ nn.BatchNorm2d(out_channels, track_running_stats=False),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
57
+ )
58
+
59
+ def forward(self, x):
60
+ identity = x
61
+ out = self.block(x)
62
+
63
+ if self.skip is not None:
64
+ identity = self.skip(x)
65
+
66
+ out += identity
67
+ out = F.relu(out)
68
+
69
+ return out
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pytorch==2.0.1
2
+ numpy==1.25.0
3
+ opencv-python==4.8.0.74
4
+ scipy==1.10.1
5
+ pyg=2.3.0
utils/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.sparse as sp
3
+ import torch
4
+
5
+ def scipy_to_torch_sparse(scp_matrix):
6
+ values = scp_matrix.data
7
+ indices = np.vstack((scp_matrix.row, scp_matrix.col))
8
+ i = torch.LongTensor(indices)
9
+ v = torch.FloatTensor(values)
10
+ shape = scp_matrix.shape
11
+
12
+ sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
13
+ return sparse_tensor
14
+
15
+ ## Adjacency Matrix
16
+ def mOrgan(N):
17
+ sub = np.zeros([N, N])
18
+ for i in range(0, N):
19
+ sub[i, i-1] = 1
20
+ sub[i, (i+1)%N] = 1
21
+ return sub
22
+
23
+ ## Downsampling Matrix
24
+ def mOrganD(N):
25
+ N2 = int(np.ceil(N/2))
26
+ sub = np.zeros([N2, N])
27
+
28
+ for i in range(0, N2):
29
+ if (2*i+1) == N:
30
+ sub[i, 2*i] = 1
31
+ else:
32
+ sub[i, 2*i] = 1/2
33
+ sub[i, 2*i+1] = 1/2
34
+
35
+ return sub
36
+
37
+ def mOrganU(N):
38
+ N2 = int(np.ceil(N/2))
39
+ sub = np.zeros([N, N2])
40
+
41
+ for i in range(0, N):
42
+ if i % 2 == 0:
43
+ sub[i, i//2] = 1
44
+ else:
45
+ sub[i, i//2] = 1/2
46
+ sub[i, (i//2 + 1) % N2] = 1/2
47
+
48
+ return sub
49
+
50
+ def genMatrixesLungsHeart():
51
+ RLUNG = 44
52
+ LLUNG = 50
53
+ HEART = 26
54
+
55
+ Asub1 = mOrgan(RLUNG)
56
+ Asub2 = mOrgan(LLUNG)
57
+ Asub3 = mOrgan(HEART)
58
+
59
+ ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
60
+ ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
61
+ ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
62
+
63
+ Dsub1 = mOrganD(RLUNG)
64
+ Dsub2 = mOrganD(LLUNG)
65
+ Dsub3 = mOrganD(HEART)
66
+
67
+ Usub1 = mOrganU(RLUNG)
68
+ Usub2 = mOrganU(LLUNG)
69
+ Usub3 = mOrganU(HEART)
70
+
71
+ p1 = RLUNG
72
+ p2 = p1 + LLUNG
73
+ p3 = p2 + HEART
74
+
75
+ p1_ = int(np.ceil(RLUNG / 2))
76
+ p2_ = p1_ + int(np.ceil(LLUNG / 2))
77
+ p3_ = p2_ + int(np.ceil(HEART / 2))
78
+
79
+ A = np.zeros([p3, p3])
80
+
81
+ A[:p1, :p1] = Asub1
82
+ A[p1:p2, p1:p2] = Asub2
83
+ A[p2:p3, p2:p3] = Asub3
84
+
85
+ AD = np.zeros([p3_, p3_])
86
+
87
+ AD[:p1_, :p1_] = ADsub1
88
+ AD[p1_:p2_, p1_:p2_] = ADsub2
89
+ AD[p2_:p3_, p2_:p3_] = ADsub3
90
+
91
+ D = np.zeros([p3_, p3])
92
+
93
+ D[:p1_, :p1] = Dsub1
94
+ D[p1_:p2_, p1:p2] = Dsub2
95
+ D[p2_:p3_, p2:p3] = Dsub3
96
+
97
+ U = np.zeros([p3, p3_])
98
+
99
+ U[:p1, :p1_] = Usub1
100
+ U[p1:p2, p1_:p2_] = Usub2
101
+ U[p2:p3, p2_:p3_] = Usub3
102
+
103
+ return A, AD, D, U
weights/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b76bbb8c8ad9774cdf3ac81c9edf04bcc800b3c7f7eacf24ce7249038f3c640f
3
+ size 70083051