PLA-Net / gcn_lib /sparse /torch_nn.py
juliocesar-io's picture
Added initial app
799e642
import torch
from torch import nn
from torch.nn import Sequential as Seq, Linear as Lin
from utils.data_util import get_atom_feature_dims, get_bond_feature_dims
##############################
# Basic layers
##############################
def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer
act = act_type.lower()
if act == 'relu':
layer = nn.ReLU(inplace)
elif act == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError('activation layer [%s] is not found' % act)
return layer
def norm_layer(norm_type, nc):
# normalization layer 1d
norm = norm_type.lower()
if norm == 'batch':
layer = nn.BatchNorm1d(nc, affine=True)
elif norm == 'layer':
layer = nn.LayerNorm(nc, elementwise_affine=True)
elif norm == 'instance':
layer = nn.InstanceNorm1d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return layer
class MultiSeq(Seq):
def __init__(self, *args):
super(MultiSeq, self).__init__(*args)
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class MLP(Seq):
def __init__(self, channels, act='relu',
norm=None, bias=True,
drop=0., last_lin=False):
m = []
for i in range(1, len(channels)):
m.append(Lin(channels[i - 1], channels[i], bias))
if (i == len(channels) - 1) and last_lin:
pass
else:
if norm is not None and norm.lower() != 'none':
m.append(norm_layer(norm, channels[i]))
if act is not None and act.lower() != 'none':
m.append(act_layer(act))
if drop > 0:
m.append(nn.Dropout2d(drop))
self.m = m
super(MLP, self).__init__(*self.m)
class AtomEncoder(nn.Module):
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = nn.ModuleList()
full_atom_feature_dims = get_atom_feature_dims()
for i, dim in enumerate(full_atom_feature_dims):
emb = nn.Embedding(dim, emb_dim)
nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:, i])
return x_embedding
class BondEncoder(nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = nn.ModuleList()
full_bond_feature_dims = get_bond_feature_dims()
for i, dim in enumerate(full_bond_feature_dims):
emb = nn.Embedding(dim, emb_dim)
nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
return bond_embedding
class MM_BondEncoder(nn.Module):
#Replaces de lookup in embedding module by one-hot-encoding
# followed by matrix multiplication to allow Float type input
# instead of Long type input (backpropagate through layer)
def __init__(self, emb_dim):
super(MM_BondEncoder, self).__init__()
self.bond_embedding_list = nn.ModuleList()
self.full_bond_feature_dims = get_bond_feature_dims()
for i, dim in enumerate(self.full_bond_feature_dims):
emb = nn.Linear(dim, emb_dim, bias=False)
nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
#Change each feature in edge_attr to one-hot-vector and embed
edge_attr1, edge_attr2, edge_attr3 = torch.split(edge_attr, self.full_bond_feature_dims, dim=1)
bond_embedding = self.bond_embedding_list[0](edge_attr1) + self.bond_embedding_list[1](edge_attr2) + self.bond_embedding_list[2](edge_attr3)
return bond_embedding
class MM_AtomEncoder(nn.Module):
#Replaces de lookup in embedding module by one-hot-encoding
# followed by matrix multiplication to allow Float type input
# instead of Long type input (backpropagate through layer)
def __init__(self, emb_dim):
super(MM_AtomEncoder, self).__init__()
self.atom_embedding_list = nn.ModuleList()
self.full_atom_feature_dims = get_atom_feature_dims()
for i, dim in enumerate(self.full_atom_feature_dims):
emb = nn.Linear(dim, emb_dim, bias=False)
nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
#Change each feature in edge_attr to one-hot-vector and embed
split = torch.split(x, self.full_atom_feature_dims, dim=1)
atom_embedding = 0
for i in range(len(self.full_atom_feature_dims)):
atom_embedding += self.atom_embedding_list[i](split[i])
return atom_embedding