Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -364,21 +364,21 @@ class ChatNTConfig(PretrainedConfig):
|
|
364 |
return output
|
365 |
|
366 |
|
367 |
-
class
|
368 |
def __init__(
|
369 |
self,
|
370 |
gpt_config: GptConfig,
|
371 |
seq_token_id: int,
|
372 |
):
|
373 |
"""
|
374 |
-
Initializes the
|
375 |
bio embeddings.
|
376 |
|
377 |
Args:
|
378 |
gpt_config: Configuration for the GPT model
|
379 |
seq_token_id: Index of the SEQ token
|
380 |
"""
|
381 |
-
super(
|
382 |
self.gpt_config = gpt_config
|
383 |
self.seq_token_id = seq_token_id
|
384 |
|
@@ -582,7 +582,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
582 |
return logits_acc, tokens_acc
|
583 |
|
584 |
|
585 |
-
class
|
586 |
config_class = ChatNTConfig
|
587 |
|
588 |
def __init__(self, config: ChatNTConfig) -> None:
|
@@ -625,11 +625,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
625 |
# Correct seq_token_id
|
626 |
self.seq_token_id -= 1
|
627 |
|
628 |
-
self.
|
629 |
-
self.
|
630 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
631 |
)
|
632 |
-
self.projection_model =
|
633 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
634 |
input_embed_dim=self.nt_config.embed_dim,
|
635 |
embed_dim=self.gpt_config.embed_dim,
|
@@ -702,7 +702,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
702 |
if projected_bio_embeddings is None:
|
703 |
# Compute bio sequences embeddings
|
704 |
bio_embeddings_list = [
|
705 |
-
self.
|
706 |
for bio_seq_num in range(num_bio_sequences)
|
707 |
]
|
708 |
|
@@ -718,7 +718,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
718 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
719 |
|
720 |
# decode
|
721 |
-
logits = self.
|
722 |
english_token_ids=english_token_ids,
|
723 |
projected_bio_embeddings=projected_bio_embeddings,
|
724 |
)
|
@@ -1498,12 +1498,12 @@ class RobertaLMHead(nn.Module):
|
|
1498 |
return {"embeddings": embeddings, "logits": logits}
|
1499 |
|
1500 |
|
1501 |
-
class
|
1502 |
def __init__(
|
1503 |
self,
|
1504 |
nt_config: NucleotideTransformerConfig,
|
1505 |
):
|
1506 |
-
super(
|
1507 |
self.nt_config = nt_config
|
1508 |
|
1509 |
# Other cases are not implemented
|
@@ -1599,14 +1599,14 @@ def build_padding_attention_mask(
|
|
1599 |
return padding_mask
|
1600 |
|
1601 |
|
1602 |
-
class
|
1603 |
def __init__(
|
1604 |
self,
|
1605 |
nt_config: NucleotideTransformerConfig,
|
1606 |
):
|
1607 |
-
super(
|
1608 |
self.nt_config = nt_config
|
1609 |
-
self.nt_model =
|
1610 |
|
1611 |
def forward(
|
1612 |
self,
|
@@ -1626,7 +1626,7 @@ class TorchBioBrainEncoder(nn.Module):
|
|
1626 |
return bio_embeddings
|
1627 |
|
1628 |
|
1629 |
-
class
|
1630 |
def __init__(
|
1631 |
self,
|
1632 |
num_heads: int,
|
@@ -1714,7 +1714,7 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
|
1714 |
return {"embeddings": x}
|
1715 |
|
1716 |
|
1717 |
-
class
|
1718 |
"""
|
1719 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
1720 |
"""
|
@@ -1736,7 +1736,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1736 |
self.name = name
|
1737 |
self.layers = nn.ModuleList(
|
1738 |
[
|
1739 |
-
|
1740 |
num_heads=self.config.attention_heads,
|
1741 |
embed_dim=self.config.embed_dim,
|
1742 |
key_size=self.config.key_size,
|
@@ -1823,7 +1823,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1823 |
return outs
|
1824 |
|
1825 |
|
1826 |
-
class
|
1827 |
def __init__(
|
1828 |
self,
|
1829 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
@@ -1843,7 +1843,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
1843 |
|
1844 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
1845 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
1846 |
-
self.perceiver_resampler =
|
1847 |
|
1848 |
def forward(
|
1849 |
self,
|
|
|
364 |
return output
|
365 |
|
366 |
|
367 |
+
class ChatNTDecoder(nn.Module):
|
368 |
def __init__(
|
369 |
self,
|
370 |
gpt_config: GptConfig,
|
371 |
seq_token_id: int,
|
372 |
):
|
373 |
"""
|
374 |
+
Initializes the ChatNT decoder, using a GPT model for text generation with
|
375 |
bio embeddings.
|
376 |
|
377 |
Args:
|
378 |
gpt_config: Configuration for the GPT model
|
379 |
seq_token_id: Index of the SEQ token
|
380 |
"""
|
381 |
+
super(ChatNTDecoder, self).__init__()
|
382 |
self.gpt_config = gpt_config
|
383 |
self.seq_token_id = seq_token_id
|
384 |
|
|
|
582 |
return logits_acc, tokens_acc
|
583 |
|
584 |
|
585 |
+
class ChatNT(PreTrainedModel):
|
586 |
config_class = ChatNTConfig
|
587 |
|
588 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
|
625 |
# Correct seq_token_id
|
626 |
self.seq_token_id -= 1
|
627 |
|
628 |
+
self.chatnt_encoder = ChatNTEncoder(nt_config=self.nt_config)
|
629 |
+
self.chatnt_decoder = ChatNTDecoder(
|
630 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
631 |
)
|
632 |
+
self.projection_model = MultiModalPerceiverResamplerProjection(
|
633 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
634 |
input_embed_dim=self.nt_config.embed_dim,
|
635 |
embed_dim=self.gpt_config.embed_dim,
|
|
|
702 |
if projected_bio_embeddings is None:
|
703 |
# Compute bio sequences embeddings
|
704 |
bio_embeddings_list = [
|
705 |
+
self.chatnt_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
706 |
for bio_seq_num in range(num_bio_sequences)
|
707 |
]
|
708 |
|
|
|
718 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
719 |
|
720 |
# decode
|
721 |
+
logits = self.chatnt_decoder(
|
722 |
english_token_ids=english_token_ids,
|
723 |
projected_bio_embeddings=projected_bio_embeddings,
|
724 |
)
|
|
|
1498 |
return {"embeddings": embeddings, "logits": logits}
|
1499 |
|
1500 |
|
1501 |
+
class NucleotideTransformer(nn.Module):
|
1502 |
def __init__(
|
1503 |
self,
|
1504 |
nt_config: NucleotideTransformerConfig,
|
1505 |
):
|
1506 |
+
super(NucleotideTransformer, self).__init__()
|
1507 |
self.nt_config = nt_config
|
1508 |
|
1509 |
# Other cases are not implemented
|
|
|
1599 |
return padding_mask
|
1600 |
|
1601 |
|
1602 |
+
class ChatNTEncoder(nn.Module):
|
1603 |
def __init__(
|
1604 |
self,
|
1605 |
nt_config: NucleotideTransformerConfig,
|
1606 |
):
|
1607 |
+
super(ChatNTEncoder, self).__init__()
|
1608 |
self.nt_config = nt_config
|
1609 |
+
self.nt_model = NucleotideTransformer(self.nt_config)
|
1610 |
|
1611 |
def forward(
|
1612 |
self,
|
|
|
1626 |
return bio_embeddings
|
1627 |
|
1628 |
|
1629 |
+
class MultiModalPerceiverResamplerBlock(nn.Module):
|
1630 |
def __init__(
|
1631 |
self,
|
1632 |
num_heads: int,
|
|
|
1714 |
return {"embeddings": x}
|
1715 |
|
1716 |
|
1717 |
+
class MultiModalPerceiverResampler(nn.Module):
|
1718 |
"""
|
1719 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
1720 |
"""
|
|
|
1736 |
self.name = name
|
1737 |
self.layers = nn.ModuleList(
|
1738 |
[
|
1739 |
+
MultiModalPerceiverResamplerBlock(
|
1740 |
num_heads=self.config.attention_heads,
|
1741 |
embed_dim=self.config.embed_dim,
|
1742 |
key_size=self.config.key_size,
|
|
|
1823 |
return outs
|
1824 |
|
1825 |
|
1826 |
+
class MultiModalPerceiverResamplerProjection(nn.Module):
|
1827 |
def __init__(
|
1828 |
self,
|
1829 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
|
|
1843 |
|
1844 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
1845 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
1846 |
+
self.perceiver_resampler = MultiModalPerceiverResampler(config=self.config)
|
1847 |
|
1848 |
def forward(
|
1849 |
self,
|