Yanisadel commited on
Commit
7b1c32a
·
verified ·
1 Parent(s): 9ca04b0

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +19 -19
chatNT.py CHANGED
@@ -364,21 +364,21 @@ class ChatNTConfig(PretrainedConfig):
364
  return output
365
 
366
 
367
- class TorchBioBrainDecoder(nn.Module):
368
  def __init__(
369
  self,
370
  gpt_config: GptConfig,
371
  seq_token_id: int,
372
  ):
373
  """
374
- Initializes the BioBrain decoder, using a GPT model for text generation with
375
  bio embeddings.
376
 
377
  Args:
378
  gpt_config: Configuration for the GPT model
379
  seq_token_id: Index of the SEQ token
380
  """
381
- super(TorchBioBrainDecoder, self).__init__()
382
  self.gpt_config = gpt_config
383
  self.seq_token_id = seq_token_id
384
 
@@ -582,7 +582,7 @@ class TorchBioBrainDecoder(nn.Module):
582
  return logits_acc, tokens_acc
583
 
584
 
585
- class TorchMultiOmicsModel(PreTrainedModel):
586
  config_class = ChatNTConfig
587
 
588
  def __init__(self, config: ChatNTConfig) -> None:
@@ -625,11 +625,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
625
  # Correct seq_token_id
626
  self.seq_token_id -= 1
627
 
628
- self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config)
629
- self.biobrain_decoder = TorchBioBrainDecoder(
630
  gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
631
  )
