Tom Aarsen commited on
Commit
30e2384
·
1 Parent(s): 12e66dc

Cast attention_mask to bool in SDPA

Browse files

I'm pretty sure this is correct. It was an int tensor before, which SDPA doesn't like

Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -190,7 +190,7 @@ class EncoderBlock(nn.Module):
190
  query=xq.transpose(1, 2),
191
  key=xk.transpose(1, 2),
192
  value=xv.transpose(1, 2),
193
- attn_mask=attention_mask,
194
  dropout_p=0,
195
  ).transpose(1, 2)
196
 
 
190
  query=xq.transpose(1, 2),
191
  key=xk.transpose(1, 2),
192
  value=xv.transpose(1, 2),
193
+ attn_mask=attention_mask.bool(),
194
  dropout_p=0,
195
  ).transpose(1, 2)
196