Spaces:
Running
Running
File size: 29,164 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 |
import logging
from typing import Optional, Tuple
import torch
import torch.nn
import torch.nn.functional as F
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)
from torch.nn.attention import SDPBackend
from .nested_tensor import NestedTensor
log = logging.getLogger(__name__)
def _validate_sdpa_input(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
if (
not isinstance(query, NestedTensor)
or not isinstance(key, NestedTensor)
or not isinstance(value, NestedTensor)
):
raise ValueError(
f"Expected query, key, and value to be nested tensors, "
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
f"and value.is_nested: {value.is_nested} instead."
)
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
if query.device != key.device or query.device != value.device:
raise ValueError(
f"Expected query, key, and value to have the same device type, "
f"but got query.device: {query.device}, key.device: {key.device}, "
f"and value.device: {value.device} instead."
)
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
raise ValueError(
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
raise ValueError(
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
)
if attn_mask is not None:
# TODO: Figure out whether masks are actually supported for this layout or not
raise ValueError("Masks are not yet supported!")
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
raise ValueError(
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
)
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
# This is expected to be called after check_tensor_shapes ensuring that the
# size() calls won't error since the inputs are all 4 dimensional
q_batch_size = params.query.size(0)
k_batch_size = params.key.size(0)
v_batch_size = params.value.size(0)
# num_heads logic for nested input is checked in
# check_for_seq_len_0_nested_tensor as there is handling there to make sure
# num_heads is not ragged
return q_batch_size == k_batch_size and q_batch_size == v_batch_size
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
max_size = 256
query_size_last = params.query.size(-1)
key_size_last = params.key.size(-1)
value_size_last = params.value.size(-1)
same_head_dim_size = (
query_size_last == key_size_last and query_size_last == value_size_last
)
if not (
same_head_dim_size
and (query_size_last % 8 == 0)
and (query_size_last <= max_size)
):
if debug:
log.warning(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same "
"last dimension and to be a multiple of 8 and less than or equal to 256. "
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
query_size_last,
key_size_last,
value_size_last,
)
return False
return True
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
param: torch.Tensor, param_name: str, debug=False
) -> bool:
assert isinstance(param, NestedTensor), "param should be a jagged NT"
if param._ragged_idx == 1:
# num_head_dims is ragged
if debug:
log.warning(
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
param_name,
)
return False
# This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
if param._min_seqlen == 0:
if debug:
log.warning(
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
param_name,
)
return False
return True
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
max_size = max(q_size, k_size, v_size)
if (
(q_size != max_size and q_size != 1)
or (k_size != max_size and k_size != 1)
or (v_size != max_size and v_size != 1)
):
if debug:
log.warning(
"Both fused kernels require query, key and value to have broadcastable %s, "
"got Query %s %d, Key %s %d, Value %s %d instead.",
param_name,
param_name,
q_size,
param_name,
k_size,
param_name,
v_size,
)
return False
return True
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
# When this function is called we are assured that the nt is dim==4
q_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.query, "query", debug
)
if params.query.is_nested
else True
)
# short circuit if any is unsafe
if not q_is_safe:
return False
k_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.key, "key", debug
)
if params.key.is_nested
else True
)
# short circuit if any is unsafe
if not k_is_safe:
return False
v_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.value, "value", debug
)
if params.value.is_nested
else True
)
# short circuit if any is unsafe
if not v_is_safe:
return False
# We now know none of the inputs have ragged num_heads, so we can safely
# access .size(1)
q_num_heads = params.query.size(1)
k_num_heads = params.key.size(1)
v_num_heads = params.value.size(1)
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
if not same_num_heads:
if (
params.query.requires_grad
or params.key.requires_grad
or params.value.requires_grad
):
if debug:
log.warning(
"Both fused kernels do not support training with broadcasted NT inputs."
)
return False
return _try_broadcast_param_size(
q_num_heads, k_num_heads, v_num_heads, "num heads", debug
)
return True
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_head_dim_size_flash_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
if (
not params.query.transpose(1, 2).is_contiguous()
or not params.key.transpose(1, 2).is_contiguous()
or not params.value.transpose(1, 2).is_contiguous()
):
if debug:
log.warning(
"If inputs are nested tensors they must be contiguous after transposing."
)
return False
if params.is_causal:
if debug:
log.warning(
"Nested tensors for query / key are not supported when is_causal=True."
)
return False
return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
if (
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
and not math_sdp_enabled()
):
return SDPBackend.ERROR
ordering = (
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION:
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
return SDPBackend.FLASH_ATTENTION
if backend == SDPBackend.EFFICIENT_ATTENTION:
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
params
):
return SDPBackend.EFFICIENT_ATTENTION
if backend == SDPBackend.MATH:
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
return SDPBackend.MATH
log.warning("Memory efficient kernel not used because:")
can_use_efficient_attention(params, debug=True)
_can_use_efficient_sdpa_jagged(params, debug=True)
log.warning("Flash attention kernel not used because:")
can_use_flash_attention(params, debug=True)
_can_use_flash_sdpa_jagged(params, debug=True)
log.warning("Math attention kernel not used because:")
_can_use_math_sdpa_jagged(params, debug=True)
return SDPBackend.ERROR
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
# This function is used to calculate two pieces of metadata that are needed
# for use with flash-attention and efficient_attention kernels. They are the
# cumulative sequence_length over a batch of sequences and the maximum
# sequence length.
# It returns a tuple of cumulative sequence lengths and the maximum sequence
# length, and the last element in the cumulative_sequence_lengths
if not isinstance(qkv, NestedTensor):
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
if qkv.lengths() is None:
# TODO: Explore performance impact of copying
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
max_seqlen = qkv._max_seqlen
n_elem = qkv.values().shape[0]
else:
# TODO: Explore performance impact of copying
cumulative_seqlen = (
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
)
batch_size = qkv.size(0)
max_seqlen = qkv._max_seqlen
# TODO: Explore performance impact when compiling
n_elem = int(cumulative_seqlen[-1].item())
return cumulative_seqlen, max_seqlen, n_elem
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
# This function checks if a nested tensor is valid for
# use with the flash-attention and efficient_attention kernels without
# needing to call contiguous on the nested tensor input.
# It checks that the storage offsets' adjacent_differences are a constant
# mutiple of the previous tensor in the nested tensor and that the strides
# are monitonically decreasing. This check is done after calling transpose on
# the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
# Returns a boolean indicating if contiguous needs to be called for input
assert isinstance(tensor, NestedTensor)
offsets = tensor.offsets()
strides = tensor._strides
n_tensors = offsets.size(0) - 1
if n_tensors <= 1:
return True
# Check initially that the tensor strides are in strictly descending order
prev_stride = strides[1]
for stride in strides[2:]:
if prev_stride <= stride:
# This would mean that the last stride is greater than the seq_len
# stride
return False
prev_stride = stride
# Congrats you made it!
return True
def _view_as_dense(
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
) -> torch.Tensor:
if tensor.is_nested:
return tensor.values()
return tensor.view(Nnz, num_heads, head_dim)
# TODO: Next iteration should add test cases and check it works
# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# q_batch_size = query.size(0)
# k_batch_size = key.size(0)
# v_batch_size = value.size(0)
# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
# q_num_heads = query.size(1)
# k_num_heads = key.size(1)
# v_num_heads = value.size(1)
# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
# head_dim_qk = query.size(3)
# head_dim_v = value.size(3)
# q_t = query.transpose(1, 2)
# k_t = key.transpose(1, 2)
# v_t = value.transpose(1, 2)
# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
# # output_batch_size/num_heads then they are 1
# q_batch_size_needs_broadcast = q_batch_size != output_batch_size
# k_batch_size_needs_broadcast = k_batch_size != output_batch_size
# v_batch_size_needs_broadcast = v_batch_size != output_batch_size
# # If {*}_batch_size_needs_broadcast, then
# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
# # this is because needs_broadcast indicates that the batch_size is 1
# # and hence there is only 1 value for seq_len
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
# # ..., outut_batch_size * {*}_t.size(1)]
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
# if q_batch_size_needs_broadcast or not q_t.is_nested:
# max_seqlen_batch_q = q_t.size(1)
# cumulative_sequence_length_q = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_q,
# max_seqlen_batch_q,
# device=q_t.device,
# dtype=torch.int32,
# )
# Nnz_q = output_batch_size * max_seqlen_batch_q
# else:
# (
# cumulative_sequence_length_q,
# max_seqlen_batch_q,
# Nnz_q,
# ) = _cumulative_and_max_seq_len_nnz(q_t)
# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
# assert k_t.size(1) == v_t.size(1)
# max_seqlen_batch_kv = k_t.size(1)
# cumulative_sequence_length_kv = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_kv,
# max_seqlen_batch_kv,
# device=k_t.device,
# dtype=torch.int32,
# )
# Nnz_kv = output_batch_size * max_seqlen_batch_kv
# else:
# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
# _cumulative_and_max_seq_len_nnz(v_t)
# if k_batch_size_needs_broadcast
# else _cumulative_and_max_seq_len_nnz(k_t)
# )
# q_num_heads_needs_broadcast = q_num_heads != output_num_heads
# k_num_heads_needs_broadcast = k_num_heads != output_num_heads
# v_num_heads_needs_broadcast = v_num_heads != output_num_heads
# if not q_t.is_nested:
# query_buffer_reshaped = q_t.expand(
# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
# )
# query_buffer_reshaped = query_buffer_reshaped.reshape(
# Nnz_q, output_num_heads, head_dim_qk
# )
# else:
# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
# q_t = q_t.contiguous()
# # If we are broadcasting then Nnz_q will be the output_batch_size since
# # seq_len is 1
# effective_batch_size_q = (
# output_batch_size if q_batch_size_needs_broadcast else Nnz_q
# )
# query_buffer_reshaped = _view_as_dense(
# q_t, effective_batch_size_q, output_num_heads, head_dim_qk
# )
# # If the physical layout of the NestedTensor's storage
# # is not: batch, {seq_len}, num_heads, head_dim then we need
# # to call contiguous
# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
# k_t = k_t.contiguous()
# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
# v_t = v_t.contiguous()
# effective_batch_size_k = (
# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
# )
# key_buffer_reshaped = _view_as_dense(
# k_t, effective_batch_size_k, output_num_heads, head_dim_qk
# )
# effective_batch_size_v = (
# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
# )
# value_buffer_reshaped = _view_as_dense(
# v_t, effective_batch_size_v, output_num_heads, head_dim_v
# )
# if not q_batch_size_needs_broadcast:
# output_shape = q_t._size
# if head_dim_v != head_dim_qk:
# output_shape[-1] = head_dim_v
# if q_num_heads_needs_broadcast:
# output_shape[1] = output_num_heads
# else:
# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
# output_shape[0] = q_t.size(1)
# output_shape[1] = output_num_heads
# output_shape[2] = head_dim_v
# return (
# query_buffer_reshaped,
# key_buffer_reshaped,
# value_buffer_reshaped,
# cumulative_sequence_length_q,
# cumulative_sequence_length_kv,
# max_seqlen_batch_q,
# max_seqlen_batch_kv,
# output_shape,
# )
def _sdpa_nested_preprocessing(query, key, value):
# Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
q_batch_size = query.size(0)
k_batch_size = key.size(0)
v_batch_size = value.size(0)
q_num_heads = query.size(1)
k_num_heads = key.size(1)
v_num_heads = value.size(1)
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
q_num_heads == k_num_heads and k_num_heads == v_num_heads
):
raise RuntimeError(
"This path is currently not implemented for jagged layout NT."
)
# return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
num_heads = query.size(1)
head_dim_qk = query.size(3)
head_dim_v = value.size(3)
q_t = query.transpose(1, 2)
k_t = key.transpose(1, 2)
v_t = value.transpose(1, 2)
(
cumulative_sequence_length_q,
max_seqlen_batch_q,
Nnz_q,
) = _cumulative_and_max_seq_len_nnz(q_t)
(
cumulative_sequence_length_kv,
max_seqlen_batch_kv,
Nnz_kv,
) = _cumulative_and_max_seq_len_nnz(k_t)
# [TODO] K and V have to have the same Nnz, should probably torch_check
# assume in order to not iterate over v
# If the physical layout of the NestedTensor's storage
# is not: batch, {seq_len}, num_heads, head_dim then we need
# to call contiguous
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
q_t = q_t.contiguous()
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
k_t = k_t.contiguous()
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
v_t = v_t.contiguous()
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
output_nt_info = {
"offsets": q_t.offsets(),
"_max_seqlen": q_t._max_seqlen,
"_min_seqlen": q_t._min_seqlen,
}
return (
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
)
def _pad_last_dim(
tensor: torch.Tensor, alignment_size: int, slice: bool
) -> torch.Tensor:
# FlashAttentionV2 requires that head dimension be a multiple of 8
# This was previously done within the kernel, however
# This causes the kernel to maybe alias query, key, value
# So instead we pad the head_dimensions to be a multiple of 8
# in the composite region
last_dim_size = tensor.size(-1)
if last_dim_size % alignment_size == 0:
return tensor
pad_count = alignment_size - (last_dim_size % alignment_size)
tensor = torch.nn.functional.pad(tensor, [0, pad_count])
if slice:
return tensor[..., 0:last_dim_size]
return tensor
# TODO: coalesce with torch/nn/utils/attention.py
def _calculate_scale(query, scale):
# TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
return softmax_scale
def _post_process_flash_output(out: torch.Tensor, og_size):
if not out.is_nested and out.size(-1) != og_size:
out = out[..., 0:og_size]
return out
def jagged_scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh
assert (
isinstance(query, NestedTensor)
and isinstance(key, NestedTensor)
and isinstance(value, NestedTensor)
)
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
# second batch dim instead). For this case, we can just send the dense buffers through
# vanilla SDPA.
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
from torch.nested._internal.ops import extract_kwargs
output = F.scaled_dot_product_attention(
query._values,
key._values,
value._values,
attn_mask=(
attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return NestedTensor(output, **extract_kwargs(query))
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal
)
if backend_choice == SDPBackend.FLASH_ATTENTION:
og_size = query.size(-1)
query_padded = _pad_last_dim(query, 8, False)
key_padded = _pad_last_dim(key, 8, False)
value_padded = _pad_last_dim(value, 8, False)
# We need to calculate the scale based off the OG head dim size
og_scale = _calculate_scale(query, scale)
(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
(
attention,
logsumexp,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
is_causal,
False,
scale=og_scale,
)
# Reshape output to convert nnz to batch_size and seq_len
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
attention = nested_view_from_values_offsets(
attention.squeeze(0), output_nt_info["offsets"]
).transpose(1, 2)
return _post_process_flash_output(attention, og_size)
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
(
query_reshaped,
key_reshaped,
value_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query, key, value)
(
attention,
log_sumexp,
seed,
offset,
max_seqlen_q,
max_seqlen_batch_kv,
) = torch.ops.aten._efficient_attention_forward(
query_reshaped.unsqueeze(0),
key_reshaped.unsqueeze(0),
value_reshaped.unsqueeze(0),
None,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
int(is_causal),
compute_logsumexp,
scale=scale,
)
# Reshape output to convert nnz to batch_size and seq_len
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
return nested_view_from_values_offsets(
attention.squeeze(0), output_nt_info["offsets"]
).transpose(1, 2)
elif backend_choice == SDPBackend.MATH:
# save the offsets and shape of the inputs, so we can reshape the final output
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
offsets = query.offsets()
d1 = query._size[1]
d2 = value._size[-1]
# convert jagged layout Nested Tensor to strided layout Nested Tensor
# which support the math implementation of SDPA
def get_strided_layout_nested_tensor(jagged_layout_nt):
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
transpose = torch.transpose(jagged_layout_nt, 1, 2)
tensor_list = transpose.values().split(list(lengths), dim=0)
strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
strided_nt = strided_nt.transpose(1, 2).contiguous()
return strided_nt
query = get_strided_layout_nested_tensor(query)
key = get_strided_layout_nested_tensor(key)
value = get_strided_layout_nested_tensor(value)
attn_out = torch._scaled_dot_product_attention_math(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)[0]
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
attn_out = attn_out.transpose(1, 2).contiguous().values()
attn_out = attn_out.view(-1, d1, d2)
attn_out = nested_view_from_values_offsets(attn_out, offsets)
attn_out = attn_out.transpose(1, 2)
return attn_out
else:
raise RuntimeError(
"No viable backend for scaled_dot_product_attention was found."
)
|