Yanisadel commited on
Commit
862ba01
·
verified ·
1 Parent(s): 7b1c32a

Update chatNT.py (#3)

Browse files

- Update chatNT.py (94f611d6d404ad9fe26c0dcd8ed5ad55f4bed639)

Files changed (1) hide show
  1. chatNT.py +20 -53
chatNT.py CHANGED
@@ -28,7 +28,6 @@ class RotaryEmbeddingConfig:
28
  class PerceiverResamplerConfig:
29
  """
30
  Parameters to initialize an PerceiverResampler model.
31
-
32
  Args:
33
  emb_layer_norm_before: Whether to use layer norm before the first attention
34
  layer.
@@ -93,9 +92,7 @@ class PerceiverResamplerConfig:
93
  class GptConfig:
94
  """
95
  Parameters to initialize a Gpt model.
96
-
97
  NOTE: the pad token is not defined
98
-
99
  Args:
100
  vocab_size: Token vocabulary.
101
  eos_token_id: used to stop sentence generation
@@ -191,7 +188,6 @@ class GptConfig:
191
  class NucleotideTransformerConfig:
192
  """
193
  Parameters to initialize an NT model.
194
-
195
  Args:
196
  alphabet_size: Token vocabulary.
197
  pad_token_id: ID of pad token.
@@ -364,21 +360,20 @@ class ChatNTConfig(PretrainedConfig):
364
  return output
365
 
366
 
367
- class ChatNTDecoder(nn.Module):
368
  def __init__(
369
  self,
370
  gpt_config: GptConfig,
371
  seq_token_id: int,
372
  ):
373
  """
374
- Initializes the ChatNT decoder, using a GPT model for text generation with
375
  bio embeddings.
376
-
377
  Args:
378
  gpt_config: Configuration for the GPT model
379
  seq_token_id: Index of the SEQ token
380
  """
381
- super(ChatNTDecoder, self).__init__()
382
  self.gpt_config = gpt_config
383
  self.seq_token_id = seq_token_id
384
 
@@ -390,13 +385,11 @@ class ChatNTDecoder(nn.Module):
390
  ) -> torch.Tensor:
391
  """
392
  Forward pass through the model.
393
-
394
  Args:
395
  english_token_ids: Tensor of English token IDs with shape
396
  (batch_size, num_english_tokens).
397
  projected_bio_embeddings: Optional tensor of bio embeddings with shape
398
  (batch_size, num_bio_sequences, ?, embed_dim).
399
-
400
  Returns:
401
  torch.Tensor: The logits from the GPT model,
402
  shaped (batch_size, num_english_tokens, vocab_size).
@@ -452,13 +445,11 @@ class ChatNTDecoder(nn.Module):
452
  ) -> Tuple[torch.Tensor, torch.Tensor]:
453
  """
454
  Inserts resampled embeddings in input_embeddings, starting at the SEQ token
455
-
456
  Args:
457
  tokens (torch.Tensor): Shape (batch_size, num_tokens)
458
  input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
459
  resampled_embeddings (torch.Tensor):
460
  Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
461
-
462
  Returns:
463
  Tuple[torch.Tensor, torch.Tensor]:
464
  - input_embeddings with resampled_embeddings inserted at the SEQ token
@@ -521,11 +512,9 @@ class ChatNTDecoder(nn.Module):
521
  ) -> Tuple[torch.Tensor, torch.Tensor]:
522
  """
523
  Removes the logits corresponding to the unused embeddings.
524
-
525
  Args:
526
  tokens: Input english tokens.
527
  logits: Input logits.
528
-
529
  Returns:
530
  Cleaned logits, last values will be equal to 0.
531
  """
@@ -582,7 +571,7 @@ class ChatNTDecoder(nn.Module):
582
  return logits_acc, tokens_acc
583
 
584
 
585
- class ChatNT(PreTrainedModel):
586
  config_class = ChatNTConfig
587
 
588
  def __init__(self, config: ChatNTConfig) -> None:
@@ -625,11 +614,11 @@ class ChatNT(PreTrainedModel):
625
  # Correct seq_token_id
626
  self.seq_token_id -= 1
627
 
628
- self.chatnt_encoder = ChatNTEncoder(nt_config=self.nt_config)
629
- self.chatnt_decoder = ChatNTDecoder(
630
  gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
631
  )
632
- self.projection_model = MultiModalPerceiverResamplerProjection(
633
  perceiver_resampler_config=self.perceiver_resampler_config,
634
  input_embed_dim=self.nt_config.embed_dim,
635
  embed_dim=self.gpt_config.embed_dim,
@@ -645,27 +634,21 @@ class ChatNT(PreTrainedModel):
645
  projected_bio_embeddings: torch.Tensor = None,
646
  ) -> dict[str, torch.Tensor]:
647
  """
