Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).

#17
by Birchlabs - opened

Prevents RuntimeError on line 382's pad_input(…).reshape()
shape '[1, 4096, 4096]' is invalid for input of size 9400320

before this change, pad_input() was basically just doing a .unsqueeze(0):
attn_output.shape
torch.Size([2295, 32, 128])
pad_input(attn_output, indices_q, bsz, max_seqlen_q).shape
torch.Size([1, 2295, 32, 128])

after this change: pad_input actually pads the input back to the original query sequence length:
pad_input(attn_output, indices_q, bsz, q_len).shape
torch.Size([1, 4096, 32, 128])
and the reshape succeeds:
pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size).shape
torch.Size([1, 4096, 4096])

I was getting a similar error. Thanks for the change... when will this get merged?

Ready to merge
This branch is ready to get merged automatically.
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment