import torch from torch import nn import torch.nn.functional as F import torch_geometric as tg from .torch_nn import MLP, act_layer, norm_layer, BondEncoder, MM_BondEncoder from .torch_edge import DilatedKnnGraph from .torch_message import GenMessagePassing, MsgNorm from torch_geometric.utils import remove_self_loops, add_self_loops class GENConv(GenMessagePassing): """ GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf SoftMax & PowerMean Aggregation """ def __init__(self, in_dim, emb_dim, args, aggr='softmax', t=1.0, learn_t=False, p=1.0, learn_p=False, y=0.0, learn_y=False, msg_norm=False, learn_msg_scale=True, encode_edge=False, bond_encoder=False, edge_feat_dim=None, norm='batch', mlp_layers=2, eps=1e-7): super(GENConv, self).__init__(aggr=aggr, t=t, learn_t=learn_t, p=p, learn_p=learn_p, y=y, learn_y=learn_y) channels_list = [in_dim] for i in range(mlp_layers-1): channels_list.append(in_dim*2) channels_list.append(emb_dim) self.mlp = MLP(channels=channels_list, norm=norm, last_lin=True) self.msg_encoder = torch.nn.ReLU() self.eps = eps self.msg_norm = msg_norm self.encode_edge = encode_edge self.bond_encoder = bond_encoder self.advs = args.advs if msg_norm: self.msg_norm = MsgNorm(learn_msg_scale=learn_msg_scale) else: self.msg_norm = None if self.encode_edge: if self.bond_encoder: if self.advs: self.edge_encoder = MM_BondEncoder(emb_dim=in_dim) else: self.edge_encoder = BondEncoder(emb_dim=in_dim) else: self.edge_encoder = torch.nn.Linear(edge_feat_dim, in_dim) def forward(self, x, edge_index, edge_attr=None): x = x if self.encode_edge and edge_attr is not None: edge_emb = self.edge_encoder(edge_attr) else: edge_emb = edge_attr m = self.propagate(edge_index, x=x, edge_attr=edge_emb) if self.msg_norm is not None: m = self.msg_norm(x, m) h = x + m out = self.mlp(h) return out def message(self, x_j, edge_attr=None): if edge_attr is not None: msg = x_j + edge_attr else: msg = x_j return self.msg_encoder(msg) + self.eps def update(self, aggr_out): return aggr_out class MRConv(nn.Module): """ Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) """ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'): super(MRConv, self).__init__() self.nn = MLP([in_channels*2, out_channels], act, norm, bias) self.aggr = aggr def forward(self, x, edge_index): """""" x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0]) return self.nn(torch.cat([x, x_j], dim=1)) class EdgConv(tg.nn.EdgeConv): """ Edge convolution layer (with activation, batch normalization) """ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'): super(EdgConv, self).__init__(MLP([in_channels*2, out_channels], act, norm, bias), aggr) def forward(self, x, edge_index): return super(EdgConv, self).forward(x, edge_index) class GATConv(nn.Module): """ Graph Attention Convolution layer (with activation, batch normalization) """ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, heads=8): super(GATConv, self).__init__() self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias) m =[] if act: m.append(act_layer(act)) if norm: m.append(norm_layer(norm, out_channels)) self.unlinear = nn.Sequential(*m) def forward(self, x, edge_index): out = self.unlinear(self.gconv(x, edge_index)) return out class SAGEConv(tg.nn.SAGEConv): r"""The GraphSAGE operator from the `"Inductive Representation Learning on Large Graphs" `_ paper .. math:: \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j) \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i} {\| \mathbf{\hat{x}}_i \|_2}. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. normalize (bool, optional): If set to :obj:`False`, output features will not be :math:`\ell_2`-normalized. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels, out_channels, nn, norm=True, bias=True, relative=False, **kwargs): self.relative = relative if norm is not None: super(SAGEConv, self).__init__(in_channels, out_channels, True, bias, **kwargs) else: super(SAGEConv, self).__init__(in_channels, out_channels, False, bias, **kwargs) self.nn = nn def forward(self, x, edge_index, size=None): """""" if size is None: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x return self.propagate(edge_index, size=size, x=x) def message(self, x_i, x_j): if self.relative: x = torch.matmul(x_j - x_i, self.weight) else: x = torch.matmul(x_j, self.weight) return x def update(self, aggr_out, x): out = self.nn(torch.cat((x, aggr_out), dim=1)) if self.bias is not None: out = out + self.bias if self.normalize: out = F.normalize(out, p=2, dim=-1) return out class RSAGEConv(SAGEConv): """ Residual SAGE convolution layer (with activation, batch normalization) """ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, relative=False): nn = MLP([out_channels + in_channels, out_channels], act, norm, bias) super(RSAGEConv, self).__init__(in_channels, out_channels, nn, norm, bias, relative) class SemiGCNConv(nn.Module): """ SemiGCN convolution layer (with activation, batch normalization) """ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): super(SemiGCNConv, self).__init__() self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias) m = [] if act: m.append(act_layer(act)) if norm: m.append(norm_layer(norm, out_channels)) self.unlinear = nn.Sequential(*m) def forward(self, x, edge_index): out = self.unlinear(self.gconv(x, edge_index)) return out class GinConv(tg.nn.GINConv): """ GINConv layer (with activation, batch normalization) """ def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='add'): super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias)) def forward(self, x, edge_index): return super(GinConv, self).forward(x, edge_index) class GraphConv(nn.Module): """ Static graph convolution layer """ def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8): super(GraphConv, self).__init__() if conv.lower() == 'edge': self.gconv = EdgConv(in_channels, out_channels, act, norm, bias) elif conv.lower() == 'mr': self.gconv = MRConv(in_channels, out_channels, act, norm, bias) elif conv.lower() == 'gat': self.gconv = GATConv(in_channels, out_channels//heads, act, norm, bias, heads) elif conv.lower() == 'gcn': self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias) elif conv.lower() == 'gin': self.gconv = GinConv(in_channels, out_channels, act, norm, bias) elif conv.lower() == 'sage': self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False) elif conv.lower() == 'rsage': self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True) else: raise NotImplementedError('conv {} is not implemented'.format(conv)) def forward(self, x, edge_index): return self.gconv(x, edge_index) class DynConv(GraphConv): """ Dynamic graph convolution layer """ def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, heads=8, **kwargs): super(DynConv, self).__init__(in_channels, out_channels, conv, act, norm, bias, heads) self.k = kernel_size self.d = dilation self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs) def forward(self, x, batch=None): edge_index = self.dilated_knn_graph(x, batch) return super(DynConv, self).forward(x, edge_index) class PlainDynBlock(nn.Module): """ Plain Dynamic graph convolution block """ def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, res_scale=1, **kwargs): super(PlainDynBlock, self).__init__() self.body = DynConv(channels, channels, kernel_size, dilation, conv, act, norm, bias, **kwargs) self.res_scale = res_scale def forward(self, x, batch=None): return self.body(x, batch), batch class ResDynBlock(nn.Module): """ Residual Dynamic graph convolution block """ def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, res_scale=1, **kwargs): super(ResDynBlock, self).__init__() self.body = DynConv(channels, channels, kernel_size, dilation, conv, act, norm, bias, **kwargs) self.res_scale = res_scale def forward(self, x, batch=None): return self.body(x, batch) + x*self.res_scale, batch class DenseDynBlock(nn.Module): """ Dense Dynamic graph convolution block """ def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, **kwargs): super(DenseDynBlock, self).__init__() self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv, act, norm, bias, **kwargs) def forward(self, x, batch=None): dense = self.body(x, batch) return torch.cat((x, dense), 1), batch class ResGraphBlock(nn.Module): """ Residual Static graph convolution block """ def __init__(self, channels, conv='edge', act='relu', norm=None, bias=True, heads=8, res_scale=1): super(ResGraphBlock, self).__init__() self.body = GraphConv(channels, channels, conv, act, norm, bias, heads) self.res_scale = res_scale def forward(self, x, edge_index): return self.body(x, edge_index) + x*self.res_scale, edge_index class DenseGraphBlock(nn.Module): """ Dense Static graph convolution block """ def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8): super(DenseGraphBlock, self).__init__() self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads) def forward(self, x, edge_index): dense = self.body(x, edge_index) return torch.cat((x, dense), 1), edge_index