First push
Browse files- app.py +126 -0
- models/HybridGNet2IGSC.py +200 -0
- models/modelUtils.py +69 -0
- requirements.txt +5 -0
- utils/utils.py +103 -0
- weights/weights.pt +3 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import gradio as gr
         | 
| 3 | 
            +
            import cv2 
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from models.HybridGNet2IGSC import Hybrid 
         | 
| 6 | 
            +
            from utils.utils import scipy_to_torch_sparse, genMatrixesLungsHeart
         | 
| 7 | 
            +
            import scipy.sparse as sp
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def getDenseMask(landmarks):
         | 
| 12 | 
            +
                RL = landmarks[0:44]
         | 
| 13 | 
            +
                LL = landmarks[44:94]
         | 
| 14 | 
            +
                H = landmarks[94:]
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                img = np.zeros([1024,1024], dtype = 'uint8')
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                RL = RL.reshape(-1, 1, 2).astype('int')
         | 
| 19 | 
            +
                LL = LL.reshape(-1, 1, 2).astype('int')
         | 
| 20 | 
            +
                H = H.reshape(-1, 1, 2).astype('int')
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                img = cv2.drawContours(img, [RL], -1, 1, -1)
         | 
| 23 | 
            +
                img = cv2.drawContours(img, [LL], -1, 1, -1)
         | 
| 24 | 
            +
                img = cv2.drawContours(img, [H], -1, 2, -1)
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
                return img
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def drawOnTop(img, landmarks):
         | 
| 30 | 
            +
                output = getDenseMask(landmarks)
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                image = np.zeros([1024, 1024, 3])
         | 
| 33 | 
            +
                image[:,:,0] = img + 0.3 * (output == 1).astype('float') - 0.1 * (output == 2).astype('float')
         | 
| 34 | 
            +
                image[:,:,1] = img + 0.3 * (output == 2).astype('float') - 0.1 * (output == 1).astype('float') 
         | 
| 35 | 
            +
                image[:,:,2] = img - 0.1 * (output == 1).astype('float') - 0.2 * (output == 2).astype('float') 
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                image = np.clip(image, 0, 1)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return image
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            def loadModel(device):
         | 
| 43 | 
            +
                A, AD, D, U = genMatrixesLungsHeart()
         | 
| 44 | 
            +
                N1 = A.shape[0]
         | 
| 45 | 
            +
                N2 = AD.shape[0]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                A = sp.csc_matrix(A).tocoo()
         | 
| 48 | 
            +
                AD = sp.csc_matrix(AD).tocoo()
         | 
| 49 | 
            +
                D = sp.csc_matrix(D).tocoo()
         | 
| 50 | 
            +
                U = sp.csc_matrix(U).tocoo()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                D_ = [D.copy()]
         | 
| 53 | 
            +
                U_ = [U.copy()]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                config = {}
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
         | 
| 58 | 
            +
                A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_, D_, U_))
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                config['latents'] = 64
         | 
| 63 | 
            +
                config['inputsize'] = 1024
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                f = 32
         | 