648
-
649
  Args:
650
  multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
651
  english_tokens_ids: Represents the prompt tokens (english tokens)
652
  Shape (batch_size, num_english_tokens)
653
-
654
  bio_tokens_ids: Represents the bio sequences tokens
655
  Shape (batch_size, num_bio_sequences, num_bio_tokens)
656
-
657
  projection_english_tokens_ids (torch.Tensor):
658
  Shape (batch_size, num_english_tokens)
659
-
660
  projected_bio_embeddings (projected_bio_embeddings, optional):
661
  Shape (batch_size, num_bio_sequencse, ?, embed_dim).
662
  Defaults to None.
663
-
664
  Returns:
665
  dict[str, torch.Tensor] containing:
666
  - logits:
667
  Shape (batch_size, num_tokens, vocab_size)
668
-
669
  - projected_bio_embeddings:
670
  Shape (batch_size, num_bio_sequences, ?, embed_dim)
671
  """
@@ -702,7 +685,7 @@ class ChatNT(PreTrainedModel):
702
  if projected_bio_embeddings is None:
703
  # Compute bio sequences embeddings
704
  bio_embeddings_list = [
705
- self.chatnt_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
706
  for bio_seq_num in range(num_bio_sequences)
707
  ]
708
 
@@ -718,7 +701,7 @@ class ChatNT(PreTrainedModel):
718
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
719
 
720
  # decode
721
- logits = self.chatnt_decoder(
722
  english_token_ids=english_token_ids,
723
  projected_bio_embeddings=projected_bio_embeddings,
724
  )
@@ -741,7 +724,6 @@ class TorchRotaryEmbedding(torch.nn.Module):
741
  def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
742
  """
743
  Create the sines and cosines for the RoPE.
744
-
745
  Returns:
746
  Sinusoidal positions of shape (self.max_seq_len, self.dim).
747
  """
@@ -774,11 +756,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
774
  def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
775
  """
776
  Prepare a tensor to apply the RoPE mechanism.
777
-
778
  Args:
779
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
780
  typically this is the key or query tensor.
781
-
782
  Returns:
783
  The even indices in the last dimension have their sign flipped.
784
  Tensor of shape (batch_size, seq_len, num_heads, head_dim).
@@ -795,12 +775,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
795
  ) -> torch.Tensor:
796
  """
797
  Applies rotary embeddings to x.
798
-
799
  Args:
800
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
801
  typically this is the key or query tensor.
802
  sincos: Tuple of sine and cosine tensors for position encoding.
803
-
804
  Returns:
805
  RoPE embeddings tensor.
806
  """
@@ -818,12 +796,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
818
  ) -> tuple[torch.Tensor, torch.Tensor]:
819
  """
820
  Applies rotary embeddings to k and q.
821
-
822
  Args:
823
  k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
824
  q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
825
  positions: optional positions offset useful when caching,
826
-
827
  Returns:
828
  RoPE embeddings for the keys and values.
829
  """
@@ -1141,11 +1117,9 @@ def build_causal_attention_mask(
1141
  """
1142
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1143
  to an attention layer.
1144
-
1145
  Args:
1146
  batch_size: Batch size.
1147
  seq_len: Length of the sequences.
1148
-
1149
  Returns:
1150
  Batch of causal masks.
1151
  """
@@ -1498,12 +1472,12 @@ class RobertaLMHead(nn.Module):
1498
  return {"embeddings": embeddings, "logits": logits}
1499
 
1500
 
1501
- class NucleotideTransformer(nn.Module):
1502
  def __init__(
1503
  self,
1504
  nt_config: NucleotideTransformerConfig,
1505
  ):
1506
- super(NucleotideTransformer, self).__init__()
1507
  self.nt_config = nt_config
1508
 
1509
  # Other cases are not implemented
@@ -1551,13 +1525,11 @@ class NucleotideTransformer(nn.Module):
1551
  ) -> torch.Tensor:
1552
  """
1553
  Computes the embeddings based on the input tokens.
1554
-
1555
  Args:
1556
  tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1557
  attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1558
  If no mask is provided, a mask by default which equals 1 over all non
1559
  pad tokens and 0 over pad tokens is computed.
1560
-
1561
  Returns:
1562
  Dictionary containing the final embeddings and logits.
1563
  """
@@ -1585,11 +1557,9 @@ def build_padding_attention_mask(
1585
  ) -> torch.Tensor:
