hyper-dti / app.py
knfn081
develop working retrieval app for pre-defined target
6a624f6
raw
history blame
7.95 kB
import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader
from rdkit import Chem
from rdkit.Chem import Draw
sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM
base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.set_page_config(layout="wide")
st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions.\n')
st.markdown('')
st.markdown(
"""
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
"""
)
#st.error('WARNING! This app is currently under development and should not be used!')
st.divider()
def about_page():
st.markdown(
"""
### About
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved
multi-task generalization in various domains, such as personalized federated learning and neural architecture
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which
requires models that are able to generalize drug-target interaction predictions in low-data scenarios.
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
"""
)
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
def retrieval():
st.markdown('## Retrieve top-k most active drug compounds')
st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')
col1, col2 = st.columns(2)
with col1:
st.markdown('### Query Target')
with col2:
st.markdown('### Drug Database')
col1, col2, col3, col4 = st.columns(4)
with col1:
ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA' or sequence == ex_target:
st.image('figures/ex_protein.jpeg', use_column_width='always')
elif sequence:
st.error('Visualization coming soon...')
with col2:
selected_encoder = st.selectbox(
'Select target encoder',('SeqVec', 'None')
)
if sequence:
if selected_encoder == 'SeqVec':
st.image('figures/protein_encoder_done.png')
with st.spinner('Encoding in progress...'):
# TODO make SeqVec embedding on the spot
with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
test_set = pickle.load(handle)
# TODO handle case if sequence not in test set
query_embedding = test_set[sequence]
st.success('Encoding complete.')
else:
query_embedding = None
st.image('figures/protein_encoder.png')
st.warning('Choose encoder above...')
with col3:
selected_database = st.selectbox(
'Select database',('Lenselink', 'None')
)
if selected_database == 'Lenselink':
c1, c2 = st.columns(2)
with c2:
st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
with st.spinner('Loading data...'):
batch_size = 64
dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
st.success('Data loaded.')
else:
dataset = None
dataloader = None
st.warning('Choose database above...')
with col4:
selected_encoder = st.selectbox(
'Select drug encoder',('CDDD', 'None')
)
if selected_database:
if selected_encoder == 'CDDD':
st.image('figures/molecule_encoder_done.png')
st.success('Encoding complete.')
else:
st.image('figures/molecule_encoder.png')
st.warning('Choose encoder above...')
if query_embedding is not None:
st.markdown('### Inference')
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
my_bar = st.progress(0, text=progress_text)
gc.collect()
torch.cuda.empty_cache()
memory = dataset
model = HyperPCM(memory=memory).to(device)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(checkpoint_path))
model.eval()
with torch.set_grad_enabled(False):
smiles = []
preds = []
i = 0
for batch, labels in dataloader:
pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']
logits = model(batch)
logits = logits.detach().cpu().numpy()
smiles.append(mids)
preds.append(logits)
my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
i += 1
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
st.markdown('### Retrieval')
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
results = results.sort_values(by='Prediction', ascending=False)
results = results.reset_index()
print(results.head(10))
cols = st.columns(5)
for j, col in enumerate(cols):
with col:
for i in range(int(selected_k/5)):
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
mol_img = Chem.Draw.MolToImage(mol)
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
page_names_to_func = {
'Retrieval': retrieval,
'About': about_page
}
selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
st.sidebar.markdown('')
page_names_to_func[selected_page]()