| 66 | 
            +
                config['filters'] = [2, f, f, f, f//2, f//2, f//2]
         | 
| 67 | 
            +
                config['skip_features'] = f
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
         | 
| 70 | 
            +
                hybrid.load_state_dict(torch.load("weights/bestMSE.pt"))
         | 
| 71 | 
            +
                hybrid.eval()
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                return hybrid
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def pad_to_square(img):
         | 
| 77 | 
            +
                h, w = img.shape[:2]
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                if h > w:
         | 
| 80 | 
            +
                    padw = (h - w) 
         | 
| 81 | 
            +
                    auxw = padw % 2
         | 
| 82 | 
            +
                    img = np.pad(img, ((0, 0), (padw//2, padw//2 + auxw)), 'constant')
         | 
| 83 | 
            +
                    
         | 
| 84 | 
            +
                    padh = 0
         | 
| 85 | 
            +
                    auxh = 0
         | 
| 86 | 
            +
                    
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    padh = (w - h) 
         | 
| 89 | 
            +
                    auxh = padh % 2
         | 
| 90 | 
            +
                    img = np.pad(img, ((padh//2, padh//2 + auxh), (0, 0)), 'constant')
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    padw = 0
         | 
| 93 | 
            +
                    auxw = 0
         | 
| 94 | 
            +
                    
         | 
| 95 | 
            +
                return img, (padh, padw, auxh, auxw)
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            def preprocess(input_img):
         | 
| 99 | 
            +
                img, padding = pad_to_square(input_img)
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                h, w = img.shape[:2]
         | 
| 102 | 
            +
                if h != 1024 or w != 1024:
         | 
| 103 | 
            +
                    img = cv2.resize(img, (1024, 1024), interpolation = cv2.INTER_CUBIC)
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                return img, (h, w, padding)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                
         | 
| 108 | 
            +
            def segment(input_img):
         | 
| 109 | 
            +
                input_img = cv2.imread(input_img, 0) / 255.0
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                img, (h, w, padding) = preprocess(input_img)    
         | 
| 112 | 
            +
                
         | 
| 113 | 
            +
                device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         | 
| 114 | 
            +
                hybrid = loadModel(device)
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
                data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
         | 
| 117 | 
            +
                
         | 
| 118 | 
            +
                with torch.no_grad():
         | 
| 119 | 
            +
                    output = hybrid(data)[0].cpu().numpy().reshape(-1, 2) * 1024
         | 
| 120 | 
            +
                   
         | 
| 121 | 
            +
                return drawOnTop(img, output)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            if __name__ == "__main__":
         | 
| 125 | 
            +
                demo = gr.Interface(segment, gr.Image(type="filepath"), "image")
         | 
| 126 | 
            +
                demo.launch()
         | 
    	
        models/HybridGNet2IGSC.py
    ADDED
    
    | @@ -0,0 +1,200 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from models.modelUtils import ChebConv, Pool, residualBlock
         | 
| 5 | 
            +
            import torchvision.ops.roi_align as roi_align
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            class EncoderConv(nn.Module):
         | 
| 10 | 
            +
                def __init__(self, latents = 64, hw = 32):
         | 
| 11 | 
            +
                    super(EncoderConv, self).__init__()
         | 
| 12 | 
            +
                    
         | 
| 13 | 
            +
                    self.latents = latents
         | 
| 14 | 
            +
                    self.c = 4
         | 
| 15 | 
            +
                    
         | 
| 16 | 
            +
                    self.size = self.c * np.array([2,4,8,16,32], dtype = np.intc)
         | 
| 17 | 
            +
                    
         | 
| 18 | 
            +
                    self.maxpool = nn.MaxPool2d(2)
         | 
| 19 | 
            +
                    
         | 
| 20 | 
            +
                    self.dconv_down1 = residualBlock(1, self.size[0])
         | 
| 21 | 
            +
                    self.dconv_down2 = residualBlock(self.size[0], self.size[1])
         | 
| 22 | 
            +
                    self.dconv_down3 = residualBlock(self.size[1], self.size[2])
         | 
| 23 | 
            +
                    self.dconv_down4 = residualBlock(self.size[2], self.size[3])
         | 
| 24 | 
            +
                    self.dconv_down5 = residualBlock(self.size[3], self.size[4])
         | 
| 25 | 
            +
                    self.dconv_down6 = residualBlock(self.size[4], self.size[4])
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
                    self.fc_mu = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
         | 
| 28 | 
            +
                    self.fc_logvar = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x):
         | 
| 31 | 
            +
                    x = self.dconv_down1(x)
         | 
| 32 | 
            +
                    x = self.maxpool(x)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    x = self.dconv_down2(x)
         | 
| 35 | 
            +
                    x = self.maxpool(x)
         | 
| 36 | 
            +
                    
         | 
| 37 | 
            +
                    conv3 = self.dconv_down3(x)
         | 
| 38 | 
            +
                    x = self.maxpool(conv3)
         | 
| 39 | 
            +
                    
         | 
| 40 | 
            +
                    conv4 = self.dconv_down4(x)
         | 
| 41 | 
            +
                    x = self.maxpool(conv4)
         | 
| 42 | 
            +
                    
         | 
| 43 | 
            +
                    conv5 = self.dconv_down5(x)
         | 
| 44 | 
            +
                    x = self.maxpool(conv5)
         | 
| 45 | 
            +
                    
         | 
| 46 | 
            +
                    conv6 = self.dconv_down6(x)
         | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                    x = conv6.view(conv6.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
         | 
| 49 | 
            +
                    
         | 
| 50 | 
            +
                    x_mu = self.fc_mu(x)
         | 
| 51 | 
            +
                    x_logvar = self.fc_logvar(x)
         | 
| 52 | 
            +
                            
         | 
| 53 | 
            +
                    return x_mu, x_logvar, conv6, conv5
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            class SkipBlock(nn.Module):
         | 
| 57 | 
            +
                def __init__(self, in_filters, window):
         | 
| 58 | 
            +
                    super(SkipBlock, self).__init__()
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    self.window = window
         | 
| 61 | 
            +
                    self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False) 
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                def lookup(self, pos, layer, salida = (1,1)):
         | 
| 64 | 
            +
                    B = pos.shape[0]
         | 
| 65 | 
            +
                    N = pos.shape[1]
         | 
| 66 | 
            +
                    F = layer.shape[1]
         | 
| 67 | 
            +
                    h = layer.shape[-1]
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    ## Scale from [0,1] to [0, h]
         | 
| 70 | 
            +
                    pos = pos * h
         | 
| 71 | 
            +
                    
         | 
| 72 | 
            +
                    _x1 = (self.window[0] // 2) * 1.0
         | 
| 73 | 
            +
                    _x2 = (self.window[0] // 2 + 1) * 1.0
         | 
| 74 | 
            +
                    _y1 = (self.window[1] // 2) * 1.0
         | 
| 75 | 
            +
                    _y2 = (self.window[1] // 2 + 1) * 1.0
         | 
| 76 | 
            +
                    
         | 
| 77 | 
            +
                    boxes = []
         | 
| 78 | 
            +
                    for batch in range(0, B):
         | 
| 79 | 
            +
                        x1 = pos[batch,:,0].reshape(-1, 1) - _x1
         | 
| 80 | 
            +
                        x2 = pos[batch,:,0].reshape(-1, 1) + _x2
         | 
| 81 | 
            +
                        y1 = pos[batch,:,1].reshape(-1, 1) - _y1
         | 
| 82 | 
            +
                        y2 = pos[batch,:,1].reshape(-1, 1) + _y2
         | 
| 83 | 
            +
                        
         | 
| 84 | 
            +
                        aux = torch.cat([x1, y1, x2, y2], axis = 1)            
         | 
| 85 | 
            +
                        boxes.append(aux)
         | 
| 86 | 
            +
                                
         | 
| 87 | 
            +
                    skip = roi_align(layer, boxes, output_size = salida, aligned=True)
         | 
| 88 | 
            +
                    vista = skip.view([B, N, -1])
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    return vista
         | 
| 91 | 
            +
                
         | 
| 92 | 
            +
                def forward(self, x, adj, conv_layer):
         | 
| 93 | 
            +
                    pos = self.graphConv_pre(x, adj)
         | 
| 94 | 
            +
                    skip = self.lookup(pos, conv_layer)
         | 
| 95 | 
            +
                    
         | 
| 96 | 
            +
                    return torch.cat((x, skip, pos), axis = 2), pos
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
            class Hybrid(nn.Module):
         | 
| 100 | 
            +
                def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices):
         | 
| 101 | 
            +
                    super(Hybrid, self).__init__()
         | 
| 102 | 
            +
                    
         | 
| 103 | 
            +
                    self.config = config
         | 
| 104 | 
            +
                    hw = config['inputsize'] // 32
         | 
| 105 | 
            +
                    self.z = config['latents']
         | 
| 106 | 
            +
                    self.encoder = EncoderConv(latents = self.z, hw = hw)
         | 
| 107 | 
            +
                    
         | 
| 108 | 
            +
                    self.downsample_matrices = downsample_matrices
         | 
| 109 | 
            +
                    self.upsample_matrices = upsample_matrices
         | 
| 110 | 
            +
                    self.adjacency_matrices = adjacency_matrices
         | 
| 111 | 
            +
                    self.kld_weight = 1e-5
         | 
| 112 | 
            +
                            
         | 
| 113 | 
            +
                    n_nodes = config['n_nodes']
         | 
| 114 | 
            +
                    self.filters = config['filters']
         | 
| 115 | 
            +
                    self.K = 6
         | 
| 116 | 
            +
                    self.window = (3,3)
         | 
| 117 | 
            +
                    
         | 
| 118 | 
            +
                    # Genero la capa fully connected del decoder
         | 
| 119 | 
            +
                    outshape = self.filters[-1] * n_nodes[-1]          
         | 
| 120 | 
            +
                    self.dec_lin = torch.nn.Linear(self.z, outshape)
         | 
| 121 | 
            +
                            
         | 
| 122 | 
            +
                    self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1])
         | 
| 123 | 
            +
                    self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2])
         | 
| 124 | 
            +
                    self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3])
         | 
| 125 | 
            +
                    self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4])
         | 
| 126 | 
            +
                    self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5])
         | 
