Spaces:
Sleeping
Sleeping
import numpy as np | |
import multiprocessing | |
import rdkit | |
import rdkit.Chem as Chem | |
rdkit.RDLogger.DisableLog('rdApp.*') | |
from SmilesPE.pretokenizer import atomwise_tokenizer | |
def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True): | |
if type(smiles) is not str or smiles == '': | |
return '', False | |
if ignore_cistrans: | |
smiles = smiles.replace('/', '').replace('\\', '') | |
if replace_rgroup: | |
tokens = atomwise_tokenizer(smiles) | |
for j, token in enumerate(tokens): | |
if token[0] == '[' and token[-1] == ']': | |
symbol = token[1:-1] | |
if symbol[0] == 'R' and symbol[1:].isdigit(): | |
tokens[j] = f'[{symbol[1:]}*]' | |
elif Chem.AtomFromSmiles(token) is None: | |
tokens[j] = '*' | |
smiles = ''.join(tokens) | |
try: | |
canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral)) | |
success = True | |
except: | |
canon_smiles = smiles | |
success = False | |
return canon_smiles, success | |
def convert_smiles_to_canonsmiles( | |
smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16): | |
with multiprocessing.Pool(num_workers) as p: | |
results = p.starmap(canonicalize_smiles, | |
[(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list], | |
chunksize=128) | |
canon_smiles, success = zip(*results) | |
return list(canon_smiles), np.mean(success) | |
class SmilesEvaluator(object): | |
def __init__(self, gold_smiles, num_workers=16): | |
self.gold_smiles = gold_smiles | |
self.gold_canon_smiles, self.gold_valid = convert_smiles_to_canonsmiles(gold_smiles, num_workers=num_workers) | |
self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles, | |
ignore_chiral=True, num_workers=num_workers) | |
self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles, | |
ignore_cistrans=True, num_workers=num_workers) | |
self.gold_canon_smiles = self._replace_empty(self.gold_canon_smiles) | |
self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral) | |
self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans) | |
def _replace_empty(self, smiles_list): | |
"""Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty.""" | |
return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "<empty>" | |
for smiles in smiles_list] | |
def evaluate(self, pred_smiles): | |
results = {} | |
results['gold_valid'] = self.gold_valid | |
# Canon SMILES | |
pred_canon_smiles, pred_valid = convert_smiles_to_canonsmiles(pred_smiles) | |
results['canon_smiles_em'] = (np.array(self.gold_canon_smiles) == np.array(pred_canon_smiles)).mean() | |
results['pred_valid'] = pred_valid | |
# Ignore chirality (Graph exact match) | |
pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True) | |
results['graph'] = (np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)).mean() | |
# Ignore double bond cis/trans | |
pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True) | |
results['canon_smiles'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)).mean() | |
# Evaluate on molecules with chiral centers | |
chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g]) | |
results['chiral_ratio'] = len(chiral) / len(self.gold_smiles) | |
results['chiral'] = (chiral[:, 0] == chiral[:, 1]).mean() if len(chiral) > 0 else -1 | |
return results | |