Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -364,21 +364,21 @@ class ChatNTConfig(PretrainedConfig):
|
|
| 364 |
return output
|
| 365 |
|
| 366 |
|
| 367 |
-
class
|
| 368 |
def __init__(
|
| 369 |
self,
|
| 370 |
gpt_config: GptConfig,
|
| 371 |
seq_token_id: int,
|
| 372 |
):
|
| 373 |
"""
|
| 374 |
-
Initializes the
|
| 375 |
bio embeddings.
|
| 376 |
|
| 377 |
Args:
|
| 378 |
gpt_config: Configuration for the GPT model
|
| 379 |
seq_token_id: Index of the SEQ token
|
| 380 |
"""
|
| 381 |
-
super(
|
| 382 |
self.gpt_config = gpt_config
|
| 383 |
self.seq_token_id = seq_token_id
|
| 384 |
|
|
@@ -582,7 +582,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 582 |
return logits_acc, tokens_acc
|
| 583 |
|
| 584 |
|
| 585 |
-
class
|
| 586 |
config_class = ChatNTConfig
|
| 587 |
|
| 588 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
@@ -625,11 +625,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 625 |
# Correct seq_token_id
|
| 626 |
self.seq_token_id -= 1
|
| 627 |
|
| 628 |
-
self.
|
| 629 |
-
self.
|
| 630 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
| 631 |
)
|
| 632 |
-
self.projection_model =
|
| 633 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
| 634 |
input_embed_dim=self.nt_config.embed_dim,
|
| 635 |
embed_dim=self.gpt_config.embed_dim,
|
|
@@ -702,7 +702,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 702 |
if projected_bio_embeddings is None:
|
| 703 |
# Compute bio sequences embeddings
|
| 704 |
bio_embeddings_list = [
|
| 705 |
-
self.
|
| 706 |
for bio_seq_num in range(num_bio_sequences)
|
| 707 |
]
|
| 708 |
|
|
@@ -718,7 +718,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 718 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 719 |
|
| 720 |
# decode
|
| 721 |
-
logits = self.
|
| 722 |
english_token_ids=english_token_ids,
|
| 723 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 724 |
)
|
|
@@ -1498,12 +1498,12 @@ class RobertaLMHead(nn.Module):
|
|
| 1498 |
return {"embeddings": embeddings, "logits": logits}
|
| 1499 |
|
| 1500 |
|
| 1501 |
-
class
|
| 1502 |
def __init__(
|
| 1503 |
self,
|
| 1504 |
nt_config: NucleotideTransformerConfig,
|
| 1505 |
):
|
| 1506 |
-
super(
|
| 1507 |
self.nt_config = nt_config
|
| 1508 |
|
| 1509 |
# Other cases are not implemented
|
|
@@ -1599,14 +1599,14 @@ def build_padding_attention_mask(
|
|
| 1599 |
return padding_mask
|
| 1600 |
|
| 1601 |
|
| 1602 |
-
class
|
| 1603 |
def __init__(
|
| 1604 |
self,
|
| 1605 |
nt_config: NucleotideTransformerConfig,
|
| 1606 |
):
|
| 1607 |
-
super(
|
| 1608 |
self.nt_config = nt_config
|
| 1609 |
-
self.nt_model =
|
| 1610 |
|
| 1611 |
def forward(
|
| 1612 |
self,
|
|
@@ -1626,7 +1626,7 @@ class TorchBioBrainEncoder(nn.Module):
|
|
| 1626 |
return bio_embeddings
|
| 1627 |
|
| 1628 |
|
| 1629 |
-
class
|
| 1630 |
def __init__(
|
| 1631 |
self,
|
| 1632 |
num_heads: int,
|
|
@@ -1714,7 +1714,7 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
|
| 1714 |
return {"embeddings": x}
|
| 1715 |
|
| 1716 |
|
| 1717 |
-
class
|
| 1718 |
"""
|
| 1719 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
| 1720 |
"""
|
|
@@ -1736,7 +1736,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1736 |
self.name = name
|
| 1737 |
self.layers = nn.ModuleList(
|
| 1738 |
[
|
| 1739 |
-
|
| 1740 |
num_heads=self.config.attention_heads,
|
| 1741 |
embed_dim=self.config.embed_dim,
|
| 1742 |
key_size=self.config.key_size,
|
|
@@ -1823,7 +1823,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1823 |
return outs
|
| 1824 |
|
| 1825 |
|
| 1826 |
-
class
|
| 1827 |
def __init__(
|
| 1828 |
self,
|
| 1829 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
|
@@ -1843,7 +1843,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1843 |
|
| 1844 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
| 1845 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
| 1846 |
-
self.perceiver_resampler =
|
| 1847 |
|
| 1848 |
def forward(
|
| 1849 |
self,
|
|
|
|
| 364 |
return output
|
| 365 |
|
| 366 |
|
| 367 |
+
class ChatNTDecoder(nn.Module):
|
| 368 |
def __init__(
|
| 369 |
self,
|
| 370 |
gpt_config: GptConfig,
|
| 371 |
seq_token_id: int,
|
| 372 |
):
|
| 373 |
"""
|
| 374 |
+
Initializes the ChatNT decoder, using a GPT model for text generation with
|
| 375 |
bio embeddings.
|
| 376 |
|
| 377 |
Args:
|
| 378 |
gpt_config: Configuration for the GPT model
|
| 379 |
seq_token_id: Index of the SEQ token
|
| 380 |
"""
|
| 381 |
+
super(ChatNTDecoder, self).__init__()
|
| 382 |
self.gpt_config = gpt_config
|
| 383 |
self.seq_token_id = seq_token_id
|
| 384 |
|
|
|
|
| 582 |
return logits_acc, tokens_acc
|
| 583 |
|
| 584 |
|
| 585 |
+
class ChatNT(PreTrainedModel):
|
| 586 |
config_class = ChatNTConfig
|
| 587 |
|
| 588 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
|
|
| 625 |
# Correct seq_token_id
|
| 626 |
self.seq_token_id -= 1
|
| 627 |
|
| 628 |
+
self.chatnt_encoder = ChatNTEncoder(nt_config=self.nt_config)
|
| 629 |
+
self.chatnt_decoder = ChatNTDecoder(
|
| 630 |
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
|
| 631 |
)
|
| 632 |
+
self.projection_model = MultiModalPerceiverResamplerProjection(
|
| 633 |
perceiver_resampler_config=self.perceiver_resampler_config,
|
| 634 |
input_embed_dim=self.nt_config.embed_dim,
|
| 635 |
embed_dim=self.gpt_config.embed_dim,
|
|
|
|
| 702 |
if projected_bio_embeddings is None:
|
| 703 |
# Compute bio sequences embeddings
|
| 704 |
bio_embeddings_list = [
|
| 705 |
+
self.chatnt_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
| 706 |
for bio_seq_num in range(num_bio_sequences)
|
| 707 |
]
|
| 708 |
|
|
|
|
| 718 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 719 |
|
| 720 |
# decode
|
| 721 |
+
logits = self.chatnt_decoder(
|
| 722 |
english_token_ids=english_token_ids,
|
| 723 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 724 |
)
|
|
|
|
| 1498 |
return {"embeddings": embeddings, "logits": logits}
|
| 1499 |
|
| 1500 |
|
| 1501 |
+
class NucleotideTransformer(nn.Module):
|
| 1502 |
def __init__(
|
| 1503 |
self,
|
| 1504 |
nt_config: NucleotideTransformerConfig,
|
| 1505 |
):
|
| 1506 |
+
super(NucleotideTransformer, self).__init__()
|
| 1507 |
self.nt_config = nt_config
|
| 1508 |
|
| 1509 |
# Other cases are not implemented
|
|
|
|
| 1599 |
return padding_mask
|
| 1600 |
|
| 1601 |
|
| 1602 |
+
class ChatNTEncoder(nn.Module):
|
| 1603 |
def __init__(
|
| 1604 |
self,
|
| 1605 |
nt_config: NucleotideTransformerConfig,
|
| 1606 |
):
|
| 1607 |
+
super(ChatNTEncoder, self).__init__()
|
| 1608 |
self.nt_config = nt_config
|
| 1609 |
+
self.nt_model = NucleotideTransformer(self.nt_config)
|
| 1610 |
|
| 1611 |
def forward(
|
| 1612 |
self,
|
|
|
|
| 1626 |
return bio_embeddings
|
| 1627 |
|
| 1628 |
|
| 1629 |
+
class MultiModalPerceiverResamplerBlock(nn.Module):
|
| 1630 |
def __init__(
|
| 1631 |
self,
|
| 1632 |
num_heads: int,
|
|
|
|
| 1714 |
return {"embeddings": x}
|
| 1715 |
|
| 1716 |
|
| 1717 |
+
class MultiModalPerceiverResampler(nn.Module):
|
| 1718 |
"""
|
| 1719 |
Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
|
| 1720 |
"""
|
|
|
|
| 1736 |
self.name = name
|
| 1737 |
self.layers = nn.ModuleList(
|
| 1738 |
[
|
| 1739 |
+
MultiModalPerceiverResamplerBlock(
|
| 1740 |
num_heads=self.config.attention_heads,
|
| 1741 |
embed_dim=self.config.embed_dim,
|
| 1742 |
key_size=self.config.key_size,
|
|
|
|
| 1823 |
return outs
|
| 1824 |
|
| 1825 |
|
| 1826 |
+
class MultiModalPerceiverResamplerProjection(nn.Module):
|
| 1827 |
def __init__(
|
| 1828 |
self,
|
| 1829 |
perceiver_resampler_config: PerceiverResamplerConfig,
|
|
|
|
| 1843 |
|
| 1844 |
self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
|
| 1845 |
self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
|
| 1846 |
+
self.perceiver_resampler = MultiModalPerceiverResampler(config=self.config)
|
| 1847 |
|
| 1848 |
def forward(
|
| 1849 |
self,
|