import gc import os import sys import torch import pickle import numpy as np import pandas as pd import streamlit as st from torch.utils.data import DataLoader from rdkit import Chem from rdkit.Chem import Draw sys.path.insert(0, os.path.abspath("src/")) from src.dataset import DrugRetrieval, collate_target from hyper_dti.models.hyper_pcm import HyperPCM base_path = os.path.dirname(__file__) data_path = os.path.join(base_path, 'data') checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.set_page_config( page_title='HyperPCM', layout='centered', menu_items={ 'About': ''' # HyperPCM HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved multi-task generalization in various domains, such as personalized federated learning and neural architecture search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which requires models that can generalize drug-target interaction predictions in low-data scenarios. In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the model as a retrieval task of the top-k most active drug compounds predicted for a given query target. ''' } ) st.title('HyperPCM: Robust Task-Conditioned Modeling of Drug-Target Interactions\n') st.markdown('') st.markdown( """ 🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 Paper: [JCIM 2024](https://pubs.acs.org/doi/10.1021/acs.jcim.3c01417); [NeurIPS 2022 AI4Science workshop](https://openreview.net/forum?id=dIX34JWnIAL) \n """ ) def about_page(): st.markdown( """ ### About HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved multi-task generalization in various domains, such as personalized federated learning and neural architecture search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which requires models that can generalize drug-target interaction predictions in low-data scenarios. In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the model as a retrieval task of the top-k most active drug compounds predicted for a given query target. """ ) st.image('figures/hyper-dti.png', caption='Overview of the HyperPCM architecture.', use_column_width='always') st.markdown( """ ### Citation Please cite our work using the following reference. ```bibtex @article{svensson2024hyperpcm, title={{HyperPCM: Robust Task-Conditioned Modeling of Drug--Target Interactions}}, author={Svensson, Emma and Hoedt, Pieter-Jan and Hochreiter, Sepp and Klambauer, G{\"u}nter}, journal={Journal of Chemical Information and Modeling}, volume = {64}, number = {7}, pages = {2539-2553}, year = {2024}, doi = {10.1021/acs.jcim.3c01417}, publisher={ACS Publications} } ``` """ ) def retrieval(): st.markdown('## Retrieval of most active drug compounds') st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.') col1, col2 = st.columns(2) with col1: st.markdown('### Query Target') with col2: st.markdown('### Drug Database') col1, col2, col3, col4 = st.columns(4) with col1: ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI' sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target) if sequence == ex_target: st.image('figures/lenselink_ex_target.jpeg', use_column_width='always') elif sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA': st.image('figures/ex_protein.jpeg', use_column_width='always') elif sequence: st.error('Visualization coming soon...') with col2: selected_encoder = st.selectbox( 'Select target encoder',(['SeqVec']) ) if sequence: st.image('figures/target_encoder_done.png', use_column_width='always') with st.spinner('Encoding in progress...'): with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle: test_set = pickle.load(handle) if sequence in list(test_set.keys()): query_embedding = test_set[sequence] else: from bio_embeddings.embed import SeqVecEmbedder encoder = SeqVecEmbedder() embeddings = encoder.embed_batch([sequence]) for emb in embeddings: query_embedding = encoder.reduce_per_protein(emb) break st.success('Encoding complete.') with col3: selected_database = st.selectbox( 'Select database',('Lenselink', 'Davis', 'DUD-E', 'DrugBank') ) l = { 'Lenselink': 314707, 'Davis': 30056, 'DUDE': 1434019, 'DrugBank': 10681, } if selected_database == 'DUD-E': selected_database = 'DUDE' st.image('figures/multi_drugs.png', use_column_width='always') #, width=125) with st.spinner(f'Loading {l[selected_database]} drugs...'): batch_size = 2048 dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding) dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target) st.success(f'{l[selected_database]} drugs loaded.') with col4: selected_encoder = st.selectbox( 'Select drug encoder',(['CDDD']) ) st.image('figures/drug_encoder_done.png', use_column_width='always') st.success('Encoding complete.') if sequence == ex_target and selected_database == 'Lenselink': st.markdown('### Inference') progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." my_bar = st.progress(0, text=progress_text) my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") st.markdown('### Retrieval') selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) results = pd.read_csv('data/Lenselink/processed/ex_results.csv') cols = st.columns(5) for j, col in enumerate(cols): with col: for i in range(int(selected_k/5)): mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) mol_img = Chem.Draw.MolToImage(mol) st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}") st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv') elif query_embedding is not None: st.markdown('### Inference') progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." my_bar = st.progress(0, text=progress_text) gc.collect() torch.cuda.empty_cache() memory = dataset model = HyperPCM(memory=memory).to(device) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) model.eval() with torch.set_grad_enabled(False): smiles = [] preds = [] i = 0 for batch, labels in dataloader: pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs'] logits = model(batch) logits = logits.detach().cpu().numpy() smiles.append(mids) preds.append(logits) my_bar.progress((batch_size*i)/len(dataset), text=progress_text) i += 1 my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") st.markdown('### Retrieval') selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) if selected_database != 'DrugBank': results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)}) else: with open(os.path.join(data_path, f'{selected_database}/processed/drugbank.pickle'), 'rb') as handle: lookup = pickle.load(handle) drug_id = np.concatenate(smiles) structure = [lookup[i] for i in drug_id] results = pd.DataFrame({'SMILES': structure, 'DrugBank ID': drug_id, 'Prediction': np.concatenate(preds)}) results = results.sort_values(by='Prediction', ascending=False) results = results.reset_index() cols = st.columns(5) for j, col in enumerate(cols): with col: for i in range(int(selected_k/5)): mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) mol_img = Chem.Draw.MolToImage(mol) if selected_database != 'DrugBank': caption = f"{results.loc[j + 5*i, 'Prediction']:.2f}" else: caption = f"{results.loc[j + 5*i, 'DrugBank ID']}:\n{results.loc[j + 5*i, 'Prediction']:.2f}" st.image(mol_img, caption=caption) st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv') page_names_to_func = { 'Retrieval': retrieval, 'About': about_page } #selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys()) #st.sidebar.markdown('') #page_names_to_func[selected_page]() tab1, tab2 = st.tabs(page_names_to_func.keys()) with tab1: page_names_to_func['Retrieval']() with tab2: page_names_to_func['About']()