Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -721,6 +721,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
721 |
english_token_ids=english_token_ids,
|
722 |
projected_bio_embeddings=projected_bio_embeddings,
|
723 |
)
|
|
|
724 |
|
725 |
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
|
726 |
|
@@ -927,9 +928,6 @@ class TorchGptGroupedQueryAttention(nn.Module):
|
|
927 |
attention_weights = nn.functional.softmax(attention_logits, dim=-1)
|
928 |
attention_weights = attention_weights.to(values.dtype)
|
929 |
|
930 |
-
print(f"Attention weights type : ", attention_weights.dtype)
|
931 |
-
print(f"Values type : ", values.dtype)
|
932 |
-
|
933 |
values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
|
934 |
values = values.contiguous().view(batch_size, seq_len, -1)
|
935 |
|
|
|
721 |
english_token_ids=english_token_ids,
|
722 |
projected_bio_embeddings=projected_bio_embeddings,
|
723 |
)
|
724 |
+
logits = logits.to(torch.float32)
|
725 |
|
726 |
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
|
727 |
|
|
|
928 |
attention_weights = nn.functional.softmax(attention_logits, dim=-1)
|
929 |
attention_weights = attention_weights.to(values.dtype)
|
930 |
|
|
|
|
|
|
|
931 |
values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
|
932 |
values = values.contiguous().view(batch_size, seq_len, -1)
|
933 |
|