knfn081
commited on
Commit
·
6a624f6
1
Parent(s):
e3bf276
develop working retrieval app for pre-defined target
Browse files- app.py +108 -317
- figures/multi_molecules.png +0 -0
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,30 +1,37 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
|
4 |
-
import
|
|
|
5 |
import pandas as pd
|
6 |
import streamlit as st
|
7 |
-
|
8 |
|
9 |
from rdkit import Chem
|
10 |
from rdkit.Chem import Draw
|
11 |
|
12 |
sys.path.insert(0, os.path.abspath("src/"))
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
datapath = os.path.join(basepath, "data")
|
18 |
|
19 |
-
st.title('HyperDTI: Task-
|
20 |
st.markdown('')
|
21 |
st.markdown(
|
22 |
"""
|
23 |
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
|
24 |
"""
|
25 |
)
|
26 |
-
st.error('WARNING! This app is currently under development and should not be used!')
|
27 |
-
|
28 |
|
29 |
def about_page():
|
30 |
st.markdown(
|
@@ -41,355 +48,139 @@ def about_page():
|
|
41 |
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
|
42 |
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
|
43 |
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
|
44 |
-
well-known benchmarks, particularly in zero-shot settings for unseen protein targets.
|
|
|
45 |
"""
|
46 |
)
|
47 |
|
48 |
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
|
49 |
-
|
50 |
-
|
51 |
-
'''
|
52 |
-
def predict_dti():
|
53 |
-
st.markdown('## Predict drug-target interaction')
|
54 |
-
|
55 |
-
st.write('In the future this page can be used to predict interactions betweek a query drug compound and a query protein target by the HyperPCM mdoel.')
|
56 |
-
|
57 |
-
col1, col2 = st.columns(2)
|
58 |
|
59 |
-
with col1:
|
60 |
-
st.markdown('### Drug')
|
61 |
-
|
62 |
-
mol_col1, mol_col2 = st.columns(2)
|
63 |
-
|
64 |
-
with mol_col1:
|
65 |
-
smiles = st.text_input('Enter query SMILES', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O')
|
66 |
-
if smiles:
|
67 |
-
mol = Chem.MolFromSmiles(smiles)
|
68 |
-
mol_img = Chem.Draw.MolToImage(mol)
|
69 |
-
st.image(mol_img) #, width = 140)
|
70 |
-
|
71 |
-
with mol_col2:
|
72 |
-
selected_encoder = st.selectbox(
|
73 |
-
'Select encoder',('None', 'CDDD', 'MolBERT', 'Dummy')
|
74 |
-
)
|
75 |
-
if smiles:
|
76 |
-
if selected_encoder == 'CDDD':
|
77 |
-
from cddd.inference import InferenceModel
|
78 |
-
CDDD_MODEL_DIR = 'src/encoders/cddd'
|
79 |
-
cddd_model = InferenceModel(CDDD_MODEL_DIR)
|
80 |
-
drug_embedding = cddd_model.seq_to_emb([smiles])
|
81 |
-
#from huggingface_hub import hf_hub_download
|
82 |
-
#precomputed_embs = f'{selected_encoder}_encoding.csv'
|
83 |
-
#REPO_ID = "emmas96/Lenselink"
|
84 |
-
#embs_path = hf_hub_download(REPO_ID, precomputed_embs)
|
85 |
-
#embs = pd.read_csv(embs_path)
|
86 |
-
#embedding = embs[smiles]
|
87 |
-
elif selected_encoder == 'MolBERT':
|
88 |
-
from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer
|
89 |
-
from huggingface_hub import hf_hub_download
|
90 |
-
CDDD_MODEL_DIR = 'encoders/molbert/last.ckpt'
|
91 |
-
REPO_ID = "emmas96/hyperpcm"
|
92 |
-
checkpoint_path = hf_hub_download(REPO_ID, MOLBERT_MODEL_DIR)
|
93 |
-
molbert_model = MolBertFeaturizer(checkpoint_path, max_seq_len=500, embedding_type='average-1-cat-pooled')
|
94 |
-
drug_embedding = molbert_model.transform([smiles])
|
95 |
-
elif selected_encoder == 'Dummy':
|
96 |
-
drug_embedding = [0,1,2,3,4,5]
|
97 |
-
else:
|
98 |
-
drug_embedding = None
|
99 |
-
st.image('figures/molecule_encoder.png')
|
100 |
-
st.warning('Choose encoder above...')
|
101 |
-
|
102 |
-
if drug_embedding is not None:
|
103 |
-
st.image('figures/molecule_encoder_done.png')
|
104 |
-
st.success('Encoding complete.')
|
105 |
-
|
106 |
-
with col2:
|
107 |
-
st.markdown('### Target')
|
108 |
-
|
109 |
-
prot_col1, prot_col2 = st.columns(2)
|
110 |
-
|
111 |
-
with prot_col1:
|
112 |
-
sequence = st.text_input('Enter query amino-acid sequence', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
|
113 |
-
|
114 |
-
if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA':
|
115 |
-
st.image('figures/ex_protein.jpeg')
|
116 |
-
elif sequence:
|
117 |
-
st.error('Visualization comming soon...')
|
118 |
-
|
119 |
-
with prot_col2:
|
120 |
-
selected_encoder = st.selectbox(
|
121 |
-
'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5')
|
122 |
-
)
|
123 |
-
|
124 |
-
if sequence:
|
125 |
-
if selected_encoder == 'SeqVec':
|
126 |
-
with st.spinner('Encoding in progress...'):
|
127 |
-
from bio_embeddings.embed import SeqVecEmbedder
|
128 |
-
encoder = SeqVecEmbedder()
|
129 |
-
embeddings = encoder.embed_batch([sequence])
|
130 |
-
for emb in embeddings:
|
131 |
-
prot_embedding = encoder.reduce_per_protein(emb)
|
132 |
-
break
|
133 |
-
elif selected_encoder == 'UniRep':
|
134 |
-
with st.spinner('Encoding in progress...'):
|
135 |
-
from jax_unirep.utils import load_params
|
136 |
-
params = load_params()
|
137 |
-
from jax_unirep.featurize import get_reps
|
138 |
-
embedding, h_final, c_final = get_reps([sequence])
|
139 |
-
prot_embedding = embedding.mean(axis=0)
|
140 |
-
elif selected_encoder == 'ESM-1b':
|
141 |
-
with st.spinner('Encoding in progress...'):
|
142 |
-
from bio_embeddings.embed import ESM1bEmbedder
|
143 |
-
encoder = ESM1bEmbedder()
|
144 |
-
embeddings = encoder.embed_batch([sequence])
|
145 |
-
for emb in embeddings:
|
146 |
-
prot_embedding = encoder.reduce_per_protein(emb)
|
147 |
-
break
|
148 |
-
elif selected_encoder == 'ProtT5':
|
149 |
-
with st.spinner('Encoding in progress...'):
|
150 |
-
from bio_embeddings.embed import ProtTransT5XLU50Embedder
|
151 |
-
encoder = ProtTransT5XLU50Embedder()
|
152 |
-
embeddings = encoder.embed_batch([sequence])
|
153 |
-
for emb in embeddings:
|
154 |
-
prot_embedding = encoder.reduce_per_protein(emb)
|
155 |
-
break
|
156 |
-
else:
|
157 |
-
prot_embedding = None
|
158 |
-
st.image('figures/protein_encoder.png')
|
159 |
-
st.warning('Choose encoder above...')
|
160 |
-
|
161 |
-
if prot_embedding is not None:
|
162 |
-
st.image('figures/protein_encoder_done.png')
|
163 |
-
st.success('Encoding complete.')
|
164 |
-
|
165 |
-
if drug_embedding is None or prot_embedding is None:
|
166 |
-
st.warning('Waiting for both drug and target embeddings to be computed...')
|
167 |
-
else:
|
168 |
-
st.markdown('### Inference')
|
169 |
-
|
170 |
-
import time
|
171 |
-
progress_text = "HyperPCM predicts the interaction between the query drug compound toward the query protein target. Please wait."
|
172 |
-
my_bar = st.progress(0, text=progress_text)
|
173 |
-
for i in range(100):
|
174 |
-
time.sleep(0.1)
|
175 |
-
my_bar.progress(i + 1, text=progress_text)
|
176 |
-
my_bar.progress(100, text="HyperPCM predicts the interaction between the query drug compound toward the query protein target. Done.")
|
177 |
-
|
178 |
-
st.markdown('### Interaction')
|
179 |
-
st.write('HyperPCM predicts an activity of xxx pChEMBL.')
|
180 |
-
'''
|
181 |
-
|
182 |
|
183 |
def retrieval():
|
184 |
st.markdown('## Retrieve top-k most active drug compounds')
|
185 |
|
186 |
st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')
|
187 |
|
188 |
-
st.
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
with col:
|
195 |
-
for i in range(int(selected_k/5)):
|
196 |
-
mol = Chem.MolFromSmiles(dummy_smiles[j])
|
197 |
-
mol_img = Chem.Draw.MolToImage(mol)
|
198 |
-
st.image(mol_img)
|
199 |
-
|
200 |
-
'''
|
201 |
col1, col2, col3, col4 = st.columns(4)
|
202 |
-
with
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
206 |
elif sequence:
|
207 |
st.error('Visualization coming soon...')
|
208 |
|
209 |
-
with
|
210 |
selected_encoder = st.selectbox(
|
211 |
-
'Select encoder
|
212 |
)
|
213 |
if sequence:
|
214 |
if selected_encoder == 'SeqVec':
|
215 |
st.image('figures/protein_encoder_done.png')
|
216 |
with st.spinner('Encoding in progress...'):
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
st.success('Encoding complete.')
|
224 |
else:
|
225 |
-
|
226 |
st.image('figures/protein_encoder.png')
|
227 |
st.warning('Choose encoder above...')
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
st.markdown('### Inference')
|
231 |
|
232 |
-
|
233 |
-
progress_text = "HyperPCM predicts the QSAR model for the query protein target. Please wait."
|
234 |
my_bar = st.progress(0, text=progress_text)
|
235 |
-
for i in range(100):
|
236 |
-
time.sleep(0.1)
|
237 |
-
my_bar.progress(i + 1, text=progress_text)
|
238 |
-
my_bar.progress(100, text="HyperPCM predicts the QSAR model for the query protein target. Done.")
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
st.markdown('### Retrieval')
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
'Select the top-k number of drug compounds to retrieve',(5, 10, 15, 20)
|
250 |
-
)
|
251 |
|
252 |
-
st.write(f'The top-{selected_k} most active drug coupounds from {selected_dataset} predicted by HyperPCM are: ')
|
253 |
-
dummy_smiles = ['CC(=O)OC1=CC=CC=C1C(=O)O', 'COc1cc(C=O)ccc1O', 'CC(=O)Nc1ccc(O)cc1', 'CC(=O)Nc1ccc(OS(=O)(=O)O)cc1', 'CC(=O)Nc1ccc(O[C@@H]2O[C@H](C(=O)O)[C@@H](O)[C@H](O)[C@H]2O)cc1']
|
254 |
cols = st.columns(5)
|
255 |
for j, col in enumerate(cols):
|
256 |
with col:
|
257 |
for i in range(int(selected_k/5)):
|
258 |
-
mol = Chem.MolFromSmiles(
|
259 |
mol_img = Chem.Draw.MolToImage(mol)
|
260 |
-
st.image(mol_img)
|
261 |
-
'''
|
262 |
-
|
263 |
-
'''
|
264 |
-
def display_protein():
|
265 |
-
st.markdown('## Display protein structure')
|
266 |
-
st.write('In the future this page will display the ESM predicted sequence of a protein target.')
|
267 |
-
|
268 |
-
st.markdown('### Target')
|
269 |
-
sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
|
270 |
|
271 |
-
if sequence:
|
272 |
-
|
273 |
-
st.image('figures/ex_protein.jpeg')
|
274 |
-
|
275 |
-
model = esm.pretrained.esmfold_v1()
|
276 |
-
model = model.eval().cuda()
|
277 |
-
|
278 |
-
with torch.no_grad():
|
279 |
-
output = model.infer_pdb(sequence)
|
280 |
-
st.write(output)
|
281 |
-
|
282 |
-
with open("result.pdb", "w") as f:
|
283 |
-
f.write(output)
|
284 |
-
|
285 |
-
|
286 |
-
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
|
287 |
-
print(struct.b_factor.mean())
|
288 |
-
|
289 |
-
|
290 |
-
"""
|
291 |
-
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
292 |
-
batch_converter = alphabet.get_batch_converter()
|
293 |
-
batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence),])
|
294 |
-
|
295 |
-
# Extract per-residue representations (on CPU)
|
296 |
-
with torch.no_grad():
|
297 |
-
results = model(batch_tokens, repr_layers=[12], return_contacts=True)
|
298 |
-
token_representations = results["representations"][12]
|
299 |
|
300 |
-
token_list = token_representations.tolist()[0][0][0]
|
301 |
-
|
302 |
-
client = Client(url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
|
303 |
-
|
304 |
-
result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768")
|
305 |
-
|
306 |
-
result_temp_seq = []
|
307 |
-
|
308 |
-
for i in result:
|
309 |
-
# result_temp_coords = i['seq']
|
310 |
-
result_temp_seq.append(i['seq'])
|
311 |
-
|
312 |
-
result_temp_seq = list(set(result_temp_seq))
|
313 |
|
314 |
-
if st.button(result_temp_seq[0]):
|
315 |
-
print(result_temp_seq[0])
|
316 |
-
elif st.button(result_temp_seq[1]):
|
317 |
-
print(result_temp_seq[1])
|
318 |
-
elif st.button(result_temp_seq[2]):
|
319 |
-
print(result_temp_seq[2])
|
320 |
-
elif st.button(result_temp_seq[3]):
|
321 |
-
print(result_temp_seq[3])
|
322 |
-
elif st.button(result_temp_seq[4]):
|
323 |
-
print(result_temp_seq[4])
|
324 |
-
|
325 |
-
start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure)
|
326 |
-
|
327 |
-
headers = {
|
328 |
-
'Content-Type': 'application/x-www-form-urlencoded',
|
329 |
-
}
|
330 |
-
response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence)
|
331 |
-
name = sequence[:3] + sequence[-3:]
|
332 |
-
pdb_string = response.content.decode('utf-8')
|
333 |
-
with open('predicted.pdb', 'w') as f:
|
334 |
-
f.write(pdb_string)
|
335 |
-
struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"])
|
336 |
-
b_value = round(struct.b_factor.mean(), 4)
|
337 |
-
render_mol(pdb_string)
|
338 |
-
if residues_marker:
|
339 |
-
start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker]))
|
340 |
-
else:
|
341 |
-
start[3] = showmol(render_pdb(id = id_PDB))
|
342 |
-
st.session_state['xq'] = st.session_state.model
|
343 |
-
|
344 |
-
# example proteins ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"], ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"]
|
345 |
-
"""
|
346 |
-
|
347 |
-
def display_context():
|
348 |
-
st.markdown('## Display context')
|
349 |
-
st.write('In the future this page will visualize the context module for a given protein, i.e., show important features and highly ranked / related proteins from the context.')
|
350 |
-
'''
|
351 |
-
|
352 |
-
def references():
|
353 |
-
st.markdown(
|
354 |
-
'''
|
355 |
-
## References
|
356 |
-
|
357 |
-
Schmidhuber, J., “Learning to control fast-weight memories: An alternative to dynamic recurrent networks.” Neural Computation, 1992.
|
358 |
-
|
359 |
-
Davis, M. I., et al. "Comprehensive analysis of kinase inhibitor selectivity." Nature Biotechnology 29.11 (2011): 1046-1051.
|
360 |
-
|
361 |
-
Ha, D., et al. “HyperNetworks”. ICLR, 2017.
|
362 |
-
|
363 |
-
Lenselink, E. B., et al. "Beyond the hype: deep neural networks outperform established methods using a ChEMBL bioactivity benchmark set." Journal of Cheminformatics 9.1 (2017): 1-14.
|
364 |
-
|
365 |
-
Alley, E. C., et al. "Unified rational protein engineering with sequence-based deep representation learning." Nature Methods 16.12 (2019): 1315-1322.
|
366 |
-
|
367 |
-
Chang, O., et al., “Principled weight initialization for hypernetworks.” ICLR, 2019.
|
368 |
-
|
369 |
-
Heinzinger, M., et al. "Modeling aspects of the language of life through transfer-learning protein sequences." BMC Bioinformatics 20.1 (2019): 1-17.
|
370 |
-
|
371 |
-
Winter, R., et al. "Learning continuous and data-driven molecular descriptors by translating equivalent chemical representations." Chemical Science 10.6 (2019): 1692-1701.
|
372 |
-
|
373 |
-
Fabian, B., et al. "Molecular representation learning with language models and domain-relevant auxiliary tasks." Workshop for ML4Molecules (2020).
|
374 |
-
|
375 |
-
Elnaggar, A., et al. "ProtTrans: Toward understanding the language of life through self-supervised learning." IEEE Transactions on Pattern Analysis and Machine Intelligence 44 (2021): 7112–7127.
|
376 |
-
|
377 |
-
Rives, A., et al. "Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences." Proceedings of the National Academy of Sciences 118.15 (2021): e2016239118.
|
378 |
-
|
379 |
-
Kim, P. T., et al. "Unsupervised Representation Learning for Proteochemometric Modeling." International Journal of Molecular Sciences 22.23 (2021): 12882.
|
380 |
-
|
381 |
-
Schimunek, J., et al., “Context-enriched molecule representations improve few-shot drug discovery.” ICLR, 2023.
|
382 |
-
|
383 |
-
'''
|
384 |
-
)
|
385 |
-
|
386 |
page_names_to_func = {
|
387 |
-
'
|
388 |
-
|
389 |
-
'Retrieve Top-k': retrieval,
|
390 |
-
#'Display Protein': display_protein,
|
391 |
-
#'Display Context': display_context,
|
392 |
-
#'References': references
|
393 |
}
|
394 |
|
395 |
selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
|
|
|
1 |
+
|
2 |
+
import gc
|
3 |
import os
|
4 |
import sys
|
5 |
+
import torch
|
6 |
+
import pickle
|
7 |
+
import numpy as np
|
8 |
import pandas as pd
|
9 |
import streamlit as st
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
|
12 |
from rdkit import Chem
|
13 |
from rdkit.Chem import Draw
|
14 |
|
15 |
sys.path.insert(0, os.path.abspath("src/"))
|
16 |
+
from src.dataset import DrugRetrieval, collate_target
|
17 |
+
from hyper_dti.models.hyper_pcm import HyperPCM
|
18 |
|
19 |
+
base_path = os.path.dirname(__file__)
|
20 |
+
data_path = os.path.join(base_path, 'data')
|
21 |
+
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
|
24 |
+
st.set_page_config(layout="wide")
|
|
|
25 |
|
26 |
+
st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions.\n')
|
27 |
st.markdown('')
|
28 |
st.markdown(
|
29 |
"""
|
30 |
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
|
31 |
"""
|
32 |
)
|
33 |
+
#st.error('WARNING! This app is currently under development and should not be used!')
|
34 |
+
st.divider()
|
35 |
|
36 |
def about_page():
|
37 |
st.markdown(
|
|
|
48 |
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
|
49 |
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
|
50 |
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
|
51 |
+
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
|
52 |
+
model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
|
53 |
"""
|
54 |
)
|
55 |
|
56 |
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def retrieval():
|
60 |
st.markdown('## Retrieve top-k most active drug compounds')
|
61 |
|
62 |
st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')
|
63 |
|
64 |
+
col1, col2 = st.columns(2)
|
65 |
+
with col1:
|
66 |
+
st.markdown('### Query Target')
|
67 |
+
with col2:
|
68 |
+
st.markdown('### Drug Database')
|
69 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
col1, col2, col3, col4 = st.columns(4)
|
71 |
+
with col1:
|
72 |
+
ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
|
73 |
+
sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
|
74 |
+
if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA' or sequence == ex_target:
|
75 |
+
st.image('figures/ex_protein.jpeg', use_column_width='always')
|
76 |
elif sequence:
|
77 |
st.error('Visualization coming soon...')
|
78 |
|
79 |
+
with col2:
|
80 |
selected_encoder = st.selectbox(
|
81 |
+
'Select target encoder',('SeqVec', 'None')
|
82 |
)
|
83 |
if sequence:
|
84 |
if selected_encoder == 'SeqVec':
|
85 |
st.image('figures/protein_encoder_done.png')
|
86 |
with st.spinner('Encoding in progress...'):
|
87 |
+
# TODO make SeqVec embedding on the spot
|
88 |
+
|
89 |
+
with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
|
90 |
+
test_set = pickle.load(handle)
|
91 |
+
# TODO handle case if sequence not in test set
|
92 |
+
query_embedding = test_set[sequence]
|
93 |
st.success('Encoding complete.')
|
94 |
else:
|
95 |
+
query_embedding = None
|
96 |
st.image('figures/protein_encoder.png')
|
97 |
st.warning('Choose encoder above...')
|
98 |
+
|
99 |
+
with col3:
|
100 |
+
selected_database = st.selectbox(
|
101 |
+
'Select database',('Lenselink', 'None')
|
102 |
+
)
|
103 |
+
if selected_database == 'Lenselink':
|
104 |
+
c1, c2 = st.columns(2)
|
105 |
+
with c2:
|
106 |
+
st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
|
107 |
+
with st.spinner('Loading data...'):
|
108 |
+
batch_size = 64
|
109 |
+
dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
|
110 |
+
dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
|
111 |
+
st.success('Data loaded.')
|
112 |
+
else:
|
113 |
+
dataset = None
|
114 |
+
dataloader = None
|
115 |
+
st.warning('Choose database above...')
|
116 |
+
|
117 |
+
with col4:
|
118 |
+
selected_encoder = st.selectbox(
|
119 |
+
'Select drug encoder',('CDDD', 'None')
|
120 |
+
)
|
121 |
+
if selected_database:
|
122 |
+
if selected_encoder == 'CDDD':
|
123 |
+
st.image('figures/molecule_encoder_done.png')
|
124 |
+
st.success('Encoding complete.')
|
125 |
+
else:
|
126 |
+
st.image('figures/molecule_encoder.png')
|
127 |
+
st.warning('Choose encoder above...')
|
128 |
+
|
129 |
+
if query_embedding is not None:
|
130 |
st.markdown('### Inference')
|
131 |
|
132 |
+
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
|
|
|
133 |
my_bar = st.progress(0, text=progress_text)
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
gc.collect()
|
136 |
+
torch.cuda.empty_cache()
|
137 |
+
memory = dataset
|
138 |
+
model = HyperPCM(memory=memory).to(device)
|
139 |
+
model = torch.nn.DataParallel(model)
|
140 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
141 |
+
model.eval()
|
142 |
+
|
143 |
+
with torch.set_grad_enabled(False):
|
144 |
+
|
145 |
+
smiles = []
|
146 |
+
preds = []
|
147 |
+
i = 0
|
148 |
+
for batch, labels in dataloader:
|
149 |
+
pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']
|
150 |
+
|
151 |
+
logits = model(batch)
|
152 |
+
logits = logits.detach().cpu().numpy()
|
153 |
+
|
154 |
+
smiles.append(mids)
|
155 |
+
preds.append(logits)
|
156 |
+
my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
|
157 |
+
i += 1
|
158 |
+
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
|
159 |
+
|
160 |
+
|
161 |
st.markdown('### Retrieval')
|
162 |
|
163 |
+
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
|
164 |
+
|
165 |
+
results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
|
166 |
+
results = results.sort_values(by='Prediction', ascending=False)
|
167 |
+
results = results.reset_index()
|
168 |
+
|
169 |
+
print(results.head(10))
|
|
|
|
|
170 |
|
|
|
|
|
171 |
cols = st.columns(5)
|
172 |
for j, col in enumerate(cols):
|
173 |
with col:
|
174 |
for i in range(int(selected_k/5)):
|
175 |
+
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
|
176 |
mol_img = Chem.Draw.MolToImage(mol)
|
177 |
+
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
page_names_to_func = {
|
182 |
+
'Retrieval': retrieval,
|
183 |
+
'About': about_page
|
|
|
|
|
|
|
|
|
184 |
}
|
185 |
|
186 |
selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
|
figures/multi_molecules.png
ADDED
![]() |
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
#setuptools
|
2 |
rdkit #==2022.3.5
|
3 |
#torch
|
4 |
#jax_unirep
|
|
|
|
|
1 |
rdkit #==2022.3.5
|
2 |
#torch
|
3 |
#jax_unirep
|