Yanisadel commited on
Commit
6ed7d0b
·
verified ·
1 Parent(s): 18c36cb

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +3 -2
chatNT.py CHANGED
@@ -640,7 +640,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
640
 
641
  def forward(
642
  self,
643
- multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
644
  projection_english_tokens_ids: torch.Tensor,
645
  projected_bio_embeddings: torch.Tensor = None,
646
  ) -> dict[str, torch.Tensor]:
@@ -671,8 +671,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
671
  """
672
  english_token_ids, bio_token_ids = multi_omics_tokens_ids
673
  english_token_ids = english_token_ids.clone()
674
- bio_token_ids = bio_token_ids.clone()
675
  projection_english_tokens_ids = projection_english_tokens_ids.clone()
 
 
676
  if projected_bio_embeddings is not None:
677
  projected_bio_embeddings = projected_bio_embeddings.clone()
678
 
 
640
 
641
  def forward(
642
  self,
643
+ multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor | None],
644
  projection_english_tokens_ids: torch.Tensor,
645
  projected_bio_embeddings: torch.Tensor = None,
646
  ) -> dict[str, torch.Tensor]:
 
671
  """
672
  english_token_ids, bio_token_ids = multi_omics_tokens_ids
673
  english_token_ids = english_token_ids.clone()
 
674
  projection_english_tokens_ids = projection_english_tokens_ids.clone()
675
+ if bio_token_ids is not None:
676
+ bio_token_ids = bio_token_ids.clone()
677
  if projected_bio_embeddings is not None:
678
  projected_bio_embeddings = projected_bio_embeddings.clone()
679