hyper-dti / app.py
emmas96's picture
change layout of about page
d2ef912
raw
history blame
9.17 kB
import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader
from rdkit import Chem
from rdkit.Chem import Draw
sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM
base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.set_page_config(layout="wide")
st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions\n')
st.markdown('')
st.markdown(
"""
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL); TBA in JCIM. \n
"""
)
def about_page():
st.markdown(
"""
### About
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved
multi-task generalization in various domains, such as personalized federated learning and neural architecture
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which
requires models that are able to generalize drug-target interaction predictions in low-data scenarios.
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the
model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
"""
)
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
def retrieval():
st.markdown('## Retrieval of most active drug compounds')
st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.')
col1, col2 = st.columns(2)
with col1:
st.markdown('### Query Target')
with col2:
st.markdown('### Drug Database')
col1, col2, col3, col4 = st.columns(4)
with col1:
ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
if sequence == ex_target:
st.image('figures/lenselink_ex_target.jpeg', use_column_width='always')
elif sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA':
st.image('figures/ex_protein.jpeg', use_column_width='always')
elif sequence:
st.error('Visualization coming soon...')
with col2:
selected_encoder = st.selectbox(
'Select target encoder',('SeqVec')
)
if sequence:
st.image('figures/protein_encoder_done.png')
with st.spinner('Encoding in progress...'):
with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
test_set = pickle.load(handle)
if sequence in list(test_set.keys()):
query_embedding = test_set[sequence]
else:
from bio_embeddings.embed import SeqVecEmbedder
encoder = SeqVecEmbedder()
embeddings = encoder.embed_batch([sequence])
for emb in embeddings:
query_embedding = encoder.reduce_per_protein(emb)
break
st.success('Encoding complete.')
with col3:
selected_database = st.selectbox(
'Select database',('Lenselink', 'Davis', 'DUD-E')
)
c1, c2 = st.columns(2)
with c2:
st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
with st.spinner('Loading data...'):
batch_size = 2048
dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
st.success('Data loaded.')
with col4:
selected_encoder = st.selectbox(
'Select drug encoder',('CDDD')
)
st.image('figures/molecule_encoder_done.png')
st.success('Encoding complete.')
if sequence == ex_target and selected_database == 'Lenselink':
st.markdown('### Inference')
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
my_bar = st.progress(0, text=progress_text)
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
st.markdown('### Retrieval')
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
results = pd.read_csv('data/Lenselink/processed/ex_results.csv')
cols = st.columns(5)
for j, col in enumerate(cols):
with col:
for i in range(int(selected_k/5)):
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
mol_img = Chem.Draw.MolToImage(mol)
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv')
elif query_embedding is not None:
st.markdown('### Inference')
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
my_bar = st.progress(0, text=progress_text)
gc.collect()
torch.cuda.empty_cache()
memory = dataset
model = HyperPCM(memory=memory).to(device)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
model.eval()
with torch.set_grad_enabled(False):
smiles = []
preds = []
i = 0
for batch, labels in dataloader:
pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']
logits = model(batch)
logits = logits.detach().cpu().numpy()
smiles.append(mids)
preds.append(logits)
my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
i += 1
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
st.markdown('### Retrieval')
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
results = results.sort_values(by='Prediction', ascending=False)
results = results.reset_index()
cols = st.columns(5)
for j, col in enumerate(cols):
with col:
for i in range(int(selected_k/5)):
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
mol_img = Chem.Draw.MolToImage(mol)
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv')
page_names_to_func = {
'Retrieval': retrieval,
'About': about_page
}
#selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
#st.sidebar.markdown('')
#page_names_to_func[selected_page]()
tab1, tab2 = st.tabs(page_names_to_func.keys())
with tab1:
page_names_to_func['Retrieval']()
with tab2:
page_names_to_func['About']()