File size: 8,348 Bytes
6a624f6
 
89f9a8d
 
6a624f6
 
 
72f9629
89f9a8d
6a624f6
89f9a8d
c0e52ad
 
 
89f9a8d
6a624f6
 
89f9a8d
6a624f6
 
 
 
89f9a8d
6a624f6
89f9a8d
35aa55d
d5d696e
fc729c7
 
35aa55d
fc729c7
 
372f84d
bc3d031
 
afba6b8
501c3b1
afba6b8
bc3d031
 
 
 
 
 
 
 
 
 
6a624f6
 
bc3d031
 
40cf9b4
2f247b2
40cf9b4
63ce71b
9e6028d
35aa55d
63ce71b
35aa55d
9e6028d
6a624f6
 
 
 
 
 
601b6c8
6a624f6
 
 
35aa55d
6a624f6
4f1ea03
b318bc6
601b6c8
6a624f6
4f1ea03
6a624f6
4f1ea03
601b6c8
4f1ea03
 
 
35aa55d
6a624f6
 
35aa55d
 
 
 
 
 
 
 
 
 
 
 
 
4f1ea03
 
6a624f6
4f1ea03
 
6a624f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1ea03
3c1ebe4
6a624f6
4f1ea03
 
6a624f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1ea03
 
6a624f6
 
 
 
 
 
 
4f1ea03
 
 
 
 
6a624f6
4f1ea03
6a624f6
2f637fc
fc729c7
7f14dcf
bc3d031
6a624f6
 
bc3d031
 
5b730e0
 
 
bc3d031
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

import gc
import os
import sys
import torch
import pickle
import numpy as np
import pandas as pd
import streamlit as st
from torch.utils.data import DataLoader

from rdkit import Chem
from rdkit.Chem import Draw

sys.path.insert(0, os.path.abspath("src/"))
from src.dataset import DrugRetrieval, collate_target
from hyper_dti.models.hyper_pcm import HyperPCM

base_path = os.path.dirname(__file__)
data_path = os.path.join(base_path, 'data')
checkpoint_path = os.path.join(base_path, 'checkpoints/lpo/cv2_test_fold6_1402/model_updated.t7')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

st.set_page_config(layout="wide")

st.title('HyperDTI: Robust Task-Conditioned Modeling of Drug-Target Interactions\n')
st.markdown('')
st.markdown(
    """
    🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti)    📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL) TBA Journal of Chemical Information and Modeling. \n
    """
)

def about_page():
    st.markdown(
        """      
        ### About
        
        HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for 
        neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved 
        multi-task generalization in various domains, such as personalized federated learning and neural architecture 
        search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased 
        information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which 
        requires models that are able to generalize drug-target interaction predictions in low-data scenarios. 
        
        In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of 
        predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on
        a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple 
        well-known benchmarks, particularly in zero-shot settings for unseen protein targets. This app demonstrates the 
        model as a retrieval task of the top-k most active drug compounds predicted for a given query target. 
        """
    )

    st.image('figures/hyper-dti.png', caption='Overview of HyperPCM architecture.')
    