| 127 | 
            +
                    
         | 
| 128 | 
            +
                    outsize1 = self.encoder.size[4]
         | 
| 129 | 
            +
                    outsize2 = self.encoder.size[4]  
         | 
| 130 | 
            +
                                 
         | 
| 131 | 
            +
                    # Guardo las capas de convoluciones en grafo
         | 
| 132 | 
            +
                    self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K)
         | 
| 133 | 
            +
                    self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K)       
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    self.SC_1 = SkipBlock(self.filters[4], self.window)
         | 
| 136 | 
            +
                    
         | 
| 137 | 
            +
                    self.graphConv_up4 = ChebConv(self.filters[4] + outsize1 + 2, self.filters[3], self.K)        
         | 
| 138 | 
            +
                    self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K)
         | 
| 139 | 
            +
                    
         | 
| 140 | 
            +
                    self.SC_2 = SkipBlock(self.filters[2], self.window)
         | 
| 141 | 
            +
                    
         | 
| 142 | 
            +
                    self.graphConv_up2 = ChebConv(self.filters[2] + outsize2 + 2, self.filters[1], self.K)
         | 
| 143 | 
            +
                    self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False)
         | 
| 144 | 
            +
                            
         | 
| 145 | 
            +
                    self.pool = Pool()
         | 
