import os import sys import numpy as np import pandas as pd import streamlit as st from rdkit import Chem from rdkit.Chem import Draw sys.path.insert(0, os.path.abspath("src/")) st.set_page_config(layout="wide") basepath = os.path.dirname(__file__) datapath = os.path.join(basepath, "data") st.title('HyperDTI: Task-conditioned modeling of drug-target interactions.\n') st.markdown('') st.markdown( """ 🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n """ ) def about_page(): st.markdown( """ ### About HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved multi-task generalization in various domains, such as personalized federated learning and neural architecture search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which requires models that are able to generalize drug-target interaction predictions in low-data scenarios. In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple well-known benchmarks, particularly in zero-shot settings for unseen protein targets. """ #st.image('hyper-dti.png') todo ) def display_dti(): st.markdown('##') col1, col2 = st.columns(2) with col1: st.markdown('### Drug') smiles = st.text_input('Enter the SMILES of the query drug compound', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O') if smiles: mol = Chem.MolFromSmiles(smiles) mol_img = Chem.Draw.MolToImage(mol) st.image(mol_img) #, width = 140) selected_encoder = st.selectbox( 'Select encoder for drug compound',('None', 'CDDD', 'MolBERT') ) if selected_encoder == 'CDDD': from cddd.inference import InferenceModel CDDD_MODEL_DIR = 'checkpoints/CDDD/default_model' cddd_model = InferenceModel(CDDD_MODEL_DIR) embedding = cddd_model.seq_to_emb([smiles]) elif selected_encoder == 'MolBERT': from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer MOLBERT_MODEL_DIR = 'checkpoints/MolBert/molbert_100epochs/checkpoints/last.ckpt' molbert_model = MolBertFeaturizer(MOLBERT_MODEL_DIR, max_seq_len=500, embedding_type='average-1-cat-pooled') embedding = molbert_model.transform([smiles]) else: st.write('No pre-trained version of HyperPCM is available for the chosen encoder.') embedding = None if embedding is not None: st.write(f'{selected_encoder} embedding') st.write(embedding) with col2: st.markdown('### Target') sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA') if sequence: st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n') selected_encoder = st.selectbox( 'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5') ) if selected_encoder == 'SeqVec': from bio_embeddings.embed import SeqVecEmbedder encoder = SeqVecEmbedder() embedding = encoder([sequence]) embedding = encoder.reduce_per_protein(embedding) elif selected_encoder == 'UniRep': from jax_unirep.utils import load_params params = load_params() from jax_unirep.featurize import get_reps embedding, h_final, c_final = get_reps([sequence]) embedding = embedding.mean(axis=0) elif selected_encoder == 'ESM-1b': from bio_embeddings.embed import ESM1bEmbedder encoder = ESM1bEmbedder() embedding = encoder([sequence]) embedding = encoder.reduce_per_protein(embedding) elif selected_encoder == 'ProtT5': from bio_embeddings.embed import ProtTransT5XLU50Embedder encoder = ProtTransT5XLU50Embedder() embedding = encoder([sequence]) embedding = encoder.reduce_per_protein(embedding) else: st.write('No pre-trained version of HyperPCM is available for the chosen encoder.') embedding = None if embedding is not None: st.write(f'{selected_encoder} embedding') st.write(embedding) def display_protein(): """ sequence = st.text_input("Enter the amino-acid sequence of the query protein target", value="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA", placeholder="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA") if sequence: def esm_search(model, sequnce, batch_converter,top_k=5): batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequnce),]) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens, repr_layers=[12], return_contacts=True) token_representations = results["representations"][12] token_list = token_representations.tolist()[0][0][0] client = Client( url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") result_temp_seq = [] for i in result: # result_temp_coords = i['seq'] result_temp_seq.append(i['seq']) result_temp_seq = list(set(result_temp_seq)) result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) st.text('search result: ') # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) if st.button(result_temp_seq[0]): print(result_temp_seq[0]) elif st.button(result_temp_seq[1]): print(result_temp_seq[1]) elif st.button(result_temp_seq[2]): print(result_temp_seq[2]) elif st.button(result_temp_seq[3]): print(result_temp_seq[3]) elif st.button(result_temp_seq[4]): print(result_temp_seq[4]) start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) def show_protein_structure(sequence): headers = { 'Content-Type': 'application/x-www-form-urlencoded', } response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) name = sequence[:3] + sequence[-3:] pdb_string = response.content.decode('utf-8') with open('predicted.pdb', 'w') as f: f.write(pdb_string) struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) b_value = round(struct.b_factor.mean(), 4) render_mol(pdb_string) if residues_marker: start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) else: start[3] = showmol(render_pdb(id = id_PDB)) st.session_state['xq'] = st.session_state.model # example proteins ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"], ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"] """ page_names_to_func = { 'About': about_page, 'Display DTI': display_dti } selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys()) st.sidebar.markdown('') page_names_to_func[selected_page]()