Last commit not found
import os | |
import sys | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
sys.path.insert(0, os.path.abspath("src/")) | |
st.set_page_config(layout="wide") | |
basepath = os.path.dirname(__file__) | |
datapath = os.path.join(basepath, "data") | |
st.title('HyperDTI: Task-conditioned modeling of drug-target interactions.\n') | |
st.markdown('') | |
st.markdown( | |
""" | |
🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n | |
""" | |
) | |
def about_page(): | |
st.markdown( | |
""" | |
### About | |
HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for | |
neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved | |
multi-task generalization in various domains, such as personalized federated learning and neural architecture | |
search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased | |
information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which | |
requires models that are able to generalize drug-target interaction predictions in low-data scenarios. | |
In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of | |
predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on | |
a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple | |
well-known benchmarks, particularly in zero-shot settings for unseen protein targets. | |
""" | |
#st.image('hyper-dti.png') todo | |
) | |
def display_dti(): | |
st.markdown('##') | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown('### Drug') | |
smiles = st.text_input('Enter the SMILES of the query drug compound', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O') | |
if smiles: | |
mol = Chem.MolFromSmiles(smiles) | |
mol_img = Chem.Draw.MolToImage(mol) | |
st.image(mol_img, width = 140) | |
selected_encoder = st.selectbox( | |
'Select encoder for drug compound',('None', 'CDDD', 'MolBERT') | |
) | |
if selected_encoder == 'CDDD': | |
from cddd.inference import InferenceModel | |
CDDD_MODEL_DIR = 'checkpoints/CDDD/default_model' | |
cddd_model = InferenceModel(CDDD_MODEL_DIR) | |
embedding = cddd_model.seq_to_emb([smiles]) | |
st.write(f'CDDD embedding: {embedding}') | |
elif selected_encoder == 'MolBERT': | |
from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer | |
MOLBERT_MODEL_DIR = 'checkpoints/MolBert/molbert_100epochs/checkpoints/last.ckpt' | |
molbert_model = MolBertFeaturizer(MOLBERT_MODEL_DIR, max_seq_len=500, embedding_type='average-1-cat-pooled') | |
embedding = molbert_model.transform([smiles]) | |
else: | |
st.write('No pre-trained version of HyperPCM is available for the chosen encoder.') | |
with col2: | |
st.markdown('### Target') | |
sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA') | |
if sequence: | |
st.write('Plot of protein to be added soon.') | |
selected_encoder = st.selectbox( | |
'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5') | |
) | |
if selected_encoder == 'SeqVec': | |
from bio_embeddings.embed import SeqVecEmbedder | |
encoder = SeqVecEmbedder() | |
embedding = encoder([sequence]) | |
embedding = encoder.reduce_per_protein(embedding) | |
st.write(f'SeqVec embedding: {embedding}') | |
elif selected_encoder == 'UniRep': | |
#from jax_unirep.utils import load_params | |
#params = load_params() | |
from jax_unirep.featurize import get_reps | |
embedding, h_final, c_final = get_reps([sequence]) | |
embedding = embedding.mean(axis=0) | |
elif selected_encoder == 'ESM-1b': | |
from bio_embeddings.embed import ESM1bEmbedder | |
encoder = ESM1bEmbedder() | |
embedding = encoder([sequence]) | |
embedding = encoder.reduce_per_protein(embedding) | |
elif selected_encoder == 'ProtT5': | |
from bio_embeddings.embed import ProtTransT5XLU50Embedder | |
encoder = ProtTransT5XLU50Embedder() | |
embedding = encoder([sequence]) | |
embedding = encoder.reduce_per_protein(embedding) | |
else: | |
st.write('No pre-trained version of HyperPCM is available for the chosen encoder.') | |
def display_protein(): | |
""" | |
sequence = st.text_input("Enter the amino-acid sequence of the query protein target", value="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA", placeholder="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA") | |
if sequence: | |
def esm_search(model, sequnce, batch_converter,top_k=5): | |
batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequnce),]) | |
# Extract per-residue representations (on CPU) | |
with torch.no_grad(): | |
results = model(batch_tokens, repr_layers=[12], return_contacts=True) | |
token_representations = results["representations"][12] | |
token_list = token_representations.tolist()[0][0][0] | |
client = Client( | |
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) | |
result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") | |
result_temp_seq = [] | |
for i in result: | |
# result_temp_coords = i['seq'] | |
result_temp_seq.append(i['seq']) | |
result_temp_seq = list(set(result_temp_seq)) | |
result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) | |
st.text('search result: ') | |
# tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) | |
if st.button(result_temp_seq[0]): | |
print(result_temp_seq[0]) | |
elif st.button(result_temp_seq[1]): | |
print(result_temp_seq[1]) | |
elif st.button(result_temp_seq[2]): | |
print(result_temp_seq[2]) | |
elif st.button(result_temp_seq[3]): | |
print(result_temp_seq[3]) | |
elif st.button(result_temp_seq[4]): | |
print(result_temp_seq[4]) | |
start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) | |
def show_protein_structure(sequence): | |
headers = { | |
'Content-Type': 'application/x-www-form-urlencoded', | |
} | |
response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) | |
name = sequence[:3] + sequence[-3:] | |
pdb_string = response.content.decode('utf-8') | |
with open('predicted.pdb', 'w') as f: | |
f.write(pdb_string) | |
struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) | |
b_value = round(struct.b_factor.mean(), 4) | |
render_mol(pdb_string) | |
if residues_marker: | |
start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) | |
else: | |
start[3] = showmol(render_pdb(id = id_PDB)) | |
st.session_state['xq'] = st.session_state.model | |
# example proteins ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"], ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"] | |
""" | |
page_names_to_func = { | |
'About': about_page, | |
'Display DTI': display_dti | |
} | |
selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys()) | |
st.sidebar.markdown('') | |
page_names_to_func[selected_page]() | |