Yanisadel commited on
Commit
6d6a20f
·
verified ·
1 Parent(s): 3bd05b8

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +3 -0
chatNT.py CHANGED
@@ -1324,6 +1324,9 @@ class MultiHeadAttention(nn.Module):
1324
  attention_weights = attention_weights / sqrt_key_size
1325
  if attention_mask is not None:
1326
  attention_weights = torch.where(attention_mask, attention_weights, -1e30)
 
 
 
1327
  if attention_weight_bias is not None:
1328
  attention_weights = F.softmax(
1329
  attention_weights + attention_weight_bias, dim=-1
 
1324
  attention_weights = attention_weights / sqrt_key_size
1325
  if attention_mask is not None:
1326
  attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1327
+
1328
+ attention_weights = attention_weights.to(value_heads.dtype)
1329
+
1330
  if attention_weight_bias is not None:
1331
  attention_weights = F.softmax(
1332
  attention_weights + attention_weight_bias, dim=-1