Tom Aarsen
commited on
Commit
·
30e2384
1
Parent(s):
12e66dc
Cast attention_mask to bool in SDPA
Browse filesI'm pretty sure this is correct. It was an int tensor before, which SDPA doesn't like
model.py
CHANGED
@@ -190,7 +190,7 @@ class EncoderBlock(nn.Module):
|
|
190 |
query=xq.transpose(1, 2),
|
191 |
key=xk.transpose(1, 2),
|
192 |
value=xv.transpose(1, 2),
|
193 |
-
attn_mask=attention_mask,
|
194 |
dropout_p=0,
|
195 |
).transpose(1, 2)
|
196 |
|
|
|
190 |
query=xq.transpose(1, 2),
|
191 |
key=xk.transpose(1, 2),
|
192 |
value=xv.transpose(1, 2),
|
193 |
+
attn_mask=attention_mask.bool(),
|
194 |
dropout_p=0,
|
195 |
).transpose(1, 2)
|
196 |
|