File size: 10,757 Bytes
6a624f6
 
89f9a8d
 
6a624f6
 
 
72f9629
89f9a8d
6a624f6
89f9a8d
c0e52ad
 
 
89f9a8d
6a624f6
 
89f9a8d
6a624f6
 
 
 
89f9a8d
846a053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89f9a8d
35aa55d
d5d696e
fc729c7
 
af1becc
fc729c7
 
372f84d
bc3d031
 
afba6b8
501c3b1
afba6b8
bc3d031
 
 
 
 
 
 
 
 
 
6a624f6
 
bc3d031
 
40cf9b4
87ce5b7
40cf9b4
63ce71b
9e6028d
35aa55d
63ce71b
35aa55d
9e6028d
6a624f6
 
 
 
 
 
601b6c8
6a624f6
92ffdca
6a624f6
a8c3bc5
 
 
6a624f6
4f1ea03
b318bc6
601b6c8
6a624f6
4f1ea03
4c77fbb
4f1ea03
601b6c8
87ce5b7
d2ef912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a624f6
 
 
d2ef912
6a624f6
eaad7fd
 
d2ef912
 
 
 
 
 
 
 
6a624f6
 
 
4c77fbb
6a624f6
87ce5b7
d2ef912
6a624f6
d2ef912
c542f1d
7cbdd67
 
 
c542f1d
 
 
 
 
d2ef912
c542f1d
 
 
 
 
 
 
 
 
6146fc6
d2ef912
c542f1d
 
4f1ea03
3c1ebe4
6a624f6
4f1ea03
 
6a624f6
 
 
 
 
956fc7c
6a624f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1ea03
 
6a624f6
 
 
 
 
 
4f1ea03
 
 
 
6a624f6
4f1ea03
6a624f6
031d745
d2ef912
 
 
 
bc3d031
6a624f6
 
bc3d031
 
d2ef912
 
 
 
 
 
 
 
bc3d031
d2ef912
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader

from rdkit import Chem
from rdkit.Chem import Draw

sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM

base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

st.set_page_config(
    page_title='HyperDTI',
    layout='centered',
    menu_items={
        'About': 
        '''
        # HyperDTI
        
        HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for 
        neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved 
        multi-task generalization in various domains, such as personalized federated learning and neural architecture 
        search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased 
        information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which 
        requires models that are able to generalize drug-target interaction predictions in low-data scenarios. 
        
        In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of 
        predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
        a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple 
        well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the 
        model as a retrieval task of the top-k most active drug compounds predicted for a given query target.
        '''
    }
    
)

st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions\n')
st.markdown('')
st.markdown(
    """
    🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti)    📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL); TBA in JCIM. \n
    """
)

def about_page():
    st.markdown(
        """      
        ### About
        
        HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for 
        neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved 
        multi-task generalization in various domains, such as personalized federated learning and neural architecture 
        search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased 
        information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which 
        requires models that are able to generalize drug-target interaction predictions in low-data scenarios. 
        
        In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of 
        predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
        a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple 
        well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the 
        model as a retrieval task of the top-k most active drug compounds predicted for a given query target. 
        """
    )

    st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.', use_column_width='always')
    

def retrieval():
    st.markdown('## Retrieval of most active drug compounds')

    st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.')

    col1, col2 = st.columns(2)
    with col1:
        st.markdown('### Query Target')
    with col2: 
        st.markdown('### Drug Database')
    
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
        sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
        if sequence == ex_target:
            st.image('figures/lenselink_ex_target.jpeg', use_column_width='always')
        elif sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA': 
            st.image('figures/ex_protein.jpeg', use_column_width='always')
        elif sequence:
            st.error('Visualization coming soon...')
    
    with col2:
        selected_encoder = st.selectbox(
                'Select target encoder',(['SeqVec'])
            )
        if sequence:
            st.image('figures/protein_encoder_done.png', use_column_width='always')
            with st.spinner('Encoding in progress...'):

                with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
                    test_set = pickle.load(handle)

                if sequence in list(test_set.keys()):
                    query_embedding = test_set[sequence]
                else: 
                    from bio_embeddings.embed import SeqVecEmbedder
                    encoder = SeqVecEmbedder()
                    embeddings = encoder.embed_batch([sequence])
                    for emb in embeddings:
                        query_embedding = encoder.reduce_per_protein(emb)
                        break
                
            st.success('Encoding complete.')
    
    with col3:
        selected_database = st.selectbox(
                'Select database',('Lenselink', 'Davis', 'DUD-E')
            )
        if selected_database == 'DUD-E':
            selected_database = 'DUDE'
        c1, c2 = st.columns(2)
        with c2:
            st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
        with st.spinner('Loading data...'):
            batch_size = 2048
            dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
            dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
        st.success('Data loaded.')
    
    with col4:
        selected_encoder = st.selectbox(
                'Select drug encoder',(['CDDD'])
            )
        st.image('figures/molecule_encoder_done.png', use_column_width='always')
        st.success('Encoding complete.')
                
    if sequence == ex_target and selected_database == 'Lenselink':
        st.markdown('### Inference')

        progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
        my_bar = st.progress(0, text=progress_text)
        my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
        
        st.markdown('### Retrieval')
        
        selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)

        results = pd.read_csv('data/Lenselink/processed/ex_results.csv')
        
        cols = st.columns(5)
        for j, col in enumerate(cols):
            with col:
                for i in range(int(selected_k/5)):
                    mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
                    mol_img = Chem.Draw.MolToImage(mol)
                    st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")

        st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv')
                    
    elif query_embedding is not None:
        st.markdown('### Inference')
        
        progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
        my_bar = st.progress(0, text=progress_text)
        
        gc.collect()
        torch.cuda.empty_cache()
        memory = dataset
        model = HyperPCM(memory=memory).to(device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
        model.eval()

        with torch.set_grad_enabled(False):

            smiles = []
            preds = []
            i = 0
            for batch, labels in dataloader:
                pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']

                logits = model(batch)
                logits = logits.detach().cpu().numpy()

                smiles.append(mids)
                preds.append(logits)
                my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
                i += 1
        my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
                
    
        st.markdown('### Retrieval')
        
        selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
        
        results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
        results = results.sort_values(by='Prediction', ascending=False)
        results = results.reset_index()
        
        cols = st.columns(5)
        for j, col in enumerate(cols):
            with col:
                for i in range(int(selected_k/5)):
                    mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
                    mol_img = Chem.Draw.MolToImage(mol)
                    st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")

        st.download_button(f'Download retrieved drug compounds from the {selected_database} database.', results.head(selected_k).to_csv(index=False).encode('utf-8'), file_name='retrieved_drugs.csv')



page_names_to_func = {
    'Retrieval': retrieval,
    'About': about_page
}

#selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
#st.sidebar.markdown('')
#page_names_to_func[selected_page]()

tab1, tab2 = st.tabs(page_names_to_func.keys())

with tab1:
    page_names_to_func['Retrieval']()

with tab2: 
    page_names_to_func['About']()