1586
  """
1587
  Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1588
-
1589
  Args:
1590
  tokens: Batch of sequences of shape (batch_size, seq_len).
1591
  pad_token_id: Int corresponding to the <pad> token to mask.
1592
-
1593
  Returns:
1594
  Batch of attention masks, masking out <pad> tokens.
1595
  """
@@ -1599,14 +1569,14 @@ def build_padding_attention_mask(
1599
  return padding_mask
1600
 
1601
 
1602
- class ChatNTEncoder(nn.Module):
1603
  def __init__(
1604
  self,
1605
  nt_config: NucleotideTransformerConfig,
1606
  ):
1607
- super(ChatNTEncoder, self).__init__()
1608
  self.nt_config = nt_config
1609
- self.nt_model = NucleotideTransformer(self.nt_config)
1610
 
1611
  def forward(
1612
  self,
@@ -1616,7 +1586,6 @@ class ChatNTEncoder(nn.Module):
1616
  Args:
1617
  bio_token_ids (torch.Tensor):
1618
  Shape (batch_size, num_bio_tokens)
1619
-
1620
  Returns:
1621
  torch.Tensor:
1622
  Shape (batch_size, num_bio_tokens, embed_dim)
@@ -1626,7 +1595,7 @@ class ChatNTEncoder(nn.Module):
1626
  return bio_embeddings
1627
 
1628
 
1629
- class MultiModalPerceiverResamplerBlock(nn.Module):
1630
  def __init__(
1631
  self,
1632
  num_heads: int,
@@ -1714,7 +1683,7 @@ class MultiModalPerceiverResamplerBlock(nn.Module):
1714
  return {"embeddings": x}
1715
 
1716
 
1717
- class MultiModalPerceiverResampler(nn.Module):
1718
  """
1719
  Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1720
  """
@@ -1726,7 +1695,6 @@ class MultiModalPerceiverResampler(nn.Module):
1726
  ):
1727
  """
1728
  Initialize a Perceiver Resampler model.
1729
-
1730
  Args:
1731
  config: Dataclass containing model hyperparameters.
1732
  name: Name for module (custom will break weight loading).
@@ -1736,7 +1704,7 @@ class MultiModalPerceiverResampler(nn.Module):
1736
  self.name = name
1737
  self.layers = nn.ModuleList(
1738
  [
1739
- MultiModalPerceiverResamplerBlock(
1740
  num_heads=self.config.attention_heads,
1741
  embed_dim=self.config.embed_dim,
1742
  key_size=self.config.key_size,
@@ -1823,7 +1791,7 @@ class MultiModalPerceiverResampler(nn.Module):
1823
  return outs
1824
 
1825
 
1826
- class MultiModalPerceiverResamplerProjection(nn.Module):
1827
  def __init__(
1828
  self,
1829
  perceiver_resampler_config: PerceiverResamplerConfig,
@@ -1843,7 +1811,7 @@ class MultiModalPerceiverResamplerProjection(nn.Module):
1843
 
1844
  self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1845
  self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1846
- self.perceiver_resampler = MultiModalPerceiverResampler(config=self.config)
1847
 
1848
  def forward(
1849
  self,
@@ -1855,10 +1823,8 @@ class MultiModalPerceiverResamplerProjection(nn.Module):
1855
  Args:
1856
  bio_token_ids (torch.Tensor):
1857
  Shape (batch_size, num_bio_tokens)
1858
-
1859
  bio_embeddings (torch.Tensor):
1860
  Shape (batch_size, num_bio_tokens, embed_dim)
1861
-
1862
  english_token_ids (torch.Tensor):
1863
  Shape (batch_size, num_english_tokens)
1864
  """
@@ -1901,3 +1867,4 @@ def build_perceiver_padding_attention_mask(
1901
  padding_mask = padding_mask[:, None, None, :]
1902
  padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1903
  return padding_mask
 
 
28
  class PerceiverResamplerConfig:
29
  """
30
  Parameters to initialize an PerceiverResampler model.
 
31
  Args:
32
  emb_layer_norm_before: Whether to use layer norm before the first attention
33
  layer.
 
92
  class GptConfig:
93
  """
94
  Parameters to initialize a Gpt model.
 
95
  NOTE: the pad token is not defined
 
96
  Args:
97
  vocab_size: Token vocabulary.
98
  eos_token_id: used to stop sentence generation
 
188
  class NucleotideTransformerConfig:
189
  """
190
  Parameters to initialize an NT model.
 
191
  Args:
192
  alphabet_size: Token vocabulary.
193
  pad_token_id: ID of pad token.
 
360
  return output
361
 
362
 
363
+ class TorchBioBrainDecoder(nn.Module):
364
  def __init__(
365
  self,
366
  gpt_config: GptConfig,
367
  seq_token_id: int,
368
  ):
369
  """
370
+ Initializes the BioBrain decoder, using a GPT model for text generation with
371
  bio embeddings.
 
372
  Args:
373
  gpt_config: Configuration for the GPT model
374
  seq_token_id: Index of the SEQ token
375
  """
376
+ super(TorchBioBrainDecoder, self).__init__()
377
  self.gpt_config = gpt_config
378
  self.seq_token_id = seq_token_id
379
 
 
385
  ) -> torch.Tensor:
386
  """
387
  Forward pass through the model.
 
388
  Args:
389
  english_token_ids: Tensor of English token IDs with shape
390
  (batch_size, num_english_tokens).
391
  projected_bio_embeddings: Optional tensor of bio embeddings with shape
392
  (batch_size, num_bio_sequences, ?, embed_dim).
 
393
  Returns:
394
  torch.Tensor: The logits from the GPT model,
395
  shaped (batch_size, num_english_tokens, vocab_size).
 
445
  ) -> Tuple[torch.Tensor, torch.Tensor]:
446
  """
447
  Inserts resampled embeddings in input_embeddings, starting at the SEQ token
 
448
  Args:
449
  tokens (torch.Tensor): Shape (batch_size, num_tokens)
450
  input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
451
  resampled_embeddings (torch.Tensor):
452
  Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
 
453
  Returns:
454
  Tuple[torch.Tensor, torch.Tensor]:
455
  - input_embeddings with resampled_embeddings inserted at the SEQ token
 
512
  ) -> Tuple[torch.Tensor, torch.Tensor]:
513
  """
514
  Removes the logits corresponding to the unused embeddings.
 
515
  Args:
516
  tokens: Input english tokens.
517
  logits: Input logits.
 
518
  Returns:
519
  Cleaned logits, last values will be equal to 0.
520
  """
 
571
  return logits_acc, tokens_acc
572
 
573
 
574
+ class TorchMultiOmicsModel(PreTrainedModel):
575
  config_class = ChatNTConfig
576
 
577
  def __init__(self, config: ChatNTConfig) -> None:
 
614
  # Correct seq_token_id
615
  self.seq_token_id -= 1
616
 
617
+ self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config)
618
+ self.biobrain_decoder = TorchBioBrainDecoder(
619
  gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
620
  )
621
+ self.projection_model = TorchMultiModalPerceiverResamplerProjection(
622
  perceiver_resampler_config=self.perceiver_resampler_config,
623
  input_embed_dim=self.nt_config.embed_dim,
624
  embed_dim=self.gpt_config.embed_dim,
 
634
  projected_bio_embeddings: torch.Tensor = None,
635
  ) -> dict[str, torch.Tensor]:
636
  """
 
637
  Args:
638
  multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
639
  english_tokens_ids: Represents the prompt tokens (english tokens)
640
  Shape (batch_size, num_english_tokens)
 
641
  bio_tokens_ids: Represents the bio sequences tokens
642
  Shape (batch_size, num_bio_sequences, num_bio_tokens)
 
643
  projection_english_tokens_ids (torch.Tensor):
644
  Shape (batch_size, num_english_tokens)
 
645
  projected_bio_embeddings (projected_bio_embeddings, optional):
646
  Shape (batch_size, num_bio_sequencse, ?, embed_dim).
647
  Defaults to None.
 
648
  Returns:
649
  dict[str, torch.Tensor] containing:
650
  - logits:
651
  Shape (batch_size, num_tokens, vocab_size)
 
652
  - projected_bio_embeddings:
653
  Shape (batch_size, num_bio_sequences, ?, embed_dim)
654
  """
 
685
  if projected_bio_embeddings is None:
686
  # Compute bio sequences embeddings
687
  bio_embeddings_list = [
688
+ self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
689
  for bio_seq_num in range(num_bio_sequences)
690
  ]
691
 
 
701
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
702
 
703
  # decode
704
+ logits = self.biobrain_decoder(
705
  english_token_ids=english_token_ids,
706
  projected_bio_embeddings=projected_bio_embeddings,
707
  )
 
724
  def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
725
  """
726
  Create the sines and cosines for the RoPE.
 
727
  Returns:
728
  Sinusoidal positions of shape (self.max_seq_len, self.dim).
729
  """
 
756
  def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
757
  """
758
  Prepare a tensor to apply the RoPE mechanism.
 
759
  Args:
760
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
761
  typically this is the key or query tensor.
 
762
  Returns:
763
  The even indices in the last dimension have their sign flipped.
764
  Tensor of shape (batch_size, seq_len, num_heads, head_dim).
 
775
  ) -> torch.Tensor:
776
  """
777
  Applies rotary embeddings to x.
 
778
  Args:
779
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
780
  typically this is the key or query tensor.
781
  sincos: Tuple of sine and cosine tensors for position encoding.
 
782
  Returns:
783
  RoPE embeddings tensor.
784
  """
 
796
  ) -> tuple[torch.Tensor, torch.Tensor]:
