Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from .torch_nn import BasicConv, batched_index_select | |
from .torch_edge import DenseDilatedKnnGraph, DilatedKnnGraph | |
import torch.nn.functional as F | |
class MRConv2d(nn.Module): | |
""" | |
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type | |
""" | |
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): | |
super(MRConv2d, self).__init__() | |
self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) | |
def forward(self, x, edge_index): | |
x_i = batched_index_select(x, edge_index[1]) | |
x_j = batched_index_select(x, edge_index[0]) | |
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) | |
return self.nn(torch.cat([x, x_j], dim=1)) | |
class EdgeConv2d(nn.Module): | |
""" | |
Edge convolution layer (with activation, batch normalization) for dense data type | |
""" | |
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): | |
super(EdgeConv2d, self).__init__() | |
self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) | |
def forward(self, x, edge_index): | |
x_i = batched_index_select(x, edge_index[1]) | |
x_j = batched_index_select(x, edge_index[0]) | |
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) | |
return max_value | |
class GraphConv2d(nn.Module): | |
""" | |
Static graph convolution layer | |
""" | |
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True): | |
super(GraphConv2d, self).__init__() | |
if conv == 'edge': | |
self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) | |
elif conv == 'mr': | |
self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) | |
else: | |
raise NotImplementedError('conv:{} is not supported'.format(conv)) | |
def forward(self, x, edge_index): | |
return self.gconv(x, edge_index) | |
class DynConv2d(GraphConv2d): | |
""" | |
Dynamic graph convolution layer | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', | |
norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'): | |
super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) | |
self.k = kernel_size | |
self.d = dilation | |
if knn == 'matrix': | |
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) | |
else: | |
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) | |
def forward(self, x): | |
edge_index = self.dilated_knn_graph(x) | |
return super(DynConv2d, self).forward(x, edge_index) | |
class PlainDynBlock2d(nn.Module): | |
""" | |
Plain Dynamic graph convolution block | |
""" | |
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, | |
bias=True, stochastic=False, epsilon=0.0, knn='matrix'): | |
super(PlainDynBlock2d, self).__init__() | |
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, | |
act, norm, bias, stochastic, epsilon, knn) | |
def forward(self, x): | |
return self.body(x) | |
class ResDynBlock2d(nn.Module): | |
""" | |
Residual Dynamic graph convolution block | |
""" | |
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, | |
bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1): | |
super(ResDynBlock2d, self).__init__() | |
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, | |
act, norm, bias, stochastic, epsilon, knn) | |
self.res_scale = res_scale | |
def forward(self, x): | |
return self.body(x) + x*self.res_scale | |
class DenseDynBlock2d(nn.Module): | |
""" | |
Dense Dynamic graph convolution block | |
""" | |
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', | |
act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'): | |
super(DenseDynBlock2d, self).__init__() | |
self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv, | |
act, norm, bias, stochastic, epsilon, knn) | |
def forward(self, x): | |
dense = self.body(x) | |
return torch.cat((x, dense), 1) | |