Use `torch.amp` rather than deprecated `torch.cuda.amp`
Browse files
torch-ext/triton_layer_norm/layer_norm.py
CHANGED
@@ -10,7 +10,7 @@ import math
|
|
10 |
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
-
from torch.
|
14 |
|
15 |
import triton
|
16 |
import triton.language as tl
|
@@ -981,7 +981,7 @@ class RMSNorm(torch.nn.Module):
|
|
981 |
|
982 |
class LayerNormLinearFn(torch.autograd.Function):
|
983 |
@staticmethod
|
984 |
-
@custom_fwd
|
985 |
def forward(
|
986 |
ctx,
|
987 |
x,
|
@@ -1040,7 +1040,7 @@ class LayerNormLinearFn(torch.autograd.Function):
|
|
1040 |
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1041 |
|
1042 |
@staticmethod
|
1043 |
-
@custom_bwd
|
1044 |
def backward(ctx, dout, *args):
|
1045 |
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1046 |
dout = dout.reshape(-1, dout.shape[-1])
|
|
|
10 |
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
+
from torch.amp import custom_fwd, custom_bwd
|
14 |
|
15 |
import triton
|
16 |
import triton.language as tl
|
|
|
981 |
|
982 |
class LayerNormLinearFn(torch.autograd.Function):
|
983 |
@staticmethod
|
984 |
+
@custom_fwd(device_type="cuda")
|
985 |
def forward(
|
986 |
ctx,
|
987 |
x,
|
|
|
1040 |
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1041 |
|
1042 |
@staticmethod
|
1043 |
+
@custom_bwd(device_type="cuda")
|
1044 |
def backward(ctx, dout, *args):
|
1045 |
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1046 |
dout = dout.reshape(-1, dout.shape[-1])
|