PLA-Net / gcn_lib /sparse /torch_message.py
juliocesar-io's picture
Added initial app
799e642
raw
history blame
3.51 kB
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter, scatter_softmax
from torch_geometric.utils import degree
class GenMessagePassing(MessagePassing):
def __init__(self, aggr='softmax',
t=1.0, learn_t=False,
p=1.0, learn_p=False,
y=0.0, learn_y=False):
if aggr in ['softmax_sg', 'softmax', 'softmax_sum']:
super(GenMessagePassing, self).__init__(aggr=None)
self.aggr = aggr
if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'):
self.learn_t = True
self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True)
else:
self.learn_t = False
self.t = t
if aggr == 'softmax_sum':
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
elif aggr in ['power', 'power_sum']:
super(GenMessagePassing, self).__init__(aggr=None)
self.aggr = aggr
if learn_p:
self.p = torch.nn.Parameter(torch.Tensor([p]), requires_grad=True)
else:
self.p = p
if aggr == 'power_sum':
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
else:
super(GenMessagePassing, self).__init__(aggr=aggr)
def aggregate(self, inputs, index, ptr=None, dim_size=None):
if self.aggr in ['add', 'mean', 'max', None]:
return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size)
elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:
if self.learn_t:
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
else:
with torch.no_grad():
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
out = scatter(inputs*out, index, dim=self.node_dim,
dim_size=dim_size, reduce='sum')
if self.aggr == 'softmax_sum':
self.sigmoid_y = torch.sigmoid(self.y)
degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
out = torch.pow(degrees, self.sigmoid_y) * out
return out
elif self.aggr in ['power', 'power_sum']:
min_value, max_value = 1e-7, 1e1
torch.clamp_(inputs, min_value, max_value)
out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
dim_size=dim_size, reduce='mean')
torch.clamp_(out, min_value, max_value)
out = torch.pow(out, 1/self.p)
if self.aggr == 'power_sum':
self.sigmoid_y = torch.sigmoid(self.y)
degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
out = torch.pow(degrees, self.sigmoid_y) * out
return out
else:
raise NotImplementedError('To be implemented')
class MsgNorm(torch.nn.Module):
def __init__(self, learn_msg_scale=False):
super(MsgNorm, self).__init__()
self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]),
requires_grad=learn_msg_scale)
def forward(self, x, msg, p=2):
msg = F.normalize(msg, p=p, dim=1)
x_norm = x.norm(p=p, dim=1, keepdim=True)
msg = msg * x_norm * self.msg_scale
return msg