def retrieval():
    st.markdown('## Retrieval of most active drug compounds')

    st.write('Use HyperPCM to generate a QSAR model for a selected query protein target and retrieve the top-k drug compounds predicted to have the highest activity toward the given protein target from the Lenselink datasets.')

    col1, col2 = st.columns(2)
    with col1:
        st.markdown('### Query Target')
    with col2: 
        st.markdown('### Drug Database')
    
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        ex_target = 'YTKMKTATNIYIFNLALADALATSTLPFQSVNYLMGTWPFGTILCKIVISIDYYNMFTSIFTLCTMSVDRYIAVCHPVKALDFRTPRNAKTVNVCNWI'
        sequence = st.text_input('Enter amino-acid sequence', value=ex_target, placeholder=ex_target)
        if sequence == 'HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA': # or sequence == ex_target:
            st.image('figures/ex_protein.jpeg', use_column_width='always')
        elif sequence:
            st.error('Visualization coming soon...')
    
    with col2:
        selected_encoder = st.selectbox(
                'Select target encoder',('SeqVec', 'None')
            )
        if sequence:
            if selected_encoder == 'SeqVec':
                st.image('figures/protein_encoder_done.png')
                with st.spinner('Encoding in progress...'):

                    with open(os.path.join(data_path, f'Lenselink/processed/SeqVec_encoding_test.pickle'), 'rb') as handle:
                        test_set = pickle.load(handle)

                    print(sequence in test_set.keys())
                    print(sequence in test_set.keys().values)
                    if sequence in test_set.keys():
                        query_embedding = test_set[sequence]
                    else: 
                        from bio_embeddings.embed import SeqVecEmbedder
                        encoder = SeqVecEmbedder()
                        embeddings = encoder.embed_batch([sequence])
                        for emb in embeddings:
                            query_embedding = encoder.reduce_per_protein(emb)
                            break
                    
                st.success('Encoding complete.')
            else: 
                query_embedding = None
                st.image('figures/protein_encoder.png')
                st.warning('Choose encoder above...')
    
    with col3:
        selected_database = st.selectbox(
                'Select database',('Lenselink', 'None')
            )
        if selected_database == 'Lenselink':
            c1, c2 = st.columns(2)
            with c2:
                st.image('figures/multi_molecules.png', use_column_width='always') #, width=125)
            with st.spinner('Loading data...'):
                batch_size = 64
                dataset = DrugRetrieval(os.path.join(data_path, selected_database), sequence, query_embedding)
                dataloader = DataLoader(dataset, num_workers=2, batch_size=batch_size, shuffle=False, collate_fn=collate_target)
            st.success('Data loaded.')
        else: 
            dataset = None
            dataloader = None
            st.warning('Choose database above...')
    
    with col4:
        selected_encoder = st.selectbox(
                'Select drug encoder',('CDDD', 'None')
            )
        if selected_database:
            if selected_encoder == 'CDDD':
                st.image('figures/molecule_encoder_done.png')
                st.success('Encoding complete.')
            else: 
                st.image('figures/molecule_encoder.png')
                st.warning('Choose encoder above...')
                
    if query_embedding is not None:
        st.markdown('### Inference')
        
        progress_text = "HyperPCM is predicting the QSAR model for the query protein target. Please wait."
        my_bar = st.progress(0, text=progress_text)
        
        gc.collect()
        torch.cuda.empty_cache()
        memory = dataset
        model = HyperPCM(memory=memory).to(device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(checkpoint_path))
        model.eval()

        with torch.set_grad_enabled(False):

            smiles = []
            preds = []
            i = 0
            for batch, labels in dataloader:
                pids, proteins, mids, molecules = batch['pids'], batch['targets'], batch['mids'], batch['drugs']

                logits = model(batch)
                logits = logits.detach().cpu().numpy()

                smiles.append(mids)
                preds.append(logits)
                my_bar.progress((batch_size*i)/len(dataset), text=progress_text)
                i += 1
        my_bar.progress(100, text="HyperPCM is predicting the QSAR model for the query protein target. Done.")
                
    
        st.markdown('### Retrieval')
        
        selected_k = st.slider(f'Top-k most active drug compounds {selected_database} predicted by HyperPCM are, for k = ', 5, 20, 5, 5)
        
        results = pd.DataFrame({'SMILES': np.concatenate(smiles), 'Prediction': np.concatenate(preds)})
        results = results.sort_values(by='Prediction', ascending=False)
        results = results.reset_index()
        
        print(results.head(10))
        
        cols = st.columns(5)
        for j, col in enumerate(cols):
            with col:
                for i in range(int(selected_k/5)):
                    mol = Chem.MolFromSmiles(results.loc[j + 5*i, 'SMILES'])
                    mol_img = Chem.Draw.MolToImage(mol)
                    st.image(mol_img, caption=f"{results.loc[j + 5*i, 'Prediction']:.2f}")
    
    
    
page_names_to_func = {
    'Retrieval': retrieval,
    'About': about_page
}

selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys())
st.sidebar.markdown('')
page_names_to_func[selected_page]()