|
|
|
import gc |
|
import os |
|
import sys |
|
import torch |
|
import pickle |
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
from torch.utils.data import DataLoader |
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
|
|
sys.path.insert(0, os.path.abspath("src/")) |
|
from src.dataset import DrugRetrieval, collate_target |
|
from hyper_dti.models.hyper_pcm import HyperPCM |
|
|
|
base_path = os.path.dirname(__file__) |
|
data_path = os.path.join(base_path, 'data') |
|
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7') |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions\n') |
|
st.markdown('') |
|
st.markdown( |
|
""" |
|
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL); TBA in JCIM. \n |
|
""" |
|
) |
|
|
|
def about_page(): |
|
st.markdown( |
|
""" |
|
### About |
|
|
|
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for |
|
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved |
|
multi-task generalization in various domains, such as personalized federated learning and neural architecture |
|
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased |
|
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which |
|
requires models that are able to generalize drug-target interaction predictions in low-data scenarios. |
|
|
|
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of |
|
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on |
|
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple |
|
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the |
|
model as a retrieval task of the top-k most active drug compounds predicted for a given query target. |
|
""" |
|
) |
|
|
|
st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.') |
|
|
|
|
|
def retrieval(): |
|
st.markdown('## Retrieval of most active drug compounds') |
|
|
|
st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.') |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.markdown('### Query Target') |
|
with col2: |
|
st.markdown('### Drug Database') |
|
|
|
col1, col2, col3, col4 = st.columns(4) |
|
with col1: |
|
ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI' |
|
sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target) |
|
if sequence == ex_target: |
|
st.image('figures/lenselink_ex_target.jpeg', use_column_width='always') |
|
elif sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA': |
|
st.image('figures/ex_protein.jpeg', use_column_width='always') |
|
elif sequence: |
|
st.error('Visualization coming soon...') |
|
|
|
with col2: |
|
selected_encoder = st.selectbox( |
|
'Select target encoder',('SeqVec') |
|
) |
|
if sequence: |
|
st.image('figures/protein_encoder_done.png') |
|
with st.spinner('Encoding in progress...'): |
|
|
|
with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle: |
|
test_set = pickle.load(handle) |
|
|
|
if sequence in list(test_set.keys()): |
|
query_embedding = test_set[sequence] |
|
else: |
|
from bio_embeddings.embed import SeqVecEmbedder |
|
encoder = SeqVecEmbedder() |
|
embeddings = encoder.embed_batch([sequence]) |
|
for emb in embeddings: |
|
query_embedding = encoder.reduce_per_protein(emb) |
|
break |
|
|
|
st.success('Encoding complete.') |
|
|
|
with col3: |
|
selected_database = st.selectbox( |
|
'Select database',('Lenselink', 'Davis', 'DUD-E') |
|
) |
|
c1, c2 = st.columns(2) |
|
with c2: |
|
st.image('figures/multi_molecules.png', use_column_width='always') |
|
with st.spinner('Loading data...'): |
|
batch_size = 2048 |
|
dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding) |
|
dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target) |
|
st.success('Data loaded.') |
|
|
|
with col4: |
|
selected_encoder = st.selectbox( |
|
'Select drug encoder',('CDDD') |
|
) |
|
st.image('figures/molecule_encoder_done.png') |
|
st.success('Encoding complete.') |
|
|
|
if sequence == ex_target and selected_database == 'Lenselink': |
|
st.markdown('### Inference') |
|
|
|
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." |
|
my_bar = st.progress(0, text=progress_text) |
|
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") |
|
|
|
st.markdown('### Retrieval') |
|
|
|
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) |
|
|
|
results = pd.read_csv('data/Lenselink/processed/ex_results.csv') |
|
|
|
cols = st.columns(5) |
|
for j, col in enumerate(cols): |
|
with col: |
|
for i in range(int(selected_k/5)): |
|
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}") |
|
|
|
st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv') |
|
|
|
elif query_embedding is not None: |
|
st.markdown('### Inference') |
|
|
|
progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." |
|
my_bar = st.progress(0, text=progress_text) |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
memory = dataset |
|
model = HyperPCM(memory=memory).to(device) |
|
model = torch.nn.DataParallel(model) |
|
model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |
|
model.eval() |
|
|
|
with torch.set_grad_enabled(False): |
|
|
|
smiles = [] |
|
preds = [] |
|
i = 0 |
|
for batch, labels in dataloader: |
|
pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs'] |
|
|
|
logits = model(batch) |
|
logits = logits.detach().cpu().numpy() |
|
|
|
smiles.append(mids) |
|
preds.append(logits) |
|
my_bar.progress((batch_size*i)/len(dataset), text=progress_text) |
|
i += 1 |
|
my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") |
|
|
|
|
|
st.markdown('### Retrieval') |
|
|
|
selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) |
|
|
|
results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)}) |
|
results = results.sort_values(by='Prediction', ascending=False) |
|
results = results.reset_index() |
|
|
|
cols = st.columns(5) |
|
for j, col in enumerate(cols): |
|
with col: |
|
for i in range(int(selected_k/5)): |
|
mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}") |
|
|
|
st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv') |
|
|
|
|
|
|
|
page_names_to_func = { |
|
'Retrieval': retrieval, |
|
'About': about_page |
|
} |
|
|
|
|
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(page_names_to_func.keys()) |
|
|
|
with tab1: |
|
page_names_to_func['Retrieval']() |
|
|
|
with tab2: |
|
page_names_to_func['About']() |
|
|