| 146 | 
            +
                    
         | 
| 147 | 
            +
                    self.reset_parameters()
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                def reset_parameters(self):
         | 
| 150 | 
            +
                    torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
                def sampling(self, mu, log_var):
         | 
| 154 | 
            +
                    std = torch.exp(0.5*log_var)
         | 
| 155 | 
            +
                    eps = torch.randn_like(std)
         | 
| 156 | 
            +
                    return eps.mul(std).add_(mu) 
         | 
| 157 | 
            +
                
         | 
| 158 | 
            +
                    
         | 
| 159 | 
            +
                def forward(self, x):
         | 
| 160 | 
            +
                    self.mu, self.log_var, conv6, conv5 = self.encoder(x)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if self.training:
         | 
| 163 | 
            +
                        z = self.sampling(self.mu, self.log_var)
         | 
| 164 | 
            +
                    else:
         | 
| 165 | 
            +
                        z = self.mu
         | 
| 166 | 
            +
                        
         | 
| 167 | 
            +
                    x = self.dec_lin(z)
         | 
| 168 | 
            +
                    x = F.relu(x)
         | 
| 169 | 
            +
                    
         | 
| 170 | 
            +
                    x = x.reshape(x.shape[0], -1, self.filters[-1])
         | 
| 171 | 
            +
                    
         | 
