File size: 7,947 Bytes
6a624f6
 
89f9a8d
 
6a624f6
 
 
72f9629
89f9a8d
6a624f6
89f9a8d
c0e52ad
 
 
89f9a8d
6a624f6
 
89f9a8d
6a624f6
 
 
 
89f9a8d
6a624f6
89f9a8d
6a624f6
d5d696e
fc729c7
 
d5d696e
fc729c7
 
6a624f6
 
372f84d
bc3d031
 
afba6b8
501c3b1
afba6b8
bc3d031
 
 
 
 
 
 
 
 
 
6a624f6
 
bc3d031
 
40cf9b4
2f247b2
40cf9b4
63ce71b
9e6028d
601b6c8
63ce71b
 
9e6028d
6a624f6
 
 
 
 
 
601b6c8
6a624f6
 
 
 
 
4f1ea03
b318bc6
601b6c8
6a624f6
4f1ea03
6a624f6
4f1ea03
601b6c8
4f1ea03
 
 
6a624f6
 
 
 
 
 
4f1ea03
 
6a624f6
4f1ea03
 
6a624f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1ea03
3c1ebe4
6a624f6
4f1ea03
 
6a624f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1ea03
 
6a624f6
 
 
 
 
 
 
4f1ea03
 
 
 
 
6a624f6
4f1ea03
6a624f6
2f637fc
fc729c7
7f14dcf
bc3d031
6a624f6
 
bc3d031
 
5b730e0
 
 
bc3d031
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader

from rdkit import Chem
from rdkit.Chem import Draw

sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM

base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

st.set_page_config(layout="wide")

st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions.\n')
st.markdown('')
st.markdown(
    """
    🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti)    📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n
    """
)
#st.error('WARNING! This app is currently under development and should not be used!')
st.divider()

def about_page():
    st.markdown(
        """      
        ### About
        
        HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for 
        neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved 
        multi-task generalization in various domains, such as personalized federated learning and neural architecture 
        search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased 
        information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which 
        requires models that are able to generalize drug-target interaction predictions in low-data scenarios. 
        
        In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of 
        predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
        a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple 
        well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the 
        model as a retrieval task of the top-k most active drug compounds predicted for a given query target. 
        """
    )

    st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
    

def retrieval():
    st.markdown('## Retrieve top-k most active drug compounds')

    st.write('In the furute this page will retrieve the top-k drug compounds that are predicted to have the highest activity toward the given protein target from either the Lenselink or Davis datasets.')

    col1, col2 = st.columns(2)
    with col1:
        st.markdown('### Query Target')
    with col2: 
        st.markdown('### Drug Database')
    
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
        sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
        if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA' or sequence == ex_target:
            st.image('figures/ex_protein.jpeg', use_column_width='always')
        elif sequence:
            st.error('Visualization coming soon...')
    
    with col2:
        selected_encoder = st.selectbox(
                'Select target encoder',('SeqVec', 'None')
            )
        if sequence:
            if selected_encoder == 'SeqVec':
                st.image('figures/protein_encoder_done.png')
                with st.spinner('Encoding in progress...'):
                    # TODO make SeqVec embedding on the spot
                    
                    with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
                        test_set = pickle.load(handle)
                    # TODO handle case if sequence not in test set
                    query_embedding = test_set[sequence]
                st.success('Encoding complete.')
            else: 
                query_embedding = None
                st.image('figures/protein_encoder.png')
                st.warning('Choose encoder above...')
    
    with col3:
        selected_database = st.selectbox(
                'Select database',('Lenselink', 'None')
            )
        if selected_database == 'Lenselink':
            c1, c2 = st.columns(2)
            with c2:
                st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
            with st.spinner('Loading data...'):
                batch_size = 64
                dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
                dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
            st.success('Data loaded.')
        else: 
            dataset = None
            dataloader = None
            st.warning('Choose database above...')
    
    with col4:
        selected_encoder = st.selectbox(
                'Select drug encoder',('CDDD', 'None')
            )
        if selected_database:
            if selected_encoder == 'CDDD':
                st.image('figures/molecule_encoder_done.png')
                st.success('Encoding complete.')
            else: 
                st.image('figures/molecule_encoder.png')
                st.warning('Choose encoder above...')
                
    if query_embedding is not None:
        st.markdown('### Inference')
        
        progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
        my_bar = st.progress(0, text=progress_text)
        
        gc.collect()
        torch.cuda.empty_cache()
        memory = dataset
        model = HyperPCM(memory=memory).to(device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(checkpoint_path))
        model.eval()

        with torch.set_grad_enabled(False):

            smiles = []
            preds = []
            i = 0
            for batch, labels in dataloader:
                pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']

                logits = model(batch)
                logits = logits.detach().cpu().numpy()

                smiles.append(mids)
                preds.append(logits)
                my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
                i += 1
        my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
                
    
        st.markdown('### Retrieval')
        
        selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
        
        results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
        results = results.sort_values(by='Prediction', ascending=False)
        results = results.reset_index()
        
        print(results.head(10))
        
        cols = st.columns(5)
        for j, col in enumerate(cols):
            with col:
                for i in range(int(selected_k/5)):
                    mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
                    mol_img = Chem.Draw.MolToImage(mol)
                    st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
    
    
    
page_names_to_func = {
    'Retrieval': retrieval,
    'About': about_page
}

selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
st.sidebar.markdown('')
page_names_to_func[selected_page]()