797
  """
798
  Applies rotary embeddings to k and q.
 
799
  Args:
800
  k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
801
  q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
802
  positions: optional positions offset useful when caching,
 
803
  Returns:
804
  RoPE embeddings for the keys and values.
805
  """
 
1117
  """
1118
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1119
  to an attention layer.
 
1120
  Args:
1121
  batch_size: Batch size.
1122
  seq_len: Length of the sequences.
 
1123
  Returns:
1124
  Batch of causal masks.
1125
  """
 
1472
  return {"embeddings": embeddings, "logits": logits}
1473
 
1474
 
1475
+ class TorchNucleotideTransformer(nn.Module):
1476
  def __init__(
1477
  self,
1478
  nt_config: NucleotideTransformerConfig,
1479
  ):
1480
+ super(TorchNucleotideTransformer, self).__init__()
1481
  self.nt_config = nt_config
1482
 
1483
  # Other cases are not implemented
 
1525
  ) -> torch.Tensor:
1526
  """
1527
  Computes the embeddings based on the input tokens.
 
1528
  Args:
1529
  tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1530
  attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1531
  If no mask is provided, a mask by default which equals 1 over all non
1532
  pad tokens and 0 over pad tokens is computed.
 
1533
  Returns:
1534
  Dictionary containing the final embeddings and logits.
1535
  """
 
