File size: 7,947 Bytes
6a624f6 89f9a8d 6a624f6 72f9629 89f9a8d 6a624f6 89f9a8d c0e52ad 89f9a8d 6a624f6 89f9a8d 6a624f6 89f9a8d 6a624f6 89f9a8d 6a624f6 d5d696e fc729c7 d5d696e fc729c7 6a624f6 372f84d bc3d031 afba6b8 501c3b1 afba6b8 bc3d031 6a624f6 bc3d031 40cf9b4 2f247b2 40cf9b4 63ce71b 9e6028d 601b6c8 63ce71b 9e6028d 6a624f6 601b6c8 6a624f6 4f1ea03 b318bc6 601b6c8 6a624f6 4f1ea03 6a624f6 4f1ea03 601b6c8 4f1ea03 6a624f6 4f1ea03 6a624f6 4f1ea03 6a624f6 4f1ea03 3c1ebe4 6a624f6 4f1ea03 6a624f6 4f1ea03 6a624f6 4f1ea03 6a624f6 4f1ea03 6a624f6 2f637fc fc729c7 7f14dcf bc3d031 6a624f6 bc3d031 5b730e0 bc3d031 |
|
import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader
from rdkit import Chem
from rdkit.Chem import Draw
sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM
base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.set_page_config(layout="wide")
st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions.\n')
st.markdown('')
st.markdown(
"""
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
"""
)
#st.error('WARNING! This app is currently under development and should not be used!')
st.divider()
def about_page():
st.markdown(
"""
### About
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved
multi-task generalization in various domains, such as personalized federated learning and neural architecture
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which
requires models that are able to generalize drug-target interaction predictions in low-data scenarios.
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
"""
)
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
def retrieval():
st.markdown('## Retrieve top-k most active drug compounds')
st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')
col1, col2 = st.columns(2)
with col1:
st.markdown('### Query Target')
with col2:
st.markdown('### Drug Database')
col1, col2, col3, col4 = st.columns(4)
with col1:
ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA' or sequence == ex_target:
st.image('figures/ex_protein.jpeg', use_column_width='always')
elif sequence:
st.error('Visualization coming soon...')
with col2:
selected_encoder = st.selectbox(
'Select target encoder',('SeqVec', 'None')
)
if sequence:
if selected_encoder == 'SeqVec':
st.image('figures/protein_encoder_done.png')
with st.spinner('Encoding in progress...'):
# TODO make SeqVec embedding on the spot
with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
test_set = pickle.load(handle)
# TODO handle case if sequence not in test set
query_embedding = test_set[sequence]
st.success('Encoding complete.')
else:
query_embedding = None
st.image('figures/protein_encoder.png')
st.warning('Choose encoder above...')
with col3:
selected_database = st.selectbox(
'Select database',('Lenselink', 'None')
)
if selected_database == 'Lenselink':
c1, c2 = st.columns(2)
with c2:
st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
with st.spinner('Loading data...'):
batch_size = 64
dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
st.success('Data loaded.')
else:
dataset = None
dataloader = None
st.warning('Choose database above...')
with col4:
selected_encoder = st.selectbox(
'Select drug encoder',('CDDD', 'None')
)
if selected_database:
if selected_encoder == 'CDDD':
st.image('figures/molecule_encoder_done.png')
st.success('Encoding complete.')
else:
st.image('figures/molecule_encoder.png')
st.warning('Choose encoder above...')
if query_embedding is not None:
st.markdown('### Inference')
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
my_bar = st.progress(0, text=progress_text)
gc.collect()
torch.cuda.empty_cache()
memory = dataset
model = HyperPCM(memory=memory).to(device)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(checkpoint_path))
model.eval()
with torch.set_grad_enabled(False):
smiles = []
preds = []
i = 0
for batch, labels in dataloader:
pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']
logits = model(batch)
logits = logits.detach().cpu().numpy()
smiles.append(mids)
preds.append(logits)
my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
i += 1
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
st.markdown('### Retrieval')
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
results = results.sort_values(by='Prediction', ascending=False)
results = results.reset_index()
print(results.head(10))
cols = st.columns(5)
for j, col in enumerate(cols):
with col:
for i in range(int(selected_k/5)):
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
mol_img = Chem.Draw.MolToImage(mol)
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
page_names_to_func = {
'Retrieval': retrieval,
'About': about_page
}
selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
st.sidebar.markdown('')
page_names_to_func[selected_page]()
|