| 172 | 
            +
                    x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices())
         | 
| 173 | 
            +
                    x = self.normalization6u(x)
         | 
| 174 | 
            +
                    x = F.relu(x)
         | 
| 175 | 
            +
                    
         | 
| 176 | 
            +
                    x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices())
         | 
| 177 | 
            +
                    x = self.normalization5u(x)
         | 
| 178 | 
            +
                    x = F.relu(x)
         | 
| 179 | 
            +
                    
         | 
| 180 | 
            +
                    x, pos1 = self.SC_1(x, self.adjacency_matrices[3]._indices(), conv6)
         | 
| 181 | 
            +
                    
         | 
| 182 | 
            +
                    x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices())
         | 
| 183 | 
            +
                    x = self.normalization4u(x)
         | 
| 184 | 
            +
                    x = F.relu(x)
         | 
| 185 | 
            +
                    
         | 
| 186 | 
            +
                    x = self.pool(x, self.upsample_matrices[0])
         | 
| 187 | 
            +
                    
         | 
| 188 | 
            +
                    x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices())
         | 
| 189 | 
            +
                    x = self.normalization3u(x)
         | 
| 190 | 
            +
                    x = F.relu(x)
         | 
| 191 | 
            +
                    
         | 
| 192 | 
            +
                    x, pos2 = self.SC_2(x, self.adjacency_matrices[1]._indices(), conv5)
         | 
| 193 | 
            +
                    
         | 
| 194 | 
            +
                    x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices())
         | 
| 195 | 
            +
                    x = self.normalization2u(x)
         | 
| 196 | 
            +
                    x = F.relu(x)
         | 
| 197 | 
            +
                    
         | 
| 198 | 
            +
                    x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) # Sin relu y sin bias
         | 
| 199 | 
            +
                    
         | 
| 200 | 
            +
                    return x, pos1, pos2
         | 
    	
        models/modelUtils.py
    ADDED
    
    | @@ -0,0 +1,69 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torch_geometric.nn.conv import MessagePassing
         | 
| 2 | 
            +
            from torch_geometric.nn.conv.cheb_conv import ChebConv
         | 
| 3 | 
            +
            from torch_geometric.nn.inits import zeros, normal
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # We change the default initialization from zeros to a normal distribution
         | 
| 6 | 
            +
            class ChebConv(ChebConv):
         | 
| 7 | 
            +
                def reset_parameters(self):
         | 
| 8 | 
            +
                    for lin in self.lins:
         | 
| 9 | 
            +
                        normal(lin, mean = 0, std = 0.1)
         | 
| 10 | 
            +
                        #lin.reset_parameters()
         | 
| 11 | 
            +
                    normal(self.bias, mean = 0, std = 0.1)
         | 
| 12 | 
            +
                    #zeros(self.bias)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # Pooling from COMA: https://github.com/pixelite1201/pytorch_coma/blob/master/layers.py
         | 
| 15 | 
            +
            class Pool(MessagePassing):
         | 
| 16 | 
            +
                def __init__(self):
         | 
| 17 | 
            +
                    # source_to_target is the default value for flow, but is specified here for explicitness
         | 
| 18 | 
            +
                    super(Pool, self).__init__(flow='source_to_target')
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def forward(self, x, pool_mat,  dtype=None):
         | 
| 21 | 
            +
                    pool_mat = pool_mat.transpose(0, 1)
         | 
| 22 | 
            +
                    out = self.propagate(edge_index=pool_mat._indices(), x=x, norm=pool_mat._values(), size=pool_mat.size())
         | 
