PLA-Net / gcn_lib /dense /torch_vertex.py
juliocesar-io's picture
Added initial app
799e642
raw
history blame
4.51 kB
import torch
from torch import nn
from .torch_nn import BasicConv, batched_index_select
from .torch_edge import DenseDilatedKnnGraph, DilatedKnnGraph
import torch.nn.functional as F
class MRConv2d(nn.Module):
"""
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
super(MRConv2d, self).__init__()
self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)
def forward(self, x, edge_index):
x_i = batched_index_select(x, edge_index[1])
x_j = batched_index_select(x, edge_index[0])
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
return self.nn(torch.cat([x, x_j], dim=1))
class EdgeConv2d(nn.Module):
"""
Edge convolution layer (with activation, batch normalization) for dense data type
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
super(EdgeConv2d, self).__init__()
self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
def forward(self, x, edge_index):
x_i = batched_index_select(x, edge_index[1])
x_j = batched_index_select(x, edge_index[0])
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
return max_value
class GraphConv2d(nn.Module):
"""
Static graph convolution layer
"""
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True):
super(GraphConv2d, self).__init__()
if conv == 'edge':
self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias)
elif conv == 'mr':
self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias)
else:
raise NotImplementedError('conv:{} is not supported'.format(conv))
def forward(self, x, edge_index):
return self.gconv(x, edge_index)
class DynConv2d(GraphConv2d):
"""
Dynamic graph convolution layer
"""
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'):
super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias)
self.k = kernel_size
self.d = dilation
if knn == 'matrix':
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
else:
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
def forward(self, x):
edge_index = self.dilated_knn_graph(x)
return super(DynConv2d, self).forward(x, edge_index)
class PlainDynBlock2d(nn.Module):
"""
Plain Dynamic graph convolution block
"""
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, stochastic=False, epsilon=0.0, knn='matrix'):
super(PlainDynBlock2d, self).__init__()
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv,
act, norm, bias, stochastic, epsilon, knn)
def forward(self, x):
return self.body(x)
class ResDynBlock2d(nn.Module):
"""
Residual Dynamic graph convolution block
"""
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1):
super(ResDynBlock2d, self).__init__()
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv,
act, norm, bias, stochastic, epsilon, knn)
self.res_scale = res_scale
def forward(self, x):
return self.body(x) + x*self.res_scale
class DenseDynBlock2d(nn.Module):
"""
Dense Dynamic graph convolution block
"""
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge',
act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'):
super(DenseDynBlock2d, self).__init__()
self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv,
act, norm, bias, stochastic, epsilon, knn)
def forward(self, x):
dense = self.body(x)
return torch.cat((x, dense), 1)