knfn081 commited on
Commit
6a624f6
·
1 Parent(s): e3bf276

develop working retrieval app for pre-defined target

Browse files
Files changed (3) hide show
  1. app.py +108 -317
  2. figures/multi_molecules.png +0 -0
  3. requirements.txt +0 -1
app.py CHANGED
@@ -1,30 +1,37 @@
 
 
1
  import os
2
  import sys
3
- #import torch
4
- import numpy as np
 
5
  import pandas as pd
6
  import streamlit as st
7
- #import esm
8
 
9
  from rdkit import Chem
10
  from rdkit.Chem import Draw
11
 
12
  sys.path.insert(0, os.path.abspath("src/"))
 
 
13
 
14
- st.set_page_config(layout="wide")
 
 
 
15
 
16
- basepath = os.path.dirname(__file__)
17
- datapath = os.path.join(basepath, "data")
18
 
19
- st.title('HyperDTI: Task-conditioned modeling of drug-target interactions.\n')
20
  st.markdown('')
21
  st.markdown(
22
  """
23
  🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
24
  """
25
  )
26
- st.error('WARNING! This app is currently under development and should not be used!')
27
-
28
 
29
  def about_page():
30
  st.markdown(
@@ -41,355 +48,139 @@ def about_page():
41
  In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
42
  predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
43
  a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
44
- well-known benchmarks, particularly in zero-shot settings for unseen protein targets.
 
45
  """
46
  )
47
 
48
  st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
49
-
50
-
51
- '''
52
- def predict_dti():
53
- st.markdown('## Predict drug-target interaction')
54
-
55
- st.write('In the future this page can be used to predict interactions betweek a query drug compound and a query protein target by the HyperPCM mdoel.')
56
-
57
- col1, col2 = st.columns(2)
58
 
59
- with col1:
60
- st.markdown('### Drug')
61
-
62
- mol_col1, mol_col2 = st.columns(2)
63
-
64
- with mol_col1:
65
- smiles = st.text_input('Enter query SMILES', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O')
66
- if smiles:
67
- mol = Chem.MolFromSmiles(smiles)
68
- mol_img = Chem.Draw.MolToImage(mol)
69
- st.image(mol_img) #, width = 140)
70
-
71
- with mol_col2:
72
- selected_encoder = st.selectbox(
73
- 'Select encoder',('None', 'CDDD', 'MolBERT', 'Dummy')
74
- )
75
- if smiles:
76
- if selected_encoder == 'CDDD':
77
- from cddd.inference import InferenceModel
78
- CDDD_MODEL_DIR = 'src/encoders/cddd'
79
- cddd_model = InferenceModel(CDDD_MODEL_DIR)
80
- drug_embedding = cddd_model.seq_to_emb([smiles])
81
- #from huggingface_hub import hf_hub_download
82
- #precomputed_embs = f'{selected_encoder}_encoding.csv'
83
- #REPO_ID = "emmas96/Lenselink"
84
- #embs_path = hf_hub_download(REPO_ID, precomputed_embs)
85
- #embs = pd.read_csv(embs_path)
86
- #embedding = embs[smiles]
87
- elif selected_encoder == 'MolBERT':
88
- from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer
89
- from huggingface_hub import hf_hub_download
90
- CDDD_MODEL_DIR = 'encoders/molbert/last.ckpt'
91
- REPO_ID = "emmas96/hyperpcm"
92
- checkpoint_path = hf_hub_download(REPO_ID, MOLBERT_MODEL_DIR)
93
- molbert_model = MolBertFeaturizer(checkpoint_path, max_seq_len=500, embedding_type='average-1-cat-pooled')
94
- drug_embedding = molbert_model.transform([smiles])
95
- elif selected_encoder == 'Dummy':
96
- drug_embedding = [0,1,2,3,4,5]
97
- else:
98
- drug_embedding = None
99
- st.image('figures/molecule_encoder.png')
100
- st.warning('Choose encoder above...')
101
-
102
- if drug_embedding is not None:
103
- st.image('figures/molecule_encoder_done.png')
104
- st.success('Encoding complete.')
105
-
106
- with col2:
107
- st.markdown('### Target')
108
-
109
- prot_col1, prot_col2 = st.columns(2)
110
-
111
- with prot_col1:
112
- sequence = st.text_input('Enter query amino-acid sequence', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
113
-
114
- if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA':
115
- st.image('figures/ex_protein.jpeg')
116
- elif sequence:
117
- st.error('Visualization comming soon...')
118
-
119
- with prot_col2:
120
- selected_encoder = st.selectbox(
121
- 'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5')
122
- )
123
-
124
- if sequence:
125
- if selected_encoder == 'SeqVec':
126
- with st.spinner('Encoding in progress...'):
127
- from bio_embeddings.embed import SeqVecEmbedder
128
- encoder = SeqVecEmbedder()
129
- embeddings = encoder.embed_batch([sequence])
130
- for emb in embeddings:
131
- prot_embedding = encoder.reduce_per_protein(emb)
132
- break
133
- elif selected_encoder == 'UniRep':
134
- with st.spinner('Encoding in progress...'):
135
- from jax_unirep.utils import load_params
136
- params = load_params()
137
- from jax_unirep.featurize import get_reps
138
- embedding, h_final, c_final = get_reps([sequence])
139
- prot_embedding = embedding.mean(axis=0)
140
- elif selected_encoder == 'ESM-1b':
141
- with st.spinner('Encoding in progress...'):
142
- from bio_embeddings.embed import ESM1bEmbedder
143
- encoder = ESM1bEmbedder()
144
- embeddings = encoder.embed_batch([sequence])
145
- for emb in embeddings:
146
- prot_embedding = encoder.reduce_per_protein(emb)
147
- break
148
- elif selected_encoder == 'ProtT5':
149
- with st.spinner('Encoding in progress...'):
150
- from bio_embeddings.embed import ProtTransT5XLU50Embedder
151
- encoder = ProtTransT5XLU50Embedder()
152
- embeddings = encoder.embed_batch([sequence])
153
- for emb in embeddings:
154
- prot_embedding = encoder.reduce_per_protein(emb)
155
- break
156
- else:
157
- prot_embedding = None
158
- st.image('figures/protein_encoder.png')
159
- st.warning('Choose encoder above...')
160
-
161
- if prot_embedding is not None:
162
- st.image('figures/protein_encoder_done.png')
163
- st.success('Encoding complete.')
164
-
165
- if drug_embedding is None or prot_embedding is None:
166
- st.warning('Waiting for both drug and target embeddings to be computed...')
167
- else:
168
- st.markdown('### Inference')
169
-
170
- import time
171
- progress_text = "HyperPCM predicts the interaction between the query drug compound toward the query protein target. Please wait."
172
- my_bar = st.progress(0, text=progress_text)
173
- for i in range(100):
174
- time.sleep(0.1)
175
- my_bar.progress(i + 1, text=progress_text)
176
- my_bar.progress(100, text="HyperPCM predicts the interaction between the query drug compound toward the query protein target. Done.")
177
-
178
- st.markdown('### Interaction')
179
- st.write('HyperPCM predicts an activity of xxx pChEMBL.')
180
- '''
181
-
182
 
183
  def retrieval():
184
  st.markdown('## Retrieve top-k most active drug compounds')
185
 
186
  st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')
187
 
188
- st.markdown('### Target')
189
-
190
- st.write(f'The top-{selected_k} most active drug coupounds from {selected_dataset} predicted by HyperPCM are: ')
191
- dummy_smiles = ['CC(=O)OC1=CC=CC=C1C(=O)O', 'COc1cc(C=O)ccc1O', 'CC(=O)Nc1ccc(O)cc1', 'CC(=O)Nc1ccc(OS(=O)(=O)O)cc1', 'CC(=O)Nc1ccc(O[C@@H]2O[C@H](C(=O)O)[C@@H](O)[C@H](O)[C@H]2O)cc1']
192
- cols = st.columns(5)
193
- for j, col in enumerate(cols):
194
- with col:
195
- for i in range(int(selected_k/5)):
196
- mol = Chem.MolFromSmiles(dummy_smiles[j])
197
- mol_img = Chem.Draw.MolToImage(mol)
198
- st.image(mol_img)
199
-
200
- '''
201
  col1, col2, col3, col4 = st.columns(4)
202
- with col2:
203
- sequence = st.text_input('Enter query amino-acid sequence', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
204
- if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA':
205
- st.image('figures/ex_protein.jpeg')
 
206
  elif sequence:
207
  st.error('Visualization coming soon...')
208
 
209
- with col3:
210
  selected_encoder = st.selectbox(
211
- 'Select encoder for protein target',('SeqVec', 'None')
212
  )
213
  if sequence:
214
  if selected_encoder == 'SeqVec':
215
  st.image('figures/protein_encoder_done.png')
216
  with st.spinner('Encoding in progress...'):
217
- from bio_embeddings.embed import SeqVecEmbedder
218
- encoder = SeqVecEmbedder()
219
- embeddings = encoder.embed_batch([sequence])
220
- for emb in embeddings:
221
- prot_embedding = encoder.reduce_per_protein(emb)
222
- break
223
  st.success('Encoding complete.')
224
  else:
225
- prot_embedding = None
226
  st.image('figures/protein_encoder.png')
227
  st.warning('Choose encoder above...')
228
-
229
- if prot_embedding is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  st.markdown('### Inference')
231
 
232
- import time
233
- progress_text = "HyperPCM predicts the QSAR model for the query protein target. Please wait."
234
  my_bar = st.progress(0, text=progress_text)
235
- for i in range(100):
236
- time.sleep(0.1)
237
- my_bar.progress(i + 1, text=progress_text)
238
- my_bar.progress(100, text="HyperPCM predicts the QSAR model for the query protein target. Done.")
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  st.markdown('### Retrieval')
241
 
242
- col1, col2 = st.columns(2)
243
- with col1:
244
- selected_dataset = st.selectbox(
245
- 'Select dataset from which the drug compounds should be retrieved',('Lenselink', 'Davis')
246
- )
247
- with col2:
248
- selected_k = st.selectbox(
249
- 'Select the top-k number of drug compounds to retrieve',(5, 10, 15, 20)
250
- )
251
 
252
- st.write(f'The top-{selected_k} most active drug coupounds from {selected_dataset} predicted by HyperPCM are: ')
253
- dummy_smiles = ['CC(=O)OC1=CC=CC=C1C(=O)O', 'COc1cc(C=O)ccc1O', 'CC(=O)Nc1ccc(O)cc1', 'CC(=O)Nc1ccc(OS(=O)(=O)O)cc1', 'CC(=O)Nc1ccc(O[C@@H]2O[C@H](C(=O)O)[C@@H](O)[C@H](O)[C@H]2O)cc1']
254
  cols = st.columns(5)
255
  for j, col in enumerate(cols):
256
  with col:
257
  for i in range(int(selected_k/5)):
258
- mol = Chem.MolFromSmiles(dummy_smiles[j])
259
  mol_img = Chem.Draw.MolToImage(mol)
260
- st.image(mol_img)
261
- '''
262
-
263
- '''
264
- def display_protein():
265
- st.markdown('## Display protein structure')
266
- st.write('In the future this page will display the ESM predicted sequence of a protein target.')
267
-
268
- st.markdown('### Target')
269
- sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
270
 
271
- if sequence:
272
-
273
- st.image('figures/ex_protein.jpeg')
274
-
275
- model = esm.pretrained.esmfold_v1()
276
- model = model.eval().cuda()
277
-
278
- with torch.no_grad():
279
- output = model.infer_pdb(sequence)
280
- st.write(output)
281
-
282
- with open("result.pdb", "w") as f:
283
- f.write(output)
284
-
285
-
286
- struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
287
- print(struct.b_factor.mean())
288
-
289
-
290
- """
291
- model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
292
- batch_converter = alphabet.get_batch_converter()
293
- batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequence),])
294
-
295
- # Extract per-residue representations (on CPU)
296
- with torch.no_grad():
297
- results = model(batch_tokens, repr_layers=[12], return_contacts=True)
298
- token_representations = results["representations"][12]
299
 
