import gc import os import sys import torch import pickle import numpy as np import pandas as pd import streamlit as st from torch.utils.data import DataLoader from rdkit import Chem from rdkit.Chem import Draw sys.path.insert(0, os.path.abspath("src/")) from src.dataset import DrugRetrieval, collate_target from hyper_dti.models.hyper_pcm import HyperPCM base_path = os.path.dirname(__file__) data_path = os.path.join(base_path, 'data') checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") st.set_page_config(layout="wide") st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions.\n') st.markdown('') st.markdown( """ 🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n """ ) #st.error('WARNING! This app is currently under development and should not be used!') #st.divider() def about_page(): st.markdown( """ ### About HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved multi-task generalization in various domains, such as personalized federated learning and neural architecture search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which requires models that are able to generalize drug-target interaction predictions in low-data scenarios. In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the model as a retrieval task of the top-k most active drug compounds predicted for a given query target. """ ) st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.') def retrieval(): st.markdown('## Retrieve top-k most active drug compounds') st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.') col1, col2 = st.columns(2) with col1: st.markdown('### Query Target') with col2: st.markdown('### Drug Database') col1, col2, col3, col4 = st.columns(4) with col1: ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI' sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target) if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA' or sequence == ex_target: st.image('figures/ex_protein.jpeg', use_column_width='always') elif sequence: st.error('Visualization coming soon...') with col2: selected_encoder = st.selectbox( 'Select target encoder',('SeqVec', 'None') ) if sequence: if selected_encoder == 'SeqVec': st.image('figures/protein_encoder_done.png') with st.spinner('Encoding in progress...'): # TODO make SeqVec embedding on the spot with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle: test_set = pickle.load(handle) # TODO handle case if sequence not in test set query_embedding = test_set[sequence] st.success('Encoding complete.') else: query_embedding = None st.image('figures/protein_encoder.png') st.warning('Choose encoder above...') with col3: selected_database = st.selectbox( 'Select database',('Lenselink', 'None') ) if selected_database == 'Lenselink': c1, c2 = st.columns(2) with c2: st.image('figures/multi_molecules.png', use_column_width='always') #, width=125) with st.spinner('Loading data...'): batch_size = 64 dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding) dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target) st.success('Data loaded.') else: dataset = None dataloader = None st.warning('Choose database above...') with col4: selected_encoder = st.selectbox( 'Select drug encoder',('CDDD', 'None') ) if selected_database: if selected_encoder == 'CDDD': st.image('figures/molecule_encoder_done.png') st.success('Encoding complete.') else: st.image('figures/molecule_encoder.png') st.warning('Choose encoder above...') if query_embedding is not None: st.markdown('### Inference') progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait." my_bar = st.progress(0, text=progress_text) gc.collect() torch.cuda.empty_cache() memory = dataset model = HyperPCM(memory=memory).to(device) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(checkpoint_path)) model.eval() with torch.set_grad_enabled(False): smiles = [] preds = [] i = 0 for batch, labels in dataloader: pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs'] logits = model(batch) logits = logits.detach().cpu().numpy() smiles.append(mids) preds.append(logits) my_bar.progress((batch_size*i)/len(dataset), text=progress_text) i += 1 my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.") st.markdown('### Retrieval') selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5) results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)}) results = results.sort_values(by='Prediction', ascending=False) results = results.reset_index() print(results.head(10)) cols = st.columns(5) for j, col in enumerate(cols): with col: for i in range(int(selected_k/5)): mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES']) mol_img = Chem.Draw.MolToImage(mol) st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}") page_names_to_func = { 'Retrieval': retrieval, 'About': about_page } selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys()) st.sidebar.markdown('') page_names_to_func[selected_page]()