lhallee commited on
Commit
3687a41
·
verified ·
1 Parent(s): 32d5094

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +3 -3
modeling_esm_plusplus.py CHANGED
@@ -711,7 +711,7 @@ class EmbeddingMixin:
711
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
712
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
713
  residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
714
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
715
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
716
  if full_embeddings:
717
  emb = emb[mask.bool()].reshape(-1, hidden_size)
@@ -743,11 +743,11 @@ class EmbeddingMixin:
743
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
744
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
745
  residue_embeddings = self._embed(input_ids, attention_mask)
746
- embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
747
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
748
  if full_embeddings:
749
  emb = emb[mask.bool()].reshape(-1, hidden_size)
750
- embeddings_dict[seq] = emb
751
 
752
  if save:
753
  torch.save(embeddings_dict, save_path)
 
711
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
712
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
713
  residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
714
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
715
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
716
  if full_embeddings:
717
  emb = emb[mask.bool()].reshape(-1, hidden_size)
 
743
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
744
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
745
  residue_embeddings = self._embed(input_ids, attention_mask)
746
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
747
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
748
  if full_embeddings:
749
  emb = emb[mask.bool()].reshape(-1, hidden_size)
750
+ embeddings_dict[seq] = emb.cpu()
751
 
752
  if save:
753
  torch.save(embeddings_dict, save_path)