emmas96 commited on
Commit
499faaf
·
1 Parent(s): 483853f

add warnings and progress bars

Browse files
Files changed (1) hide show
  1. app.py +41 -26
app.py CHANGED
@@ -75,7 +75,7 @@ def predict_dti():
75
  from cddd.inference import InferenceModel
76
  CDDD_MODEL_DIR = 'src/encoders/cddd'
77
  cddd_model = InferenceModel(CDDD_MODEL_DIR)
78
- embedding = cddd_model.seq_to_emb([smiles])
79
  #from huggingface_hub import hf_hub_download
80
  #precomputed_embs = f'{selected_encoder}_encoding.csv'
81
  #REPO_ID = "emmas96/Lenselink"
@@ -89,12 +89,12 @@ def predict_dti():
89
  REPO_ID = "emmas96/hyperpcm"
90
  checkpoint_path = hf_hub_download(REPO_ID, MOLBERT_MODEL_DIR)
91
  molbert_model = MolBertFeaturizer(checkpoint_path, max_seq_len=500, embedding_type='average-1-cat-pooled')
92
- embedding = molbert_model.transform([smiles])
93
  else:
94
  #st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
95
- embedding = None
96
  st.image('molecule_encoder.png')
97
- if embedding is not None:
98
  #st.write(f'{selected_encoder} embedding')
99
  #st.write(embedding)
100
  st.image('molecule_encoder_done.png')
@@ -108,7 +108,8 @@ def predict_dti():
108
  sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
109
 
110
  if sequence:
111
- st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n')
 
112
 
113
  with prot_col2:
