import os import sys import numpy as np import pandas as pd import streamlit as st from rdkit import Chem from rdkit.Chem import Draw sys.path.insert(0, os.path.abspath("src/")) st.set_page_config(layout="wide") basepath = os.path.dirname(__file__) datapath = os.path.join(basepath, "data") st.title('HyperDTI: Task-conditioned modeling of drug-target interactions.\n') st.markdown('') st.markdown( """ 🧬 Github: [ml-jku/hyper-dti](https://https://github.com/ml-jku/hyper-dti) 📝 NeurIPS 2022 AI4Science workshop paper: [OpenReview](https://openreview.net/forum?id=dIX34JWnIAL)\n """ ) def about_page(): st.markdown( """ ### About HyperNetworks have been established as an effective technique to achieve fast adaptation of parameters for neural networks. Recently, HyperNetwork predictions conditioned on descriptors of tasks have improved multi-task generalization in various domains, such as personalized federated learning and neural architecture search. Especially powerful results were achieved in few- and zero-shot settings, attributed to the increased information sharing by the HyperNetwork. With the rise of new diseases fast discovery of drugs is needed which requires models that are able to generalize drug-target interaction predictions in low-data scenarios. In this work, we propose the HyperPCM model, a task-conditioned HyperNetwork approach for the problem of predicting drug-target interactions in drug discovery. Our model learns to generate a QSAR model specialized on a given protein target. We demonstrate state-of-the-art performance over previous methods on multiple well-known benchmarks, particularly in zero-shot settings for unseen protein targets. """ #st.image('hyper-dti.png') todo ) def display_dti(): st.markdown('##') smiles = st.text_input("Enter the SMILES of the query drug compound", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O") if smiles: mol = Chem.MolFromSmiles(smiles) mol_img = Chem.Draw.MolToImage(mol) col1, col2, col3 = st.columns(3) with col1: st.write("") with col2: st.image(mol_img, width = 140) with col3: st.write("") st.markdown('##') def display_protein(): """ sequence = st.text_input("Enter the amino-acid sequence of the query protein target", value="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA", placeholder="HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA") if sequence: def esm_search(model, sequnce, batch_converter,top_k=5): batch_labels, batch_strs, batch_tokens = batch_converter([("protein1", sequnce),]) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens, repr_layers=[12], return_contacts=True) token_representations = results["representations"][12] token_list = token_representations.tolist()[0][0][0] client = Client( url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) result = client.fetch("SELECT seq, distance('topK=500')(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768") result_temp_seq = [] for i in result: # result_temp_coords = i['seq'] result_temp_seq.append(i['seq']) result_temp_seq = list(set(result_temp_seq)) result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) st.text('search result: ') # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) if st.button(result_temp_seq[0]): print(result_temp_seq[0]) elif st.button(result_temp_seq[1]): print(result_temp_seq[1]) elif st.button(result_temp_seq[2]): print(result_temp_seq[2]) elif st.button(result_temp_seq[3]): print(result_temp_seq[3]) elif st.button(result_temp_seq[4]): print(result_temp_seq[4]) start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) def show_protein_structure(sequence): headers = { 'Content-Type': 'application/x-www-form-urlencoded', } response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) name = sequence[:3] + sequence[-3:] pdb_string = response.content.decode('utf-8') with open('predicted.pdb', 'w') as f: f.write(pdb_string) struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) b_value = round(struct.b_factor.mean(), 4) render_mol(pdb_string) if residues_marker: start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) else: start[3] = showmol(render_pdb(id = id_PDB)) st.session_state['xq'] = st.session_state.model # example proteins ["HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA"], ["AHKLFIGGLPNYLNDDQVKELLTSFGPLKAFNLVKDSATGLSKGYAFCEYVDINVTDQAIAGLNGMQLGDKKLLVQRASVGAKNA"] """ page_names_to_func = { 'About': about_page, 'Display DTI': display_dti } selected_page = st.sidebar.selectbox('Choose function', page_names_to_func.keys()) st.sidebar.markdown('') page_names_to_func[selected_page]()