ipd's picture
updated demo
3915777
raw
history blame
20.1 kB
PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
# Tokenizer
from transformers import BertTokenizer
# Mamba
#from mamba_ssm.models.mixer_seq_simple import MixerModel
# Data
import pandas as pd
import numpy as np
# Chemistry
from rdkit import Chem
from rdkit.Chem import PandasTools
from rdkit.Chem import Descriptors
PandasTools.RenderImagesInAllDataFrames(True)
# Standard library
import regex as re
import random
import os
import gc
from tqdm import tqdm
from huggingface_hub import hf_hub_download
tqdm.pandas()
# function to canonicalize SMILES
def normalize_smiles(smi, canonical=True, isomeric=False):
try:
normalized = Chem.MolToSmiles(
Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
)
except:
normalized = None
return normalized
class MolTranBertTokenizer(BertTokenizer):
def __init__(self, vocab_file: str = '',
do_lower_case=False,
unk_token='<pad>',
sep_token='<eos>',
pad_token='<pad>',
cls_token='<bos>',
mask_token='<mask>',
**kwargs):
super().__init__(vocab_file,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs)
self.regex_tokenizer = re.compile(PATTERN)
self.wordpiece_tokenizer = None
self.basic_tokenizer = None
with open(vocab_file) as f:
self.padding_idx = f.readlines().index(pad_token+'\n')
def _tokenize(self, text):
split_tokens = self.regex_tokenizer.findall(text)
return split_tokens
def convert_idx_to_tokens(self, idx_tensor):
tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
return tokens
def convert_tokens_to_string(self, tokens):
stopwords = ['<bos>', '<eos>']
clean_tokens = [word for word in tokens if word not in stopwords]
out_string = ''.join(clean_tokens)
return out_string
def get_padding_idx(self):
return self.padding_idx
def idx_to_smiles(self, torch_model, idx):
'''Convert tokens idx back to SMILES text'''
rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
return decoded_smiles
class AutoEncoderLayer(nn.Module):
def __init__(self, feature_size, latent_size):
super().__init__()
self.encoder = self.Encoder(feature_size, latent_size)
self.decoder = self.Decoder(feature_size, latent_size)
class Encoder(nn.Module):
def __init__(self, feature_size, latent_size):
super().__init__()
self.is_cuda_available = torch.cuda.is_available()
self.fc1 = nn.Linear(feature_size, latent_size)
self.ln_f = nn.LayerNorm(latent_size)
self.lat = nn.Linear(latent_size, latent_size, bias=False)
def forward(self, x):
if self.is_cuda_available:
self.fc1.cuda()
self.ln_f.cuda()
self.lat.cuda()
x = x.cuda()
x = F.gelu(self.fc1(x))
x = self.ln_f(x)
x = self.lat(x)
return x # -> (N, D)
class Decoder(nn.Module):
def __init__(self, feature_size, latent_size):
super().__init__()
self.is_cuda_available = torch.cuda.is_available()
self.fc1 = nn.Linear(latent_size, latent_size)
self.ln_f = nn.LayerNorm(latent_size)
self.rec = nn.Linear(latent_size, feature_size, bias=False)
def forward(self, x):
if self.is_cuda_available:
self.fc1.cuda()
self.ln_f.cuda()
self.rec.cuda()
x = x.cuda()
x = F.gelu(self.fc1(x))
x = self.ln_f(x)
x = self.rec(x)
return x # -> (N, L*D)
class LangLayer(nn.Module):
def __init__(self, n_embd, n_vocab):
super().__init__()
self.is_cuda_available = torch.cuda.is_available()
self.embed = nn.Linear(n_embd, n_embd)
self.ln_f = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, n_vocab, bias=False)
def forward(self, tensor):
if self.is_cuda_available:
self.embed.cuda()
self.ln_f.cuda()
self.head.cuda()
tensor = tensor.cuda()
tensor = self.embed(tensor)
tensor = F.gelu(tensor)
tensor = self.ln_f(tensor)
tensor = self.head(tensor)
return tensor
class Net(nn.Module):
def __init__(self, smiles_embed_dim, n_output=1, dropout=0.2):
super().__init__()
self.desc_skip_connection = True
self.fc1 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.relu1 = nn.GELU()
self.fc2 = nn.Linear(smiles_embed_dim, smiles_embed_dim)
self.dropout2 = nn.Dropout(dropout)
self.relu2 = nn.GELU()
self.final = nn.Linear(smiles_embed_dim, n_output)
def forward(self, smiles_emb, multitask=False):
x_out = self.fc1(smiles_emb)
x_out = self.dropout1(x_out)
x_out = self.relu1(x_out)
if self.desc_skip_connection is True:
x_out = x_out + smiles_emb
z = self.fc2(x_out)
z = self.dropout2(z)
z = self.relu2(z)
if self.desc_skip_connection is True:
z = self.final(z + x_out)
else:
z = self.final(z)
if multitask:
return F.sigmoid(z)
return z
class MolEncoder(nn.Module):
def __init__(self, config, n_vocab):
super().__init__()
self.config = config
self.mamba = MixerModel(
d_model=config['n_embd'],
n_layer=config['n_layer'],
ssm_cfg=dict(
d_state=config['d_state'],
d_conv=config['d_conv'],
expand=config['expand_factor'],
dt_rank=config['dt_rank'],
dt_min=config['dt_min'],
dt_max=config['dt_max'],
dt_init=config['dt_init'],
dt_scale=config['dt_scale'],
dt_init_floor=config['dt_init_floor'],
conv_bias=bool(config['conv_bias']),
bias=bool(config['bias']),
),
vocab_size=n_vocab,
rms_norm=False,
fused_add_norm=False,
)
# classification
self.lang_model = LangLayer(config['n_embd'], n_vocab)
def forward(self, idx, mask):
x = self.mamba(idx)
# add padding
token_embeddings = x
input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
mask_embeddings = (token_embeddings * input_mask_expanded)
token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.config['max_len'] - mask_embeddings.shape[1]), value=0)
return token_embeddings
class MoLDecoder(nn.Module):
def __init__(self, n_vocab, max_len, n_embd, n_gpu=None):
super(MoLDecoder, self).__init__()
self.max_len = max_len
self.n_embd = n_embd
self.n_gpu = n_gpu
self.autoencoder = AutoEncoderLayer(n_embd*max_len, n_embd)
self.lm_head = LangLayer(n_embd, n_vocab)
class Smi_ssed(nn.Module):
"""granite.materials.smi-ssed 336M Parameters"""
def __init__(self, tokenizer, config=None):
super(Smi_ssed, self).__init__()
# configuration
self.config = config
self.tokenizer = tokenizer
self.padding_idx = tokenizer.get_padding_idx()
self.n_vocab = len(self.tokenizer.vocab)
self.is_cuda_available = torch.cuda.is_available()
# instantiate modules
if self.config:
self.encoder = MolEncoder(self.config, self.n_vocab)
self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['d_dropout'])
def load_checkpoint(self, ckpt_path):
# load checkpoint file
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
# load hyparameters
self.config = checkpoint['hparams']
self.max_len = self.config['max_len']
self.n_embd = self.config['n_embd']
self._set_seed(self.config['seed'])
# instantiate modules
self.encoder = MolEncoder(self.config, self.n_vocab)
self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else 1, dropout=self.config['d_dropout'])
# load weights
self.load_state_dict(checkpoint['MODEL_STATE'], strict=False)
# load RNG states each time the model and states are loaded from checkpoint
if 'rng' in self.config:
rng = self.config['rng']
for key, value in rng.items():
if key =='torch_state':
torch.set_rng_state(value.cpu())
elif key =='cuda_state':
torch.cuda.set_rng_state(value.cpu())
elif key =='numpy_state':
np.random.set_state(value)
elif key =='python_state':
random.setstate(value)
else:
print('unrecognized state')
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_seed(self, value):
print('Random Seed:', value)
random.seed(value)
torch.manual_seed(value)
torch.cuda.manual_seed(value)
torch.cuda.manual_seed_all(value)
np.random.seed(value)
cudnn.deterministic = True
cudnn.benchmark = False
def forward(self, smiles, batch_size=100):
return self.decode(self.encode(smiles, batch_size=batch_size, return_torch=True))
def tokenize(self, smiles):
"""Tokenize a string into tokens."""
if isinstance(smiles, str):
batch = [smiles]
else:
batch = smiles
tokens = self.tokenizer(
batch,
padding=True,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
max_length=self.max_len,
)
idx = tokens['input_ids'].clone().detach()
mask = tokens['attention_mask'].clone().detach()
if self.is_cuda_available:
return idx.cuda(), mask.cuda()
return idx, mask
def extract_all(self, smiles):
"""Extract all elements from each part of smi-ssed. Be careful."""
# evaluation mode
self.encoder.eval()
self.decoder.eval()
if self.is_cuda_available:
self.encoder.cuda()
self.decoder.cuda()
# handle single str or a list of str
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
# SMILES normalization
smiles = smiles.apply(normalize_smiles)
null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
smiles = smiles.dropna()
# tokenizer
idx, mask = self.tokenize(smiles)
###########
# Encoder #
###########
# encoder forward
x = self.encoder.mamba(idx)
# mean pooling
token_embeddings = x
input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
true_set = sum_embeddings / sum_mask # DO NOT USE THIS FOR DOWNSTREAM TASKS, USE `pred_set` INSTEAD
# add padding
mask_embeddings = (token_embeddings * input_mask_expanded)
token_embeddings = F.pad(mask_embeddings, pad=(0, 0, 0, self.max_len - mask_embeddings.shape[1]), value=0)
idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=2)
true_ids = idx
true_cte = token_embeddings
true_cte = true_cte.view(-1, self.max_len*self.n_embd)
###########
# Decoder #
###########
# CTE autoencoder
pred_set = self.decoder.autoencoder.encoder(true_cte)
pred_cte = self.decoder.autoencoder.decoder(pred_set)
# reconstruct tokens
pred_ids = self.decoder.lm_head(pred_cte.view(-1, self.max_len, self.n_embd))
pred_ids = torch.argmax(pred_ids, axis=-1)
# replacing null SMILES with NaN values
for idx in null_idx:
true_ids = true_ids.tolist()
pred_ids = pred_ids.tolist()
true_cte = true_cte.tolist()
pred_cte = pred_cte.tolist()
true_set = true_set.tolist()
pred_set = pred_set.tolist()
true_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
pred_ids.insert(idx, np.array([np.nan]*self.config['max_len']))
true_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
pred_cte.insert(idx, np.array([np.nan] * (self.config['max_len']*self.config['n_embd'])))
true_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
pred_set.insert(idx, np.array([np.nan]*self.config['n_embd']))
if len(null_idx) > 0:
true_ids = torch.tensor(true_ids)
pred_ids = torch.tensor(pred_ids)
true_cte = torch.tensor(true_cte)
pred_cte = torch.tensor(pred_cte)
true_set = torch.tensor(true_set)
pred_set = torch.tensor(pred_set)
return ((true_ids, pred_ids), # tokens
(true_cte, pred_cte), # token embeddings
(true_set, pred_set)) # smiles embeddings
def extract_embeddings(self, smiles):
"""Extract token and SMILES embeddings."""
# evaluation mode
self.encoder.eval()
if self.is_cuda_available:
self.encoder.cuda()
# tokenizer
idx, mask = self.tokenize(smiles)
# encoder forward
token_embeddings = self.encoder(idx, mask)
# aggregate token embeddings (similar to mean pooling)
# CAUTION: use the embeddings from the autoencoder.
smiles_embeddings = self.decoder.autoencoder.encoder(token_embeddings.view(-1, self.max_len*self.n_embd))
# add padding
idx = F.pad(idx, pad=(0, self.max_len - idx.shape[1], 0, 0), value=self.padding_idx)
return idx, token_embeddings, smiles_embeddings
def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False):
"""Extract efficiently SMILES embeddings per batches."""
# TODO: remove useCuda argument
# handle single str or a list of str
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
# SMILES normalization
smiles = smiles.apply(normalize_smiles)
null_idx = smiles[smiles.isnull()].index.to_list() # keep track of SMILES that cannot normalize
smiles = smiles.dropna()
# process in batches
n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
embeddings = [
self.extract_embeddings(list(batch))[2].cpu().detach().numpy()
for batch in tqdm(np.array_split(smiles, n_split))
]
flat_list = [item for sublist in embeddings for item in sublist]
# clear GPU memory
if self.is_cuda_available:
torch.cuda.empty_cache()
gc.collect()
# replacing null SMILES with NaN values
for idx in null_idx:
flat_list.insert(idx, np.array([np.nan]*self.config['n_embd']))
flat_list = np.asarray(flat_list)
if return_torch:
return torch.tensor(flat_list)
return pd.DataFrame(flat_list)
def embd_to_smiles(self, embds):
# evaluation mode
self.decoder.eval()
if self.is_cuda_available:
self.decoder.cuda()
# reconstruct token embeddings
pred_token_embds = self.decoder.autoencoder.decoder(embds)
# reconstruct tokens
pred_idx = self.decoder.lm_head(pred_token_embds.view(-1, self.max_len, self.n_embd))
pred_idx = torch.argmax(pred_idx, axis=-1).cpu().detach().numpy()
# convert idx to tokens
pred_smiles = []
for i in range(pred_idx.shape[0]):
idx = pred_idx[i]
smiles = self.tokenizer.idx_to_smiles(self, idx)
smiles = smiles.replace('<bos>', '') # begin token
smiles = smiles.replace('<eos>', '') # end token
smiles = smiles.replace('<pad>', '') # pad token
pred_smiles.append(smiles)
return pred_smiles
def decode(self, smiles_embeddings, batch_size=100):
"""Decode SMILES embeddings back to SMILES."""
# process in batches
n_split = smiles_embeddings.shape[0] // batch_size if smiles_embeddings.shape[0] >= batch_size else smiles_embeddings.shape[0]
embeddings = [
self.embd_to_smiles(batch) for batch in tqdm(np.array_split(smiles_embeddings, n_split))
]
pred_smiles = [item for sublist in embeddings for item in sublist]
# clear GPU memory
if self.is_cuda_available:
torch.cuda.empty_cache()
gc.collect()
return pred_smiles
def __str__(self):
return 'smi-ssed'
def load_smi_ssed(folder="./smi_ssed",
ckpt_filename="smi-ssed_130.pt",
vocab_filename="bert_vocab_curated.txt"
):
repo_id = "ibm/materials.smi_ssed"
filename = "bert_vocab_curated.txt"
vocab_filename = hf_hub_download(repo_id=repo_id, filename=filename)
tokenizer = MolTranBertTokenizer(vocab_filename)
model = Smi_ssed(tokenizer)
filename = "smi_ssed_130.pt"
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
model.load_checkpoint(file_path)
model.eval()
#tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
#model = Smi_ssed(tokenizer)
#model.load_checkpoint(os.path.join(folder, ckpt_filename))
#model.eval()
print('Vocab size:', len(tokenizer.vocab))
print(f'[INFERENCE MODE - {str(model)}]')
return model