import torch import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_scatter import scatter, scatter_softmax from torch_geometric.utils import degree class GenMessagePassing(MessagePassing): def __init__(self, aggr='softmax', t=1.0, learn_t=False, p=1.0, learn_p=False, y=0.0, learn_y=False): if aggr in ['softmax_sg', 'softmax', 'softmax_sum']: super(GenMessagePassing, self).__init__(aggr=None) self.aggr = aggr if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'): self.learn_t = True self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True) else: self.learn_t = False self.t = t if aggr == 'softmax_sum': self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) elif aggr in ['power', 'power_sum']: super(GenMessagePassing, self).__init__(aggr=None) self.aggr = aggr if learn_p: self.p = torch.nn.Parameter(torch.Tensor([p]), requires_grad=True) else: self.p = p if aggr == 'power_sum': self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) else: super(GenMessagePassing, self).__init__(aggr=aggr) def aggregate(self, inputs, index, ptr=None, dim_size=None): if self.aggr in ['add', 'mean', 'max', None]: return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size) elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']: if self.learn_t: out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) else: with torch.no_grad(): out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) out = scatter(inputs*out, index, dim=self.node_dim, dim_size=dim_size, reduce='sum') if self.aggr == 'softmax_sum': self.sigmoid_y = torch.sigmoid(self.y) degrees = degree(index, num_nodes=dim_size).unsqueeze(1) out = torch.pow(degrees, self.sigmoid_y) * out return out elif self.aggr in ['power', 'power_sum']: min_value, max_value = 1e-7, 1e1 torch.clamp_(inputs, min_value, max_value) out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, dim_size=dim_size, reduce='mean') torch.clamp_(out, min_value, max_value) out = torch.pow(out, 1/self.p) if self.aggr == 'power_sum': self.sigmoid_y = torch.sigmoid(self.y) degrees = degree(index, num_nodes=dim_size).unsqueeze(1) out = torch.pow(degrees, self.sigmoid_y) * out return out else: raise NotImplementedError('To be implemented') class MsgNorm(torch.nn.Module): def __init__(self, learn_msg_scale=False): super(MsgNorm, self).__init__() self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=learn_msg_scale) def forward(self, x, msg, p=2): msg = F.normalize(msg, p=p, dim=1) x_norm = x.norm(p=p, dim=1, keepdim=True) msg = msg * x_norm * self.msg_scale return msg