Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +7 -3
modeling_fastesm.py
CHANGED
|
@@ -603,6 +603,7 @@ class EmbeddingMixin:
|
|
| 603 |
tokenizer: PreTrainedTokenizerBase,
|
| 604 |
batch_size: int = 2,
|
| 605 |
max_len: int = 512,
|
|
|
|
| 606 |
full_embeddings: bool = False,
|
| 607 |
embed_dtype: torch.dtype = torch.float32,
|
| 608 |
pooling_types: List[str] = ['mean'],
|
|
@@ -654,8 +655,9 @@ class EmbeddingMixin:
|
|
| 654 |
)
|
| 655 |
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
| 656 |
"""
|
| 657 |
-
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 658 |
sequences = sorted(sequences, key=len, reverse=True)
|
|
|
|
| 659 |
collate_fn = build_collator(tokenizer)
|
| 660 |
device = self.device
|
| 661 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
|
@@ -686,7 +688,7 @@ class EmbeddingMixin:
|
|
| 686 |
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
| 687 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 688 |
if full_embeddings:
|
| 689 |
-
emb = emb[mask.bool()]
|
| 690 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 691 |
(seq, emb.cpu().numpy().tobytes()))
|
| 692 |
|
|
@@ -716,7 +718,9 @@ class EmbeddingMixin:
|
|
| 716 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 717 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
| 718 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
|
| 719 |
-
for seq, emb in zip(seqs, embeddings):
|
|
|
|
|
|
|
| 720 |
embeddings_dict[seq] = emb
|
| 721 |
|
| 722 |
if save:
|
|
|
|
| 603 |
tokenizer: PreTrainedTokenizerBase,
|
| 604 |
batch_size: int = 2,
|
| 605 |
max_len: int = 512,
|
| 606 |
+
truncate: bool = True,
|
| 607 |
full_embeddings: bool = False,
|
| 608 |
embed_dtype: torch.dtype = torch.float32,
|
| 609 |
pooling_types: List[str] = ['mean'],
|
|
|
|
| 655 |
)
|
| 656 |
>>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
| 657 |
"""
|
| 658 |
+
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 659 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 660 |
+
hidden_size = self.config.hidden_size
|
| 661 |
collate_fn = build_collator(tokenizer)
|
| 662 |
device = self.device
|
| 663 |
pooler = Pooler(pooling_types) if not full_embeddings else None
|
|
|
|
| 688 |
embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
|
| 689 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 690 |
if full_embeddings:
|
| 691 |
+
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 692 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 693 |
(seq, emb.cpu().numpy().tobytes()))
|
| 694 |
|
|
|
|
| 718 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 719 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
| 720 |
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
|
| 721 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 722 |
+
if full_embeddings:
|
| 723 |
+
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 724 |
embeddings_dict[seq] = emb
|
| 725 |
|
| 726 |
if save:
|