Spaces:
Sleeping
Sleeping
File size: 3,513 Bytes
799e642 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter, scatter_softmax
from torch_geometric.utils import degree
class GenMessagePassing(MessagePassing):
def __init__(self, aggr='softmax',
t=1.0, learn_t=False,
p=1.0, learn_p=False,
y=0.0, learn_y=False):
if aggr in ['softmax_sg', 'softmax', 'softmax_sum']:
super(GenMessagePassing, self).__init__(aggr=None)
self.aggr = aggr
if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'):
self.learn_t = True
self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True)
else:
self.learn_t = False
self.t = t
if aggr == 'softmax_sum':
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
elif aggr in ['power', 'power_sum']:
super(GenMessagePassing, self).__init__(aggr=None)
self.aggr = aggr
if learn_p:
self.p = torch.nn.Parameter(torch.Tensor([p]), requires_grad=True)
else:
self.p = p
if aggr == 'power_sum':
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
else:
super(GenMessagePassing, self).__init__(aggr=aggr)
def aggregate(self, inputs, index, ptr=None, dim_size=None):
if self.aggr in ['add', 'mean', 'max', None]:
return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size)
elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:
if self.learn_t:
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
else:
with torch.no_grad():
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
out = scatter(inputs*out, index, dim=self.node_dim,
dim_size=dim_size, reduce='sum')
if self.aggr == 'softmax_sum':
self.sigmoid_y = torch.sigmoid(self.y)
degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
out = torch.pow(degrees, self.sigmoid_y) * out
return out
elif self.aggr in ['power', 'power_sum']:
min_value, max_value = 1e-7, 1e1
torch.clamp_(inputs, min_value, max_value)
out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
dim_size=dim_size, reduce='mean')
torch.clamp_(out, min_value, max_value)
out = torch.pow(out, 1/self.p)
if self.aggr == 'power_sum':
self.sigmoid_y = torch.sigmoid(self.y)
degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
out = torch.pow(degrees, self.sigmoid_y) * out
return out
else:
raise NotImplementedError('To be implemented')
class MsgNorm(torch.nn.Module):
def __init__(self, learn_msg_scale=False):
super(MsgNorm, self).__init__()
self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]),
requires_grad=learn_msg_scale)
def forward(self, x, msg, p=2):
msg = F.normalize(msg, p=p, dim=1)
x_norm = x.norm(p=p, dim=1, keepdim=True)
msg = msg * x_norm * self.msg_scale
return msg
|