Spaces:
Running
on
Zero
Running
on
Zero
Upload x_transformer_1_23_2.py
Browse files- x_transformer_1_23_2.py +11 -4
x_transformer_1_23_2.py
CHANGED
|
@@ -268,7 +268,8 @@ class Attend(nn.Module):
|
|
| 268 |
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 269 |
|
| 270 |
# PyTorch 2.3-2.4 SDPA backend code...
|
| 271 |
-
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
|
|
|
| 272 |
|
| 273 |
# New PyTorch 2.5 SDPA backend code:
|
| 274 |
# with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
|
@@ -501,7 +502,8 @@ class AutoregressiveWrapper(Module):
|
|
| 501 |
ignore_index = -100,
|
| 502 |
pad_value = 0,
|
| 503 |
mask_prob = 0.,
|
| 504 |
-
add_attn_z_loss = False
|
|
|
|
| 505 |
):
|
| 506 |
super().__init__()
|
| 507 |
self.pad_value = pad_value
|
|
@@ -516,6 +518,7 @@ class AutoregressiveWrapper(Module):
|
|
| 516 |
|
| 517 |
# whether to add router z-loss
|
| 518 |
self.add_attn_z_loss = add_attn_z_loss
|
|
|
|
| 519 |
|
| 520 |
@torch.inference_mode()
|
| 521 |
@eval_decorator
|
|
@@ -709,8 +712,12 @@ class AutoregressiveWrapper(Module):
|
|
| 709 |
|
| 710 |
if add_attn_z_loss:
|
| 711 |
loss = loss + cache.attn_z_loss
|
| 712 |
-
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
#===============================================================================
|
| 716 |
|
|
|
|
| 268 |
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 269 |
|
| 270 |
# PyTorch 2.3-2.4 SDPA backend code...
|
| 271 |
+
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
| 272 |
+
with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
|
| 273 |
|
| 274 |
# New PyTorch 2.5 SDPA backend code:
|
| 275 |
# with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
|
|
|
| 502 |
ignore_index = -100,
|
| 503 |
pad_value = 0,
|
| 504 |
mask_prob = 0.,
|
| 505 |
+
add_attn_z_loss = False,
|
| 506 |
+
return_cache=False
|
| 507 |
):
|
| 508 |
super().__init__()
|
| 509 |
self.pad_value = pad_value
|
|
|
|
| 518 |
|
| 519 |
# whether to add router z-loss
|
| 520 |
self.add_attn_z_loss = add_attn_z_loss
|
| 521 |
+
self.return_cache = return_cache
|
| 522 |
|
| 523 |
@torch.inference_mode()
|
| 524 |
@eval_decorator
|
|
|
|
| 712 |
|
| 713 |
if add_attn_z_loss:
|
| 714 |
loss = loss + cache.attn_z_loss
|
| 715 |
+
|
| 716 |
+
if self.return_cache:
|
| 717 |
+
return loss, acc, cache
|
| 718 |
+
|
| 719 |
+
else:
|
| 720 |
+
return loss, acc
|
| 721 |
|
| 722 |
#===============================================================================
|
| 723 |
|