File size: 10,757 Bytes
6a624f6 89f9a8d 6a624f6 72f9629 89f9a8d 6a624f6 89f9a8d c0e52ad 89f9a8d 6a624f6 89f9a8d 6a624f6 89f9a8d 846a053 89f9a8d 35aa55d d5d696e fc729c7 af1becc fc729c7 372f84d bc3d031 afba6b8 501c3b1 afba6b8 bc3d031 6a624f6 bc3d031 40cf9b4 87ce5b7 40cf9b4 63ce71b 9e6028d 35aa55d 63ce71b 35aa55d 9e6028d 6a624f6 601b6c8 6a624f6 92ffdca 6a624f6 a8c3bc5 6a624f6 4f1ea03 b318bc6 601b6c8 6a624f6 4f1ea03 4c77fbb 4f1ea03 601b6c8 87ce5b7 d2ef912 6a624f6 d2ef912 6a624f6 eaad7fd d2ef912 6a624f6 4c77fbb 6a624f6 87ce5b7 d2ef912 6a624f6 d2ef912 c542f1d 7cbdd67 c542f1d d2ef912 c542f1d 6146fc6 d2ef912 c542f1d 4f1ea03 3c1ebe4 6a624f6 4f1ea03 6a624f6 956fc7c 6a624f6 4f1ea03 6a624f6 4f1ea03 6a624f6 4f1ea03 6a624f6 031d745 d2ef912 bc3d031 6a624f6 bc3d031 d2ef912 bc3d031 d2ef912 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader
from rdkit import Chem
from rdkit.Chem import Draw
sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM
base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.set_page_config(
page_title='HyperDTI',
layout='centered',
menu_items={
'About':
'''
# HyperDTI
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved
multi-task generalization in various domains, such as personalized federated learning and neural architecture
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which
requires models that are able to generalize drug-target interaction predictions in low-data scenarios.
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
'''
}
)
st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions\n')
st.markdown('')
st.markdown(
"""
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL); TBA in JCIM. \n
"""
)
def about_page():
st.markdown(
"""
### About
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved
multi-task generalization in various domains, such as personalized federated learning and neural architecture
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which
requires models that are able to generalize drug-target interaction predictions in low-data scenarios.
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
"""
)
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.', use_column_width='always')
def retrieval():
st.markdown('## Retrieval of most active drug compounds')
st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.')
col1, col2 = st.columns(2)
with col1:
st.markdown('### Query Target')
with col2:
st.markdown('### Drug Database')
col1, col2, col3, col4 = st.columns(4)
with col1:
ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
if sequence == ex_target:
st.image('figures/lenselink_ex_target.jpeg', use_column_width='always')
elif sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA':
st.image('figures/ex_protein.jpeg', use_column_width='always')
elif sequence:
st.error('Visualization coming soon...')
with col2:
selected_encoder = st.selectbox(
'Select target encoder',(['SeqVec'])
)
if sequence:
st.image('figures/protein_encoder_done.png', use_column_width='always')
with st.spinner('Encoding in progress...'):
with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
test_set = pickle.load(handle)
if sequence in list(test_set.keys()):
query_embedding = test_set[sequence]
else:
from bio_embeddings.embed import SeqVecEmbedder
encoder = SeqVecEmbedder()
embeddings = encoder.embed_batch([sequence])
for emb in embeddings:
query_embedding = encoder.reduce_per_protein(emb)
break
st.success('Encoding complete.')
with col3:
selected_database = st.selectbox(
'Select database',('Lenselink', 'Davis', 'DUD-E')
)
if selected_database == 'DUD-E':
selected_database = 'DUDE'
c1, c2 = st.columns(2)
with c2:
st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
with st.spinner('Loading data...'):
batch_size = 2048
dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
st.success('Data loaded.')
with col4:
selected_encoder = st.selectbox(
'Select drug encoder',(['CDDD'])
)
st.image('figures/molecule_encoder_done.png', use_column_width='always')
st.success('Encoding complete.')
if sequence == ex_target and selected_database == 'Lenselink':
st.markdown('### Inference')
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
my_bar = st.progress(0, text=progress_text)
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
st.markdown('### Retrieval')
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
results = pd.read_csv('data/Lenselink/processed/ex_results.csv')
cols = st.columns(5)
for j, col in enumerate(cols):
with col:
for i in range(int(selected_k/5)):
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
mol_img = Chem.Draw.MolToImage(mol)
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv')
elif query_embedding is not None:
st.markdown('### Inference')
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
my_bar = st.progress(0, text=progress_text)
gc.collect()
torch.cuda.empty_cache()
memory = dataset
model = HyperPCM(memory=memory).to(device)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
model.eval()
with torch.set_grad_enabled(False):
smiles = []
preds = []
i = 0
for batch, labels in dataloader:
pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']
logits = model(batch)
logits = logits.detach().cpu().numpy()
smiles.append(mids)
preds.append(logits)
my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
i += 1
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
st.markdown('### Retrieval')
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
results = results.sort_values(by='Prediction', ascending=False)
results = results.reset_index()
cols = st.columns(5)
for j, col in enumerate(cols):
with col:
for i in range(int(selected_k/5)):
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
mol_img = Chem.Draw.MolToImage(mol)
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv')
page_names_to_func = {
'Retrieval': retrieval,
'About': about_page
}
#selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
#st.sidebar.markdown('')
#page_names_to_func[selected_page]()
tab1, tab2 = st.tabs(page_names_to_func.keys())
with tab1:
page_names_to_func['Retrieval']()
with tab2:
page_names_to_func['About']()
|