Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn import Sequential as Seq, Linear as Lin | |
| from utils.data_util import get_atom_feature_dims, get_bond_feature_dims | |
| ############################## | |
| # Basic layers | |
| ############################## | |
| def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): | |
| # activation layer | |
| act = act_type.lower() | |
| if act == 'relu': | |
| layer = nn.ReLU(inplace) | |
| elif act == 'leakyrelu': | |
| layer = nn.LeakyReLU(neg_slope, inplace) | |
| elif act == 'prelu': | |
| layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) | |
| else: | |
| raise NotImplementedError('activation layer [%s] is not found' % act) | |
| return layer | |
| def norm_layer(norm_type, nc): | |
| # normalization layer 1d | |
| norm = norm_type.lower() | |
| if norm == 'batch': | |
| layer = nn.BatchNorm1d(nc, affine=True) | |
| elif norm == 'layer': | |
| layer = nn.LayerNorm(nc, elementwise_affine=True) | |
| elif norm == 'instance': | |
| layer = nn.InstanceNorm1d(nc, affine=False) | |
| else: | |
| raise NotImplementedError('normalization layer [%s] is not found' % norm) | |
| return layer | |
| class MultiSeq(Seq): | |
| def __init__(self, *args): | |
| super(MultiSeq, self).__init__(*args) | |
| def forward(self, *inputs): | |
| for module in self._modules.values(): | |
| if type(inputs) == tuple: | |
| inputs = module(*inputs) | |
| else: | |
| inputs = module(inputs) | |
| return inputs | |
| class MLP(Seq): | |
| def __init__(self, channels, act='relu', | |
| norm=None, bias=True, | |
| drop=0., last_lin=False): | |
| m = [] | |
| for i in range(1, len(channels)): | |
| m.append(Lin(channels[i - 1], channels[i], bias)) | |
| if (i == len(channels) - 1) and last_lin: | |
| pass | |
| else: | |
| if norm is not None and norm.lower() != 'none': | |
| m.append(norm_layer(norm, channels[i])) | |
| if act is not None and act.lower() != 'none': | |
| m.append(act_layer(act)) | |
| if drop > 0: | |
| m.append(nn.Dropout2d(drop)) | |
| self.m = m | |
| super(MLP, self).__init__(*self.m) | |
| class AtomEncoder(nn.Module): | |
| def __init__(self, emb_dim): | |
| super(AtomEncoder, self).__init__() | |
| self.atom_embedding_list = nn.ModuleList() | |
| full_atom_feature_dims = get_atom_feature_dims() | |
| for i, dim in enumerate(full_atom_feature_dims): | |
| emb = nn.Embedding(dim, emb_dim) | |
| nn.init.xavier_uniform_(emb.weight.data) | |
| self.atom_embedding_list.append(emb) | |
| def forward(self, x): | |
| x_embedding = 0 | |
| for i in range(x.shape[1]): | |
| x_embedding += self.atom_embedding_list[i](x[:, i]) | |
| return x_embedding | |
| class BondEncoder(nn.Module): | |
| def __init__(self, emb_dim): | |
| super(BondEncoder, self).__init__() | |
| self.bond_embedding_list = nn.ModuleList() | |
| full_bond_feature_dims = get_bond_feature_dims() | |
| for i, dim in enumerate(full_bond_feature_dims): | |
| emb = nn.Embedding(dim, emb_dim) | |
| nn.init.xavier_uniform_(emb.weight.data) | |
| self.bond_embedding_list.append(emb) | |
| def forward(self, edge_attr): | |
| bond_embedding = 0 | |
| for i in range(edge_attr.shape[1]): | |
| bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) | |
| return bond_embedding | |
| class MM_BondEncoder(nn.Module): | |
| #Replaces de lookup in embedding module by one-hot-encoding | |
| # followed by matrix multiplication to allow Float type input | |
| # instead of Long type input (backpropagate through layer) | |
| def __init__(self, emb_dim): | |
| super(MM_BondEncoder, self).__init__() | |
| self.bond_embedding_list = nn.ModuleList() | |
| self.full_bond_feature_dims = get_bond_feature_dims() | |
| for i, dim in enumerate(self.full_bond_feature_dims): | |
| emb = nn.Linear(dim, emb_dim, bias=False) | |
| nn.init.xavier_uniform_(emb.weight.data) | |
| self.bond_embedding_list.append(emb) | |
| def forward(self, edge_attr): | |
| #Change each feature in edge_attr to one-hot-vector and embed | |
| edge_attr1, edge_attr2, edge_attr3 = torch.split(edge_attr, self.full_bond_feature_dims, dim=1) | |
| bond_embedding = self.bond_embedding_list[0](edge_attr1) + self.bond_embedding_list[1](edge_attr2) + self.bond_embedding_list[2](edge_attr3) | |
| return bond_embedding | |
| class MM_AtomEncoder(nn.Module): | |
| #Replaces de lookup in embedding module by one-hot-encoding | |
| # followed by matrix multiplication to allow Float type input | |
| # instead of Long type input (backpropagate through layer) | |
| def __init__(self, emb_dim): | |
| super(MM_AtomEncoder, self).__init__() | |
| self.atom_embedding_list = nn.ModuleList() | |
| self.full_atom_feature_dims = get_atom_feature_dims() | |
| for i, dim in enumerate(self.full_atom_feature_dims): | |
| emb = nn.Linear(dim, emb_dim, bias=False) | |
| nn.init.xavier_uniform_(emb.weight.data) | |
| self.atom_embedding_list.append(emb) | |
| def forward(self, x): | |
| #Change each feature in edge_attr to one-hot-vector and embed | |
| split = torch.split(x, self.full_atom_feature_dims, dim=1) | |
| atom_embedding = 0 | |
| for i in range(len(self.full_atom_feature_dims)): | |
| atom_embedding += self.atom_embedding_list[i](split[i]) | |
| return atom_embedding |