danieldk HF Staff commited on
Commit
02598b2
·
1 Parent(s): 02bea52

Use `torch.amp` rather than deprecated `torch.cuda.amp`

Browse files
torch-ext/triton_layer_norm/layer_norm.py CHANGED
@@ -10,7 +10,7 @@ import math
10
 
11
  import torch
12
  import torch.nn.functional as F
13
- from torch.cuda.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
@@ -981,7 +981,7 @@ class RMSNorm(torch.nn.Module):
981
 
982
  class LayerNormLinearFn(torch.autograd.Function):
983
  @staticmethod
984
- @custom_fwd
985
  def forward(
986
  ctx,
987
  x,
@@ -1040,7 +1040,7 @@ class LayerNormLinearFn(torch.autograd.Function):
1040
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1041
 
1042
  @staticmethod
1043
- @custom_bwd
1044
  def backward(ctx, dout, *args):
1045
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1046
  dout = dout.reshape(-1, dout.shape[-1])
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
+ from torch.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
 
981
 
982
  class LayerNormLinearFn(torch.autograd.Function):
983
  @staticmethod
984
+ @custom_fwd(device_type="cuda")
985
  def forward(
986
  ctx,
987
  x,
 
1040
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1041
 
1042
  @staticmethod
1043
+ @custom_bwd(device_type="cuda")
1044
  def backward(ctx, dout, *args):
1045
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1046
  dout = dout.reshape(-1, dout.shape[-1])