Update chatNT.py (#3)
Browse files- Update chatNT.py (94f611d6d404ad9fe26c0dcd8ed5ad55f4bed639)
chatNT.py
CHANGED
@@ -28,7 +28,6 @@ class RotaryEmbeddingConfig:
|
|
28 |
class PerceiverResamplerConfig:
|
29 |
"""
|
30 |
Parameters to initialize an PerceiverResampler model.
|
31 |
-
|
32 |
Args:
|
33 |
emb_layer_norm_before: Whether to use layer norm before the first attention
|
34 |
layer.
|
@@ -93,9 +92,7 @@ class PerceiverResamplerConfig:
|
|
93 |
class GptConfig:
|
94 |
"""
|
95 |
Parameters to initialize a Gpt model.
|
96 |
-
|
97 |
NOTE: the pad token is not defined
|
98 |
-
|
99 |
Args:
|
100 |
vocab_size: Token vocabulary.
|
101 |
eos_token_id: used to stop sentence generation
|
@@ -191,7 +188,6 @@ class GptConfig:
|
|
191 |
class NucleotideTransformerConfig:
|
192 |
"""
|
193 |
Parameters to initialize an NT model.
|
194 |
-
|
195 |
Args:
|
196 |
alphabet_size: Token vocabulary.
|
197 |
pad_token_id: ID of pad token.
|
@@ -364,21 +360,20 @@ class ChatNTConfig(PretrainedConfig):
|
|
364 |
return output
|
365 |
|
366 |
|
367 |
-
class
|
368 |
def __init__(
|
369 |
self,
|
370 |
gpt_config: GptConfig,
|
371 |
seq_token_id: int,
|
372 |
):
|
373 |
"""
|
374 |
-
Initializes the
|
375 |
bio embeddings.
|
376 |
-
|
377 |
Args:
|
378 |
gpt_config: Configuration for the GPT model
|
379 |
seq_token_id: Index of the SEQ token
|
380 |
"""
|
381 |
-
super(
|
382 |
self.gpt_config = gpt_config
|
383 |
self.seq_token_id = seq_token_id
|
384 |
|
@@ -390,13 +385,11 @@ class ChatNTDecoder(nn.Module):
|
|
390 |
) -> torch.Tensor:
|
391 |
"""
|
392 |
Forward pass through the model.
|
393 |
-
|
394 |
Args:
|
395 |
english_token_ids: Tensor of English token IDs with shape
|
396 |
(batch_size, num_english_tokens).
|
397 |
projected_bio_embeddings: Optional tensor of bio embeddings with shape
|
398 |
(batch_size, num_bio_sequences, ?, embed_dim).
|
399 |
-
|
400 |
Returns:
|
401 |
torch.Tensor: The logits from the GPT model,
|
402 |
shaped (batch_size, num_english_tokens, vocab_size).
|
@@ -452,13 +445,11 @@ class ChatNTDecoder(nn.Module):
|
|
452 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
453 |
"""
|
454 |
Inserts resampled embeddings in input_embeddings, starting at the SEQ token
|
455 |
-
|
456 |
Args:
|
457 |
tokens (torch.Tensor): Shape (batch_size, num_tokens)
|
458 |
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
|
459 |
resampled_embeddings (torch.Tensor):
|
460 |
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
|
461 |
-
|
462 |
Returns:
|
463 |
Tuple[torch.Tensor, torch.Tensor]:
|
464 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
@@ -521,11 +512,9 @@ class ChatNTDecoder(nn.Module):
|
|
521 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
522 |
"""
|
523 |
Removes the logits corresponding to the unused embeddings.
|
524 |
-
|
525 |
Args:
|
526 |
tokens: Input english tokens.
|
527 |
logits: Input logits.
|
528 |
-
|
529 |
Returns:
|
530 |
Cleaned logits, last values will be equal to 0.
|
531 |
"""
|
@@ -582,7 +571,7 @@ class ChatNTDecoder(nn.Module):
|
|
582 |
return logits_acc, tokens_acc
|
583 |
|
584 |
|
585 |
-
class
|
586 |
config_class = ChatNTConfig
|
587 |
|
588 |
def __init__(self, config: ChatNTConfig) -> None:
|
@@ -625,11 +614,11 @@ class ChatNT(PreTrainedModel):
|
|
625 |
# Correct seq_token_id
|
626 |
self.seq_token_id -= 1
|
627 |
|
628 |
-
self.
|
629 |
-
self.
|
630 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
631 |
)
|
632 |
-
self.projection_model =
|
633 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
634 |
input_embed_dim=self.nt_config.embed_dim,
|
635 |
embed_dim=self.gpt_config.embed_dim,
|
@@ -645,27 +634,21 @@ class ChatNT(PreTrainedModel):
|
|
645 |
projected_bio_embeddings: torch.Tensor = None,
|
646 |
) -> dict[str, torch.Tensor]:
|
647 |
"""
|
648 |
-
|
649 |
Args:
|
650 |
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
|
651 |
english_tokens_ids: Represents the prompt tokens (english tokens)
|
652 |
Shape (batch_size, num_english_tokens)
|
653 |
-
|
654 |
bio_tokens_ids: Represents the bio sequences tokens
|
655 |
Shape (batch_size, num_bio_sequences, num_bio_tokens)
|
656 |
-
|
657 |
projection_english_tokens_ids (torch.Tensor):
|
658 |
Shape (batch_size, num_english_tokens)
|
659 |
-
|
660 |
projected_bio_embeddings (projected_bio_embeddings, optional):
|
661 |
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
|
662 |
Defaults to None.
|
663 |
-
|
664 |
Returns:
|
665 |
dict[str, torch.Tensor] containing:
|
666 |
- logits:
|
667 |
Shape (batch_size, num_tokens, vocab_size)
|
668 |
-
|
669 |
- projected_bio_embeddings:
|
670 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
671 |
"""
|
@@ -702,7 +685,7 @@ class ChatNT(PreTrainedModel):
|
|
702 |
if projected_bio_embeddings is None:
|
703 |
# Compute bio sequences embeddings
|
704 |
bio_embeddings_list = [
|
705 |
-
self.
|
706 |
for bio_seq_num in range(num_bio_sequences)
|
707 |
]
|
708 |
|
@@ -718,7 +701,7 @@ class ChatNT(PreTrainedModel):
|
|
718 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
719 |
|
720 |
# decode
|
721 |
-
logits = self.
|
722 |
english_token_ids=english_token_ids,
|
723 |
projected_bio_embeddings=projected_bio_embeddings,
|
724 |
)
|
@@ -741,7 +724,6 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
741 |
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
742 |
"""
|
743 |
Create the sines and cosines for the RoPE.
|
744 |
-
|
745 |
Returns:
|
746 |
Sinusoidal positions of shape (self.max_seq_len, self.dim).
|
747 |
"""
|
@@ -774,11 +756,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
774 |
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
775 |
"""
|
776 |
Prepare a tensor to apply the RoPE mechanism.
|
777 |
-
|
778 |
Args:
|
779 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
780 |
typically this is the key or query tensor.
|
781 |
-
|
782 |
Returns:
|
783 |
The even indices in the last dimension have their sign flipped.
|
784 |
Tensor of shape (batch_size, seq_len, num_heads, head_dim).
|
@@ -795,12 +775,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
795 |
) -> torch.Tensor:
|
796 |
"""
|
797 |
Applies rotary embeddings to x.
|
798 |
-
|
799 |
Args:
|
800 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
801 |
typically this is the key or query tensor.
|
802 |
sincos: Tuple of sine and cosine tensors for position encoding.
|
803 |
-
|
804 |
Returns:
|
805 |
RoPE embeddings tensor.
|
806 |
"""
|
@@ -818,12 +796,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
818 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
819 |
"""
|
820 |
Applies rotary embeddings to k and q.
|
821 |
-
|
822 |
Args:
|
823 |
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
824 |
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
825 |
positions: optional positions offset useful when caching,
|
826 |
-
|
827 |
Returns:
|
828 |
RoPE embeddings for the keys and values.
|
829 |
"""
|
@@ -1141,11 +1117,9 @@ def build_causal_attention_mask(
|
|
1141 |
"""
|
1142 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
1143 |
to an attention layer.
|
1144 |
-
|
1145 |
Args:
|
1146 |
batch_size: Batch size.
|
1147 |
seq_len: Length of the sequences.
|
1148 |
-
|
1149 |
Returns:
|
1150 |
Batch of causal masks.
|
1151 |
"""
|
@@ -1498,12 +1472,12 @@ class RobertaLMHead(nn.Module):
|
|
1498 |
return {"embeddings": embeddings, "logits": logits}
|
1499 |
|
1500 |
|
1501 |
-
class
|
1502 |
def __init__(
|
1503 |
self,
|
1504 |
nt_config: NucleotideTransformerConfig,
|
1505 |
):
|
1506 |
-
super(
|
1507 |
self.nt_config = nt_config
|
1508 |
|
1509 |
# Other cases are not implemented
|
@@ -1551,13 +1525,11 @@ class NucleotideTransformer(nn.Module):
|
|
1551 |
) -> torch.Tensor:
|
1552 |
"""
|
1553 |
Computes the embeddings based on the input tokens.
|
1554 |
-
|
1555 |
Args:
|
1556 |
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
|
1557 |
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
|
1558 |
If no mask is provided, a mask by default which equals 1 over all non
|
1559 |
pad tokens and 0 over pad tokens is computed.
|
1560 |
-
|
1561 |
Returns:
|
1562 |
Dictionary containing the final embeddings and logits.
|
1563 |
"""
|
@@ -1585,11 +1557,9 @@ def build_padding_attention_mask(
|
|
1585 |
) -> torch.Tensor:
|
1586 |
"""
|
1587 |
Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
|
1588 |
-
|
1589 |
Args:
|
1590 |
tokens: Batch of sequences of shape (batch_size, seq_len).
|
1591 |
pad_token_id: Int corresponding to the <pad> token to mask.
|
1592 |
-
|
1593 |
Returns:
|
1594 |
Batch of attention masks, masking out <pad> tokens.
|
1595 |
"""
|
@@ -1599,14 +1569,14 @@ def build_padding_attention_mask(
|
|
1599 |
return padding_mask
|
1600 |
|
1601 |
|
1602 |
-
class
|
1603 |
def __init__(
|
1604 |
self,
|
1605 |
nt_config: NucleotideTransformerConfig,
|
1606 |
):
|
1607 |
-
super(
|
1608 |
self.nt_config = nt_config
|
1609 |
-
self.nt_model =
|
1610 |
|
1611 |
def forward(
|
1612 |
self,
|
@@ -1616,7 +1586,6 @@ class ChatNTEncoder(nn.Module):
|
|
1616 |
Args:
|
1617 |
bio_token_ids (torch.Tensor):
|
1618 |
Shape (batch_size, num_bio_tokens)
|
1619 |
-
|
1620 |
Returns:
|
1621 |
torch.Tensor:
|
1622 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
@@ -1626,7 +1595,7 @@ class ChatNTEncoder(nn.Module):
|
|
1626 |
return bio_embeddings
|
1627 |
|
1628 |
|
1629 |
-
class
|
1630 |
def __init__(
|
1631 |
self,
|
1632 |
num_heads: int,
|
@@ -1714,7 +1683,7 @@ class MultiModalPerceiverResamplerBlock(nn.Module):
|
|
1714 |
return {"embeddings": x}
|
1715 |
|
1716 |
|
1717 |
-
class
|
1718 |
"""
|
1719 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
1720 |
"""
|
@@ -1726,7 +1695,6 @@ class MultiModalPerceiverResampler(nn.Module):
|
|
1726 |
):
|
1727 |
"""
|
1728 |
Initialize a Perceiver Resampler model.
|
1729 |
-
|
1730 |
Args:
|
1731 |
config: Dataclass containing model hyperparameters.
|
1732 |
name: Name for module (custom will break weight loading).
|
@@ -1736,7 +1704,7 @@ class MultiModalPerceiverResampler(nn.Module):
|
|
1736 |
self.name = name
|
1737 |
self.layers = nn.ModuleList(
|
1738 |
[
|
1739 |
-
|
1740 |
num_heads=self.config.attention_heads,
|
1741 |
embed_dim=self.config.embed_dim,
|
1742 |
key_size=self.config.key_size,
|
@@ -1823,7 +1791,7 @@ class MultiModalPerceiverResampler(nn.Module):
|
|
1823 |
return outs
|
1824 |
|
1825 |
|
1826 |
-
class
|
1827 |
def __init__(
|
1828 |
self,
|
1829 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
@@ -1843,7 +1811,7 @@ class MultiModalPerceiverResamplerProjection(nn.Module):
|
|
1843 |
|
1844 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
1845 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
1846 |
-
self.perceiver_resampler =
|
1847 |
|
1848 |
def forward(
|
1849 |
self,
|
@@ -1855,10 +1823,8 @@ class MultiModalPerceiverResamplerProjection(nn.Module):
|
|
1855 |
Args:
|
1856 |
bio_token_ids (torch.Tensor):
|
1857 |
Shape (batch_size, num_bio_tokens)
|
1858 |
-
|
1859 |
bio_embeddings (torch.Tensor):
|
1860 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
1861 |
-
|
1862 |
english_token_ids (torch.Tensor):
|
1863 |
Shape (batch_size, num_english_tokens)
|
1864 |
"""
|
@@ -1901,3 +1867,4 @@ def build_perceiver_padding_attention_mask(
|
|
1901 |
padding_mask = padding_mask[:, None, None, :]
|
1902 |
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
|
1903 |
return padding_mask
|
|
|
|
28 |
class PerceiverResamplerConfig:
|
29 |
"""
|
30 |
Parameters to initialize an PerceiverResampler model.
|
|
|
31 |
Args:
|
32 |
emb_layer_norm_before: Whether to use layer norm before the first attention
|
33 |
layer.
|
|
|
92 |
class GptConfig:
|
93 |
"""
|
94 |
Parameters to initialize a Gpt model.
|
|
|
95 |
NOTE: the pad token is not defined
|
|
|
96 |
Args:
|
97 |
vocab_size: Token vocabulary.
|
98 |
eos_token_id: used to stop sentence generation
|
|
|
188 |
class NucleotideTransformerConfig:
|
189 |
"""
|
190 |
Parameters to initialize an NT model.
|
|
|
191 |
Args:
|
192 |
alphabet_size: Token vocabulary.
|
193 |
pad_token_id: ID of pad token.
|
|
|
360 |
return output
|
361 |
|
362 |
|
363 |
+
class TorchBioBrainDecoder(nn.Module):
|
364 |
def __init__(
|
365 |
self,
|
366 |
gpt_config: GptConfig,
|
367 |
seq_token_id: int,
|
368 |
):
|
369 |
"""
|
370 |
+
Initializes the BioBrain decoder, using a GPT model for text generation with
|
371 |
bio embeddings.
|
|
|
372 |
Args:
|
373 |
gpt_config: Configuration for the GPT model
|
374 |
seq_token_id: Index of the SEQ token
|
375 |
"""
|
376 |
+
super(TorchBioBrainDecoder, self).__init__()
|
377 |
self.gpt_config = gpt_config
|
378 |
self.seq_token_id = seq_token_id
|
379 |
|
|
|
385 |
) -> torch.Tensor:
|
386 |
"""
|
387 |
Forward pass through the model.
|
|
|
388 |
Args:
|
389 |
english_token_ids: Tensor of English token IDs with shape
|
390 |
(batch_size, num_english_tokens).
|
391 |
projected_bio_embeddings: Optional tensor of bio embeddings with shape
|
392 |
(batch_size, num_bio_sequences, ?, embed_dim).
|
|
|
393 |
Returns:
|
394 |
torch.Tensor: The logits from the GPT model,
|
395 |
shaped (batch_size, num_english_tokens, vocab_size).
|
|
|
445 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
446 |
"""
|
447 |
Inserts resampled embeddings in input_embeddings, starting at the SEQ token
|
|
|
448 |
Args:
|
449 |
tokens (torch.Tensor): Shape (batch_size, num_tokens)
|
450 |
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
|
451 |
resampled_embeddings (torch.Tensor):
|
452 |
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
|
|
|
453 |
Returns:
|
454 |
Tuple[torch.Tensor, torch.Tensor]:
|
455 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
|
|
512 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
513 |
"""
|
514 |
Removes the logits corresponding to the unused embeddings.
|
|
|
515 |
Args:
|
516 |
tokens: Input english tokens.
|
517 |
logits: Input logits.
|
|
|
518 |
Returns:
|
519 |
Cleaned logits, last values will be equal to 0.
|
520 |
"""
|
|
|
571 |
return logits_acc, tokens_acc
|
572 |
|
573 |
|
574 |
+
class TorchMultiOmicsModel(PreTrainedModel):
|
575 |
config_class = ChatNTConfig
|
576 |
|
577 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
|
614 |
# Correct seq_token_id
|
615 |
self.seq_token_id -= 1
|
616 |
|
617 |
+
self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config)
|
618 |
+
self.biobrain_decoder = TorchBioBrainDecoder(
|
619 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
620 |
)
|
621 |
+
self.projection_model = TorchMultiModalPerceiverResamplerProjection(
|
622 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
623 |
input_embed_dim=self.nt_config.embed_dim,
|
624 |
embed_dim=self.gpt_config.embed_dim,
|
|
|
634 |
projected_bio_embeddings: torch.Tensor = None,
|
635 |
) -> dict[str, torch.Tensor]:
|
636 |
"""
|
|
|
637 |
Args:
|
638 |
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
|
639 |
english_tokens_ids: Represents the prompt tokens (english tokens)
|
640 |
Shape (batch_size, num_english_tokens)
|
|
|
641 |
bio_tokens_ids: Represents the bio sequences tokens
|
642 |
Shape (batch_size, num_bio_sequences, num_bio_tokens)
|
|
|
643 |
projection_english_tokens_ids (torch.Tensor):
|
644 |
Shape (batch_size, num_english_tokens)
|
|
|
645 |
projected_bio_embeddings (projected_bio_embeddings, optional):
|
646 |
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
|
647 |
Defaults to None.
|
|
|
648 |
Returns:
|
649 |
dict[str, torch.Tensor] containing:
|
650 |
- logits:
|
651 |
Shape (batch_size, num_tokens, vocab_size)
|
|
|
652 |
- projected_bio_embeddings:
|
653 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
654 |
"""
|
|
|
685 |
if projected_bio_embeddings is None:
|
686 |
# Compute bio sequences embeddings
|
687 |
bio_embeddings_list = [
|
688 |
+
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
689 |
for bio_seq_num in range(num_bio_sequences)
|
690 |
]
|
691 |
|
|
|
701 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
702 |
|
703 |
# decode
|
704 |
+
logits = self.biobrain_decoder(
|
705 |
english_token_ids=english_token_ids,
|
706 |
projected_bio_embeddings=projected_bio_embeddings,
|
707 |
)
|
|
|
724 |
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
725 |
"""
|
726 |
Create the sines and cosines for the RoPE.
|
|
|
727 |
Returns:
|
728 |
Sinusoidal positions of shape (self.max_seq_len, self.dim).
|
729 |
"""
|
|
|
756 |
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
757 |
"""
|
758 |
Prepare a tensor to apply the RoPE mechanism.
|
|
|
759 |
Args:
|
760 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
761 |
typically this is the key or query tensor.
|
|
|
762 |
Returns:
|
763 |
The even indices in the last dimension have their sign flipped.
|
764 |
Tensor of shape (batch_size, seq_len, num_heads, head_dim).
|
|
|
775 |
) -> torch.Tensor:
|
776 |
"""
|
777 |
Applies rotary embeddings to x.
|
|
|
778 |
Args:
|
779 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
780 |
typically this is the key or query tensor.
|
781 |
sincos: Tuple of sine and cosine tensors for position encoding.
|
|
|
782 |
Returns:
|
783 |
RoPE embeddings tensor.
|
784 |
"""
|
|
|
796 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
797 |
"""
|
798 |
Applies rotary embeddings to k and q.
|
|
|
799 |
Args:
|
800 |
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
801 |
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
802 |
positions: optional positions offset useful when caching,
|
|
|
803 |
Returns:
|
804 |
RoPE embeddings for the keys and values.
|
805 |
"""
|
|
|
1117 |
"""
|
1118 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
1119 |
to an attention layer.
|
|
|
1120 |
Args:
|
1121 |
batch_size: Batch size.
|
1122 |
seq_len: Length of the sequences.
|
|
|
1123 |
Returns:
|
1124 |
Batch of causal masks.
|
1125 |
"""
|
|
|
1472 |
return {"embeddings": embeddings, "logits": logits}
|
1473 |
|
1474 |
|
1475 |
+
class TorchNucleotideTransformer(nn.Module):
|
1476 |
def __init__(
|
1477 |
self,
|
1478 |
nt_config: NucleotideTransformerConfig,
|
1479 |
):
|
1480 |
+
super(TorchNucleotideTransformer, self).__init__()
|
1481 |
self.nt_config = nt_config
|
1482 |
|
1483 |
# Other cases are not implemented
|
|
|
1525 |
) -> torch.Tensor:
|
1526 |
"""
|
1527 |
Computes the embeddings based on the input tokens.
|
|
|
1528 |
Args:
|
1529 |
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
|
1530 |
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
|
1531 |
If no mask is provided, a mask by default which equals 1 over all non
|
1532 |
pad tokens and 0 over pad tokens is computed.
|
|
|
1533 |
Returns:
|
1534 |
Dictionary containing the final embeddings and logits.
|
1535 |
"""
|
|
|
1557 |
) -> torch.Tensor:
|
1558 |
"""
|
1559 |
Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
|
|
|
1560 |
Args:
|
1561 |
tokens: Batch of sequences of shape (batch_size, seq_len).
|
1562 |
pad_token_id: Int corresponding to the <pad> token to mask.
|
|
|
1563 |
Returns:
|
1564 |
Batch of attention masks, masking out <pad> tokens.
|
1565 |
"""
|
|
|
1569 |
return padding_mask
|
1570 |
|
1571 |
|
1572 |
+
class TorchBioBrainEncoder(nn.Module):
|
1573 |
def __init__(
|
1574 |
self,
|
1575 |
nt_config: NucleotideTransformerConfig,
|
1576 |
):
|
1577 |
+
super(TorchBioBrainEncoder, self).__init__()
|
1578 |
self.nt_config = nt_config
|
1579 |
+
self.nt_model = TorchNucleotideTransformer(self.nt_config)
|
1580 |
|
1581 |
def forward(
|
1582 |
self,
|
|
|
1586 |
Args:
|
1587 |
bio_token_ids (torch.Tensor):
|
1588 |
Shape (batch_size, num_bio_tokens)
|
|
|
1589 |
Returns:
|
1590 |
torch.Tensor:
|
1591 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
|
1595 |
return bio_embeddings
|
1596 |
|
1597 |
|
1598 |
+
class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
1599 |
def __init__(
|
1600 |
self,
|
1601 |
num_heads: int,
|
|
|
1683 |
return {"embeddings": x}
|
1684 |
|
1685 |
|
1686 |
+
class TorchMultiModalPerceiverResampler(nn.Module):
|
1687 |
"""
|
1688 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
1689 |
"""
|
|
|
1695 |
):
|
1696 |
"""
|
1697 |
Initialize a Perceiver Resampler model.
|
|
|
1698 |
Args:
|
1699 |
config: Dataclass containing model hyperparameters.
|
1700 |
name: Name for module (custom will break weight loading).
|
|
|
1704 |
self.name = name
|
1705 |
self.layers = nn.ModuleList(
|
1706 |
[
|
1707 |
+
TorchMultiModalPerceiverResamplerBlock(
|
1708 |
num_heads=self.config.attention_heads,
|
1709 |
embed_dim=self.config.embed_dim,
|
1710 |
key_size=self.config.key_size,
|
|
|
1791 |
return outs
|
1792 |
|
1793 |
|
1794 |
+
class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
1795 |
def __init__(
|
1796 |
self,
|
1797 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
|
|
1811 |
|
1812 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
1813 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
1814 |
+
self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
|
1815 |
|
1816 |
def forward(
|
1817 |
self,
|
|
|
1823 |
Args:
|
1824 |
bio_token_ids (torch.Tensor):
|
1825 |
Shape (batch_size, num_bio_tokens)
|
|
|
1826 |
bio_embeddings (torch.Tensor):
|
1827 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
|
1828 |
english_token_ids (torch.Tensor):
|
1829 |
Shape (batch_size, num_english_tokens)
|
1830 |
"""
|
|
|
1867 |
padding_mask = padding_mask[:, None, None, :]
|
1868 |
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
|
1869 |
return padding_mask
|
1870 |
+
|