Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch.nn import Sequential as Seq, Linear as Lin | |
from utils.data_util import get_atom_feature_dims, get_bond_feature_dims | |
############################## | |
# Basic layers | |
############################## | |
def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): | |
# activation layer | |
act = act_type.lower() | |
if act == 'relu': | |
layer = nn.ReLU(inplace) | |
elif act == 'leakyrelu': | |
layer = nn.LeakyReLU(neg_slope, inplace) | |
elif act == 'prelu': | |
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) | |
else: | |
raise NotImplementedError('activation layer [%s] is not found' % act) | |
return layer | |
def norm_layer(norm_type, nc): | |
# normalization layer 1d | |
norm = norm_type.lower() | |
if norm == 'batch': | |
layer = nn.BatchNorm1d(nc, affine=True) | |
elif norm == 'layer': | |
layer = nn.LayerNorm(nc, elementwise_affine=True) | |
elif norm == 'instance': | |
layer = nn.InstanceNorm1d(nc, affine=False) | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % norm) | |
return layer | |
class MultiSeq(Seq): | |
def __init__(self, *args): | |
super(MultiSeq, self).__init__(*args) | |
def forward(self, *inputs): | |
for module in self._modules.values(): | |
if type(inputs) == tuple: | |
inputs = module(*inputs) | |
else: | |
inputs = module(inputs) | |
return inputs | |
class MLP(Seq): | |
def __init__(self, channels, act='relu', | |
norm=None, bias=True, | |
drop=0., last_lin=False): | |
m = [] | |
for i in range(1, len(channels)): | |
m.append(Lin(channels[i - 1], channels[i], bias)) | |
if (i == len(channels) - 1) and last_lin: | |
pass | |
else: | |
if norm is not None and norm.lower() != 'none': | |
m.append(norm_layer(norm, channels[i])) | |
if act is not None and act.lower() != 'none': | |
m.append(act_layer(act)) | |
if drop > 0: | |
m.append(nn.Dropout2d(drop)) | |
self.m = m | |
super(MLP, self).__init__(*self.m) | |
class AtomEncoder(nn.Module): | |
def __init__(self, emb_dim): | |
super(AtomEncoder, self).__init__() | |
self.atom_embedding_list = nn.ModuleList() | |
full_atom_feature_dims = get_atom_feature_dims() | |
for i, dim in enumerate(full_atom_feature_dims): | |
emb = nn.Embedding(dim, emb_dim) | |
nn.init.xavier_uniform_(emb.weight.data) | |
self.atom_embedding_list.append(emb) | |
def forward(self, x): | |
x_embedding = 0 | |
for i in range(x.shape[1]): | |
x_embedding += self.atom_embedding_list[i](x[:, i]) | |
return x_embedding | |
class BondEncoder(nn.Module): | |
def __init__(self, emb_dim): | |
super(BondEncoder, self).__init__() | |
self.bond_embedding_list = nn.ModuleList() | |
full_bond_feature_dims = get_bond_feature_dims() | |
for i, dim in enumerate(full_bond_feature_dims): | |
emb = nn.Embedding(dim, emb_dim) | |
nn.init.xavier_uniform_(emb.weight.data) | |
self.bond_embedding_list.append(emb) | |
def forward(self, edge_attr): | |
bond_embedding = 0 | |
for i in range(edge_attr.shape[1]): | |
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) | |
return bond_embedding | |
class MM_BondEncoder(nn.Module): | |
#Replaces de lookup in embedding module by one-hot-encoding | |
# followed by matrix multiplication to allow Float type input | |
# instead of Long type input (backpropagate through layer) | |
def __init__(self, emb_dim): | |
super(MM_BondEncoder, self).__init__() | |
self.bond_embedding_list = nn.ModuleList() | |
self.full_bond_feature_dims = get_bond_feature_dims() | |
for i, dim in enumerate(self.full_bond_feature_dims): | |
emb = nn.Linear(dim, emb_dim, bias=False) | |
nn.init.xavier_uniform_(emb.weight.data) | |
self.bond_embedding_list.append(emb) | |
def forward(self, edge_attr): | |
#Change each feature in edge_attr to one-hot-vector and embed | |
edge_attr1, edge_attr2, edge_attr3 = torch.split(edge_attr, self.full_bond_feature_dims, dim=1) | |
bond_embedding = self.bond_embedding_list[0](edge_attr1) + self.bond_embedding_list[1](edge_attr2) + self.bond_embedding_list[2](edge_attr3) | |
return bond_embedding | |
class MM_AtomEncoder(nn.Module): | |
#Replaces de lookup in embedding module by one-hot-encoding | |
# followed by matrix multiplication to allow Float type input | |
# instead of Long type input (backpropagate through layer) | |
def __init__(self, emb_dim): | |
super(MM_AtomEncoder, self).__init__() | |
self.atom_embedding_list = nn.ModuleList() | |
self.full_atom_feature_dims = get_atom_feature_dims() | |
for i, dim in enumerate(self.full_atom_feature_dims): | |
emb = nn.Linear(dim, emb_dim, bias=False) | |
nn.init.xavier_uniform_(emb.weight.data) | |
self.atom_embedding_list.append(emb) | |
def forward(self, x): | |
#Change each feature in edge_attr to one-hot-vector and embed | |
split = torch.split(x, self.full_atom_feature_dims, dim=1) | |
atom_embedding = 0 | |
for i in range(len(self.full_atom_feature_dims)): | |
atom_embedding += self.atom_embedding_list[i](split[i]) | |
return atom_embedding |