vumichien commited on
Commit
4c86b48
·
1 Parent(s): 3229a1d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import gradio as gr
3
+ from rdkit import Chem, RDLogger
4
+ from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow import keras
8
+
9
+ # Config
10
+ class Featurizer:
11
+ def __init__(self, allowable_sets):
12
+ self.dim = 0
13
+ self.features_mapping = {}
14
+ for k, s in allowable_sets.items():
15
+ s = sorted(list(s))
16
+ self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim)))
17
+ self.dim += len(s)
18
+
19
+ def encode(self, inputs):
20
+ output = np.zeros((self.dim,))
21
+ for name_feature, feature_mapping in self.features_mapping.items():
22
+ feature = getattr(self, name_feature)(inputs)
23
+ if feature not in feature_mapping:
24
+ continue
25
+ output[feature_mapping[feature]] = 1.0
26
+ return output
27
+
28
+
29
+ class AtomFeaturizer(Featurizer):
30
+ def __init__(self, allowable_sets):
31
+ super().__init__(allowable_sets)
32
+
33
+ def symbol(self, atom):
34
+ return atom.GetSymbol()
35
+
36
+ def n_valence(self, atom):
37
+ return atom.GetTotalValence()
38
+
39
+ def n_hydrogens(self, atom):
40
+ return atom.GetTotalNumHs()
41
+
42
+ def hybridization(self, atom):
43
+ return atom.GetHybridization().name.lower()
44
+
45
+
46
+ class BondFeaturizer(Featurizer):
47
+ def __init__(self, allowable_sets):
48
+ super().__init__(allowable_sets)
49
+ self.dim += 1
50
+
51
+ def encode(self, bond):
52
+ output = np.zeros((self.dim,))
53
+ if bond is None:
54
+ output[-1] = 1.0
55
+ return output
56
+ output = super().encode(bond)
57
+ return output
58
+
59
+ def bond_type(self, bond):
60
+ return bond.GetBondType().name.lower()
61
+
62
+ def conjugated(self, bond):
63
+ return bond.GetIsConjugated()
64
+
65
+
66
+ atom_featurizer = AtomFeaturizer(
67
+ allowable_sets={
68
+ "symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
69
+ "n_valence": {0, 1, 2, 3, 4, 5, 6},
70
+ "n_hydrogens": {0, 1, 2, 3, 4},
71
+ "hybridization": {"s", "sp", "sp2", "sp3"},
72
+ }
73
+ )
74
+
75
+ bond_featurizer = BondFeaturizer(
76
+ allowable_sets={
77
+ "bond_type": {"single", "double", "triple", "aromatic"},
78
+ "conjugated": {True, False},
79
+ }
80
+ )
81
+
82
+ def molecule_from_smiles(smiles):
83
+ # MolFromSmiles(m, sanitize=True) should be equivalent to
84
+ # MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...)
85
+ molecule = Chem.MolFromSmiles(smiles, sanitize=False)
86
+
87
+ # If sanitization is unsuccessful, catch the error, and try again without
88
+ # the sanitization step that caused the error
89
+ flag = Chem.SanitizeMol(molecule, catchErrors=True)
90
+ if flag != Chem.SanitizeFlags.SANITIZE_NONE:
91
+ Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
92
+
93
+ Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
94
+ return molecule
95
+
96
+
97
+ def graph_from_molecule(molecule):
98
+ # Initialize graph
99
+ atom_features = []
100
+ bond_features = []
101
+ pair_indices = []
102
+
103
+ for atom in molecule.GetAtoms():
104
+ atom_features.append(atom_featurizer.encode(atom))
105
+
106
+ # Add self-loops
107
+ pair_indices.append([atom.GetIdx(), atom.GetIdx()])
108
+ bond_features.append(bond_featurizer.encode(None))
109
+
110
+ for neighbor in atom.GetNeighbors():
111
+ bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
112
+ pair_indices.append([atom.GetIdx(), neighbor.GetIdx()])
113
+ bond_features.append(bond_featurizer.encode(bond))
114
+
115
+ return np.array(atom_features), np.array(bond_features), np.array(pair_indices)
116
+
117
+
118
+ def graphs_from_smiles(smiles_list):
119
+ # Initialize graphs
120
+ atom_features_list = []
121
+ bond_features_list = []
122
+ pair_indices_list = []
123
+
124
+ for smiles in smiles_list:
125
+ molecule = molecule_from_smiles(smiles)
126
+ atom_features, bond_features, pair_indices = graph_from_molecule(molecule)
127
+
128
+ atom_features_list.append(atom_features)
129
+ bond_features_list.append(bond_features)
130
+ pair_indices_list.append(pair_indices)
131
+
132
+ # Convert lists to ragged tensors for tf.data.Dataset later on
133
+ return (
134
+ tf.ragged.constant(atom_features_list, dtype=tf.float32),
135
+ tf.ragged.constant(bond_features_list, dtype=tf.float32),
136
+ tf.ragged.constant(pair_indices_list, dtype=tf.int64),
137
+ )
138
+
139
+ model = from_pretrained_keras("keras-io/wgan-molecular-graphs")
140
+
141
+ def predict(smiles, label):
142
+ molecules = [molecule_from_smiles(smiles)]
143
+ input = graphs_from_smiles([smiles])
144
+ label = pd.Series([label])
145
+ test_dataset = MPNNDataset(input, label)
146
+ y_pred = tf.squeeze(model.predict(test_dataset), axis=1)
147
+ legends = [f"y_true/y_pred = {label[i]}/{y_pred[i]:.2f}" for i in range(len(label))]
148
+ MolsToGridImage(molecules, molsPerRow=1, legends=legends, returnPNG=False, subImgSize=(550, 550)).save("img.png")
149
+ return 'img.png'
150
+
151
+ inputs = [
152
+ gr.Textbox(label='Smiles of molecular'),
153
+ gr.Textbox(label='Molecular permeability')
154
+ ]
155
+
156
+ examples = [
157
+ ["CO/N=C(C(=O)N[C@H]1[C@H]2SCC(=C(N2C1=O)C(O)=O)C)/c3csc(N)n3", 0],
158
+ ["[C@H]37[C@H]2[C@@]([C@](C(COC(C1=CC(=CC=C1)[S](O)(=O)=O)=O)=O)(O)[C@@H](C2)C)(C[C@@H]([C@@H]3[C@@]4(C(=CC5=C(C4)C=N[N]5C6=CC=CC=C6)C(=C7)C)C)O)C", 1],
159
+ ["CNCCCC2(C)C(=O)N(c1ccccc1)c3ccccc23", 1],
160
+ ["O.N[C@@H](C(=O)NC1C2CCC(=C(N2C1=O)C(O)=O)Cl)c3ccccc3", 0],
161
+ ["[C@@]4([C@@]3([C@H]([C@H]2[C@@H]([C@@]1(C(=CC(=O)CC1)CC2)C)[C@H](C3)O)CC4)C)(C(COC(C)=O)=O)OC(CC)=O", 1],
162
+ ["[C@]34([C@H](C2[C@@](F)([C@@]1(C(=CC(=O)C=C1)[C@@H](F)C2)C)[C@@H](O)C3)C[C@H]5OC(O[C@@]45C(=O)COC(=O)C6CC6)(C)C)C", 1]
163
+
164
+ ]
165
+ gr.Interface(
166
+ fn=predict,
167
+ title="Predict blood-brain barrier permeability of molecular",
168
+ description = "Message-passing neural network (MPNN) for molecular property prediction",
169
+ inputs=inputs,
170
+ examples=examples,
171
+ outputs="image",
172
+ article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the keras example from <a href=\"https://keras.io/examples/graph/mpnn-molecular-graphs/\">Alexander Kensert</a>",
173
+ ).launch(debug=True, enable_queue=True)