| 23 | 
            +
                    return out
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def message(self, x_j, norm):
         | 
| 26 | 
            +
                    return norm.view(1, -1, 1) * x_j
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
            import torch.nn as nn
         | 
| 30 | 
            +
            import torch.nn.functional as F
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            class residualBlock(nn.Module):
         | 
| 33 | 
            +
                def __init__(self, in_channels, out_channels, stride=1):
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    Args:
         | 
| 36 | 
            +
                      in_channels (int):  Number of input channels.
         | 
| 37 | 
            +
                      out_channels (int): Number of output channels.
         | 
| 38 | 
            +
                      stride (int):       Controls the stride.
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    super(residualBlock, self).__init__()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.skip = nn.Sequential()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    if stride != 1 or in_channels != out_channels:
         | 
| 45 | 
            +
                      self.skip = nn.Sequential(
         | 
| 46 | 
            +
                        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
         | 
| 47 | 
            +
                        nn.BatchNorm2d(out_channels, track_running_stats=False))
         | 
| 48 | 
            +
                    else:
         | 
| 49 | 
            +
                      self.skip = None
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.block = nn.Sequential(nn.BatchNorm2d(in_channels, track_running_stats=False),
         | 
| 52 | 
            +
                                               nn.ReLU(inplace=True),
         | 
| 53 | 
            +
                                               nn.Conv2d(in_channels, out_channels, 3, padding=1),
         | 
| 54 | 
            +
                                               nn.BatchNorm2d(out_channels, track_running_stats=False),
         | 
| 55 | 
            +
                                               nn.ReLU(inplace=True),
         | 
| 56 | 
            +
                                               nn.Conv2d(out_channels, out_channels, 3, padding=1)
         | 
| 57 | 
            +
                                               )   
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x):
         | 
| 60 | 
            +
                    identity = x
         | 
| 61 | 
            +
                    out = self.block(x)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    if self.skip is not None:
         | 
| 64 | 
            +
                        identity = self.skip(x)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    out += identity
         | 
| 67 | 
            +
                    out = F.relu(out)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    return out
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pytorch==2.0.1
         | 
| 2 | 
            +
            numpy==1.25.0
         | 
| 3 | 
            +
            opencv-python==4.8.0.74
         | 
| 4 | 
            +
            scipy==1.10.1
         | 
| 5 | 
            +
            pyg=2.3.0
         | 
    	
        utils/utils.py
    ADDED
    
    | @@ -0,0 +1,103 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import scipy.sparse as sp
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def scipy_to_torch_sparse(scp_matrix):
         | 
| 6 | 
            +
                values = scp_matrix.data
         | 
| 7 | 
            +
                indices = np.vstack((scp_matrix.row, scp_matrix.col))
         | 
| 8 | 
            +
                i = torch.LongTensor(indices)
         | 
| 9 | 
            +
                v = torch.FloatTensor(values)
         | 
| 10 | 
            +
                shape = scp_matrix.shape
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
         | 
| 13 | 
            +
                return sparse_tensor
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ## Adjacency Matrix
         | 
| 16 | 
            +
            def mOrgan(N):
         | 
| 17 | 
            +
                sub = np.zeros([N, N])
         | 
| 18 | 
            +
                for i in range(0, N):
         | 
| 19 | 
            +
                    sub[i, i-1] = 1
         | 
| 20 | 
            +
                    sub[i, (i+1)%N] = 1
         | 
| 21 | 
            +
                return sub
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            ## Downsampling Matrix
         | 
| 24 | 
            +
            def mOrganD(N):
         | 
| 25 | 
            +
                N2 = int(np.ceil(N/2))
         | 
| 26 | 
            +
                sub = np.zeros([N2, N])
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                for i in range(0, N2):
         | 
| 29 | 
            +
                    if (2*i+1) == N:
         | 
| 30 | 
            +
                        sub[i, 2*i] = 1
         | 
| 31 | 
            +
                    else:
         | 
