Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1324,6 +1324,9 @@ class MultiHeadAttention(nn.Module):
|
|
1324 |
attention_weights = attention_weights / sqrt_key_size
|
1325 |
if attention_mask is not None:
|
1326 |
attention_weights = torch.where(attention_mask, attention_weights, -1e30)
|
|
|
|
|
|
|
1327 |
if attention_weight_bias is not None:
|
1328 |
attention_weights = F.softmax(
|
1329 |
attention_weights + attention_weight_bias, dim=-1
|
|
|
1324 |
attention_weights = attention_weights / sqrt_key_size
|
1325 |
if attention_mask is not None:
|
1326 |
attention_weights = torch.where(attention_mask, attention_weights, -1e30)
|
1327 |
+
|
1328 |
+
attention_weights = attention_weights.to(value_heads.dtype)
|
1329 |
+
|
1330 |
if attention_weight_bias is not None:
|
1331 |
attention_weights = F.softmax(
|
1332 |
attention_weights + attention_weight_bias, dim=-1
|