emmas96 commited on
Commit
1a8ab89
·
1 Parent(s): c2ba285

load molecule encoder checkpoints from HuggingFace model card

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -62,13 +62,19 @@ def display_dti():
62
  )
63
  if selected_encoder == 'CDDD':
64
  from cddd.inference import InferenceModel
65
- CDDD_MODEL_DIR = 'checkpoints/CDDD/default_model'
66
- cddd_model = InferenceModel(CDDD_MODEL_DIR)
 
 
 
67
  embedding = cddd_model.seq_to_emb([smiles])
68
  elif selected_encoder == 'MolBERT':
69
  from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer
70
- MOLBERT_MODEL_DIR = 'checkpoints/MolBert/molbert_100epochs/checkpoints/last.ckpt'
71
- molbert_model = MolBertFeaturizer(MOLBERT_MODEL_DIR, max_seq_len=500, embedding_type='average-1-cat-pooled')
 
 
 
72
  embedding = molbert_model.transform([smiles])
73
  else:
74
  st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')
 
62
  )
63
  if selected_encoder == 'CDDD':
64
  from cddd.inference import InferenceModel
65
+ from huggingface_hub import hf_hub_download
66
+ CDDD_MODEL_DIR = 'encoders/cddd'
67
+ REPO_ID = "emmas96/hyperpcm"
68
+ checkpoint_path = hf_hub_download(REPO_ID, CDDD_MODEL_DIR)
69
+ cddd_model = InferenceModel(checkpoint_path)
70
  embedding = cddd_model.seq_to_emb([smiles])
71
  elif selected_encoder == 'MolBERT':
72
  from molbert.utils.featurizer.molbert_featurizer import MolBertFeaturizer
73
+ from huggingface_hub import hf_hub_download
74
+ CDDD_MODEL_DIR = 'encoders/molbert/last.ckpt'
75
+ REPO_ID = "emmas96/hyperpcm"
76
+ checkpoint_path = hf_hub_download(REPO_ID, MOLBERT_MODEL_DIR)
77
+ molbert_model = MolBertFeaturizer(checkpoint_path, max_seq_len=500, embedding_type='average-1-cat-pooled')
78
  embedding = molbert_model.transform([smiles])
79
  else:
80
  st.write('No pre-trained version of HyperPCM is available for the chosen encoder.')