Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +3 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -575,7 +575,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 577 |
x = self.embed(input_ids)
|
| 578 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
|
| 579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 580 |
|
| 581 |
for seq, emb in zip(seqs, embeddings):
|
|
@@ -595,10 +595,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 595 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 597 |
x = self.embed(input_ids)
|
| 598 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
| 599 |
if full_precision:
|
| 600 |
residue_embeddings = residue_embeddings.float()
|
| 601 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 602 |
for seq, emb in zip(seqs, embeddings):
|
| 603 |
embeddings_dict[seq] = emb
|
| 604 |
|
|
|
|
| 575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 577 |
x = self.embed(input_ids)
|
| 578 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
| 579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 580 |
|
| 581 |
for seq, emb in zip(seqs, embeddings):
|
|
|
|
| 595 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 597 |
x = self.embed(input_ids)
|
| 598 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
|
| 599 |
if full_precision:
|
| 600 |
residue_embeddings = residue_embeddings.float()
|
| 601 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
| 602 |
for seq, emb in zip(seqs, embeddings):
|
| 603 |
embeddings_dict[seq] = emb
|
| 604 |
|