114
  selected_encoder = st.selectbox(
@@ -117,41 +118,45 @@ def predict_dti():
117
  if sequence:
118
  if selected_encoder == 'SeqVec':
119
  from bio_embeddings.embed import SeqVecEmbedder
120
- encoder = SeqVecEmbedder()
121
- embeddings = encoder.embed_batch([sequence])
 
122
  for emb in embeddings:
123
- embedding = encoder.reduce_per_protein(emb)
124
  break
125
  elif selected_encoder == 'UniRep':
126
  from jax_unirep.utils import load_params
127
  params = load_params()
128
  from jax_unirep.featurize import get_reps
129
  embedding, h_final, c_final = get_reps([sequence])
130
- embedding = embedding.mean(axis=0)
131
  elif selected_encoder == 'ESM-1b':
132
  from bio_embeddings.embed import ESM1bEmbedder
133
  encoder = ESM1bEmbedder()
134
  embeddings = encoder.embed_batch([sequence])
135
  for emb in embeddings:
136
- embedding = encoder.reduce_per_protein(emb)
137
  break
138
  elif selected_encoder == 'ProtT5':
139
  from bio_embeddings.embed import ProtTransT5XLU50Embedder
140
  encoder = ProtTransT5XLU50Embedder()
141
  embeddings = encoder.embed_batch([sequence])
142
  for emb in embeddings:
143
- embedding = encoder.reduce_per_protein(emb)
144
  break
145
  else:
146
- #st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
147
- embedding = None
148
  st.image('protein_encoder.png')
149
- if embedding is not None:
150
  #st.write(f'{selected_encoder} embedding')
151
  #st.write(embedding)
152
  st.image('protein_encoder_done.png')
153
 
154
- st.write('TODO run inference with HyperPCM on the given drug compound and protein target.')
 
 
 
155
 
156
 
157
  def retrieval():
@@ -165,23 +170,33 @@ def retrieval():
165
  if sequence:
166
  col1, col2 = st.columns(2)
167
  with col1:
168
- st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n')
 
169
 
170
  with col2:
171
- st.write('Currently encoding the protein with SecVec...')
172
  st.image('protein_encoder_done.png')
173
 
174
- #from bio_embeddings.embed import SeqVecEmbedder
175
- #encoder = SeqVecEmbedder()
176
- #embeddings = encoder.embed_batch([sequence])
177
- #for emb in embeddings:
178
- # embedding = encoder.reduce_per_protein(emb)
179
- # break
180
- st.write('Encoding complete.')
 
 
 
 
 
 
 
 
 
 
181
 
182
  st.markdown('### Retrieval')
183
- st.write('TODO HyperPCM predicts the QSAR model for the given protein target.')
184
-
185
  col1, col2 = st.columns(2)
186
  with col1:
187
  selected_dataset = st.selectbox(
 
75
  from cddd.inference import InferenceModel
76
  CDDD_MODEL_DIR = 'src/encoders/cddd'
77
  cddd_model = InferenceModel(CDDD_MODEL_DIR)
78
+ drug_embedding = cddd_model.seq_to_emb([smiles])
79
  #from huggingface_hub import hf_hub_download
80
  #precomputed_embs = f'{selected_encoder}_encoding.csv'
81
  #REPO_ID = "emmas96/Lenselink"
 
89
  REPO_ID = "emmas96/hyperpcm"
90
  checkpoint_path = hf_hub_download(REPO_ID, MOLBERT_MODEL_DIR)
91
  molbert_model = MolBertFeaturizer(checkpoint_path, max_seq_len=500, embedding_type='average-1-cat-pooled')
92
+ drug_embedding = molbert_model.transform([smiles])
93
  else:
94
  #st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
95
+ drug_embedding = None
96
  st.image('molecule_encoder.png')
97
+ if drug_embedding is not None:
98
  #st.write(f'{selected_encoder} embedding')
99
  #st.write(embedding)
100
  st.image('molecule_encoder_done.png')
 
108
  sequence = st.text_input('Enter the amino-acid sequence of the query protein target', value='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA', placeholder='HXHVWPVQDAKARFSEFLDACITEGPQIVSRRGAEEAVLVPIGEWRRLQAAA')
109
 
110
  if sequence:
111
+ #st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n')
112
+ st.error('Visualization of protein to be added soon.')
113
 
114
  with prot_col2:
115
  selected_encoder = st.selectbox(
 
118
  if sequence:
119
  if selected_encoder == 'SeqVec':
120
  from bio_embeddings.embed import SeqVecEmbedder
121
+ encoder = SeqVecEmbedder()
122
+ with st.spinner('Currently encoding the query protein target with SeqVec...'):
123
+ embeddings = encoder.embed_batch([sequence])
124
  for emb in embeddings:
125
+ prot_embedding = encoder.reduce_per_protein(emb)
126
  break
127
  elif selected_encoder == 'UniRep':
128
  from jax_unirep.utils import load_params
129
  params = load_params()
130
  from jax_unirep.featurize import get_reps
131
  embedding, h_final, c_final = get_reps([sequence])
132
+ prot_embedding = embedding.mean(axis=0)
133
  elif selected_encoder == 'ESM-1b':
134
  from bio_embeddings.embed import ESM1bEmbedder
135
  encoder = ESM1bEmbedder()
136
  embeddings = encoder.embed_batch([sequence])
137
  for emb in embeddings:
138
+ prot_embedding = encoder.reduce_per_protein(emb)
139
  break
140
  elif selected_encoder == 'ProtT5':
141
  from bio_embeddings.embed import ProtTransT5XLU50Embedder
142
  encoder = ProtTransT5XLU50Embedder()
143
  embeddings = encoder.embed_batch([sequence])
144
  for emb in embeddings:
145
+ prot_embedding = encoder.reduce_per_protein(emb)
146
  break
147
  else:
148
+ st.warning('Chosen encoder above.')
149
+ prot_embedding = None
150
  st.image('protein_encoder.png')
151
+ if prot_embedding is not None:
152
  #st.write(f'{selected_encoder} embedding')
153
  #st.write(embedding)
154
  st.image('protein_encoder_done.png')
155
 
156
+ if not drug_embedding or not prot_embedding:
157
+ st.error('Witing for computed drug and target embeddings...')
158
+ else:
159
+ st.warning('In the future inference will be run with HyperPCM on the given drug compound and protein target...')
160
 
161
 
162
  def retrieval():
 
170
  if sequence:
171
  col1, col2 = st.columns(2)
172
  with col1:
173
+ #st.markdown('\n\n\n\n Plot of protein to be added soon. \n\n\n\n')
174
+ st.error('Visualization of protein to be added soon.')
175
 
176
  with col2:
177
+ #st.write('Currently encoding the protein with SecVec...')
178
  st.image('protein_encoder_done.png')
179
 
180
+ from bio_embeddings.embed import SeqVecEmbedder
181
+ encoder = SeqVecEmbedder()
182
+ with st.spinner('Currently encoding the query protein target with SeqVec...'):
183
+ embeddings = encoder.embed_batch([sequence])
184
+ for emb in embeddings:
185
+ embedding = encoder.reduce_per_protein(emb)
186
+ break
187
+ st.success('Encoding complete.')
188
+
189
+ st.markdown('### Inference')
190
+
191
+ import time
192
+ progress_text = "HyperPCM predicts the QSAR model for the query protein target. Please wait."
193
+ my_bar = st.progress(0, text=progress_text)
194
+ for i in range(100):
195
+ time.sleep(0.1)
196
+ my_bar.progress(i + 1, text=progress_text)
197
 
198
  st.markdown('### Retrieval')
199
+
 
200
  col1, col2 = st.columns(2)
201
  with col1:
202
  selected_dataset = st.selectbox(