Spaces:
Runtime error
Runtime error
from huggingface_hub import from_pretrained_keras | |
import gradio as gr | |
from rdkit import Chem, RDLogger | |
from rdkit.Chem.Draw import MolsToGridImage | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
import pandas as pd | |
# Config | |
class Featurizer: | |
def __init__(self, allowable_sets): | |
self.dim = 0 | |
self.features_mapping = {} | |
for k, s in allowable_sets.items(): | |
s = sorted(list(s)) | |
self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim))) | |
self.dim += len(s) | |
def encode(self, inputs): | |
output = np.zeros((self.dim,)) | |
for name_feature, feature_mapping in self.features_mapping.items(): | |
feature = getattr(self, name_feature)(inputs) | |
if feature not in feature_mapping: | |
continue | |
output[feature_mapping[feature]] = 1.0 | |
return output | |
class AtomFeaturizer(Featurizer): | |
def __init__(self, allowable_sets): | |
super().__init__(allowable_sets) | |
def symbol(self, atom): | |
return atom.GetSymbol() | |
def n_valence(self, atom): | |
return atom.GetTotalValence() | |
def n_hydrogens(self, atom): | |
return atom.GetTotalNumHs() | |
def hybridization(self, atom): | |
return atom.GetHybridization().name.lower() | |
class BondFeaturizer(Featurizer): | |
def __init__(self, allowable_sets): | |
super().__init__(allowable_sets) | |
self.dim += 1 | |
def encode(self, bond): | |
output = np.zeros((self.dim,)) | |
if bond is None: | |
output[-1] = 1.0 | |
return output | |
output = super().encode(bond) | |
return output | |
def bond_type(self, bond): | |
return bond.GetBondType().name.lower() | |
def conjugated(self, bond): | |
return bond.GetIsConjugated() | |
atom_featurizer = AtomFeaturizer( | |
allowable_sets={ | |
"symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, | |
"n_valence": {0, 1, 2, 3, 4, 5, 6}, | |
"n_hydrogens": {0, 1, 2, 3, 4}, | |
"hybridization": {"s", "sp", "sp2", "sp3"}, | |
} | |
) | |
bond_featurizer = BondFeaturizer( | |
allowable_sets={ | |
"bond_type": {"single", "double", "triple", "aromatic"}, | |
"conjugated": {True, False}, | |
} | |
) | |
def molecule_from_smiles(smiles): | |
# MolFromSmiles(m, sanitize=True) should be equivalent to | |
# MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...) | |
molecule = Chem.MolFromSmiles(smiles, sanitize=False) | |
# If sanitization is unsuccessful, catch the error, and try again without | |
# the sanitization step that caused the error | |
flag = Chem.SanitizeMol(molecule, catchErrors=True) | |
if flag != Chem.SanitizeFlags.SANITIZE_NONE: | |
Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) | |
Chem.AssignStereochemistry(molecule, cleanIt=True, force=True) | |
return molecule | |
def graph_from_molecule(molecule): | |
# Initialize graph | |
atom_features = [] | |
bond_features = [] | |
pair_indices = [] | |
for atom in molecule.GetAtoms(): | |
atom_features.append(atom_featurizer.encode(atom)) | |
# Add self-loops | |
pair_indices.append([atom.GetIdx(), atom.GetIdx()]) | |
bond_features.append(bond_featurizer.encode(None)) | |
for neighbor in atom.GetNeighbors(): | |
bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()) | |
pair_indices.append([atom.GetIdx(), neighbor.GetIdx()]) | |
bond_features.append(bond_featurizer.encode(bond)) | |
return np.array(atom_features), np.array(bond_features), np.array(pair_indices) | |
def graphs_from_smiles(smiles_list): | |
# Initialize graphs | |
atom_features_list = [] | |
bond_features_list = [] | |
pair_indices_list = [] | |
for smiles in smiles_list: | |
molecule = molecule_from_smiles(smiles) | |
atom_features, bond_features, pair_indices = graph_from_molecule(molecule) | |
atom_features_list.append(atom_features) | |
bond_features_list.append(bond_features) | |
pair_indices_list.append(pair_indices) | |
# Convert lists to ragged tensors for tf.data.Dataset later on | |
return ( | |
tf.ragged.constant(atom_features_list, dtype=tf.float32), | |
tf.ragged.constant(bond_features_list, dtype=tf.float32), | |
tf.ragged.constant(pair_indices_list, dtype=tf.int64), | |
) | |
def prepare_batch(x_batch, y_batch): | |
"""Merges (sub)graphs of batch into a single global (disconnected) graph | |
""" | |
atom_features, bond_features, pair_indices = x_batch | |
# Obtain number of atoms and bonds for each graph (molecule) | |
num_atoms = atom_features.row_lengths() | |
num_bonds = bond_features.row_lengths() | |
# Obtain partition indices (molecule_indicator), which will be used to | |
# gather (sub)graphs from global graph in model later on | |
molecule_indices = tf.range(len(num_atoms)) | |
molecule_indicator = tf.repeat(molecule_indices, num_atoms) | |
# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to | |
# 'pair_indices' (and merging ragged tensors) actualizes the global graph | |
gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:]) | |
increment = tf.cumsum(num_atoms[:-1]) | |
increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)]) | |
pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
pair_indices = pair_indices + increment[:, tf.newaxis] | |
atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch | |
def MPNNDataset(X, y, batch_size=32, shuffle=False): | |
dataset = tf.data.Dataset.from_tensor_slices((X, (y))) | |
if shuffle: | |
dataset = dataset.shuffle(1024) | |
return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1) | |
model = from_pretrained_keras("keras-io/MPNN-for-molecular-property-prediction") | |
def predict(smiles, label): | |
molecules = [molecule_from_smiles(smiles)] | |
input = graphs_from_smiles([smiles]) | |
label = pd.Series([label]) | |
test_dataset = MPNNDataset(input, label) | |
y_pred = tf.squeeze(model.predict(test_dataset), axis=1) | |
legends = [f"y_true/y_pred = {label[i]}/{y_pred[i]:.2f}" for i in range(len(label))] | |
MolsToGridImage(molecules, molsPerRow=1, legends=legends, returnPNG=False, subImgSize=(650, 650)).save("img.png") | |
return 'img.png' | |
inputs = [ | |
gr.Textbox(label='Smiles of molecular'), | |
gr.Textbox(label='Molecular permeability') | |
] | |
examples = [ | |
["CO/N=C(C(=O)N[C@H]1[C@H]2SCC(=C(N2C1=O)C(O)=O)C)/c3csc(N)n3", 0], | |
["[C@H]37[C@H]2[C@@]([C@](C(COC(C1=CC(=CC=C1)[S](O)(=O)=O)=O)=O)(O)[C@@H](C2)C)(C[C@@H]([C@@H]3[C@@]4(C(=CC5=C(C4)C=N[N]5C6=CC=CC=C6)C(=C7)C)C)O)C", 1], | |
["CNCCCC2(C)C(=O)N(c1ccccc1)c3ccccc23", 1], | |
["O.N[C@@H](C(=O)NC1C2CCC(=C(N2C1=O)C(O)=O)Cl)c3ccccc3", 0], | |
["[C@@]4([C@@]3([C@H]([C@H]2[C@@H]([C@@]1(C(=CC(=O)CC1)CC2)C)[C@H](C3)O)CC4)C)(C(COC(C)=O)=O)OC(CC)=O", 1], | |
["[C@]34([C@H](C2[C@@](F)([C@@]1(C(=CC(=O)C=C1)[C@@H](F)C2)C)[C@@H](O)C3)C[C@H]5OC(O[C@@]45C(=O)COC(=O)C6CC6)(C)C)C", 1] | |
] | |
gr.Interface( | |
fn=predict, | |
title="Predict blood-brain barrier permeability of molecular", | |
description = "Message-passing neural network (MPNN) for molecular property prediction", | |
inputs=inputs, | |
examples=examples, | |
outputs="image", | |
article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the keras example from <a href=\"https://keras.io/examples/graph/mpnn-molecular-graphs/\">Alexander Kensert</a>", | |
).launch(debug=False, enable_queue=True) |