emmas96 commited on
Commit
ef4a36f
·
1 Parent(s): 2abf58c

add dummy encoders for drug and target

Browse files
Files changed (1) hide show
  1. app.py +59 -20
app.py CHANGED
@@ -46,29 +46,68 @@ def about_page():
46
 
47
  def display_dti():
48
  st.markdown('##')
49
- smiles = st.text_input('Enter the SMILES of the query drug compound', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O')
50
 
51
- if smiles:
52
- mol = Chem.MolFromSmiles(smiles)
53
- mol_img = Chem.Draw.MolToImage(mol)
54
- col1, col2, col3 = st.columns(3)
55
- with col1:
56
- st.write('')
57
- with col2:
58
  st.image(mol_img, width = 140)
59
- with col3:
60
- st.write('')
61
- st.markdown('##')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- selected_encoder = st.selectbox(
64
- 'Select encoder for drug compound',('None', 'CDDD')
65
- )
66
- if selected_encoder == 'CDDD':
67
- from cddd.inference import InferenceModel
68
- CDDD_MODEL_DIR = 'checkpoints/CDDD/default_model'
69
- cddd_model = InferenceModel(CDDD_MODEL_DIR)
70
- embedding = cddd_model.seq_to_emb([smiles])
71
- st.write(f'CDDD embedding: {embedding}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  def display_protein():
 
46
 
47
  def display_dti():
48
  st.markdown('##')
49
+ col1, col2 = st.columns(2)
50
 
51
+ with col1:
52
+ st.markdown('### Drug')
53
+ smiles = st.text_input('Enter the SMILES of the query drug compound', value='CC(=O)OC1=CC=CC=C1C(=O)O', placeholder='CC(=O)OC1=CC=CC=C1C(=O)O')
54
+
55
+ if smiles:
56
+ mol = Chem.MolFromSmiles(smiles)
57
+ mol_img = Chem.Draw.MolToImage(mol)
58
  st.image(mol_img, width = 140)
59
+
60
+ selected_encoder = st.selectbox(
61
+ 'Select encoder for drug compound',('None', 'CDDD', 'MolBERT')
62
+ )
63
+ if selected_encoder == 'CDDD':
64
+ from cddd.inference import InferenceModel
65
+ CDDD_MODEL_DIR = 'checkpoints/CDDD/default_model'
66
+ cddd_model = InferenceModel(CDDD_MODEL_DIR)
67
+ embedding = cddd_model.seq_to_emb([smiles])
68
+ st.write(f'CDDD embedding: {embedding}')
69
+ elif selected_encoder == 'MolBERT':
70
+ from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer
71
+ MOLBERT_MODEL_DIR = 'checkpoints/MolBert/molbert_100epochs/checkpoints/last.ckpt'
72
+ molbert_model = MolBertFeaturizer(MOLBERT_MODEL_DIR, max_seq_len=500, embedding_type='average-1-cat-pooled')
73
+ embedding = molbert_model.transform([smiles])
74
+ else:
75
+ st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
76
 
77
+ with col2:
78
+ st.markdown('### Target')
79
+ sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
80
+
81
+ if sequence:
82
+ st.write('Plot of protein to be added soon.')
83
+
84
+ selected_encoder = st.selectbox(
85
+ 'Select encoder for protein target',('None', 'SeqVec', 'UniRep', 'ESM-1b', 'ProtT5')
86
+ )
87
+ if selected_encoder == 'SeqVec':
88
+ from bio_embeddings.embed import SeqVecEmbedder
89
+ encoder SeqVecEmbedder()
90
+ embedding = encoder([sequence])
91
+ embedding = encoder.reduce_per_protein(embedding)
92
+ st.write(f'SeqVec embedding: {embedding}')
93
+ elif selected_encoder == 'UniRep':
94
+ #from jax_unirep.utils import load_params
95
+ #params = load_params()
96
+ from jax_unirep.featurize import get_reps
97
+ embedding, h_final, c_final = get_reps([sequence])
98
+ embedding = embedding.mean(axis=0)
99
+ elif selected_encoder == 'ESM-1b':
100
+ from bio_embeddings.embed import ESM1bEmbedder
101
+ encoder = ESM1bEmbedder()
102
+ embedding = encoder([sequence])
103
+ embedding = encoder.reduce_per_protein(embedding)
104
+ elif selected_encoder == 'ProtT5':
105
+ from bio_embeddings.embed import ProtTransT5XLU50Embedder
106
+ encoder = ProtTransT5XLU50Embedder()
107
+ embedding = encoder([sequence])
108
+ embedding = encoder.reduce_per_protein(embedding)
109
+ else:
110
+ st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
111
 
112
 
113
  def display_protein():