| 32 | 
            +
                        sub[i, 2*i] = 1/2
         | 
| 33 | 
            +
                        sub[i, 2*i+1] = 1/2
         | 
| 34 | 
            +
                        
         | 
| 35 | 
            +
                return sub
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def mOrganU(N):
         | 
| 38 | 
            +
                N2 = int(np.ceil(N/2))
         | 
| 39 | 
            +
                sub = np.zeros([N, N2])
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                for i in range(0, N):
         | 
| 42 | 
            +
                    if i % 2 == 0:
         | 
| 43 | 
            +
                        sub[i, i//2] = 1
         | 
| 44 | 
            +
                    else:
         | 
| 45 | 
            +
                        sub[i, i//2] = 1/2
         | 
| 46 | 
            +
                        sub[i, (i//2 + 1) % N2] = 1/2
         | 
| 47 | 
            +
                        
         | 
| 48 | 
            +
                return sub
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def genMatrixesLungsHeart():       
         | 
| 51 | 
            +
                RLUNG = 44
         | 
| 52 | 
            +
                LLUNG = 50
         | 
| 53 | 
            +
                HEART = 26
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                Asub1 = mOrgan(RLUNG)
         | 
| 56 | 
            +
                Asub2 = mOrgan(LLUNG)
         | 
| 57 | 
            +
                Asub3 = mOrgan(HEART)
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
         | 
| 60 | 
            +
                ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
         | 
| 61 | 
            +
                ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
         | 
| 62 | 
            +
                                
         | 
| 63 | 
            +
                Dsub1 = mOrganD(RLUNG)
         | 
| 64 | 
            +
                Dsub2 = mOrganD(LLUNG)
         | 
| 65 | 
            +
                Dsub3 = mOrganD(HEART)
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                Usub1 = mOrganU(RLUNG)
         | 
| 68 | 
            +
                Usub2 = mOrganU(LLUNG)
         | 
| 69 | 
            +
                Usub3 = mOrganU(HEART)
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
                p1 = RLUNG
         | 
| 72 | 
            +
                p2 = p1 + LLUNG
         | 
| 73 | 
            +
                p3 = p2 + HEART
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                p1_ = int(np.ceil(RLUNG / 2))
         | 
| 76 | 
            +
                p2_ = p1_ + int(np.ceil(LLUNG / 2))
         | 
| 77 | 
            +
                p3_ = p2_ + int(np.ceil(HEART / 2))
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                A = np.zeros([p3, p3])
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                A[:p1, :p1] = Asub1
         | 
| 82 | 
            +
                A[p1:p2, p1:p2] = Asub2
         | 
| 83 | 
            +
                A[p2:p3, p2:p3] = Asub3
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                AD = np.zeros([p3_, p3_])
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                AD[:p1_, :p1_] = ADsub1
         | 
| 88 | 
            +
                AD[p1_:p2_, p1_:p2_] = ADsub2
         | 
| 89 | 
            +
                AD[p2_:p3_, p2_:p3_] = ADsub3
         | 
| 90 | 
            +
               
         | 
| 91 | 
            +
                D = np.zeros([p3_, p3])
         | 
| 92 | 
            +
                
         | 
| 93 | 
            +
                D[:p1_, :p1] = Dsub1
         | 
| 94 | 
            +
                D[p1_:p2_, p1:p2] = Dsub2
         | 
| 95 | 
            +
                D[p2_:p3_, p2:p3] = Dsub3
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                U = np.zeros([p3, p3_])
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                U[:p1, :p1_] = Usub1
         | 
| 100 | 
            +
                U[p1:p2, p1_:p2_] = Usub2
         | 
| 101 | 
            +
                U[p2:p3, p2_:p3_] = Usub3
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
                return A, AD, D, U
         | 
    	
        weights/weights.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b76bbb8c8ad9774cdf3ac81c9edf04bcc800b3c7f7eacf24ce7249038f3c640f
         | 
| 3 | 
            +
            size 70083051
         | 