300
- token_list = token_representations.tolist()[0][0][0]
301
-
302
- client = Client(url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
303
-
304
- result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768")
305
-
306
- result_temp_seq = []
307
-
308
- for i in result:
309
- # result_temp_coords = i['seq']
310
- result_temp_seq.append(i['seq'])
311
-
312
- result_temp_seq = list(set(result_temp_seq))
313
 
314
- if st.button(result_temp_seq[0]):
315
- print(result_temp_seq[0])
316
- elif st.button(result_temp_seq[1]):
317
- print(result_temp_seq[1])
318
- elif st.button(result_temp_seq[2]):
319
- print(result_temp_seq[2])
320
- elif st.button(result_temp_seq[3]):
321
- print(result_temp_seq[3])
322
- elif st.button(result_temp_seq[4]):
323
- print(result_temp_seq[4])
324
-
325
- start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure)
326
-
327
- headers = {
328
- 'Content-Type': 'application/x-www-form-urlencoded',
329
- }
330
- response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence)
331
- name = sequence[:3] + sequence[-3:]
332
- pdb_string = response.content.decode('utf-8')
333
- with open('predicted.pdb', 'w') as f:
334
- f.write(pdb_string)
335
- struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"])
336
- b_value = round(struct.b_factor.mean(), 4)
337
- render_mol(pdb_string)
338
- if residues_marker:
339
- start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker]))
340
- else:
341
- start[3] = showmol(render_pdb(id = id_PDB))
342
- st.session_state['xq'] = st.session_state.model
343
-
344
- # example proteins ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"], ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"]
345
- """
346
-
347
- def display_context():
348
- st.markdown('## Display context')
349
- st.write('In the future this page will visualize the context module for a given protein, i.e., show important features and highly ranked / related proteins from the context.')
350
- '''
351
-
352
- def references():
353
- st.markdown(
354
- '''
355
- ## References
356
-
357
- Schmidhuber, J., “Learning to control fast-weight memories: An alternative to dynamic recurrent networks.” Neural Computation, 1992.
358
-
359
- Davis, M. I., et al. "Comprehensive analysis of kinase inhibitor selectivity." Nature Biotechnology 29.11 (2011): 1046-1051.
360
-
361
- Ha, D., et al. “HyperNetworks”. ICLR, 2017.
362
-
363
- Lenselink, E. B., et al. "Beyond the hype: deep neural networks outperform established methods using a ChEMBL bioactivity benchmark set." Journal of Cheminformatics 9.1 (2017): 1-14.
364
-
365
- Alley, E. C., et al. "Unified rational protein engineering with sequence-based deep representation learning." Nature Methods 16.12 (2019): 1315-1322.
366
-
367
- Chang, O., et al., “Principled weight initialization for hypernetworks.” ICLR, 2019.
368
-
369
- Heinzinger, M., et al. "Modeling aspects of the language of life through transfer-learning protein sequences." BMC Bioinformatics 20.1 (2019): 1-17.
370
-
371
- Winter, R., et al. "Learning continuous and data-driven molecular descriptors by translating equivalent chemical representations." Chemical Science 10.6 (2019): 1692-1701.
372
-
373
- Fabian, B., et al. "Molecular representation learning with language models and domain-relevant auxiliary tasks." Workshop for ML4Molecules (2020).
374
-
375
- Elnaggar, A., et al. "ProtTrans: Toward understanding the language of life through self-supervised learning." IEEE Transactions on Pattern Analysis and Machine Intelligence 44 (2021): 7112–7127.
376
-
377
- Rives, A., et al. "Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences." Proceedings of the National Academy of Sciences 118.15 (2021): e2016239118.
378
-
379
- Kim, P. T., et al. "Unsupervised Representation Learning for Proteochemometric Modeling." International Journal of Molecular Sciences 22.23 (2021): 12882.
380
-
381
- Schimunek, J., et al., “Context-enriched molecule representations improve few-shot drug discovery.” ICLR, 2023.
382
-
383
- '''
384
- )
385
-
386
  page_names_to_func = {
387
- 'About': about_page,
388
- #'Predict DTI': predict_dti,
389
- 'Retrieve Top-k': retrieval,
390
- #'Display Protein': display_protein,
391
- #'Display Context': display_context,
392
- #'References': references
393
  }
394
 
395
  selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
 
1
+
2
+ import gc
3
  import os
4
  import sys
5
+ import torch
6
+ import pickle
7
+ import numpy as np
8
  import pandas as pd
9
  import streamlit as st
10
+ from torch.utils.data import DataLoader
11
 
12
  from rdkit import Chem
13
  from rdkit.Chem import Draw
14
 
15
  sys.path.insert(0, os.path.abspath("src/"))
16
+ from src.dataset import DrugRetrieval, collate_target
17
+ from hyper_dti.models.hyper_pcm import HyperPCM
18
 
19
+ base_path = os.path.dirname(__file__)
20
+ data_path = os.path.join(base_path, 'data')
21
+ checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
+ st.set_page_config(layout="wide")
 
25
 
26
+ st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions.\n')
27
  st.markdown('')
28
  st.markdown(
29
  """
