import torch from torch import nn from torch.nn import Sequential as Seq, Linear as Lin from utils.data_util import get_atom_feature_dims, get_bond_feature_dims ############################## # Basic layers ############################## def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): # activation layer act = act_type.lower() if act == 'relu': layer = nn.ReLU(inplace) elif act == 'leakyrelu': layer = nn.LeakyReLU(neg_slope, inplace) elif act == 'prelu': layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) else: raise NotImplementedError('activation layer [%s] is not found' % act) return layer def norm_layer(norm_type, nc): # normalization layer 1d norm = norm_type.lower() if norm == 'batch': layer = nn.BatchNorm1d(nc, affine=True) elif norm == 'layer': layer = nn.LayerNorm(nc, elementwise_affine=True) elif norm == 'instance': layer = nn.InstanceNorm1d(nc, affine=False) else: raise NotImplementedError('normalization layer [%s] is not found' % norm) return layer class MultiSeq(Seq): def __init__(self, *args): super(MultiSeq, self).__init__(*args) def forward(self, *inputs): for module in self._modules.values(): if type(inputs) == tuple: inputs = module(*inputs) else: inputs = module(inputs) return inputs class MLP(Seq): def __init__(self, channels, act='relu', norm=None, bias=True, drop=0., last_lin=False): m = [] for i in range(1, len(channels)): m.append(Lin(channels[i - 1], channels[i], bias)) if (i == len(channels) - 1) and last_lin: pass else: if norm is not None and norm.lower() != 'none': m.append(norm_layer(norm, channels[i])) if act is not None and act.lower() != 'none': m.append(act_layer(act)) if drop > 0: m.append(nn.Dropout2d(drop)) self.m = m super(MLP, self).__init__(*self.m) class AtomEncoder(nn.Module): def __init__(self, emb_dim): super(AtomEncoder, self).__init__() self.atom_embedding_list = nn.ModuleList() full_atom_feature_dims = get_atom_feature_dims() for i, dim in enumerate(full_atom_feature_dims): emb = nn.Embedding(dim, emb_dim) nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb) def forward(self, x): x_embedding = 0 for i in range(x.shape[1]): x_embedding += self.atom_embedding_list[i](x[:, i]) return x_embedding class BondEncoder(nn.Module): def __init__(self, emb_dim): super(BondEncoder, self).__init__() self.bond_embedding_list = nn.ModuleList() full_bond_feature_dims = get_bond_feature_dims() for i, dim in enumerate(full_bond_feature_dims): emb = nn.Embedding(dim, emb_dim) nn.init.xavier_uniform_(emb.weight.data) self.bond_embedding_list.append(emb) def forward(self, edge_attr): bond_embedding = 0 for i in range(edge_attr.shape[1]): bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) return bond_embedding class MM_BondEncoder(nn.Module): #Replaces de lookup in embedding module by one-hot-encoding # followed by matrix multiplication to allow Float type input # instead of Long type input (backpropagate through layer) def __init__(self, emb_dim): super(MM_BondEncoder, self).__init__() self.bond_embedding_list = nn.ModuleList() self.full_bond_feature_dims = get_bond_feature_dims() for i, dim in enumerate(self.full_bond_feature_dims): emb = nn.Linear(dim, emb_dim, bias=False) nn.init.xavier_uniform_(emb.weight.data) self.bond_embedding_list.append(emb) def forward(self, edge_attr): #Change each feature in edge_attr to one-hot-vector and embed edge_attr1, edge_attr2, edge_attr3 = torch.split(edge_attr, self.full_bond_feature_dims, dim=1) bond_embedding = self.bond_embedding_list[0](edge_attr1) + self.bond_embedding_list[1](edge_attr2) + self.bond_embedding_list[2](edge_attr3) return bond_embedding class MM_AtomEncoder(nn.Module): #Replaces de lookup in embedding module by one-hot-encoding # followed by matrix multiplication to allow Float type input # instead of Long type input (backpropagate through layer) def __init__(self, emb_dim): super(MM_AtomEncoder, self).__init__() self.atom_embedding_list = nn.ModuleList() self.full_atom_feature_dims = get_atom_feature_dims() for i, dim in enumerate(self.full_atom_feature_dims): emb = nn.Linear(dim, emb_dim, bias=False) nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb) def forward(self, x): #Change each feature in edge_attr to one-hot-vector and embed split = torch.split(x, self.full_atom_feature_dims, dim=1) atom_embedding = 0 for i in range(len(self.full_atom_feature_dims)): atom_embedding += self.atom_embedding_list[i](split[i]) return atom_embedding