Yanisadel commited on
Commit
18c36cb
·
verified ·
1 Parent(s): 112bf64

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +1 -3
chatNT.py CHANGED
@@ -721,6 +721,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
721
  english_token_ids=english_token_ids,
722
  projected_bio_embeddings=projected_bio_embeddings,
723
  )
 
724
 
725
  outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
726
 
@@ -927,9 +928,6 @@ class TorchGptGroupedQueryAttention(nn.Module):
927
  attention_weights = nn.functional.softmax(attention_logits, dim=-1)
928
  attention_weights = attention_weights.to(values.dtype)
929
 
930
- print(f"Attention weights type : ", attention_weights.dtype)
931
- print(f"Values type : ", values.dtype)
932
-
933
  values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
934
  values = values.contiguous().view(batch_size, seq_len, -1)
935
 
 
721
  english_token_ids=english_token_ids,
722
  projected_bio_embeddings=projected_bio_embeddings,
723
  )
724
+ logits = logits.to(torch.float32)
725
 
726
  outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
727
 
 
928
  attention_weights = nn.functional.softmax(attention_logits, dim=-1)
929
  attention_weights = attention_weights.to(values.dtype)
930
 
 
 
 
931
  values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
932
  values = values.contiguous().view(batch_size, seq_len, -1)
933