30
  🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
31
  """
32
  )
33
+ #st.error('WARNING! This app is currently under development and should not be used!')
34
+ st.divider()
35
 
36
  def about_page():
37
  st.markdown(
 
48
  In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
49
  predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
50
  a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
51
+ well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
52
+ model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
53
  """
54
  )
55
 
56
  st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
 
 
 
 
 
 
 
 
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def retrieval():
60
  st.markdown('## Retrieve top-k most active drug compounds')
61
 
62
  st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')
63
 
64
+ col1, col2 = st.columns(2)
65
+ with col1:
66
+ st.markdown('### Query Target')
67
+ with col2:
68
+ st.markdown('### Drug Database')
69
+
 
 
 
 
 
 
 
70
  col1, col2, col3, col4 = st.columns(4)
71
+ with col1:
72
+ ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
73
+ sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
74
+ if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA' or sequence == ex_target:
75
+ st.image('figures/ex_protein.jpeg', use_column_width='always')
76
  elif sequence:
77
  st.error('Visualization coming soon...')
78
 
79
+ with col2:
80
  selected_encoder = st.selectbox(
81
+ 'Select target encoder',('SeqVec', 'None')
82
  )
83
  if sequence:
84
  if selected_encoder == 'SeqVec':
85
  st.image('figures/protein_encoder_done.png')