632
- self.projection_model = TorchMultiModalPerceiverResamplerProjection(
633
  perceiver_resampler_config=self.perceiver_resampler_config,
634
  input_embed_dim=self.nt_config.embed_dim,
635
  embed_dim=self.gpt_config.embed_dim,
@@ -702,7 +702,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
702
  if projected_bio_embeddings is None:
703
  # Compute bio sequences embeddings
704
  bio_embeddings_list = [
705
- self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
706
  for bio_seq_num in range(num_bio_sequences)
707
  ]
708
 
@@ -718,7 +718,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
718
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
719
 
720
  # decode
721
- logits = self.biobrain_decoder(
722
  english_token_ids=english_token_ids,
723
  projected_bio_embeddings=projected_bio_embeddings,
724
  )
@@ -1498,12 +1498,12 @@ class RobertaLMHead(nn.Module):
1498
  return {"embeddings": embeddings, "logits": logits}
1499
 
1500
 
1501
- class TorchNucleotideTransformer(nn.Module):
1502
  def __init__(
1503
  self,
1504
  nt_config: NucleotideTransformerConfig,
1505
  ):
1506
- super(TorchNucleotideTransformer, self).__init__()
1507
  self.nt_config = nt_config
1508
 
1509
  # Other cases are not implemented
@@ -1599,14 +1599,14 @@ def build_padding_attention_mask(
1599
  return padding_mask
1600
 
1601
 
1602
- class TorchBioBrainEncoder(nn.Module):
1603
  def __init__(
1604
  self,
1605
  nt_config: NucleotideTransformerConfig,
1606
  ):
1607
- super(TorchBioBrainEncoder, self).__init__()
1608
  self.nt_config = nt_config
1609
- self.nt_model = TorchNucleotideTransformer(self.nt_config)
1610
 
1611
  def forward(
1612
  self,
@@ -1626,7 +1626,7 @@ class TorchBioBrainEncoder(nn.Module):
1626
  return bio_embeddings
1627
 
1628
 
1629
- class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1630
  def __init__(
1631
  self,
1632
  num_heads: int,
@@ -1714,7 +1714,7 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1714
  return {"embeddings": x}
1715
 
1716
 
1717
- class TorchMultiModalPerceiverResampler(nn.Module):
1718
  """
1719
  Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1720
  """
@@ -1736,7 +1736,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1736
  self.name = name
1737
  self.layers = nn.ModuleList(
1738
  [
1739
- TorchMultiModalPerceiverResamplerBlock(
1740
  num_heads=self.config.attention_heads,
1741
  embed_dim=self.config.embed_dim,
1742
  key_size=self.config.key_size,
@@ -1823,7 +1823,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1823
  return outs
1824
 
1825
 
1826
- class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1827
  def __init__(
1828
  self,
1829
  perceiver_resampler_config: PerceiverResamplerConfig,
@@ -1843,7 +1843,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1843
 
1844
  self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1845
  self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1846
- self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1847
 
1848
  def forward(
1849
  self,
 
364
  return output
365
 
366
 
367
+ class ChatNTDecoder(nn.Module):
368
  def __init__(
369
  self,
370
  gpt_config: GptConfig,
371
  seq_token_id: int,
372
  ):
373
  """
374
+ Initializes the ChatNT decoder, using a GPT model for text generation with
375
  bio embeddings.
376
 
377
  Args:
378
  gpt_config: Configuration for the GPT model
379
  seq_token_id: Index of the SEQ token
380
  """
381
+ super(ChatNTDecoder, self).__init__()
382
  self.gpt_config = gpt_config
383
  self.seq_token_id = seq_token_id
384
 
 
582
  return logits_acc, tokens_acc
583
 
584
 
585
+ class ChatNT(PreTrainedModel):
586
  config_class = ChatNTConfig
587
 
588
  def __init__(self, config: ChatNTConfig) -> None:
 
625
  # Correct seq_token_id
626
  self.seq_token_id -= 1
627
 
628
+ self.chatnt_encoder = ChatNTEncoder(nt_config=self.nt_config)
629
+ self.chatnt_decoder = ChatNTDecoder(
630
  gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
631
  )
632
+ self.projection_model = MultiModalPerceiverResamplerProjection(
633
  perceiver_resampler_config=self.perceiver_resampler_config,
634
  input_embed_dim=self.nt_config.embed_dim,
635
  embed_dim=self.gpt_config.embed_dim,
 
702
  if projected_bio_embeddings is None:
703
  # Compute bio sequences embeddings
704
  bio_embeddings_list = [
705
+ self.chatnt_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
706
  for bio_seq_num in range(num_bio_sequences)
707
  ]
708
 
 
718
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
719
 
720
  # decode
721
+ logits = self.chatnt_decoder(
722
  english_token_ids=english_token_ids,
723
  projected_bio_embeddings=projected_bio_embeddings,
724
  )
 
1498
  return {"embeddings": embeddings, "logits": logits}
1499
 
1500
 
1501
+ class NucleotideTransformer(nn.Module):
1502
  def __init__(
1503
  self,
1504
  nt_config: NucleotideTransformerConfig,
1505
  ):
1506
+ super(NucleotideTransformer, self).__init__()
1507
  self.nt_config = nt_config
1508
 
1509
  # Other cases are not implemented
 
1599
  return padding_mask
1600
 
1601
 
1602
+ class ChatNTEncoder(nn.Module):
1603
  def __init__(
1604
  self,
1605
  nt_config: NucleotideTransformerConfig,
1606
  ):
1607
+ super(ChatNTEncoder, self).__init__()
1608
  self.nt_config = nt_config
1609
+ self.nt_model = NucleotideTransformer(self.nt_config)
1610
 
1611
  def forward(
1612
  self,
 
1626
  return bio_embeddings
1627
 
1628
 
1629
+ class MultiModalPerceiverResamplerBlock(nn.Module):
1630
  def __init__(
1631
  self,
1632
  num_heads: int,
 
1714
  return {"embeddings": x}
1715
 
1716
 
1717
+ class MultiModalPerceiverResampler(nn.Module):
1718
  """
1719
  Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1720
  """
 
1736
  self.name = name
1737
  self.layers = nn.ModuleList(
1738
  [
1739
+ MultiModalPerceiverResamplerBlock(
1740
  num_heads=self.config.attention_heads,
1741
  embed_dim=self.config.embed_dim,
1742
  key_size=self.config.key_size,
 
1823
  return outs
1824
 
1825
 
1826
+ class MultiModalPerceiverResamplerProjection(nn.Module):
1827
  def __init__(
1828
  self,
1829
  perceiver_resampler_config: PerceiverResamplerConfig,
 
1843
 
1844
  self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1845
  self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1846
+ self.perceiver_resampler = MultiModalPerceiverResampler(config=self.config)
1847
 
1848
  def forward(
1849
  self,