|
import os |
|
import sys |
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
|
|
sys.path.insert(0, os.path.abspath("src/")) |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
basepath = os.path.dirname(__file__) |
|
datapath = os.path.join(basepath, "data") |
|
|
|
st.title('HyperDTI: Task-conditioned modeling of drug-target interactions.\n') |
|
st.markdown('') |
|
st.markdown( |
|
""" |
|
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n |
|
""" |
|
) |
|
|
|
|
|
def about_page(): |
|
st.markdown( |
|
""" |
|
### About |
|
|
|
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for |
|
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved |
|
multi-task generalization in various domains, such as personalized federated learning and neural architecture |
|
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased |
|
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which |
|
requires models that are able to generalize drug-target interaction predictions in low-data scenarios. |
|
|
|
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of |
|
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on |
|
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple |
|
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. |
|
""" |
|
|
|
) |
|
|
|
|
|
def display_dti(): |
|
st.markdown('##') |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.markdown('### Drug') |
|
smiles = st.text_input('Enter the SMILES of the query drug compound', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O') |
|
|
|
if smiles: |
|
mol = Chem.MolFromSmiles(smiles) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img) |
|
|
|
selected_encoder = st.selectbox( |
|
'Select encoder for drug compound',('None', 'CDDD', 'MolBERT') |
|
) |
|
if selected_encoder == 'CDDD': |
|
from cddd.inference import InferenceModel |
|
CDDD_MODEL_DIR = 'src/encoders/cddd' |
|
cddd_model = InferenceModel(CDDD_MODEL_DIR) |
|
embedding = cddd_model.seq_to_emb([smiles]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
elif selected_encoder == 'MolBERT': |
|
from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer |
|
from huggingface_hub import hf_hub_download |
|
CDDD_MODEL_DIR = 'encoders/molbert/last.ckpt' |
|
REPO_ID = "emmas96/hyperpcm" |
|
checkpoint_path = hf_hub_download(REPO_ID, MOLBERT_MODEL_DIR) |
|
molbert_model = MolBertFeaturizer(checkpoint_path, max_seq_len=500, embedding_type='average-1-cat-pooled') |
|
embedding = molbert_model.transform([smiles]) |
|
else: |
|
st.write('No pre-trained version of HyperPCM is available for the chosen encoder.') |
|
embedding = None |
|
if embedding is not None: |
|
st.write(f'{selected_encoder} embedding') |
|
st.write(embedding) |
|
|
|
with col2: |
|
st.markdown('### Target') |
|
sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA') |
|
|
|
if sequence: |
|
st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n') |
|
|
|
selected_encoder = st.selectbox( |
|
'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5') |
|
) |
|
if selected_encoder == 'SeqVec': |
|
from bio_embeddings.embed import SeqVecEmbedder |
|
encoder = SeqVecEmbedder() |
|
embeddings = encoder.embed_batch([sequence]) |
|
for emb in embeddings: |
|
embedding = encoder.reduce_per_protein(emb) |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
elif selected_encoder == 'UniRep': |
|
from jax_unirep.utils import load_params |
|
params = load_params() |
|
from jax_unirep.featurize import get_reps |
|
embedding, h_final, c_final = get_reps([sequence]) |
|
embedding = embedding.mean(axis=0) |
|
elif selected_encoder == 'ESM-1b': |
|
from bio_embeddings.embed import ESM1bEmbedder |
|
encoder = ESM1bEmbedder() |
|
embeddings = encoder.embed_batch([sequence]) |
|
for emb in embeddings: |
|
embedding = encoder.reduce_per_protein(emb) |
|
break |
|
elif selected_encoder == 'ProtT5': |
|
from bio_embeddings.embed import ProtTransT5XLU50Embedder |
|
encoder = ProtTransT5XLU50Embedder() |
|
embeddings = encoder.embed_batch([sequence]) |
|
for emb in embeddings: |
|
embedding = encoder.reduce_per_protein(emb) |
|
break |
|
else: |
|
st.write('No pre-trained version of HyperPCM is available for the chosen encoder.') |
|
embedding = None |
|
if embedding is not None: |
|
st.write(f'{selected_encoder} embedding') |
|
st.write(embedding) |
|
|
|
def retrieval(): |
|
st.markdown('##') |
|
|
|
st.markdown('### Target') |
|
sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA') |
|
|
|
if sequence: |
|
st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n') |
|
|
|
selected_encoder = st.selectbox( |
|
'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5') |
|
) |
|
|
|
st.markdown('### Retrieval of top-5 drug coupound from ChEMBL:') |
|
col1, col2, col3, col4, col5 = st.columns(5) |
|
with col1: |
|
smiles = 'CC(=O)OC1=CC=CC=C1C(=O)O' |
|
mol = Chem.MolFromSmiles(smiles) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img) |
|
|
|
with col2: |
|
smiles = 'COc1cc(C=O)ccc1O' |
|
mol = Chem.MolFromSmiles(smiles) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img) |
|
|
|
with col3: |
|
smiles = 'CC(=O)Nc1ccc(O)cc1' |
|
mol = Chem.MolFromSmiles(smiles) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img) |
|
|
|
with col4: |
|
smiles = 'CC(=O)Nc1ccc(OS(=O)(=O)O)cc1' |
|
mol = Chem.MolFromSmiles(smiles) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img) |
|
|
|
with col5: |
|
smiles = 'CC(=O)Nc1ccc(O[C@@H]2O[C@H](C(=O)O)[C@@H](O)[C@H](O)[C@H]2O)cc1' |
|
mol = Chem.MolFromSmiles(smiles) |
|
mol_img = Chem.Draw.MolToImage(mol) |
|
st.image(mol_img) |
|
|
|
|
|
def display_protein(): |
|
""" |
|
sequence = st.text_input("Enter the amino-acid sequence of the query protein target", value="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA", placeholder="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA") |
|
|
|
if sequence: |
|
def esm_search(model, sequnce, batch_converter,top_k=5): |
|
batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequnce),]) |
|
|
|
# Extract per-residue representations (on CPU) |
|
with torch.no_grad(): |
|
results = model(batch_tokens, repr_layers=[12], return_contacts=True) |
|
token_representations = results["representations"][12] |
|
|
|
token_list = token_representations.tolist()[0][0][0] |
|
|
|
client = Client( |
|
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) |
|
|
|
result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") |
|
|
|
result_temp_seq = [] |
|
|
|
for i in result: |
|
# result_temp_coords = i['seq'] |
|
result_temp_seq.append(i['seq']) |
|
|
|
result_temp_seq = list(set(result_temp_seq)) |
|
|
|
result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) |
|
st.text('search result: ') |
|
# tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) |
|
if st.button(result_temp_seq[0]): |
|
print(result_temp_seq[0]) |
|
elif st.button(result_temp_seq[1]): |
|
print(result_temp_seq[1]) |
|
elif st.button(result_temp_seq[2]): |
|
print(result_temp_seq[2]) |
|
elif st.button(result_temp_seq[3]): |
|
print(result_temp_seq[3]) |
|
elif st.button(result_temp_seq[4]): |
|
print(result_temp_seq[4]) |
|
|
|
start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) |
|
def show_protein_structure(sequence): |
|
headers = { |
|
'Content-Type': 'application/x-www-form-urlencoded', |
|
} |
|
response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) |
|
name = sequence[:3] + sequence[-3:] |
|
pdb_string = response.content.decode('utf-8') |
|
with open('predicted.pdb', 'w') as f: |
|
f.write(pdb_string) |
|
struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) |
|
b_value = round(struct.b_factor.mean(), 4) |
|
render_mol(pdb_string) |
|
if residues_marker: |
|
start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) |
|
else: |
|
start[3] = showmol(render_pdb(id = id_PDB)) |
|
st.session_state['xq'] = st.session_state.model |
|
|
|
# example proteins ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"], ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"] |
|
""" |
|
|
|
page_names_to_func = { |
|
'About': about_page, |
|
|
|
|
|
} |
|
|
|
selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys()) |
|
st.sidebar.markdown('') |
|
page_names_to_func[selected_page]() |
|
|
|
|