86
  with st.spinner('Encoding in progress...'):
87
+ # TODO make SeqVec embedding on the spot
88
+
89
+ with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
90
+ test_set = pickle.load(handle)
91
+ # TODO handle case if sequence not in test set
92
+ query_embedding = test_set[sequence]
93
  st.success('Encoding complete.')
94
  else:
95
+ query_embedding = None
96
  st.image('figures/protein_encoder.png')
97
  st.warning('Choose encoder above...')
98
+
99
+ with col3:
100
+ selected_database = st.selectbox(
101
+ 'Select database',('Lenselink', 'None')
102
+ )
103
+ if selected_database == 'Lenselink':
104
+ c1, c2 = st.columns(2)
105
+ with c2:
106
+ st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
107
+ with st.spinner('Loading data...'):
108
+ batch_size = 64
109
+ dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
110
+ dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
111
+ st.success('Data loaded.')
112
+ else:
113
+ dataset = None
114
+ dataloader = None
115
+ st.warning('Choose database above...')
116
+
117
+ with col4:
118
+ selected_encoder = st.selectbox(
119
+ 'Select drug encoder',('CDDD', 'None')
120
+ )
121
+ if selected_database:
122
+ if selected_encoder == 'CDDD':
123
+ st.image('figures/molecule_encoder_done.png')
124
+ st.success('Encoding complete.')
125
+ else:
126
+ st.image('figures/molecule_encoder.png')
127
+ st.warning('Choose encoder above...')
128
+
129
+ if query_embedding is not None:
130
  st.markdown('### Inference')