1557
  ) -> torch.Tensor:
1558
  """
1559
  Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
 
1560
  Args:
1561
  tokens: Batch of sequences of shape (batch_size, seq_len).
1562
  pad_token_id: Int corresponding to the <pad> token to mask.
 
1563
  Returns:
1564
  Batch of attention masks, masking out <pad> tokens.
1565
  """
 
1569
  return padding_mask
1570
 
1571
 
1572
+ class TorchBioBrainEncoder(nn.Module):
1573
  def __init__(
1574
  self,
1575
  nt_config: NucleotideTransformerConfig,
1576
  ):
1577
+ super(TorchBioBrainEncoder, self).__init__()
1578
  self.nt_config = nt_config
1579
+ self.nt_model = TorchNucleotideTransformer(self.nt_config)
1580
 
1581
  def forward(
1582
  self,
 
1586
  Args:
1587
  bio_token_ids (torch.Tensor):
1588
  Shape (batch_size, num_bio_tokens)
 
1589
  Returns:
1590
  torch.Tensor:
1591
  Shape (batch_size, num_bio_tokens, embed_dim)
 
1595
  return bio_embeddings
1596
 
1597
 
1598
+ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1599
  def __init__(
1600
  self,
1601
  num_heads: int,
 
1683
  return {"embeddings": x}
1684
 
1685
 
1686
+ class TorchMultiModalPerceiverResampler(nn.Module):
1687
  """
1688
  Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1689
  """
 
1695
  ):
1696
  """
1697
  Initialize a Perceiver Resampler model.
 
1698
  Args:
1699
  config: Dataclass containing model hyperparameters.
1700
  name: Name for module (custom will break weight loading).
 
1704
  self.name = name
1705
  self.layers = nn.ModuleList(
1706
  [
1707
+ TorchMultiModalPerceiverResamplerBlock(
1708
  num_heads=self.config.attention_heads,
1709
  embed_dim=self.config.embed_dim,
1710
  key_size=self.config.key_size,
 
1791
  return outs
1792
 
1793
 
1794
+ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1795
  def __init__(
1796
  self,
1797
  perceiver_resampler_config: PerceiverResamplerConfig,
 
1811
 
1812
  self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1813
  self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1814
+ self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1815
 
1816
  def forward(
1817
  self,
 
1823
  Args:
1824
  bio_token_ids (torch.Tensor):
1825
  Shape (batch_size, num_bio_tokens)
 
1826
  bio_embeddings (torch.Tensor):
1827
  Shape (batch_size, num_bio_tokens, embed_dim)
 
1828
  english_token_ids (torch.Tensor):
1829
  Shape (batch_size, num_english_tokens)
1830
  """
 
1867
  padding_mask = padding_mask[:, None, None, :]
1868
  padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1869
  return padding_mask
1870
+