Use `torch.amp` rather than deprecated `torch.cuda.amp`
Browse files
torch-ext/triton_layer_norm/layer_norm.py
CHANGED
|
@@ -10,7 +10,7 @@ import math
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn.functional as F
|
| 13 |
-
from torch.
|
| 14 |
|
| 15 |
import triton
|
| 16 |
import triton.language as tl
|
|
@@ -981,7 +981,7 @@ class RMSNorm(torch.nn.Module):
|
|
| 981 |
|
| 982 |
class LayerNormLinearFn(torch.autograd.Function):
|
| 983 |
@staticmethod
|
| 984 |
-
@custom_fwd
|
| 985 |
def forward(
|
| 986 |
ctx,
|
| 987 |
x,
|
|
@@ -1040,7 +1040,7 @@ class LayerNormLinearFn(torch.autograd.Function):
|
|
| 1040 |
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1041 |
|
| 1042 |
@staticmethod
|
| 1043 |
-
@custom_bwd
|
| 1044 |
def backward(ctx, dout, *args):
|
| 1045 |
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1046 |
dout = dout.reshape(-1, dout.shape[-1])
|
|
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn.functional as F
|
| 13 |
+
from torch.amp import custom_fwd, custom_bwd
|
| 14 |
|
| 15 |
import triton
|
| 16 |
import triton.language as tl
|
|
|
|
| 981 |
|
| 982 |
class LayerNormLinearFn(torch.autograd.Function):
|
| 983 |
@staticmethod
|
| 984 |
+
@custom_fwd(device_type="cuda")
|
| 985 |
def forward(
|
| 986 |
ctx,
|
| 987 |
x,
|
|
|
|
| 1040 |
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
| 1041 |
|
| 1042 |
@staticmethod
|
| 1043 |
+
@custom_bwd(device_type="cuda")
|
| 1044 |
def backward(ctx, dout, *args):
|
| 1045 |
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
| 1046 |
dout = dout.reshape(-1, dout.shape[-1])
|