131
 
132
+ progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
 
133
  my_bar = st.progress(0, text=progress_text)
 
 
 
 
134
 
135
+ gc.collect()
136
+ torch.cuda.empty_cache()
137
+ memory = dataset
138
+ model = HyperPCM(memory=memory).to(device)
139
+ model = torch.nn.DataParallel(model)
140
+ model.load_state_dict(torch.load(checkpoint_path))
141
+ model.eval()
142
+
143
+ with torch.set_grad_enabled(False):
144
+
145
+ smiles = []
146
+ preds = []
147
+ i = 0
148
+ for batch, labels in dataloader:
149
+ pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']
150
+
151
+ logits = model(batch)
152
+ logits = logits.detach().cpu().numpy()
153
+
154
+ smiles.append(mids)
155
+ preds.append(logits)
156
+ my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
157
+ i += 1
158
+ my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
159
+
160
+
161
  st.markdown('### Retrieval')
162
 
163
+ selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
164
+
165
+ results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
166
+ results = results.sort_values(by='Prediction', ascending=False)
167
+ results = results.reset_index()
168
+
169
+ print(results.head(10))
 
 
170
 
 
 
171
  cols = st.columns(5)
172
  for j, col in enumerate(cols):
173
  with col:
174
  for i in range(int(selected_k/5)):
175
+ mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
176
  mol_img = Chem.Draw.MolToImage(mol)
177
+ st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  page_names_to_func = {
182
+ 'Retrieval': retrieval,
183
+ 'About': about_page
 
 
 
 
184
  }
185
 
186
  selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
figures/multi_molecules.png ADDED
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- #setuptools
2
  rdkit #==2022.3.5
3
  #torch
4
  #jax_unirep
 
 
1
  rdkit #==2022.3.5
2
  #torch
3
  #jax_unirep