Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -640,7 +640,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
640 |
|
641 |
def forward(
|
642 |
self,
|
643 |
-
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
|
644 |
projection_english_tokens_ids: torch.Tensor,
|
645 |
projected_bio_embeddings: torch.Tensor = None,
|
646 |
) -> dict[str, torch.Tensor]:
|
@@ -671,8 +671,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
671 |
"""
|
672 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
673 |
english_token_ids = english_token_ids.clone()
|
674 |
-
bio_token_ids = bio_token_ids.clone()
|
675 |
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
|
|
|
|
676 |
if projected_bio_embeddings is not None:
|
677 |
projected_bio_embeddings = projected_bio_embeddings.clone()
|
678 |
|
|
|
640 |
|
641 |
def forward(
|
642 |
self,
|
643 |
+
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor | None],
|
644 |
projection_english_tokens_ids: torch.Tensor,
|
645 |
projected_bio_embeddings: torch.Tensor = None,
|
646 |
) -> dict[str, torch.Tensor]:
|
|
|
671 |
"""
|
672 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
673 |
english_token_ids = english_token_ids.clone()
|
|
|
674 |
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
675 |
+
if bio_token_ids is not None:
|
676 |
+
bio_token_ids = bio_token_ids.clone()
|
677 |
if projected_bio_embeddings is not None:
|
678 |
projected_bio_embeddings = projected_bio_embeddings.clone()
|
679 |
|