Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +3 -3
modeling_esm_plusplus.py
CHANGED
@@ -711,7 +711,7 @@ class EmbeddingMixin:
|
|
711 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
712 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
713 |
residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
|
714 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
715 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
716 |
if full_embeddings:
|
717 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
@@ -743,11 +743,11 @@ class EmbeddingMixin:
|
|
743 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
744 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
745 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
746 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
747 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
748 |
if full_embeddings:
|
749 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
750 |
-
embeddings_dict[seq] = emb
|
751 |
|
752 |
if save:
|
753 |
torch.save(embeddings_dict, save_path)
|
|
|
711 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
712 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
713 |
residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
|
714 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
715 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
716 |
if full_embeddings:
|
717 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
|
743 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
744 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
745 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
746 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
747 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
748 |
if full_embeddings:
|
749 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
750 |
+
embeddings_dict[seq] = emb.cpu()
|
751 |
|
752 |
if save:
|
753 |
torch.save(embeddings_dict, save_path)
|