pr-include-rev-in-flake

#1
by drbh HF Staff - opened
README.md CHANGED
@@ -1,80 +1,9 @@
1
- ---
2
- license: bsd-3-clause
3
- tags:
4
- - kernel
5
- ---
6
- # Triton layer normalization kernels.
7
-
8
- This kernel implements layers normalization using Triton. This kernel is from
9
- the [flash-attention](https://github.com/Dao-AILab/flash-attention) project.
10
-
11
- ## Functions
12
-
13
- ### Function `layer_norm`
14
-
15
- `(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, residual: Optional[torch.Tensor] = None, x1: Optional[torch.Tensor] = None, weight1: Optional[torch.Tensor] = None, bias1: Optional[torch.Tensor] = None, eps: float = 1e-06, dropout_p: float = 0.0, rowscale=None, prenorm: bool = False, residual_in_fp32: bool = False, is_rms_norm: bool = False, return_dropout_mask: bool = False, out: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None)`
16
-
17
- Apply layer normalization to the input tensor with Triton acceleration.
18
-
19
- ### Parameters
20
-
21
- - **x** (*torch.Tensor*) --
22
- Input tensor to normalize.
23
- - **weight** (*torch.Tensor*) --
24
- Scale parameter for normalization.
25
- - **bias** (*torch.Tensor*) --
26
- Shift parameter for normalization.
27
- - **residual** (*torch.Tensor*, *optional*) --
28
- Optional residual tensor to add to the input before normalization.
29
- - **x1** (*torch.Tensor*, *optional*) --
30
- Optional second input tensor to combine with *x*. When provided, the function
31
- first adds *x1* to *x* and then applies normalization.
32
- - **weight1** (*torch.Tensor*, *optional*) --
33
- Scale parameter for the second normalization.
34
- - **bias1** (*torch.Tensor*, *optional*) --
35
- Shift parameter for the second normalization.
36
- - **eps** (*float*, *optional*, defaults to 1e-6) --
37
- Small constant added for numerical stability in normalization.
38
- - **dropout_p** (*float*, *optional*, defaults to 0.0) --
39
- Dropout probability. If greater than 0, applies dropout to the input before
40
- normalization and residual addition.
41
- - **rowscale** (*torch.Tensor*, *optional*) --
42
- Optional scaling factor applied to each row of the input tensor.
43
- Not compatible with the use of *x1*.
44
- - **prenorm** (*bool*, *optional*, defaults to False) --
45
- If True, returns both the normalized output and the unnormalized input+residual.
46
- - **residual_in_fp32** (*bool*, *optional*, defaults to False) --
47
- If True, performs the residual connection in FP32 precision.
48
- - **is_rms_norm** (*bool*, *optional*, defaults to False) --
49
- If True, uses RMS normalization instead of layer normalization.
50
- - **return_dropout_mask** (*bool*, *optional*, defaults to False) --
51
- If True, returns the dropout mask used for the computation.
52
- - **out** (*torch.Tensor*, *optional*) --
53
- Output tensor for the normalized result. If *None*, a new tensor is allocated.
54
- - **residual_out** (*torch.Tensor*, *optional*) --
55
- Output tensor for the residual result when using prenorm. If *None*, a new tensor
56
- is allocated when needed.
57
-
58
- ### Returns
59
-
60
- **Type**: *torch.Tensor* or tuple of *torch.Tensor*
61
-
62
- - The normalized input.
63
- - The second normalization of the input if *weight1* is provided.
64
- - The residual tensor if *prenorm* is set.
65
- - The dropout mask if *return_dropout_mask* is set.
66
- - The dropout mask for *x1* if *x1* is provided and *return_dropout_mask* is set.
67
-
68
- ## Layers
69
-
70
- ### Class `LlamaRMSNorm`
71
-
72
- No documentation available.
73
-
74
- #### Methods
75
-
76
- ##### Method `forward`
77
-
78
- `(self, hidden_states: torch.Tensor) -> torch.Tensor`
79
-
80
- No documentation available.
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## triton-layer-norm
8
+
9
+ Triton layer norm [from flash-attention](https://github.com/Dao-AILab/flash-attention).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build.toml CHANGED
@@ -1,3 +1,5 @@
1
  [general]
2
  name = "triton_layer_norm"
 
 
3
  universal = true
 
1
  [general]
2
  name = "triton_layer_norm"
3
+
4
+ [torch]
5
  universal = true
build/torch-universal/triton_layer_norm/__init__.py CHANGED
@@ -1,117 +1,5 @@
1
- """Triton layer normalization kernels
2
-
3
- This kernel implements layers normalization using Triton. This kernel is from
4
- the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
-
11
- from . import layers
12
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
13
 
 
14
 
15
- def layer_norm(
16
- x: torch.Tensor,
17
- weight: torch.Tensor,
18
- bias: torch.Tensor,
19
- residual: Optional[torch.Tensor] = None,
20
- x1: Optional[torch.Tensor] = None,
21
- weight1: Optional[torch.Tensor] = None,
22
- bias1: Optional[torch.Tensor] = None,
23
- eps: float = 1e-6,
24
- dropout_p: float = 0.0,
25
- rowscale=None,
26
- prenorm: bool = False,
27
- residual_in_fp32: bool = False,
28
- zero_centered_weight: bool = False,
29
- is_rms_norm: bool = False,
30
- return_dropout_mask: bool = False,
31
- out: Optional[torch.Tensor] = None,
32
- residual_out: Optional[torch.Tensor] = None,
33
- ):
34
- """
35
- Apply layer normalization to the input tensor with Triton acceleration.
36
-
37
- Args:
38
- x (`torch.Tensor`):
39
- Input tensor to normalize.
40
- weight (`torch.Tensor`):
41
- Scale parameter for normalization.
42
- bias (`torch.Tensor`):
43
- Shift parameter for normalization.
44
- residual (`torch.Tensor`, *optional*):
45
- Optional residual tensor to add to the input before normalization.
46
- x1 (`torch.Tensor`, *optional*):
47
- Optional second input tensor to combine with `x`. When provided, the function
48
- first adds `x1` to `x` and then applies normalization.
49
- weight1 (`torch.Tensor`, *optional*):
50
- Scale parameter for the second normalization.
51
- bias1 (`torch.Tensor`, *optional*):
52
- Shift parameter for the second normalization.
53
- eps (`float`, *optional*, defaults to 1e-6):
54
- Small constant added for numerical stability in normalization.
55
- dropout_p (`float`, *optional*, defaults to 0.0):
56
- Dropout probability. If greater than 0, applies dropout to the input before
57
- normalization and residual addition.
58
- rowscale (`torch.Tensor`, *optional*):
59
- Optional scaling factor applied to each row of the input tensor.
60
- Not compatible with the use of `x1`.
61
- prenorm (`bool`, *optional*, defaults to False):
62
- If True, returns both the normalized output and the unnormalized input+residual.
63
- residual_in_fp32 (`bool`, *optional*, defaults to False):
64
- If True, performs the residual connection in FP32 precision.
65
- zero_centered_weight (`bool`, *optional*, defaults to False):
66
- When set to true, 1.0 is added to the weight before applying it.
67
- is_rms_norm (`bool`, *optional*, defaults to False):
68
- If True, uses RMS normalization instead of layer normalization.
69
- return_dropout_mask (`bool`, *optional*, defaults to False):
70
- If True, returns the dropout mask used for the computation.
71
- out (`torch.Tensor`, *optional*):
72
- Output tensor for the normalized result. If `None`, a new tensor is allocated.
73
- residual_out (`torch.Tensor`, *optional*):
74
- Output tensor for the residual result when using prenorm. If `None`, a new tensor
75
- is allocated when needed.
76
-
77
- Returns:
78
- `torch.Tensor` or tuple of `torch.Tensor`:
79
- - The normalized input.
80
- - The second normalization of the input if `weight1` is provided.
81
- - The residual tensor if `prenorm` is set.
82
- - The dropout mask if `return_dropout_mask` is set.
83
- - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
84
- """
85
- return layer_norm_fn(
86
- x,
87
- weight,
88
- bias,
89
- residual,
90
- x1,
91
- weight1,
92
- bias1,
93
- eps,
94
- dropout_p,
95
- rowscale,
96
- prenorm,
97
- residual_in_fp32,
98
- is_rms_norm,
99
- return_dropout_mask,
100
- out=out,
101
- residual_out=residual_out,
102
- )
103
-
104
-
105
- __kernel_metadata__ = {
106
- "license": "bsd-3-clause",
107
- }
108
-
109
-
110
- __all__ = [
111
- "__kernel_metadata__",
112
- "layers",
113
- "layer_norm",
114
- "layer_norm_fn",
115
- "layer_norm_linear_fn",
116
- "rms_norm_fn",
117
- ]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
+ from . import layers
4
 
5
+ __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-universal/triton_layer_norm/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._triton_layer_norm_9b61b27_dirty
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_triton_layer_norm_9b61b27_dirty::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-universal/triton_layer_norm/layer_norm.py CHANGED
@@ -7,40 +7,14 @@
7
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
 
9
  import math
10
- from typing import Optional, List
11
 
12
  import torch
13
  import torch.nn.functional as F
14
- from torch import Tensor
15
 
16
  import triton
17
  import triton.language as tl
18
 
19
- from ._ops import add_op_namespace_prefix
20
- from .utils.torch import custom_fwd, custom_bwd
21
- from .utils.library import triton_op
22
-
23
-
24
- def maybe_contiguous_lastdim(x):
25
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
26
-
27
-
28
- def maybe_contiguous(x):
29
- return x.contiguous() if x is not None else None
30
-
31
-
32
- def triton_autotune_configs():
33
- # Return configs with a valid warp count for the current device
34
- configs = []
35
- # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
36
- max_threads_per_block = 1024
37
- # Default to warp size 32 if not defined by device
38
- warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
39
- # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
40
- return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32]
41
- if warp_count * warp_size <= max_threads_per_block]
42
- # return [triton.Config({}, num_warps=8)]
43
-
44
 
45
  def layer_norm_ref(
46
  x,
@@ -54,7 +28,6 @@ def layer_norm_ref(
54
  dropout_p=0.0,
55
  rowscale=None,
56
  prenorm=False,
57
- zero_centered_weight=False,
58
  dropout_mask=None,
59
  dropout_mask1=None,
60
  upcast=False,
@@ -68,10 +41,6 @@ def layer_norm_ref(
68
  x1 = x1.float() if x1 is not None else None
69
  weight1 = weight1.float() if weight1 is not None else None
70
  bias1 = bias1.float() if bias1 is not None else None
71
- if zero_centered_weight:
72
- weight = weight + 1.0
73
- if weight1 is not None:
74
- weight1 = weight1 + 1.0
75
  if x1 is not None:
76
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
77
  if rowscale is not None:
@@ -90,9 +59,9 @@ def layer_norm_ref(
90
  x = x + x1
91
  if residual is not None:
92
  x = (x + residual).to(x.dtype)
93
- out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
94
- dtype
95
- )
96
  if weight1 is None:
97
  return out if not prenorm else (out, x)
98
  else:
@@ -114,7 +83,6 @@ def rms_norm_ref(
114
  dropout_p=0.0,
115
  rowscale=None,
116
  prenorm=False,
117
- zero_centered_weight=False,
118
  dropout_mask=None,
119
  dropout_mask1=None,
120
  upcast=False,
@@ -128,10 +96,6 @@ def rms_norm_ref(
128
  x1 = x1.float() if x1 is not None else None
129
  weight1 = weight1.float() if weight1 is not None else None
130
  bias1 = bias1.float() if bias1 is not None else None
131
- if zero_centered_weight:
132
- weight = weight + 1.0
133
- if weight1 is not None:
134
- weight1 = weight1 + 1.0
135
  if x1 is not None:
136
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
137
  if rowscale is not None:
@@ -151,26 +115,34 @@ def rms_norm_ref(
151
  if residual is not None:
152
  x = (x + residual).to(x.dtype)
153
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
154
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
 
 
155
  if weight1 is None:
156
  return out if not prenorm else (out, x)
157
  else:
158
- out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
159
- dtype
160
- )
161
  return (out, out1) if not prenorm else (out, out1, x)
162
 
163
 
164
  @triton.autotune(
165
- configs=triton_autotune_configs(),
166
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"],
 
 
 
 
 
 
 
167
  )
168
- # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
169
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
170
  # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
171
- # @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
172
- # @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
173
- # @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
174
  @triton.jit
175
  def _layer_norm_fwd_1pass_kernel(
176
  X, # pointer to the input
@@ -186,7 +158,6 @@ def _layer_norm_fwd_1pass_kernel(
186
  ROWSCALE,
187
  SEEDS, # Dropout seeds for each row
188
  DROPOUT_MASK,
189
- DROPOUT_MASK1,
190
  Mean, # pointer to the mean
191
  Rstd, # pointer to the 1/std
192
  stride_x_row, # how much to increase the pointer when moving by 1 row
@@ -199,7 +170,6 @@ def _layer_norm_fwd_1pass_kernel(
199
  N, # number of columns in X
200
  eps, # epsilon to avoid division by zero
201
  dropout_p, # Dropout probability
202
- zero_centered_weight, # If true, add 1.0 to the weight
203
  IS_RMS_NORM: tl.constexpr,
204
  BLOCK_N: tl.constexpr,
205
  HAS_RESIDUAL: tl.constexpr,
@@ -233,7 +203,9 @@ def _layer_norm_fwd_1pass_kernel(
233
  if HAS_DROPOUT:
234
  # Compute dropout mask
235
  # 7 rounds is good enough, and reduces register pressure
236
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
237
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
238
  if STORE_DROPOUT_MASK:
239
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
@@ -246,11 +218,12 @@ def _layer_norm_fwd_1pass_kernel(
246
  # Compute dropout mask
247
  # 7 rounds is good enough, and reduces register pressure
248
  keep_mask = (
249
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
250
  )
251
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
252
  if STORE_DROPOUT_MASK:
253
- tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
254
  x += x1
255
  if HAS_RESIDUAL:
256
  residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
@@ -270,8 +243,6 @@ def _layer_norm_fwd_1pass_kernel(
270
  # Normalize and apply linear transformation
271
  mask = cols < N
272
  w = tl.load(W + cols, mask=mask).to(tl.float32)
273
- if zero_centered_weight:
274
- w += 1.0
275
  if HAS_BIAS:
276
  b = tl.load(B + cols, mask=mask).to(tl.float32)
277
  x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
@@ -280,8 +251,6 @@ def _layer_norm_fwd_1pass_kernel(
280
  tl.store(Y + cols, y, mask=mask)
281
  if HAS_W1:
282
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
283
- if zero_centered_weight:
284
- w1 += 1.0
285
  if HAS_B1:
286
  b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
287
  y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
@@ -289,87 +258,25 @@ def _layer_norm_fwd_1pass_kernel(
289
 
290
 
291
  def _layer_norm_fwd(
292
- x: Tensor,
293
- weight: Tensor,
294
- bias: Tensor,
295
- eps: float,
296
- residual: Optional[Tensor] = None,
297
- x1: Optional[Tensor] = None,
298
- weight1: Optional[Tensor] = None,
299
- bias1: Optional[Tensor] = None,
300
- dropout_p: float = 0.0,
301
- rowscale: Optional[Tensor] = None,
302
- out_dtype: Optional[torch.dtype] = None,
303
- residual_dtype: Optional[torch.dtype] = None,
304
- zero_centered_weight: bool = False,
305
- is_rms_norm: bool = False,
306
- return_dropout_mask: bool = False,
307
- out: Optional[Tensor] = None,
308
- residual_out: Optional[Tensor] = None
309
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
310
- # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
311
- # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
312
- # so that _layer_norm_fwd_impl doesn't have to return them.
313
- if out is None:
314
- out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
315
  if residual is not None:
316
  residual_dtype = residual.dtype
317
- if residual_out is None and (
318
- residual is not None
319
- or (residual_dtype is not None and residual_dtype != x.dtype)
320
- or dropout_p > 0.0
321
- or rowscale is not None
322
- or x1 is not None
323
- ):
324
- residual_out = torch.empty_like(
325
- x, dtype=residual_dtype if residual_dtype is not None else x.dtype
326
- )
327
- else:
328
- residual_out = None
329
- y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
330
- x,
331
- weight,
332
- bias,
333
- eps,
334
- out,
335
- residual=residual,
336
- x1=x1,
337
- weight1=weight1,
338
- bias1=bias1,
339
- dropout_p=dropout_p,
340
- rowscale=rowscale,
341
- zero_centered_weight=zero_centered_weight,
342
- is_rms_norm=is_rms_norm,
343
- return_dropout_mask=return_dropout_mask,
344
- residual_out=residual_out,
345
- )
346
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
347
- if residual_out is None:
348
- residual_out = x
349
- return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
350
-
351
-
352
- # [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
353
- # since we're returning a tuple of tensors
354
- @triton_op(add_op_namespace_prefix("layer_norm_fwd_impl"), mutates_args={"out", "residual_out"},
355
- schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)")
356
- def _layer_norm_fwd_impl(
357
- x: Tensor,
358
- weight: Tensor,
359
- bias: Tensor,
360
- eps: float,
361
- out: Tensor,
362
- residual: Optional[Tensor] = None,
363
- x1: Optional[Tensor] = None,
364
- weight1: Optional[Tensor] = None,
365
- bias1: Optional[Tensor] = None,
366
- dropout_p: float = 0.0,
367
- rowscale: Optional[Tensor] = None,
368
- zero_centered_weight: bool = False,
369
- is_rms_norm: bool = False,
370
- return_dropout_mask: bool = False,
371
- residual_out: Optional[Tensor] = None
372
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
373
  M, N = x.shape
374
  assert x.stride(-1) == 1
375
  if residual is not None:
@@ -393,17 +300,41 @@ def _layer_norm_fwd_impl(
393
  if rowscale is not None:
394
  assert rowscale.is_contiguous()
395
  assert rowscale.shape == (M,)
396
- assert out.shape == x.shape
 
 
 
 
397
  assert out.stride(-1) == 1
398
- if residual_out is not None:
399
- assert residual_out.shape == x.shape
400
- assert residual_out.stride(-1) == 1
401
  if weight1 is not None:
402
  y1 = torch.empty_like(out)
403
  assert y1.stride(-1) == 1
404
  else:
405
  y1 = None
406
- mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
408
  if dropout_p > 0.0:
409
  seeds = torch.randint(
@@ -412,20 +343,18 @@ def _layer_norm_fwd_impl(
412
  else:
413
  seeds = None
414
  if return_dropout_mask and dropout_p > 0.0:
415
- dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
416
- if x1 is not None:
417
- dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
418
- else:
419
- dropout_mask1 = None
420
  else:
421
- dropout_mask, dropout_mask1 = None, None
422
  # Less than 64KB per feature: enqueue fused kernel
423
  MAX_FUSED_SIZE = 65536 // x.element_size()
424
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
425
  if N > BLOCK_N:
426
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
427
  with torch.cuda.device(x.device.index):
428
- torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
429
  x,
430
  out,
431
  weight,
@@ -439,7 +368,6 @@ def _layer_norm_fwd_impl(
439
  rowscale,
440
  seeds,
441
  dropout_mask,
442
- dropout_mask1,
443
  mean,
444
  rstd,
445
  x.stride(0),
@@ -452,8 +380,6 @@ def _layer_norm_fwd_impl(
452
  N,
453
  eps,
454
  dropout_p,
455
- # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
456
- int(zero_centered_weight),
457
  is_rms_norm,
458
  BLOCK_N,
459
  residual is not None,
@@ -462,26 +388,50 @@ def _layer_norm_fwd_impl(
462
  dropout_p > 0.0,
463
  dropout_mask is not None,
464
  rowscale is not None,
465
- HAS_X1=x1 is not None,
466
- HAS_W1=weight1 is not None,
467
- HAS_B1=bias1 is not None,
468
  )
469
- return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
 
472
  @triton.autotune(
473
- configs=triton_autotune_configs(),
474
- key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  )
476
- # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
477
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
478
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
479
  # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
480
- # @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
481
- # @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
482
- # @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
483
- # @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
484
- # @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
485
  @triton.jit
486
  def _layer_norm_bwd_kernel(
487
  X, # pointer to the input
@@ -515,7 +465,6 @@ def _layer_norm_bwd_kernel(
515
  N, # number of columns in X
516
  eps, # epsilon to avoid division by zero
517
  dropout_p,
518
- zero_centered_weight,
519
  rows_per_program,
520
  IS_RMS_NORM: tl.constexpr,
521
  BLOCK_N: tl.constexpr,
@@ -549,14 +498,10 @@ def _layer_norm_bwd_kernel(
549
  if RECOMPUTE_OUTPUT:
550
  Y += row_start * stride_y_row
551
  w = tl.load(W + cols, mask=mask).to(tl.float32)
552
- if zero_centered_weight:
553
- w += 1.0
554
  if RECOMPUTE_OUTPUT and HAS_BIAS:
555
  b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
556
  if HAS_DY1:
557
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
558
- if zero_centered_weight:
559
- w1 += 1.0
560
  dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
561
  if HAS_BIAS:
562
  db = tl.zeros((BLOCK_N,), dtype=tl.float32)
@@ -605,14 +550,18 @@ def _layer_norm_bwd_kernel(
605
  if HAS_DX1:
606
  if HAS_DROPOUT:
607
  keep_mask = (
608
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
609
  )
610
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
611
  else:
612
  dx1 = dx
613
  tl.store(DX1 + cols, dx1, mask=mask)
614
  if HAS_DROPOUT:
615
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
 
616
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
617
  if HAS_ROWSCALE:
618
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
@@ -642,93 +591,31 @@ def _layer_norm_bwd_kernel(
642
 
643
 
644
  def _layer_norm_bwd(
645
- dy: Tensor,
646
- x: Tensor,
647
- weight: Tensor,
648
- bias: Tensor,
649
- eps: float,
650
- mean: Tensor,
651
- rstd: Tensor,
652
- dresidual: Optional[Tensor] = None,
653
- dy1: Optional[Tensor] = None,
654
- weight1: Optional[Tensor] = None,
655
- bias1: Optional[Tensor] = None,
656
- seeds: Optional[Tensor] = None,
657
- dropout_p: float = 0.0,
658
- rowscale: Optional[Tensor] = None,
659
- has_residual: bool = False,
660
- has_x1: bool = False,
661
- zero_centered_weight: bool = False,
662
- is_rms_norm: bool = False,
663
- x_dtype: Optional[torch.dtype] = None,
664
- recompute_output: bool = False,
665
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
666
- # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x,
667
- # which makes torch.library unhappy
668
- dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl(
669
- dy,
670
- x,
671
- weight,
672
- bias,
673
- eps,
674
- mean,
675
- rstd,
676
- dresidual,
677
- dy1,
678
- weight1,
679
- bias1,
680
- seeds,
681
- dropout_p,
682
- rowscale,
683
- has_residual,
684
- has_x1,
685
- zero_centered_weight,
686
- is_rms_norm,
687
- x_dtype=x_dtype,
688
- recompute_output=recompute_output,
689
- )
690
- # Don't need to compute dresidual_in separately in this case
691
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
692
- dresidual_in = dx
693
- if has_x1 and dropout_p == 0.0:
694
- dx1 = dx
695
- return dx, dw, db, dresidual_in, dx1, dw1, db1, y
696
-
697
-
698
-
699
- @triton_op(add_op_namespace_prefix("layer_norm_bwd_impl"), mutates_args={},
700
- schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)",
701
- allow_decomposition=False, # Don't let torch.compile trace inside
702
- )
703
- def _layer_norm_bwd_impl(
704
- dy: Tensor,
705
- x: Tensor,
706
- weight: Tensor,
707
- bias: Tensor,
708
- eps: float,
709
- mean: Tensor,
710
- rstd: Tensor,
711
- dresidual: Optional[Tensor] = None,
712
- dy1: Optional[Tensor] = None,
713
- weight1: Optional[Tensor] = None,
714
- bias1: Optional[Tensor] = None,
715
- seeds: Optional[Tensor] = None,
716
- dropout_p: float = 0.0,
717
- rowscale: Optional[Tensor] = None,
718
- has_residual: bool = False,
719
- has_x1: bool = False,
720
- zero_centered_weight: bool = False,
721
- is_rms_norm: bool = False,
722
- x_dtype: Optional[torch.dtype] = None,
723
- recompute_output: bool = False,
724
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
725
  M, N = x.shape
726
  assert x.stride(-1) == 1
727
- dy = maybe_contiguous_lastdim(dy)
728
  assert dy.stride(-1) == 1
729
  assert dy.shape == (M, N)
730
  if dresidual is not None:
731
- dresidual = maybe_contiguous_lastdim(dresidual)
732
  assert dresidual.stride(-1) == 1
733
  assert dresidual.shape == (M, N)
734
  assert weight.shape == (N,)
@@ -737,7 +624,6 @@ def _layer_norm_bwd_impl(
737
  assert bias.stride(-1) == 1
738
  assert bias.shape == (N,)
739
  if dy1 is not None:
740
- dy1 = maybe_contiguous_lastdim(dy1)
741
  assert weight1 is not None
742
  assert dy1.shape == dy.shape
743
  assert dy1.stride(-1) == 1
@@ -766,18 +652,22 @@ def _layer_norm_bwd_impl(
766
  else None
767
  )
768
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
769
- y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
 
 
 
 
770
  if recompute_output:
771
- assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
 
 
772
 
773
  # Less than 64KB per feature: enqueue fused kernel
774
  MAX_FUSED_SIZE = 65536 // x.element_size()
775
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
776
  if N > BLOCK_N:
777
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
778
- # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
779
- # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
780
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
781
  _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
782
  _db = (
783
  torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
@@ -789,7 +679,7 @@ def _layer_norm_bwd_impl(
789
  rows_per_program = math.ceil(M / sm_count)
790
  grid = (sm_count,)
791
  with torch.cuda.device(x.device.index):
792
- torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid](
793
  x,
794
  weight,
795
  bias,
@@ -821,8 +711,6 @@ def _layer_norm_bwd_impl(
821
  N,
822
  eps,
823
  dropout_p,
824
- # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
825
- int(zero_centered_weight),
826
  rows_per_program,
827
  is_rms_norm,
828
  BLOCK_N,
@@ -830,22 +718,24 @@ def _layer_norm_bwd_impl(
830
  dresidual_in is not None,
831
  bias is not None,
832
  dropout_p > 0.0,
833
- HAS_ROWSCALE=rowscale is not None,
834
- HAS_DY1=dy1 is not None,
835
- HAS_DX1=dx1 is not None,
836
- HAS_B1=bias1 is not None,
837
- RECOMPUTE_OUTPUT=y is not None,
838
  )
839
  dw = _dw.sum(0).to(weight.dtype)
840
  db = _db.sum(0).to(bias.dtype) if bias is not None else None
841
  dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
842
  db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
843
- # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx
844
- return dx, dw, db, dresidual_in, dx1, dw1, db1, y
 
 
 
 
 
 
 
 
845
 
846
 
847
  class LayerNormFn(torch.autograd.Function):
848
-
849
  @staticmethod
850
  def forward(
851
  ctx,
@@ -861,27 +751,34 @@ class LayerNormFn(torch.autograd.Function):
861
  rowscale=None,
862
  prenorm=False,
863
  residual_in_fp32=False,
864
- zero_centered_weight=False,
865
  is_rms_norm=False,
866
  return_dropout_mask=False,
867
- out_dtype=None,
868
  out=None,
869
- residual_out=None
870
  ):
871
  x_shape_og = x.shape
872
  # reshape input data into 2D tensor
873
- x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
 
 
874
  if residual is not None:
875
  assert residual.shape == x_shape_og
876
- residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
 
 
877
  if x1 is not None:
878
  assert x1.shape == x_shape_og
879
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
880
- x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
 
 
881
  weight = weight.contiguous()
882
- bias = maybe_contiguous(bias)
883
- weight1 = maybe_contiguous(weight1)
884
- bias1 = maybe_contiguous(bias1)
 
 
 
885
  if rowscale is not None:
886
  rowscale = rowscale.reshape(-1).contiguous()
887
  residual_dtype = (
@@ -893,24 +790,24 @@ class LayerNormFn(torch.autograd.Function):
893
  out = out.reshape(-1, out.shape[-1])
894
  if residual_out is not None:
895
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
896
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
897
- x,
898
- weight,
899
- bias,
900
- eps,
901
- residual,
902
- x1,
903
- weight1,
904
- bias1,
905
- dropout_p=dropout_p,
906
- rowscale=rowscale,
907
- out_dtype=out_dtype,
908
- residual_dtype=residual_dtype,
909
- zero_centered_weight=zero_centered_weight,
910
- is_rms_norm=is_rms_norm,
911
- return_dropout_mask=return_dropout_mask,
912
- out=out,
913
- residual_out=residual_out,
914
  )
915
  ctx.save_for_backward(
916
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@@ -923,12 +820,17 @@ class LayerNormFn(torch.autograd.Function):
923
  ctx.has_x1 = x1 is not None
924
  ctx.prenorm = prenorm
925
  ctx.x_dtype = x.dtype
926
- ctx.zero_centered_weight = zero_centered_weight
927
  y = y.reshape(x_shape_og)
928
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
929
- residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
930
- dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
931
- dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
 
 
 
 
 
 
932
  if not return_dropout_mask:
933
  if weight1 is None:
934
  return y if not prenorm else (y, residual_out)
@@ -952,19 +854,26 @@ class LayerNormFn(torch.autograd.Function):
952
  def backward(ctx, dy, *args):
953
  x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
954
  dy = dy.reshape(-1, dy.shape[-1])
 
 
 
955
  if weight1 is not None:
956
  dy1, args = args[0], args[1:]
957
  dy1 = dy1.reshape(-1, dy1.shape[-1])
 
 
958
  assert dy1.shape == x.shape
959
  else:
960
  dy1 = None
961
  if ctx.prenorm:
962
  dresidual = args[0]
963
  dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 
 
964
  assert dresidual.shape == x.shape
965
  else:
966
  dresidual = None
967
- dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd(
968
  dy,
969
  x,
970
  weight,
@@ -981,10 +890,8 @@ class LayerNormFn(torch.autograd.Function):
981
  rowscale,
982
  ctx.has_residual,
983
  ctx.has_x1,
984
- ctx.zero_centered_weight,
985
  ctx.is_rms_norm,
986
  x_dtype=ctx.x_dtype,
987
- recompute_output=False,
988
  )
989
  return (
990
  dx.reshape(ctx.x_shape_og),
@@ -1003,8 +910,6 @@ class LayerNormFn(torch.autograd.Function):
1003
  None,
1004
  None,
1005
  None,
1006
- None,
1007
- None,
1008
  )
1009
 
1010
 
@@ -1021,12 +926,10 @@ def layer_norm_fn(
1021
  rowscale=None,
1022
  prenorm=False,
1023
  residual_in_fp32=False,
1024
- zero_centered_weight=False,
1025
  is_rms_norm=False,
1026
  return_dropout_mask=False,
1027
- out_dtype=None,
1028
  out=None,
1029
- residual_out=None
1030
  ):
1031
  return LayerNormFn.apply(
1032
  x,
@@ -1041,12 +944,10 @@ def layer_norm_fn(
1041
  rowscale,
1042
  prenorm,
1043
  residual_in_fp32,
1044
- zero_centered_weight,
1045
  is_rms_norm,
1046
  return_dropout_mask,
1047
- out_dtype,
1048
  out,
1049
- residual_out
1050
  )
1051
 
1052
 
@@ -1063,11 +964,9 @@ def rms_norm_fn(
1063
  rowscale=None,
1064
  prenorm=False,
1065
  residual_in_fp32=False,
1066
- zero_centered_weight=False,
1067
  return_dropout_mask=False,
1068
- out_dtype=None,
1069
  out=None,
1070
- residual_out=None
1071
  ):
1072
  return LayerNormFn.apply(
1073
  x,
@@ -1082,19 +981,16 @@ def rms_norm_fn(
1082
  rowscale,
1083
  prenorm,
1084
  residual_in_fp32,
1085
- zero_centered_weight,
1086
  True,
1087
  return_dropout_mask,
1088
- out_dtype,
1089
  out,
1090
- residual_out
1091
  )
1092
 
1093
 
1094
  class RMSNorm(torch.nn.Module):
1095
 
1096
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
1097
- device=None, dtype=None):
1098
  factory_kwargs = {"device": device, "dtype": dtype}
1099
  super().__init__()
1100
  self.eps = eps
@@ -1102,16 +998,12 @@ class RMSNorm(torch.nn.Module):
1102
  self.drop = torch.nn.Dropout(dropout_p)
1103
  else:
1104
  self.drop = None
1105
- self.zero_centered_weight = zero_centered_weight
1106
  self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1107
  self.register_parameter("bias", None)
1108
  self.reset_parameters()
1109
 
1110
  def reset_parameters(self):
1111
- if not self.zero_centered_weight:
1112
- torch.nn.init.ones_(self.weight)
1113
- else:
1114
- torch.nn.init.zeros_(self.weight)
1115
 
1116
  def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1117
  return rms_norm_fn(
@@ -1123,14 +1015,12 @@ class RMSNorm(torch.nn.Module):
1123
  dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1124
  prenorm=prenorm,
1125
  residual_in_fp32=residual_in_fp32,
1126
- zero_centered_weight=self.zero_centered_weight,
1127
  )
1128
 
1129
 
1130
  class LayerNormLinearFn(torch.autograd.Function):
1131
-
1132
  @staticmethod
1133
- @custom_fwd
1134
  def forward(
1135
  ctx,
1136
  x,
@@ -1146,12 +1036,17 @@ class LayerNormLinearFn(torch.autograd.Function):
1146
  ):
1147
  x_shape_og = x.shape
1148
  # reshape input data into 2D tensor
1149
- x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
 
 
1150
  if residual is not None:
1151
  assert residual.shape == x_shape_og
1152
- residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
 
 
1153
  norm_weight = norm_weight.contiguous()
1154
- norm_bias = maybe_contiguous(norm_bias)
 
1155
  residual_dtype = (
1156
  residual.dtype
1157
  if residual is not None
@@ -1163,17 +1058,25 @@ class LayerNormLinearFn(torch.autograd.Function):
1163
  norm_bias,
1164
  eps,
1165
  residual,
1166
- out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
 
 
 
 
1167
  residual_dtype=residual_dtype,
1168
  is_rms_norm=is_rms_norm,
1169
  )
1170
  y = y.reshape(x_shape_og)
1171
- dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
 
 
1172
  linear_weight = linear_weight.to(dtype)
1173
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1174
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1175
  # We don't store y, will be recomputed in the backward pass to save memory
1176
- ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
 
 
1177
  ctx.x_shape_og = x_shape_og
1178
  ctx.eps = eps
1179
  ctx.is_rms_norm = is_rms_norm
@@ -1184,17 +1087,20 @@ class LayerNormLinearFn(torch.autograd.Function):
1184
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1185
 
1186
  @staticmethod
1187
- @custom_bwd
1188
  def backward(ctx, dout, *args):
1189
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1190
  dout = dout.reshape(-1, dout.shape[-1])
1191
  dy = F.linear(dout, linear_weight.t())
1192
  dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1193
- dy = maybe_contiguous_lastdim(dy)
 
1194
  assert dy.shape == x.shape
1195
  if ctx.prenorm:
1196
  dresidual = args[0]
1197
- dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1]))
 
 
1198
  assert dresidual.shape == x.shape
1199
  else:
1200
  dresidual = None
 
7
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
 
9
  import math
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
+ from torch.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def layer_norm_ref(
20
  x,
 
28
  dropout_p=0.0,
29
  rowscale=None,
30
  prenorm=False,
 
31
  dropout_mask=None,
32
  dropout_mask1=None,
33
  upcast=False,
 
41
  x1 = x1.float() if x1 is not None else None
42
  weight1 = weight1.float() if weight1 is not None else None
43
  bias1 = bias1.float() if bias1 is not None else None
 
 
 
 
44
  if x1 is not None:
45
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
  if rowscale is not None:
 
59
  x = x + x1
60
  if residual is not None:
61
  x = (x + residual).to(x.dtype)
62
+ out = F.layer_norm(
63
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
64
+ ).to(dtype)
65
  if weight1 is None:
66
  return out if not prenorm else (out, x)
67
  else:
 
83
  dropout_p=0.0,
84
  rowscale=None,
85
  prenorm=False,
 
86
  dropout_mask=None,
87
  dropout_mask1=None,
88
  upcast=False,
 
96
  x1 = x1.float() if x1 is not None else None
97
  weight1 = weight1.float() if weight1 is not None else None
98
  bias1 = bias1.float() if bias1 is not None else None
 
 
 
 
99
  if x1 is not None:
100
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
  if rowscale is not None:
 
115
  if residual is not None:
116
  x = (x + residual).to(x.dtype)
117
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
119
+ dtype
120
+ )
121
  if weight1 is None:
122
  return out if not prenorm else (out, x)
123
  else:
124
+ out1 = (
125
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
126
+ ).to(dtype)
127
  return (out, out1) if not prenorm else (out, out1, x)
128
 
129
 
130
  @triton.autotune(
131
+ configs=[
132
+ triton.Config({}, num_warps=1),
133
+ triton.Config({}, num_warps=2),
134
+ triton.Config({}, num_warps=4),
135
+ triton.Config({}, num_warps=8),
136
+ triton.Config({}, num_warps=16),
137
+ triton.Config({}, num_warps=32),
138
+ ],
139
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
140
  )
 
141
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
142
  # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
143
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
144
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
145
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
146
  @triton.jit
147
  def _layer_norm_fwd_1pass_kernel(
148
  X, # pointer to the input
 
158
  ROWSCALE,
159
  SEEDS, # Dropout seeds for each row
160
  DROPOUT_MASK,
 
161
  Mean, # pointer to the mean
162
  Rstd, # pointer to the 1/std
163
  stride_x_row, # how much to increase the pointer when moving by 1 row
 
170
  N, # number of columns in X
171
  eps, # epsilon to avoid division by zero
172
  dropout_p, # Dropout probability
 
173
  IS_RMS_NORM: tl.constexpr,
174
  BLOCK_N: tl.constexpr,
175
  HAS_RESIDUAL: tl.constexpr,
 
203
  if HAS_DROPOUT:
204
  # Compute dropout mask
205
  # 7 rounds is good enough, and reduces register pressure
206
+ keep_mask = (
207
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
208
+ )
209
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
210
  if STORE_DROPOUT_MASK:
211
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
 
218
  # Compute dropout mask
219
  # 7 rounds is good enough, and reduces register pressure
220
  keep_mask = (
221
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
222
+ > dropout_p
223
  )
224
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
225
  if STORE_DROPOUT_MASK:
226
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
227
  x += x1
228
  if HAS_RESIDUAL:
229
  residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
 
243
  # Normalize and apply linear transformation
244
  mask = cols < N
245
  w = tl.load(W + cols, mask=mask).to(tl.float32)
 
 
246
  if HAS_BIAS:
247
  b = tl.load(B + cols, mask=mask).to(tl.float32)
248
  x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 
251
  tl.store(Y + cols, y, mask=mask)
252
  if HAS_W1:
253
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
 
 
254
  if HAS_B1:
255
  b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
256
  y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
 
258
 
259
 
260
  def _layer_norm_fwd(
261
+ x,
262
+ weight,
263
+ bias,
264
+ eps,
265
+ residual=None,
266
+ x1=None,
267
+ weight1=None,
268
+ bias1=None,
269
+ dropout_p=0.0,
270
+ rowscale=None,
271
+ out_dtype=None,
272
+ residual_dtype=None,
273
+ is_rms_norm=False,
274
+ return_dropout_mask=False,
275
+ out=None,
276
+ residual_out=None,
277
+ ):
 
 
 
 
 
 
278
  if residual is not None:
279
  residual_dtype = residual.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  M, N = x.shape
281
  assert x.stride(-1) == 1
282
  if residual is not None:
 
300
  if rowscale is not None:
301
  assert rowscale.is_contiguous()
302
  assert rowscale.shape == (M,)
303
+ # allocate output
304
+ if out is None:
305
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
306
+ else:
307
+ assert out.shape == x.shape
308
  assert out.stride(-1) == 1
 
 
 
309
  if weight1 is not None:
310
  y1 = torch.empty_like(out)
311
  assert y1.stride(-1) == 1
312
  else:
313
  y1 = None
314
+ if (
315
+ residual is not None
316
+ or (residual_dtype is not None and residual_dtype != x.dtype)
317
+ or dropout_p > 0.0
318
+ or rowscale is not None
319
+ or x1 is not None
320
+ ):
321
+ if residual_out is None:
322
+ residual_out = torch.empty(
323
+ M,
324
+ N,
325
+ device=x.device,
326
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
327
+ )
328
+ else:
329
+ assert residual_out.shape == x.shape
330
+ assert residual_out.stride(-1) == 1
331
+ else:
332
+ residual_out = None
333
+ mean = (
334
+ torch.empty((M,), dtype=torch.float32, device=x.device)
335
+ if not is_rms_norm
336
+ else None
337
+ )
338
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
339
  if dropout_p > 0.0:
340
  seeds = torch.randint(
 
343
  else:
344
  seeds = None
345
  if return_dropout_mask and dropout_p > 0.0:
346
+ dropout_mask = torch.empty(
347
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
348
+ )
 
 
349
  else:
350
+ dropout_mask = None
351
  # Less than 64KB per feature: enqueue fused kernel
352
  MAX_FUSED_SIZE = 65536 // x.element_size()
353
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
354
  if N > BLOCK_N:
355
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
356
  with torch.cuda.device(x.device.index):
357
+ _layer_norm_fwd_1pass_kernel[(M,)](
358
  x,
359
  out,
360
  weight,
 
368
  rowscale,
369
  seeds,
370
  dropout_mask,
 
371
  mean,
372
  rstd,
373
  x.stride(0),
 
380
  N,
381
  eps,
382
  dropout_p,
 
 
383
  is_rms_norm,
384
  BLOCK_N,
385
  residual is not None,
 
388
  dropout_p > 0.0,
389
  dropout_mask is not None,
390
  rowscale is not None,
 
 
 
391
  )
392
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
393
+ if dropout_mask is not None and x1 is not None:
394
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
395
+ else:
396
+ dropout_mask1 = None
397
+ return (
398
+ out,
399
+ y1,
400
+ mean,
401
+ rstd,
402
+ residual_out if residual_out is not None else x,
403
+ seeds,
404
+ dropout_mask,
405
+ dropout_mask1,
406
+ )
407
 
408
 
409
  @triton.autotune(
410
+ configs=[
411
+ triton.Config({}, num_warps=1),
412
+ triton.Config({}, num_warps=2),
413
+ triton.Config({}, num_warps=4),
414
+ triton.Config({}, num_warps=8),
415
+ triton.Config({}, num_warps=16),
416
+ triton.Config({}, num_warps=32),
417
+ ],
418
+ key=[
419
+ "N",
420
+ "HAS_DRESIDUAL",
421
+ "STORE_DRESIDUAL",
422
+ "IS_RMS_NORM",
423
+ "HAS_BIAS",
424
+ "HAS_DROPOUT",
425
+ ],
426
  )
 
427
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
428
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
429
  # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
430
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
431
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
432
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
433
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
434
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
435
  @triton.jit
436
  def _layer_norm_bwd_kernel(
437
  X, # pointer to the input
 
465
  N, # number of columns in X
466
  eps, # epsilon to avoid division by zero
467
  dropout_p,
 
468
  rows_per_program,
469
  IS_RMS_NORM: tl.constexpr,
470
  BLOCK_N: tl.constexpr,
 
498
  if RECOMPUTE_OUTPUT:
499
  Y += row_start * stride_y_row
500
  w = tl.load(W + cols, mask=mask).to(tl.float32)
 
 
501
  if RECOMPUTE_OUTPUT and HAS_BIAS:
502
  b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
503
  if HAS_DY1:
504
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
 
 
505
  dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
  if HAS_BIAS:
507
  db = tl.zeros((BLOCK_N,), dtype=tl.float32)
 
550
  if HAS_DX1:
551
  if HAS_DROPOUT:
552
  keep_mask = (
553
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
554
+ > dropout_p
555
  )
556
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
557
  else:
558
  dx1 = dx
559
  tl.store(DX1 + cols, dx1, mask=mask)
560
  if HAS_DROPOUT:
561
+ keep_mask = (
562
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
563
+ > dropout_p
564
+ )
565
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
566
  if HAS_ROWSCALE:
567
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
 
591
 
592
 
593
  def _layer_norm_bwd(
594
+ dy,
595
+ x,
596
+ weight,
597
+ bias,
598
+ eps,
599
+ mean,
600
+ rstd,
601
+ dresidual=None,
602
+ dy1=None,
603
+ weight1=None,
604
+ bias1=None,
605
+ seeds=None,
606
+ dropout_p=0.0,
607
+ rowscale=None,
608
+ has_residual=False,
609
+ has_x1=False,
610
+ is_rms_norm=False,
611
+ x_dtype=None,
612
+ recompute_output=False,
613
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  M, N = x.shape
615
  assert x.stride(-1) == 1
 
616
  assert dy.stride(-1) == 1
617
  assert dy.shape == (M, N)
618
  if dresidual is not None:
 
619
  assert dresidual.stride(-1) == 1
620
  assert dresidual.shape == (M, N)
621
  assert weight.shape == (N,)
 
624
  assert bias.stride(-1) == 1
625
  assert bias.shape == (N,)
626
  if dy1 is not None:
 
627
  assert weight1 is not None
628
  assert dy1.shape == dy.shape
629
  assert dy1.stride(-1) == 1
 
652
  else None
653
  )
654
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
655
+ y = (
656
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
657
+ if recompute_output
658
+ else None
659
+ )
660
  if recompute_output:
661
+ assert (
662
+ weight1 is None
663
+ ), "recompute_output is not supported with parallel LayerNorm"
664
 
665
  # Less than 64KB per feature: enqueue fused kernel
666
  MAX_FUSED_SIZE = 65536 // x.element_size()
667
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
668
  if N > BLOCK_N:
669
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
670
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
 
 
671
  _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
672
  _db = (
673
  torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
 
679
  rows_per_program = math.ceil(M / sm_count)
680
  grid = (sm_count,)
681
  with torch.cuda.device(x.device.index):
682
+ _layer_norm_bwd_kernel[grid](
683
  x,
684
  weight,
685
  bias,
 
711
  N,
712
  eps,
713
  dropout_p,
 
 
714
  rows_per_program,
715
  is_rms_norm,
716
  BLOCK_N,
 
718
  dresidual_in is not None,
719
  bias is not None,
720
  dropout_p > 0.0,
 
 
 
 
 
721
  )
722
  dw = _dw.sum(0).to(weight.dtype)
723
  db = _db.sum(0).to(bias.dtype) if bias is not None else None
724
  dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
725
  db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
726
+ # Don't need to compute dresidual_in separately in this case
727
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
728
+ dresidual_in = dx
729
+ if has_x1 and dropout_p == 0.0:
730
+ dx1 = dx
731
+ return (
732
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
733
+ if not recompute_output
734
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
735
+ )
736
 
737
 
738
  class LayerNormFn(torch.autograd.Function):
 
739
  @staticmethod
740
  def forward(
741
  ctx,
 
751
  rowscale=None,
752
  prenorm=False,
753
  residual_in_fp32=False,
 
754
  is_rms_norm=False,
755
  return_dropout_mask=False,
 
756
  out=None,
757
+ residual_out=None,
758
  ):
759
  x_shape_og = x.shape
760
  # reshape input data into 2D tensor
761
+ x = x.reshape(-1, x.shape[-1])
762
+ if x.stride(-1) != 1:
763
+ x = x.contiguous()
764
  if residual is not None:
765
  assert residual.shape == x_shape_og
766
+ residual = residual.reshape(-1, residual.shape[-1])
767
+ if residual.stride(-1) != 1:
768
+ residual = residual.contiguous()
769
  if x1 is not None:
770
  assert x1.shape == x_shape_og
771
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
772
+ x1 = x1.reshape(-1, x1.shape[-1])
773
+ if x1.stride(-1) != 1:
774
+ x1 = x1.contiguous()
775
  weight = weight.contiguous()
776
+ if bias is not None:
777
+ bias = bias.contiguous()
778
+ if weight1 is not None:
779
+ weight1 = weight1.contiguous()
780
+ if bias1 is not None:
781
+ bias1 = bias1.contiguous()
782
  if rowscale is not None:
783
  rowscale = rowscale.reshape(-1).contiguous()
784
  residual_dtype = (
 
790
  out = out.reshape(-1, out.shape[-1])
791
  if residual_out is not None:
792
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
793
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
794
+ _layer_norm_fwd(
795
+ x,
796
+ weight,
797
+ bias,
798
+ eps,
799
+ residual,
800
+ x1,
801
+ weight1,
802
+ bias1,
803
+ dropout_p=dropout_p,
804
+ rowscale=rowscale,
805
+ residual_dtype=residual_dtype,
806
+ is_rms_norm=is_rms_norm,
807
+ return_dropout_mask=return_dropout_mask,
808
+ out=out,
809
+ residual_out=residual_out,
810
+ )
811
  )
812
  ctx.save_for_backward(
813
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
 
820
  ctx.has_x1 = x1 is not None
821
  ctx.prenorm = prenorm
822
  ctx.x_dtype = x.dtype
 
823
  y = y.reshape(x_shape_og)
824
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
825
+ residual_out = (
826
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
827
+ )
828
+ dropout_mask = (
829
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
830
+ )
831
+ dropout_mask1 = (
832
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
833
+ )
834
  if not return_dropout_mask:
835
  if weight1 is None:
836
  return y if not prenorm else (y, residual_out)
 
854
  def backward(ctx, dy, *args):
855
  x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
856
  dy = dy.reshape(-1, dy.shape[-1])
857
+ if dy.stride(-1) != 1:
858
+ dy = dy.contiguous()
859
+ assert dy.shape == x.shape
860
  if weight1 is not None:
861
  dy1, args = args[0], args[1:]
862
  dy1 = dy1.reshape(-1, dy1.shape[-1])
863
+ if dy1.stride(-1) != 1:
864
+ dy1 = dy1.contiguous()
865
  assert dy1.shape == x.shape
866
  else:
867
  dy1 = None
868
  if ctx.prenorm:
869
  dresidual = args[0]
870
  dresidual = dresidual.reshape(-1, dresidual.shape[-1])
871
+ if dresidual.stride(-1) != 1:
872
+ dresidual = dresidual.contiguous()
873
  assert dresidual.shape == x.shape
874
  else:
875
  dresidual = None
876
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
877
  dy,
878
  x,
879
  weight,
 
890
  rowscale,
891
  ctx.has_residual,
892
  ctx.has_x1,
 
893
  ctx.is_rms_norm,
894
  x_dtype=ctx.x_dtype,
 
895
  )
896
  return (
897
  dx.reshape(ctx.x_shape_og),
 
910
  None,
911
  None,
912
  None,
 
 
913
  )
914
 
915
 
 
926
  rowscale=None,
927
  prenorm=False,
928
  residual_in_fp32=False,
 
929
  is_rms_norm=False,
930
  return_dropout_mask=False,
 
931
  out=None,
932
+ residual_out=None,
933
  ):
934
  return LayerNormFn.apply(
935
  x,
 
944
  rowscale,
945
  prenorm,
946
  residual_in_fp32,
 
947
  is_rms_norm,
948
  return_dropout_mask,
 
949
  out,
950
+ residual_out,
951
  )
952
 
953
 
 
964
  rowscale=None,
965
  prenorm=False,
966
  residual_in_fp32=False,
 
967
  return_dropout_mask=False,
 
968
  out=None,
969
+ residual_out=None,
970
  ):
971
  return LayerNormFn.apply(
972
  x,
 
981
  rowscale,
982
  prenorm,
983
  residual_in_fp32,
 
984
  True,
985
  return_dropout_mask,
 
986
  out,
987
+ residual_out,
988
  )
989
 
990
 
991
  class RMSNorm(torch.nn.Module):
992
 
993
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
 
994
  factory_kwargs = {"device": device, "dtype": dtype}
995
  super().__init__()
996
  self.eps = eps
 
998
  self.drop = torch.nn.Dropout(dropout_p)
999
  else:
1000
  self.drop = None
 
1001
  self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1002
  self.register_parameter("bias", None)
1003
  self.reset_parameters()
1004
 
1005
  def reset_parameters(self):
1006
+ torch.nn.init.ones_(self.weight)
 
 
 
1007
 
1008
  def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1009
  return rms_norm_fn(
 
1015
  dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1016
  prenorm=prenorm,
1017
  residual_in_fp32=residual_in_fp32,
 
1018
  )
1019
 
1020
 
1021
  class LayerNormLinearFn(torch.autograd.Function):
 
1022
  @staticmethod
1023
+ @custom_fwd(device_type="cuda")
1024
  def forward(
1025
  ctx,
1026
  x,
 
1036
  ):
1037
  x_shape_og = x.shape
1038
  # reshape input data into 2D tensor
1039
+ x = x.reshape(-1, x.shape[-1])
1040
+ if x.stride(-1) != 1:
1041
+ x = x.contiguous()
1042
  if residual is not None:
1043
  assert residual.shape == x_shape_og
1044
+ residual = residual.reshape(-1, residual.shape[-1])
1045
+ if residual.stride(-1) != 1:
1046
+ residual = residual.contiguous()
1047
  norm_weight = norm_weight.contiguous()
1048
+ if norm_bias is not None:
1049
+ norm_bias = norm_bias.contiguous()
1050
  residual_dtype = (
1051
  residual.dtype
1052
  if residual is not None
 
1058
  norm_bias,
1059
  eps,
1060
  residual,
1061
+ out_dtype=(
1062
+ None
1063
+ if not torch.is_autocast_enabled()
1064
+ else torch.get_autocast_gpu_dtype()
1065
+ ),
1066
  residual_dtype=residual_dtype,
1067
  is_rms_norm=is_rms_norm,
1068
  )
1069
  y = y.reshape(x_shape_og)
1070
+ dtype = (
1071
+ torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1072
+ )
1073
  linear_weight = linear_weight.to(dtype)
1074
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1075
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1076
  # We don't store y, will be recomputed in the backward pass to save memory
1077
+ ctx.save_for_backward(
1078
+ residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1079
+ )
1080
  ctx.x_shape_og = x_shape_og
1081
  ctx.eps = eps
1082
  ctx.is_rms_norm = is_rms_norm
 
1087
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1088
 
1089
  @staticmethod
1090
+ @custom_bwd(device_type="cuda")
1091
  def backward(ctx, dout, *args):
1092
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1093
  dout = dout.reshape(-1, dout.shape[-1])
1094
  dy = F.linear(dout, linear_weight.t())
1095
  dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1096
+ if dy.stride(-1) != 1:
1097
+ dy = dy.contiguous()
1098
  assert dy.shape == x.shape
1099
  if ctx.prenorm:
1100
  dresidual = args[0]
1101
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1102
+ if dresidual.stride(-1) != 1:
1103
+ dresidual = dresidual.contiguous()
1104
  assert dresidual.shape == x.shape
1105
  else:
1106
  dresidual = None
build/torch-universal/triton_layer_norm/layers.py CHANGED
@@ -1,46 +1,4 @@
1
- import torch
2
- from torch import nn
3
 
4
- from .layer_norm import rms_norm_fn
5
 
6
-
7
- class LlamaRMSNorm(nn.Module):
8
- """
9
- RMS Layer Norm for Llama models.
10
-
11
- Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
12
- `transformers`.
13
-
14
- Attributes:
15
- weight (`torch.Tensor`): The learnable scaling parameter.
16
- variance_epsilon (`float`): The epsilon value for numerical stability.
17
- """
18
- weight: torch.Tensor
19
- variance_epsilon: float
20
-
21
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
- """
23
- Apply RMS normalization to the input hidden states.
24
-
25
- Args:
26
- hidden_states (`torch.Tensor`):
27
- Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
28
- where the last dimension is the feature dimension to be normalized.
29
-
30
- Returns:
31
- `torch.Tensor`:
32
- The normalized tensor with the same shape as the input `hidden_states`.
33
- """
34
- return rms_norm_fn(
35
- hidden_states,
36
- self.weight,
37
- bias=None,
38
- residual=None,
39
- eps=self.variance_epsilon,
40
- dropout_p=0.0,
41
- prenorm=False,
42
- residual_in_fp32=False,
43
- )
44
-
45
-
46
- __all__ = ["LlamaRMSNorm"]
 
1
+ from .layer_norm import RMSNorm
 
2
 
 
3
 
4
+ __all__ = ["RMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-universal/triton_layer_norm/utils/__init__.py DELETED
File without changes
build/torch-universal/triton_layer_norm/utils/library.py DELETED
@@ -1,66 +0,0 @@
1
- # Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py
2
- # The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema.
3
-
4
- from typing import Optional, Callable, Iterable, Union
5
-
6
- from torch.library import custom_op, CustomOpDef
7
- from torch._library.triton import set_wrap_triton_enabled
8
-
9
-
10
- def triton_op(
11
- name: str,
12
- fn: Optional[Callable] = None,
13
- /,
14
- *,
15
- mutates_args: Union[str, Iterable[str]],
16
- schema: Optional[str] = None,
17
- # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False,
18
- # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator
19
- # and so inductor can't trace inside.
20
- allow_decomposition=True,
21
- ) -> Callable:
22
- def dec(fn: Callable[..., object]) -> CustomOpDef:
23
- def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
24
- # Optimization: we're passing regular Tensors into the triton kernel, so
25
- # no need to go through HOP dispatch
26
- with set_wrap_triton_enabled(False):
27
- return fn(*args, **kwargs)
28
-
29
- result = custom_op(
30
- name,
31
- backend_fn,
32
- mutates_args=mutates_args,
33
- # This is the only difference with the PyTorch implementation
34
- schema=schema,
35
- )
36
- from torch._subclasses.functional_tensor import FunctionalTensorMode
37
-
38
- # We require that the user pass us a function that is make_fx traceable,
39
- # so we can just register it as the Fake/meta kernel.
40
- result.register_fake(fn)
41
-
42
- if allow_decomposition:
43
- # We decompose the operator when FunctionalTensorMode is active.
44
- # The goal is to decompose the operator in AOTDispatcher.
45
- # - With torch.compile, this means that the backend (usually Inductor)
46
- # can see a call to the triton kernel(s) and so it can directly optimize
47
- # them by inlining them into the lowering process.
48
- def functional_decomp( # type: ignore[no-untyped-def]
49
- mode, op, types, args, kwargs
50
- ):
51
- from torch.export._trace import custom_triton_ops_decomposition_disabled
52
-
53
- if custom_triton_ops_decomposition_disabled():
54
- return mode.__torch_dispatch__(op, types, args, kwargs)
55
- else:
56
- with mode:
57
- return fn(*args, **kwargs)
58
-
59
- result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
60
-
61
- return result
62
-
63
- if fn is None:
64
- return dec
65
- else:
66
- return dec(fn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-universal/triton_layer_norm/utils/torch.py DELETED
@@ -1,21 +0,0 @@
1
- import torch
2
- from typing import Callable
3
-
4
-
5
- def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
- def decorator(*args, **kwargs):
7
- if cuda_amp_deprecated:
8
- kwargs["device_type"] = "cuda"
9
- return dec(*args, **kwargs)
10
- return decorator
11
-
12
-
13
- if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
- deprecated = True
15
- from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
- else:
17
- deprecated = False
18
- from torch.cuda.amp import custom_fwd, custom_bwd
19
-
20
- custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
- custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1754038838,
77
- "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1756320464,
102
- "narHash": "sha256-x9LI4h87/Z9UgTQjgeG0fRcdeXl91xIqBlTauGKZM70=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "b4accba4496b28faef19a0487fbcf9686b14e2ef",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1752785354,
117
- "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
118
- "owner": "nixos",
119
- "repo": "nixpkgs",
120
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "nixos",
125
- "repo": "nixpkgs",
126
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix CHANGED
@@ -10,11 +10,5 @@
10
  self,
11
  kernel-builder,
12
  }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- path = ./.;
15
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
- # Import-time autotune.
17
- doGetKernelCheck = false;
18
- pythonCheckInputs = pkgs: with pkgs; [ einops ];
19
- };
20
  }
 
10
  self,
11
  kernel-builder,
12
  }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
 
 
 
 
 
 
14
  }
tests/test_layer_norm.py DELETED
@@ -1,373 +0,0 @@
1
- # Copyright (c) 2024, Tri Dao.
2
-
3
- import pytest
4
- import torch
5
- import torch.nn.functional as F
6
- from einops import rearrange, repeat
7
-
8
- from triton_layer_norm import (
9
- layer_norm_fn,
10
- layer_norm_linear_fn,
11
- )
12
- from triton_layer_norm.layer_norm import layer_norm_ref, rms_norm_ref
13
-
14
-
15
- is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
16
-
17
-
18
- # @pytest.mark.parametrize("zero_centered_weight", [False, True])
19
- @pytest.mark.parametrize("zero_centered_weight", [False])
20
- @pytest.mark.parametrize("has_weight1", [False, True])
21
- # @pytest.mark.parametrize("has_weight1", [False])
22
- @pytest.mark.parametrize("has_x1", [False, True])
23
- # @pytest.mark.parametrize("has_x1", [False])
24
- @pytest.mark.parametrize("has_rowscale", [False, True])
25
- # @pytest.mark.parametrize("has_rowscale", [False])
26
- @pytest.mark.parametrize("dropout_p", [0.0, 0.27])
27
- # @pytest.mark.parametrize("dropout_p", [0.0])
28
- @pytest.mark.parametrize("prenorm", [True, False])
29
- # @pytest.mark.parametrize("prenorm", [True])
30
- @pytest.mark.parametrize("is_rms_norm", [False, True])
31
- # @pytest.mark.parametrize("is_rms_norm", [True])
32
- @pytest.mark.parametrize("has_residual", [True, False])
33
- # @pytest.mark.parametrize("has_residual", [True])
34
- @pytest.mark.parametrize(
35
- "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
36
- )
37
- # @pytest.mark.parametrize("weight_dtype", [torch.float32])
38
- @pytest.mark.parametrize(
39
- "input_dtype,residual_dtype",
40
- [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
41
- + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
42
- )
43
- # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
44
- @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
45
- # @pytest.mark.parametrize("hidden_size", [1024])
46
- def test_layer_norm(
47
- hidden_size,
48
- input_dtype,
49
- residual_dtype,
50
- weight_dtype,
51
- has_residual,
52
- is_rms_norm,
53
- prenorm,
54
- dropout_p,
55
- has_rowscale,
56
- has_x1,
57
- has_weight1,
58
- zero_centered_weight,
59
- ):
60
- if has_rowscale and has_x1:
61
- pytest.skip("Not supported")
62
- device = "cuda"
63
- if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
64
- atol = 5e-2
65
- elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
66
- atol = 1e-2
67
- else:
68
- atol = 1e-4
69
- # set seed
70
- torch.random.manual_seed(0)
71
- batch_size = 8
72
- seqlen = 512
73
- layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
74
- allclose = (
75
- # Sometimes x0_pt.grad is NaN
76
- lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
77
- <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
78
- or (
79
- # Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
80
- # by multiply and divide by 0.3
81
- (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
82
- and (x - x_ref).abs().max()
83
- <= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
84
- )
85
- )
86
- x0 = torch.randn(
87
- batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
88
- )
89
- x0_pt = x0.detach().clone().requires_grad_()
90
- x0_ref = x0.detach().clone().requires_grad_()
91
- if has_residual:
92
- res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
93
- res_pt = res.detach().clone().requires_grad_()
94
- res_ref = res.detach().clone().requires_grad_()
95
- else:
96
- res, res_pt, res_ref = None, None, None
97
- weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
98
- if not is_rms_norm:
99
- bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
100
- else:
101
- bias = None
102
- weight_pt = weight.detach().clone().requires_grad_()
103
- weight_ref = weight.detach().clone().requires_grad_()
104
- bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
105
- bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
106
- if has_x1:
107
- x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
108
- x1_pt = x1.detach().clone().requires_grad_()
109
- x1_ref = x1.detach().clone().requires_grad_()
110
- else:
111
- x1, x1_pt, x1_ref = None, None, None
112
- if has_weight1:
113
- weight1 = torch.randn(
114
- hidden_size, device=device, dtype=weight_dtype, requires_grad=True
115
- )
116
- weight1_pt = weight1.detach().clone().requires_grad_()
117
- weight1_ref = weight1.detach().clone().requires_grad_()
118
- if not is_rms_norm:
119
- bias1 = torch.randn(
120
- hidden_size, device=device, dtype=weight_dtype, requires_grad=True
121
- )
122
- else:
123
- bias1 = None
124
- bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
125
- bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
126
- else:
127
- weight1, weight1_pt, weight1_ref = None, None, None
128
- bias1, bias1_pt, bias1_ref = None, None, None
129
-
130
- rowscale = (
131
- torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
132
- if has_rowscale
133
- else None
134
- )
135
-
136
- residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
137
- out, *rest = layer_norm_fn(
138
- x0,
139
- weight,
140
- bias,
141
- residual=res,
142
- x1=x1,
143
- weight1=weight1,
144
- bias1=bias1,
145
- eps=1e-6,
146
- dropout_p=dropout_p,
147
- rowscale=rowscale,
148
- prenorm=prenorm,
149
- residual_in_fp32=residual_in_fp32,
150
- zero_centered_weight=zero_centered_weight,
151
- is_rms_norm=is_rms_norm,
152
- return_dropout_mask=True,
153
- )
154
- dropout_mask = rest[-2] if dropout_p > 0.0 else None
155
- dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
156
- out_pt = layer_norm_ref_fn(
157
- x0_pt,
158
- weight_pt,
159
- bias_pt,
160
- residual=res_pt,
161
- x1=x1_pt,
162
- weight1=weight1_pt,
163
- bias1=bias1_pt,
164
- eps=1e-6,
165
- dropout_p=dropout_p,
166
- rowscale=rowscale,
167
- prenorm=prenorm,
168
- zero_centered_weight=zero_centered_weight,
169
- dropout_mask=dropout_mask,
170
- dropout_mask1=dropout_mask1,
171
- )
172
- out_ref = layer_norm_ref_fn(
173
- x0_ref,
174
- weight_ref,
175
- bias_ref,
176
- residual=res_ref,
177
- x1=x1_ref,
178
- weight1=weight1_ref,
179
- bias1=bias1_ref,
180
- eps=1e-6,
181
- dropout_p=dropout_p,
182
- rowscale=rowscale,
183
- prenorm=prenorm,
184
- zero_centered_weight=zero_centered_weight,
185
- dropout_mask=dropout_mask,
186
- dropout_mask1=dropout_mask1,
187
- upcast=True,
188
- )
189
- if not has_weight1:
190
- if prenorm:
191
- residual = rest[0]
192
- out_pt, residual_pt = out_pt
193
- out_ref, residual_ref = out_ref
194
- out1, out1_pt, out1_ref = None, None, None
195
- else:
196
- out1 = rest.pop(0)
197
- if prenorm:
198
- residual = rest[0]
199
- out_pt, out1_pt, residual_pt = out_pt
200
- out_ref, out1_ref, residual_ref = out_ref
201
- else:
202
- out_pt, out1_pt = out_pt
203
- out_ref, out1_ref = out_ref
204
- assert out.dtype == input_dtype
205
- if prenorm:
206
- assert residual.dtype == residual_dtype
207
- assert allclose(residual, residual_pt, residual_ref)
208
- assert allclose(out, out_pt, out_ref)
209
- if out1 is not None:
210
- assert out1.dtype == input_dtype
211
- assert allclose(out1, out1_pt, out1_ref)
212
- if dropout_mask is not None:
213
- dropout_fraction = 1.0 - dropout_mask.float().mean()
214
- assert abs(dropout_fraction - dropout_p) < 0.01
215
- if dropout_mask1 is not None:
216
- dropout_fraction = 1.0 - dropout_mask1.float().mean()
217
- assert abs(dropout_fraction - dropout_p) < 0.01
218
- assert not torch.equal(dropout_mask, dropout_mask1)
219
-
220
- g = torch.randn_like(out) / batch_size
221
- if has_weight1:
222
- out = out * F.gelu(out1)
223
- out_pt = out_pt * F.gelu(out1_pt)
224
- out_ref = out_ref * F.gelu(out1_ref)
225
- if not prenorm:
226
- out.backward(g)
227
- out_pt.backward(g)
228
- out_ref.backward(g)
229
- else:
230
- (out * F.sigmoid(residual)).backward(g)
231
- (out_pt * F.sigmoid(residual_pt)).backward(g)
232
- (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
233
- assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
234
- if has_residual:
235
- assert allclose(res.grad, res_pt.grad, res_ref.grad)
236
- if has_x1:
237
- assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
238
- assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
239
- if bias is not None:
240
- assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
241
- if has_weight1:
242
- assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
243
- if bias1 is not None:
244
- assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
245
-
246
-
247
- @pytest.mark.parametrize("prenorm", [True, False])
248
- # @pytest.mark.parametrize("prenorm", [True])
249
- @pytest.mark.parametrize("is_rms_norm", [False, True])
250
- # @pytest.mark.parametrize("is_rms_norm", [True])
251
- @pytest.mark.parametrize("has_residual", [True, False])
252
- # @pytest.mark.parametrize("has_residual", [False])
253
- @pytest.mark.parametrize("weight_dtype", [torch.float32])
254
- @pytest.mark.parametrize(
255
- "input_dtype,residual_dtype",
256
- [(torch.float16, torch.float16), (torch.float16, torch.float32)]
257
- + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
258
- )
259
- # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
260
- @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
261
- # @pytest.mark.parametrize("hidden_size", [256])
262
- def test_layer_norm_linear(
263
- hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
264
- ):
265
- device = "cuda"
266
- if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
267
- atol = 5e-2
268
- elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
269
- atol = 1e-2
270
- else:
271
- atol = 1e-4
272
- # set seed
273
- torch.random.manual_seed(0)
274
- batch_size = 4
275
- seqlen = 512
276
- # batch_size = 1
277
- # seqlen = 1
278
- layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
279
- allclose = (
280
- lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
281
- <= 2 * (x_pt - x_ref).abs().max() + atol
282
- )
283
- x0 = torch.randn(
284
- batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
285
- )
286
- x0_pt = x0.detach().clone().requires_grad_()
287
- x0_ref = x0.detach().clone().requires_grad_()
288
- if has_residual:
289
- res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
290
- res_pt = res.detach().clone().requires_grad_()
291
- res_ref = res.detach().clone().requires_grad_()
292
- else:
293
- res, res_pt, res_ref = None, None, None
294
- norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
295
- if not is_rms_norm:
296
- norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
297
- else:
298
- norm_bias = None
299
- norm_weight_pt = norm_weight.detach().clone().requires_grad_()
300
- norm_weight_ref = norm_weight.detach().clone().requires_grad_()
301
- norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
302
- norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
303
- linear_weight = torch.empty(
304
- 2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
305
- )
306
- torch.nn.init.xavier_uniform_(linear_weight)
307
- if not is_rms_norm:
308
- linear_bias = torch.randn(
309
- 2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
310
- )
311
- else:
312
- linear_bias = None
313
- linear_weight_pt = linear_weight.detach().clone().requires_grad_()
314
- linear_weight_ref = linear_weight.detach().clone().requires_grad_()
315
- linear_bias_pt = (
316
- linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
317
- )
318
- linear_bias_ref = (
319
- linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
320
- )
321
-
322
- residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
323
- with torch.autocast(device_type="cuda", dtype=input_dtype):
324
- out, *rest = layer_norm_linear_fn(
325
- x0,
326
- norm_weight,
327
- norm_bias,
328
- linear_weight,
329
- linear_bias,
330
- residual=res,
331
- eps=1e-6,
332
- prenorm=prenorm,
333
- residual_in_fp32=residual_in_fp32,
334
- is_rms_norm=is_rms_norm,
335
- )
336
- out_pt, *rest_pt = layer_norm_ref_fn(
337
- x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
338
- )
339
- with torch.autocast(device_type="cuda", dtype=input_dtype):
340
- out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
341
- out_ref, *rest_ref = layer_norm_ref_fn(
342
- x0_ref,
343
- norm_weight_ref,
344
- norm_bias_ref,
345
- residual=res_ref,
346
- eps=1e-6,
347
- prenorm=prenorm,
348
- upcast=True,
349
- )
350
- out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
351
- if prenorm:
352
- residual = rest[0]
353
- residual_pt = rest_pt[0]
354
- residual_ref = rest_ref[0]
355
- assert out.dtype == input_dtype
356
- if prenorm:
357
- assert residual.dtype == residual_dtype
358
- assert allclose(residual, residual_pt, residual_ref)
359
- assert allclose(out, out_pt, out_ref)
360
-
361
- g = torch.randn_like(out) / batch_size
362
- out.backward(g)
363
- out_pt.backward(g)
364
- out_ref.backward(g)
365
- assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
366
- if has_residual:
367
- assert allclose(res.grad, res_pt.grad, res_ref.grad)
368
- assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
369
- if norm_bias is not None:
370
- assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
371
- assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
372
- if linear_bias is not None:
373
- assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/triton_layer_norm/__init__.py CHANGED
@@ -1,117 +1,5 @@
1
- """Triton layer normalization kernels
2
-
3
- This kernel implements layers normalization using Triton. This kernel is from
4
- the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
5
- """
6
-
7
- from typing import Optional
8
-
9
- import torch
10
-
11
- from . import layers
12
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
13
 
 
14
 
15
- def layer_norm(
16
- x: torch.Tensor,
17
- weight: torch.Tensor,
18
- bias: torch.Tensor,
19
- residual: Optional[torch.Tensor] = None,
20
- x1: Optional[torch.Tensor] = None,
21
- weight1: Optional[torch.Tensor] = None,
22
- bias1: Optional[torch.Tensor] = None,
23
- eps: float = 1e-6,
24
- dropout_p: float = 0.0,
25
- rowscale=None,
26
- prenorm: bool = False,
27
- residual_in_fp32: bool = False,
28
- zero_centered_weight: bool = False,
29
- is_rms_norm: bool = False,
30
- return_dropout_mask: bool = False,
31
- out: Optional[torch.Tensor] = None,
32
- residual_out: Optional[torch.Tensor] = None,
33
- ):
34
- """
35
- Apply layer normalization to the input tensor with Triton acceleration.
36
-
37
- Args:
38
- x (`torch.Tensor`):
39
- Input tensor to normalize.
40
- weight (`torch.Tensor`):
41
- Scale parameter for normalization.
42
- bias (`torch.Tensor`):
43
- Shift parameter for normalization.
44
- residual (`torch.Tensor`, *optional*):
45
- Optional residual tensor to add to the input before normalization.
46
- x1 (`torch.Tensor`, *optional*):
47
- Optional second input tensor to combine with `x`. When provided, the function
48
- first adds `x1` to `x` and then applies normalization.
49
- weight1 (`torch.Tensor`, *optional*):
50
- Scale parameter for the second normalization.
51
- bias1 (`torch.Tensor`, *optional*):
52
- Shift parameter for the second normalization.
53
- eps (`float`, *optional*, defaults to 1e-6):
54
- Small constant added for numerical stability in normalization.
55
- dropout_p (`float`, *optional*, defaults to 0.0):
56
- Dropout probability. If greater than 0, applies dropout to the input before
57
- normalization and residual addition.
58
- rowscale (`torch.Tensor`, *optional*):
59
- Optional scaling factor applied to each row of the input tensor.
60
- Not compatible with the use of `x1`.
61
- prenorm (`bool`, *optional*, defaults to False):
62
- If True, returns both the normalized output and the unnormalized input+residual.
63
- residual_in_fp32 (`bool`, *optional*, defaults to False):
64
- If True, performs the residual connection in FP32 precision.
65
- zero_centered_weight (`bool`, *optional*, defaults to False):
66
- When set to true, 1.0 is added to the weight before applying it.
67
- is_rms_norm (`bool`, *optional*, defaults to False):
68
- If True, uses RMS normalization instead of layer normalization.
69
- return_dropout_mask (`bool`, *optional*, defaults to False):
70
- If True, returns the dropout mask used for the computation.
71
- out (`torch.Tensor`, *optional*):
72
- Output tensor for the normalized result. If `None`, a new tensor is allocated.
73
- residual_out (`torch.Tensor`, *optional*):
74
- Output tensor for the residual result when using prenorm. If `None`, a new tensor
75
- is allocated when needed.
76
-
77
- Returns:
78
- `torch.Tensor` or tuple of `torch.Tensor`:
79
- - The normalized input.
80
- - The second normalization of the input if `weight1` is provided.
81
- - The residual tensor if `prenorm` is set.
82
- - The dropout mask if `return_dropout_mask` is set.
83
- - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
84
- """
85
- return layer_norm_fn(
86
- x,
87
- weight,
88
- bias,
89
- residual,
90
- x1,
91
- weight1,
92
- bias1,
93
- eps,
94
- dropout_p,
95
- rowscale,
96
- prenorm,
97
- residual_in_fp32,
98
- is_rms_norm,
99
- return_dropout_mask,
100
- out=out,
101
- residual_out=residual_out,
102
- )
103
-
104
-
105
- __kernel_metadata__ = {
106
- "license": "bsd-3-clause",
107
- }
108
-
109
-
110
- __all__ = [
111
- "__kernel_metadata__",
112
- "layers",
113
- "layer_norm",
114
- "layer_norm_fn",
115
- "layer_norm_linear_fn",
116
- "rms_norm_fn",
117
- ]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
+ from . import layers
4
 
5
+ __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/triton_layer_norm/layer_norm.py CHANGED
@@ -7,40 +7,14 @@
7
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
 
9
  import math
10
- from typing import Optional, List
11
 
12
  import torch
13
  import torch.nn.functional as F
14
- from torch import Tensor
15
 
16
  import triton
17
  import triton.language as tl
18
 
19
- from ._ops import add_op_namespace_prefix
20
- from .utils.torch import custom_fwd, custom_bwd
21
- from .utils.library import triton_op
22
-
23
-
24
- def maybe_contiguous_lastdim(x):
25
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
26
-
27
-
28
- def maybe_contiguous(x):
29
- return x.contiguous() if x is not None else None
30
-
31
-
32
- def triton_autotune_configs():
33
- # Return configs with a valid warp count for the current device
34
- configs = []
35
- # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
36
- max_threads_per_block = 1024
37
- # Default to warp size 32 if not defined by device
38
- warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
39
- # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
40
- return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32]
41
- if warp_count * warp_size <= max_threads_per_block]
42
- # return [triton.Config({}, num_warps=8)]
43
-
44
 
45
  def layer_norm_ref(
46
  x,
@@ -54,7 +28,6 @@ def layer_norm_ref(
54
  dropout_p=0.0,
55
  rowscale=None,
56
  prenorm=False,
57
- zero_centered_weight=False,
58
  dropout_mask=None,
59
  dropout_mask1=None,
60
  upcast=False,
@@ -68,10 +41,6 @@ def layer_norm_ref(
68
  x1 = x1.float() if x1 is not None else None
69
  weight1 = weight1.float() if weight1 is not None else None
70
  bias1 = bias1.float() if bias1 is not None else None
71
- if zero_centered_weight:
72
- weight = weight + 1.0
73
- if weight1 is not None:
74
- weight1 = weight1 + 1.0
75
  if x1 is not None:
76
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
77
  if rowscale is not None:
@@ -90,9 +59,9 @@ def layer_norm_ref(
90
  x = x + x1
91
  if residual is not None:
92
  x = (x + residual).to(x.dtype)
93
- out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
94
- dtype
95
- )
96
  if weight1 is None:
97
  return out if not prenorm else (out, x)
98
  else:
@@ -114,7 +83,6 @@ def rms_norm_ref(
114
  dropout_p=0.0,
115
  rowscale=None,
116
  prenorm=False,
117
- zero_centered_weight=False,
118
  dropout_mask=None,
119
  dropout_mask1=None,
120
  upcast=False,
@@ -128,10 +96,6 @@ def rms_norm_ref(
128
  x1 = x1.float() if x1 is not None else None
129
  weight1 = weight1.float() if weight1 is not None else None
130
  bias1 = bias1.float() if bias1 is not None else None
131
- if zero_centered_weight:
132
- weight = weight + 1.0
133
- if weight1 is not None:
134
- weight1 = weight1 + 1.0
135
  if x1 is not None:
136
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
137
  if rowscale is not None:
@@ -151,26 +115,34 @@ def rms_norm_ref(
151
  if residual is not None:
152
  x = (x + residual).to(x.dtype)
153
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
154
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
 
 
155
  if weight1 is None:
156
  return out if not prenorm else (out, x)
157
  else:
158
- out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
159
- dtype
160
- )
161
  return (out, out1) if not prenorm else (out, out1, x)
162
 
163
 
164
  @triton.autotune(
165
- configs=triton_autotune_configs(),
166
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"],
 
 
 
 
 
 
 
167
  )
168
- # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
169
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
170
  # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
171
- # @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
172
- # @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
173
- # @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
174
  @triton.jit
175
  def _layer_norm_fwd_1pass_kernel(
176
  X, # pointer to the input
@@ -186,7 +158,6 @@ def _layer_norm_fwd_1pass_kernel(
186
  ROWSCALE,
187
  SEEDS, # Dropout seeds for each row
188
  DROPOUT_MASK,
189
- DROPOUT_MASK1,
190
  Mean, # pointer to the mean
191
  Rstd, # pointer to the 1/std
192
  stride_x_row, # how much to increase the pointer when moving by 1 row
@@ -199,7 +170,6 @@ def _layer_norm_fwd_1pass_kernel(
199
  N, # number of columns in X
200
  eps, # epsilon to avoid division by zero
201
  dropout_p, # Dropout probability
202
- zero_centered_weight, # If true, add 1.0 to the weight
203
  IS_RMS_NORM: tl.constexpr,
204
  BLOCK_N: tl.constexpr,
205
  HAS_RESIDUAL: tl.constexpr,
@@ -233,7 +203,9 @@ def _layer_norm_fwd_1pass_kernel(
233
  if HAS_DROPOUT:
234
  # Compute dropout mask
235
  # 7 rounds is good enough, and reduces register pressure
236
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
237
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
238
  if STORE_DROPOUT_MASK:
239
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
@@ -246,11 +218,12 @@ def _layer_norm_fwd_1pass_kernel(
246
  # Compute dropout mask
247
  # 7 rounds is good enough, and reduces register pressure
248
  keep_mask = (
249
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
250
  )
251
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
252
  if STORE_DROPOUT_MASK:
253
- tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
254
  x += x1
255
  if HAS_RESIDUAL:
256
  residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
@@ -270,8 +243,6 @@ def _layer_norm_fwd_1pass_kernel(
270
  # Normalize and apply linear transformation
271
  mask = cols < N
272
  w = tl.load(W + cols, mask=mask).to(tl.float32)
273
- if zero_centered_weight:
274
- w += 1.0
275
  if HAS_BIAS:
276
  b = tl.load(B + cols, mask=mask).to(tl.float32)
277
  x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
@@ -280,8 +251,6 @@ def _layer_norm_fwd_1pass_kernel(
280
  tl.store(Y + cols, y, mask=mask)
281
  if HAS_W1:
282
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
283
- if zero_centered_weight:
284
- w1 += 1.0
285
  if HAS_B1:
286
  b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
287
  y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
@@ -289,87 +258,25 @@ def _layer_norm_fwd_1pass_kernel(
289
 
290
 
291
  def _layer_norm_fwd(
292
- x: Tensor,
293
- weight: Tensor,
294
- bias: Tensor,
295
- eps: float,
296
- residual: Optional[Tensor] = None,
297
- x1: Optional[Tensor] = None,
298
- weight1: Optional[Tensor] = None,
299
- bias1: Optional[Tensor] = None,
300
- dropout_p: float = 0.0,
301
- rowscale: Optional[Tensor] = None,
302
- out_dtype: Optional[torch.dtype] = None,
303
- residual_dtype: Optional[torch.dtype] = None,
304
- zero_centered_weight: bool = False,
305
- is_rms_norm: bool = False,
306
- return_dropout_mask: bool = False,
307
- out: Optional[Tensor] = None,
308
- residual_out: Optional[Tensor] = None
309
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
310
- # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
311
- # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
312
- # so that _layer_norm_fwd_impl doesn't have to return them.
313
- if out is None:
314
- out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
315
  if residual is not None:
316
  residual_dtype = residual.dtype
317
- if residual_out is None and (
318
- residual is not None
319
- or (residual_dtype is not None and residual_dtype != x.dtype)
320
- or dropout_p > 0.0
321
- or rowscale is not None
322
- or x1 is not None
323
- ):
324
- residual_out = torch.empty_like(
325
- x, dtype=residual_dtype if residual_dtype is not None else x.dtype
326
- )
327
- else:
328
- residual_out = None
329
- y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
330
- x,
331
- weight,
332
- bias,
333
- eps,
334
- out,
335
- residual=residual,
336
- x1=x1,
337
- weight1=weight1,
338
- bias1=bias1,
339
- dropout_p=dropout_p,
340
- rowscale=rowscale,
341
- zero_centered_weight=zero_centered_weight,
342
- is_rms_norm=is_rms_norm,
343
- return_dropout_mask=return_dropout_mask,
344
- residual_out=residual_out,
345
- )
346
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
347
- if residual_out is None:
348
- residual_out = x
349
- return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
350
-
351
-
352
- # [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
353
- # since we're returning a tuple of tensors
354
- @triton_op(add_op_namespace_prefix("layer_norm_fwd_impl"), mutates_args={"out", "residual_out"},
355
- schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)")
356
- def _layer_norm_fwd_impl(
357
- x: Tensor,
358
- weight: Tensor,
359
- bias: Tensor,
360
- eps: float,
361
- out: Tensor,
362
- residual: Optional[Tensor] = None,
363
- x1: Optional[Tensor] = None,
364
- weight1: Optional[Tensor] = None,
365
- bias1: Optional[Tensor] = None,
366
- dropout_p: float = 0.0,
367
- rowscale: Optional[Tensor] = None,
368
- zero_centered_weight: bool = False,
369
- is_rms_norm: bool = False,
370
- return_dropout_mask: bool = False,
371
- residual_out: Optional[Tensor] = None
372
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
373
  M, N = x.shape
374
  assert x.stride(-1) == 1
375
  if residual is not None:
@@ -393,17 +300,41 @@ def _layer_norm_fwd_impl(
393
  if rowscale is not None:
394
  assert rowscale.is_contiguous()
395
  assert rowscale.shape == (M,)
396
- assert out.shape == x.shape
 
 
 
 
397
  assert out.stride(-1) == 1
398
- if residual_out is not None:
399
- assert residual_out.shape == x.shape
400
- assert residual_out.stride(-1) == 1
401
  if weight1 is not None:
402
  y1 = torch.empty_like(out)
403
  assert y1.stride(-1) == 1
404
  else:
405
  y1 = None
406
- mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
408
  if dropout_p > 0.0:
409
  seeds = torch.randint(
@@ -412,20 +343,18 @@ def _layer_norm_fwd_impl(
412
  else:
413
  seeds = None
414
  if return_dropout_mask and dropout_p > 0.0:
415
- dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
416
- if x1 is not None:
417
- dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
418
- else:
419
- dropout_mask1 = None
420
  else:
421
- dropout_mask, dropout_mask1 = None, None
422
  # Less than 64KB per feature: enqueue fused kernel
423
  MAX_FUSED_SIZE = 65536 // x.element_size()
424
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
425
  if N > BLOCK_N:
426
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
427
  with torch.cuda.device(x.device.index):
428
- torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
429
  x,
430
  out,
431
  weight,
@@ -439,7 +368,6 @@ def _layer_norm_fwd_impl(
439
  rowscale,
440
  seeds,
441
  dropout_mask,
442
- dropout_mask1,
443
  mean,
444
  rstd,
445
  x.stride(0),
@@ -452,8 +380,6 @@ def _layer_norm_fwd_impl(
452
  N,
453
  eps,
454
  dropout_p,
455
- # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
456
- int(zero_centered_weight),
457
  is_rms_norm,
458
  BLOCK_N,
459
  residual is not None,
@@ -462,26 +388,50 @@ def _layer_norm_fwd_impl(
462
  dropout_p > 0.0,
463
  dropout_mask is not None,
464
  rowscale is not None,
465
- HAS_X1=x1 is not None,
466
- HAS_W1=weight1 is not None,
467
- HAS_B1=bias1 is not None,
468
  )
469
- return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
 
472
  @triton.autotune(
473
- configs=triton_autotune_configs(),
474
- key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  )
476
- # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
477
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
478
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
479
  # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
480
- # @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
481
- # @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
482
- # @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
483
- # @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
484
- # @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
485
  @triton.jit
486
  def _layer_norm_bwd_kernel(
487
  X, # pointer to the input
@@ -515,7 +465,6 @@ def _layer_norm_bwd_kernel(
515
  N, # number of columns in X
516
  eps, # epsilon to avoid division by zero
517
  dropout_p,
518
- zero_centered_weight,
519
  rows_per_program,
520
  IS_RMS_NORM: tl.constexpr,
521
  BLOCK_N: tl.constexpr,
@@ -549,14 +498,10 @@ def _layer_norm_bwd_kernel(
549
  if RECOMPUTE_OUTPUT:
550
  Y += row_start * stride_y_row
551
  w = tl.load(W + cols, mask=mask).to(tl.float32)
552
- if zero_centered_weight:
553
- w += 1.0
554
  if RECOMPUTE_OUTPUT and HAS_BIAS:
555
  b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
556
  if HAS_DY1:
557
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
558
- if zero_centered_weight:
559
- w1 += 1.0
560
  dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
561
  if HAS_BIAS:
562
  db = tl.zeros((BLOCK_N,), dtype=tl.float32)
@@ -605,14 +550,18 @@ def _layer_norm_bwd_kernel(
605
  if HAS_DX1:
606
  if HAS_DROPOUT:
607
  keep_mask = (
608
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
609
  )
610
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
611
  else:
612
  dx1 = dx
613
  tl.store(DX1 + cols, dx1, mask=mask)
614
  if HAS_DROPOUT:
615
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
 
616
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
617
  if HAS_ROWSCALE:
618
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
@@ -642,93 +591,31 @@ def _layer_norm_bwd_kernel(
642
 
643
 
644
  def _layer_norm_bwd(
645
- dy: Tensor,
646
- x: Tensor,
647
- weight: Tensor,
648
- bias: Tensor,
649
- eps: float,
650
- mean: Tensor,
651
- rstd: Tensor,
652
- dresidual: Optional[Tensor] = None,
653
- dy1: Optional[Tensor] = None,
654
- weight1: Optional[Tensor] = None,
655
- bias1: Optional[Tensor] = None,
656
- seeds: Optional[Tensor] = None,
657
- dropout_p: float = 0.0,
658
- rowscale: Optional[Tensor] = None,
659
- has_residual: bool = False,
660
- has_x1: bool = False,
661
- zero_centered_weight: bool = False,
662
- is_rms_norm: bool = False,
663
- x_dtype: Optional[torch.dtype] = None,
664
- recompute_output: bool = False,
665
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
666
- # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x,
667
- # which makes torch.library unhappy
668
- dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl(
669
- dy,
670
- x,
671
- weight,
672
- bias,
673
- eps,
674
- mean,
675
- rstd,
676
- dresidual,
677
- dy1,
678
- weight1,
679
- bias1,
680
- seeds,
681
- dropout_p,
682
- rowscale,
683
- has_residual,
684
- has_x1,
685
- zero_centered_weight,
686
- is_rms_norm,
687
- x_dtype=x_dtype,
688
- recompute_output=recompute_output,
689
- )
690
- # Don't need to compute dresidual_in separately in this case
691
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
692
- dresidual_in = dx
693
- if has_x1 and dropout_p == 0.0:
694
- dx1 = dx
695
- return dx, dw, db, dresidual_in, dx1, dw1, db1, y
696
-
697
-
698
-
699
- @triton_op(add_op_namespace_prefix("layer_norm_bwd_impl"), mutates_args={},
700
- schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)",
701
- allow_decomposition=False, # Don't let torch.compile trace inside
702
- )
703
- def _layer_norm_bwd_impl(
704
- dy: Tensor,
705
- x: Tensor,
706
- weight: Tensor,
707
- bias: Tensor,
708
- eps: float,
709
- mean: Tensor,
710
- rstd: Tensor,
711
- dresidual: Optional[Tensor] = None,
712
- dy1: Optional[Tensor] = None,
713
- weight1: Optional[Tensor] = None,
714
- bias1: Optional[Tensor] = None,
715
- seeds: Optional[Tensor] = None,
716
- dropout_p: float = 0.0,
717
- rowscale: Optional[Tensor] = None,
718
- has_residual: bool = False,
719
- has_x1: bool = False,
720
- zero_centered_weight: bool = False,
721
- is_rms_norm: bool = False,
722
- x_dtype: Optional[torch.dtype] = None,
723
- recompute_output: bool = False,
724
- ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
725
  M, N = x.shape
726
  assert x.stride(-1) == 1
727
- dy = maybe_contiguous_lastdim(dy)
728
  assert dy.stride(-1) == 1
729
  assert dy.shape == (M, N)
730
  if dresidual is not None:
731
- dresidual = maybe_contiguous_lastdim(dresidual)
732
  assert dresidual.stride(-1) == 1
733
  assert dresidual.shape == (M, N)
734
  assert weight.shape == (N,)
@@ -737,7 +624,6 @@ def _layer_norm_bwd_impl(
737
  assert bias.stride(-1) == 1
738
  assert bias.shape == (N,)
739
  if dy1 is not None:
740
- dy1 = maybe_contiguous_lastdim(dy1)
741
  assert weight1 is not None
742
  assert dy1.shape == dy.shape
743
  assert dy1.stride(-1) == 1
@@ -766,18 +652,22 @@ def _layer_norm_bwd_impl(
766
  else None
767
  )
768
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
769
- y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
 
 
 
 
770
  if recompute_output:
771
- assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
 
 
772
 
773
  # Less than 64KB per feature: enqueue fused kernel
774
  MAX_FUSED_SIZE = 65536 // x.element_size()
775
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
776
  if N > BLOCK_N:
777
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
778
- # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
779
- # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
780
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
781
  _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
782
  _db = (
783
  torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
@@ -789,7 +679,7 @@ def _layer_norm_bwd_impl(
789
  rows_per_program = math.ceil(M / sm_count)
790
  grid = (sm_count,)
791
  with torch.cuda.device(x.device.index):
792
- torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid](
793
  x,
794
  weight,
795
  bias,
@@ -821,8 +711,6 @@ def _layer_norm_bwd_impl(
821
  N,
822
  eps,
823
  dropout_p,
824
- # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
825
- int(zero_centered_weight),
826
  rows_per_program,
827
  is_rms_norm,
828
  BLOCK_N,
@@ -830,22 +718,24 @@ def _layer_norm_bwd_impl(
830
  dresidual_in is not None,
831
  bias is not None,
832
  dropout_p > 0.0,
833
- HAS_ROWSCALE=rowscale is not None,
834
- HAS_DY1=dy1 is not None,
835
- HAS_DX1=dx1 is not None,
836
- HAS_B1=bias1 is not None,
837
- RECOMPUTE_OUTPUT=y is not None,
838
  )
839
  dw = _dw.sum(0).to(weight.dtype)
840
  db = _db.sum(0).to(bias.dtype) if bias is not None else None
841
  dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
842
  db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
843
- # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx
844
- return dx, dw, db, dresidual_in, dx1, dw1, db1, y
 
 
 
 
 
 
 
 
845
 
846
 
847
  class LayerNormFn(torch.autograd.Function):
848
-
849
  @staticmethod
850
  def forward(
851
  ctx,
@@ -861,27 +751,34 @@ class LayerNormFn(torch.autograd.Function):
861
  rowscale=None,
862
  prenorm=False,
863
  residual_in_fp32=False,
864
- zero_centered_weight=False,
865
  is_rms_norm=False,
866
  return_dropout_mask=False,
867
- out_dtype=None,
868
  out=None,
869
- residual_out=None
870
  ):
871
  x_shape_og = x.shape
872
  # reshape input data into 2D tensor
873
- x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
 
 
874
  if residual is not None:
875
  assert residual.shape == x_shape_og
876
- residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
 
 
877
  if x1 is not None:
878
  assert x1.shape == x_shape_og
879
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
880
- x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
 
 
881
  weight = weight.contiguous()
882
- bias = maybe_contiguous(bias)
883
- weight1 = maybe_contiguous(weight1)
884
- bias1 = maybe_contiguous(bias1)
 
 
 
885
  if rowscale is not None:
886
  rowscale = rowscale.reshape(-1).contiguous()
887
  residual_dtype = (
@@ -893,24 +790,24 @@ class LayerNormFn(torch.autograd.Function):
893
  out = out.reshape(-1, out.shape[-1])
894
  if residual_out is not None:
895
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
896
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
897
- x,
898
- weight,
899
- bias,
900
- eps,
901
- residual,
902
- x1,
903
- weight1,
904
- bias1,
905
- dropout_p=dropout_p,
906
- rowscale=rowscale,
907
- out_dtype=out_dtype,
908
- residual_dtype=residual_dtype,
909
- zero_centered_weight=zero_centered_weight,
910
- is_rms_norm=is_rms_norm,
911
- return_dropout_mask=return_dropout_mask,
912
- out=out,
913
- residual_out=residual_out,
914
  )
915
  ctx.save_for_backward(
916
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@@ -923,12 +820,17 @@ class LayerNormFn(torch.autograd.Function):
923
  ctx.has_x1 = x1 is not None
924
  ctx.prenorm = prenorm
925
  ctx.x_dtype = x.dtype
926
- ctx.zero_centered_weight = zero_centered_weight
927
  y = y.reshape(x_shape_og)
928
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
929
- residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
930
- dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
931
- dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
 
 
 
 
 
 
932
  if not return_dropout_mask:
933
  if weight1 is None:
934
  return y if not prenorm else (y, residual_out)
@@ -952,19 +854,26 @@ class LayerNormFn(torch.autograd.Function):
952
  def backward(ctx, dy, *args):
953
  x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
954
  dy = dy.reshape(-1, dy.shape[-1])
 
 
 
955
  if weight1 is not None:
956
  dy1, args = args[0], args[1:]
957
  dy1 = dy1.reshape(-1, dy1.shape[-1])
 
 
958
  assert dy1.shape == x.shape
959
  else:
960
  dy1 = None
961
  if ctx.prenorm:
962
  dresidual = args[0]
963
  dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 
 
964
  assert dresidual.shape == x.shape
965
  else:
966
  dresidual = None
967
- dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd(
968
  dy,
969
  x,
970
  weight,
@@ -981,10 +890,8 @@ class LayerNormFn(torch.autograd.Function):
981
  rowscale,
982
  ctx.has_residual,
983
  ctx.has_x1,
984
- ctx.zero_centered_weight,
985
  ctx.is_rms_norm,
986
  x_dtype=ctx.x_dtype,
987
- recompute_output=False,
988
  )
989
  return (
990
  dx.reshape(ctx.x_shape_og),
@@ -1003,8 +910,6 @@ class LayerNormFn(torch.autograd.Function):
1003
  None,
1004
  None,
1005
  None,
1006
- None,
1007
- None,
1008
  )
1009
 
1010
 
@@ -1021,12 +926,10 @@ def layer_norm_fn(
1021
  rowscale=None,
1022
  prenorm=False,
1023
  residual_in_fp32=False,
1024
- zero_centered_weight=False,
1025
  is_rms_norm=False,
1026
  return_dropout_mask=False,
1027
- out_dtype=None,
1028
  out=None,
1029
- residual_out=None
1030
  ):
1031
  return LayerNormFn.apply(
1032
  x,
@@ -1041,12 +944,10 @@ def layer_norm_fn(
1041
  rowscale,
1042
  prenorm,
1043
  residual_in_fp32,
1044
- zero_centered_weight,
1045
  is_rms_norm,
1046
  return_dropout_mask,
1047
- out_dtype,
1048
  out,
1049
- residual_out
1050
  )
1051
 
1052
 
@@ -1063,11 +964,9 @@ def rms_norm_fn(
1063
  rowscale=None,
1064
  prenorm=False,
1065
  residual_in_fp32=False,
1066
- zero_centered_weight=False,
1067
  return_dropout_mask=False,
1068
- out_dtype=None,
1069
  out=None,
1070
- residual_out=None
1071
  ):
1072
  return LayerNormFn.apply(
1073
  x,
@@ -1082,19 +981,16 @@ def rms_norm_fn(
1082
  rowscale,
1083
  prenorm,
1084
  residual_in_fp32,
1085
- zero_centered_weight,
1086
  True,
1087
  return_dropout_mask,
1088
- out_dtype,
1089
  out,
1090
- residual_out
1091
  )
1092
 
1093
 
1094
  class RMSNorm(torch.nn.Module):
1095
 
1096
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
1097
- device=None, dtype=None):
1098
  factory_kwargs = {"device": device, "dtype": dtype}
1099
  super().__init__()
1100
  self.eps = eps
@@ -1102,16 +998,12 @@ class RMSNorm(torch.nn.Module):
1102
  self.drop = torch.nn.Dropout(dropout_p)
1103
  else:
1104
  self.drop = None
1105
- self.zero_centered_weight = zero_centered_weight
1106
  self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1107
  self.register_parameter("bias", None)
1108
  self.reset_parameters()
1109
 
1110
  def reset_parameters(self):
1111
- if not self.zero_centered_weight:
1112
- torch.nn.init.ones_(self.weight)
1113
- else:
1114
- torch.nn.init.zeros_(self.weight)
1115
 
1116
  def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1117
  return rms_norm_fn(
@@ -1123,14 +1015,12 @@ class RMSNorm(torch.nn.Module):
1123
  dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1124
  prenorm=prenorm,
1125
  residual_in_fp32=residual_in_fp32,
1126
- zero_centered_weight=self.zero_centered_weight,
1127
  )
1128
 
1129
 
1130
  class LayerNormLinearFn(torch.autograd.Function):
1131
-
1132
  @staticmethod
1133
- @custom_fwd
1134
  def forward(
1135
  ctx,
1136
  x,
@@ -1146,12 +1036,17 @@ class LayerNormLinearFn(torch.autograd.Function):
1146
  ):
1147
  x_shape_og = x.shape
1148
  # reshape input data into 2D tensor
1149
- x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
 
 
1150
  if residual is not None:
1151
  assert residual.shape == x_shape_og
1152
- residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
 
 
1153
  norm_weight = norm_weight.contiguous()
1154
- norm_bias = maybe_contiguous(norm_bias)
 
1155
  residual_dtype = (
1156
  residual.dtype
1157
  if residual is not None
@@ -1163,17 +1058,25 @@ class LayerNormLinearFn(torch.autograd.Function):
1163
  norm_bias,
1164
  eps,
1165
  residual,
1166
- out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
 
 
 
 
1167
  residual_dtype=residual_dtype,
1168
  is_rms_norm=is_rms_norm,
1169
  )
1170
  y = y.reshape(x_shape_og)
1171
- dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
 
 
1172
  linear_weight = linear_weight.to(dtype)
1173
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1174
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1175
  # We don't store y, will be recomputed in the backward pass to save memory
1176
- ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
 
 
1177
  ctx.x_shape_og = x_shape_og
1178
  ctx.eps = eps
1179
  ctx.is_rms_norm = is_rms_norm
@@ -1184,17 +1087,20 @@ class LayerNormLinearFn(torch.autograd.Function):
1184
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1185
 
1186
  @staticmethod
1187
- @custom_bwd
1188
  def backward(ctx, dout, *args):
1189
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1190
  dout = dout.reshape(-1, dout.shape[-1])
1191
  dy = F.linear(dout, linear_weight.t())
1192
  dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1193
- dy = maybe_contiguous_lastdim(dy)
 
1194
  assert dy.shape == x.shape
1195
  if ctx.prenorm:
1196
  dresidual = args[0]
1197
- dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1]))
 
 
1198
  assert dresidual.shape == x.shape
1199
  else:
1200
  dresidual = None
 
7
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
 
9
  import math
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
+ from torch.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def layer_norm_ref(
20
  x,
 
28
  dropout_p=0.0,
29
  rowscale=None,
30
  prenorm=False,
 
31
  dropout_mask=None,
32
  dropout_mask1=None,
33
  upcast=False,
 
41
  x1 = x1.float() if x1 is not None else None
42
  weight1 = weight1.float() if weight1 is not None else None
43
  bias1 = bias1.float() if bias1 is not None else None
 
 
 
 
44
  if x1 is not None:
45
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
  if rowscale is not None:
 
59
  x = x + x1
60
  if residual is not None:
61
  x = (x + residual).to(x.dtype)
62
+ out = F.layer_norm(
63
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
64
+ ).to(dtype)
65
  if weight1 is None:
66
  return out if not prenorm else (out, x)
67
  else:
 
83
  dropout_p=0.0,
84
  rowscale=None,
85
  prenorm=False,
 
86
  dropout_mask=None,
87
  dropout_mask1=None,
88
  upcast=False,
 
96
  x1 = x1.float() if x1 is not None else None
97
  weight1 = weight1.float() if weight1 is not None else None
98
  bias1 = bias1.float() if bias1 is not None else None
 
 
 
 
99
  if x1 is not None:
100
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
  if rowscale is not None:
 
115
  if residual is not None:
116
  x = (x + residual).to(x.dtype)
117
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
119
+ dtype
120
+ )
121
  if weight1 is None:
122
  return out if not prenorm else (out, x)
123
  else:
124
+ out1 = (
125
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
126
+ ).to(dtype)
127
  return (out, out1) if not prenorm else (out, out1, x)
128
 
129
 
130
  @triton.autotune(
131
+ configs=[
132
+ triton.Config({}, num_warps=1),
133
+ triton.Config({}, num_warps=2),
134
+ triton.Config({}, num_warps=4),
135
+ triton.Config({}, num_warps=8),
136
+ triton.Config({}, num_warps=16),
137
+ triton.Config({}, num_warps=32),
138
+ ],
139
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
140
  )
 
141
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
142
  # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
143
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
144
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
145
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
146
  @triton.jit
147
  def _layer_norm_fwd_1pass_kernel(
148
  X, # pointer to the input
 
158
  ROWSCALE,
159
  SEEDS, # Dropout seeds for each row
160
  DROPOUT_MASK,
 
161
  Mean, # pointer to the mean
162
  Rstd, # pointer to the 1/std
163
  stride_x_row, # how much to increase the pointer when moving by 1 row
 
170
  N, # number of columns in X
171
  eps, # epsilon to avoid division by zero
172
  dropout_p, # Dropout probability
 
173
  IS_RMS_NORM: tl.constexpr,
174
  BLOCK_N: tl.constexpr,
175
  HAS_RESIDUAL: tl.constexpr,
 
203
  if HAS_DROPOUT:
204
  # Compute dropout mask
205
  # 7 rounds is good enough, and reduces register pressure
206
+ keep_mask = (
207
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
208
+ )
209
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
210
  if STORE_DROPOUT_MASK:
211
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
 
218
  # Compute dropout mask
219
  # 7 rounds is good enough, and reduces register pressure
220
  keep_mask = (
221
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
222
+ > dropout_p
223
  )
224
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
225
  if STORE_DROPOUT_MASK:
226
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
227
  x += x1
228
  if HAS_RESIDUAL:
229
  residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
 
243
  # Normalize and apply linear transformation
244
  mask = cols < N
245
  w = tl.load(W + cols, mask=mask).to(tl.float32)
 
 
246
  if HAS_BIAS:
247
  b = tl.load(B + cols, mask=mask).to(tl.float32)
248
  x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 
251
  tl.store(Y + cols, y, mask=mask)
252
  if HAS_W1:
253
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
 
 
254
  if HAS_B1:
255
  b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
256
  y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
 
258
 
259
 
260
  def _layer_norm_fwd(
261
+ x,
262
+ weight,
263
+ bias,
264
+ eps,
265
+ residual=None,
266
+ x1=None,
267
+ weight1=None,
268
+ bias1=None,
269
+ dropout_p=0.0,
270
+ rowscale=None,
271
+ out_dtype=None,
272
+ residual_dtype=None,
273
+ is_rms_norm=False,
274
+ return_dropout_mask=False,
275
+ out=None,
276
+ residual_out=None,
277
+ ):
 
 
 
 
 
 
278
  if residual is not None:
279
  residual_dtype = residual.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  M, N = x.shape
281
  assert x.stride(-1) == 1
282
  if residual is not None:
 
300
  if rowscale is not None:
301
  assert rowscale.is_contiguous()
302
  assert rowscale.shape == (M,)
303
+ # allocate output
304
+ if out is None:
305
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
306
+ else:
307
+ assert out.shape == x.shape
308
  assert out.stride(-1) == 1
 
 
 
309
  if weight1 is not None:
310
  y1 = torch.empty_like(out)
311
  assert y1.stride(-1) == 1
312
  else:
313
  y1 = None
314
+ if (
315
+ residual is not None
316
+ or (residual_dtype is not None and residual_dtype != x.dtype)
317
+ or dropout_p > 0.0
318
+ or rowscale is not None
319
+ or x1 is not None
320
+ ):
321
+ if residual_out is None:
322
+ residual_out = torch.empty(
323
+ M,
324
+ N,
325
+ device=x.device,
326
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
327
+ )
328
+ else:
329
+ assert residual_out.shape == x.shape
330
+ assert residual_out.stride(-1) == 1
331
+ else:
332
+ residual_out = None
333
+ mean = (
334
+ torch.empty((M,), dtype=torch.float32, device=x.device)
335
+ if not is_rms_norm
336
+ else None
337
+ )
338
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
339
  if dropout_p > 0.0:
340
  seeds = torch.randint(
 
343
  else:
344
  seeds = None
345
  if return_dropout_mask and dropout_p > 0.0:
346
+ dropout_mask = torch.empty(
347
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
348
+ )
 
 
349
  else:
350
+ dropout_mask = None
351
  # Less than 64KB per feature: enqueue fused kernel
352
  MAX_FUSED_SIZE = 65536 // x.element_size()
353
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
354
  if N > BLOCK_N:
355
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
356
  with torch.cuda.device(x.device.index):
357
+ _layer_norm_fwd_1pass_kernel[(M,)](
358
  x,
359
  out,
360
  weight,
 
368
  rowscale,
369
  seeds,
370
  dropout_mask,
 
371
  mean,
372
  rstd,
373
  x.stride(0),
 
380
  N,
381
  eps,
382
  dropout_p,
 
 
383
  is_rms_norm,
384
  BLOCK_N,
385
  residual is not None,
 
388
  dropout_p > 0.0,
389
  dropout_mask is not None,
390
  rowscale is not None,
 
 
 
391
  )
392
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
393
+ if dropout_mask is not None and x1 is not None:
394
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
395
+ else:
396
+ dropout_mask1 = None
397
+ return (
398
+ out,
399
+ y1,
400
+ mean,
401
+ rstd,
402
+ residual_out if residual_out is not None else x,
403
+ seeds,
404
+ dropout_mask,
405
+ dropout_mask1,
406
+ )
407
 
408
 
409
  @triton.autotune(
410
+ configs=[
411
+ triton.Config({}, num_warps=1),
412
+ triton.Config({}, num_warps=2),
413
+ triton.Config({}, num_warps=4),
414
+ triton.Config({}, num_warps=8),
415
+ triton.Config({}, num_warps=16),
416
+ triton.Config({}, num_warps=32),
417
+ ],
418
+ key=[
419
+ "N",
420
+ "HAS_DRESIDUAL",
421
+ "STORE_DRESIDUAL",
422
+ "IS_RMS_NORM",
423
+ "HAS_BIAS",
424
+ "HAS_DROPOUT",
425
+ ],
426
  )
 
427
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
428
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
429
  # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
430
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
431
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
432
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
433
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
434
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
435
  @triton.jit
436
  def _layer_norm_bwd_kernel(
437
  X, # pointer to the input
 
465
  N, # number of columns in X
466
  eps, # epsilon to avoid division by zero
467
  dropout_p,
 
468
  rows_per_program,
469
  IS_RMS_NORM: tl.constexpr,
470
  BLOCK_N: tl.constexpr,
 
498
  if RECOMPUTE_OUTPUT:
499
  Y += row_start * stride_y_row
500
  w = tl.load(W + cols, mask=mask).to(tl.float32)
 
 
501
  if RECOMPUTE_OUTPUT and HAS_BIAS:
502
  b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
503
  if HAS_DY1:
504
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
 
 
505
  dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
  if HAS_BIAS:
507
  db = tl.zeros((BLOCK_N,), dtype=tl.float32)
 
550
  if HAS_DX1:
551
  if HAS_DROPOUT:
552
  keep_mask = (
553
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
554
+ > dropout_p
555
  )
556
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
557
  else:
558
  dx1 = dx
559
  tl.store(DX1 + cols, dx1, mask=mask)
560
  if HAS_DROPOUT:
561
+ keep_mask = (
562
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
563
+ > dropout_p
564
+ )
565
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
566
  if HAS_ROWSCALE:
567
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
 
591
 
592
 
593
  def _layer_norm_bwd(
594
+ dy,
595
+ x,
596
+ weight,
597
+ bias,
598
+ eps,
599
+ mean,
600
+ rstd,
601
+ dresidual=None,
602
+ dy1=None,
603
+ weight1=None,
604
+ bias1=None,
605
+ seeds=None,
606
+ dropout_p=0.0,
607
+ rowscale=None,
608
+ has_residual=False,
609
+ has_x1=False,
610
+ is_rms_norm=False,
611
+ x_dtype=None,
612
+ recompute_output=False,
613
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  M, N = x.shape
615
  assert x.stride(-1) == 1
 
616
  assert dy.stride(-1) == 1
617
  assert dy.shape == (M, N)
618
  if dresidual is not None:
 
619
  assert dresidual.stride(-1) == 1
620
  assert dresidual.shape == (M, N)
621
  assert weight.shape == (N,)
 
624
  assert bias.stride(-1) == 1
625
  assert bias.shape == (N,)
626
  if dy1 is not None:
 
627
  assert weight1 is not None
628
  assert dy1.shape == dy.shape
629
  assert dy1.stride(-1) == 1
 
652
  else None
653
  )
654
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
655
+ y = (
656
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
657
+ if recompute_output
658
+ else None
659
+ )
660
  if recompute_output:
661
+ assert (
662
+ weight1 is None
663
+ ), "recompute_output is not supported with parallel LayerNorm"
664
 
665
  # Less than 64KB per feature: enqueue fused kernel
666
  MAX_FUSED_SIZE = 65536 // x.element_size()
667
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
668
  if N > BLOCK_N:
669
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
670
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
 
 
671
  _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
672
  _db = (
673
  torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
 
679
  rows_per_program = math.ceil(M / sm_count)
680
  grid = (sm_count,)
681
  with torch.cuda.device(x.device.index):
682
+ _layer_norm_bwd_kernel[grid](
683
  x,
684
  weight,
685
  bias,
 
711
  N,
712
  eps,
713
  dropout_p,
 
 
714
  rows_per_program,
715
  is_rms_norm,
716
  BLOCK_N,
 
718
  dresidual_in is not None,
719
  bias is not None,
720
  dropout_p > 0.0,
 
 
 
 
 
721
  )
722
  dw = _dw.sum(0).to(weight.dtype)
723
  db = _db.sum(0).to(bias.dtype) if bias is not None else None
724
  dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
725
  db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
726
+ # Don't need to compute dresidual_in separately in this case
727
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
728
+ dresidual_in = dx
729
+ if has_x1 and dropout_p == 0.0:
730
+ dx1 = dx
731
+ return (
732
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
733
+ if not recompute_output
734
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
735
+ )
736
 
737
 
738
  class LayerNormFn(torch.autograd.Function):
 
739
  @staticmethod
740
  def forward(
741
  ctx,
 
751
  rowscale=None,
752
  prenorm=False,
753
  residual_in_fp32=False,
 
754
  is_rms_norm=False,
755
  return_dropout_mask=False,
 
756
  out=None,
757
+ residual_out=None,
758
  ):
759
  x_shape_og = x.shape
760
  # reshape input data into 2D tensor
761
+ x = x.reshape(-1, x.shape[-1])
762
+ if x.stride(-1) != 1:
763
+ x = x.contiguous()
764
  if residual is not None:
765
  assert residual.shape == x_shape_og
766
+ residual = residual.reshape(-1, residual.shape[-1])
767
+ if residual.stride(-1) != 1:
768
+ residual = residual.contiguous()
769
  if x1 is not None:
770
  assert x1.shape == x_shape_og
771
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
772
+ x1 = x1.reshape(-1, x1.shape[-1])
773
+ if x1.stride(-1) != 1:
774
+ x1 = x1.contiguous()
775
  weight = weight.contiguous()
776
+ if bias is not None:
777
+ bias = bias.contiguous()
778
+ if weight1 is not None:
779
+ weight1 = weight1.contiguous()
780
+ if bias1 is not None:
781
+ bias1 = bias1.contiguous()
782
  if rowscale is not None:
783
  rowscale = rowscale.reshape(-1).contiguous()
784
  residual_dtype = (
 
790
  out = out.reshape(-1, out.shape[-1])
791
  if residual_out is not None:
792
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
793
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
794
+ _layer_norm_fwd(
795
+ x,
796
+ weight,
797
+ bias,
798
+ eps,
799
+ residual,
800
+ x1,
801
+ weight1,
802
+ bias1,
803
+ dropout_p=dropout_p,
804
+ rowscale=rowscale,
805
+ residual_dtype=residual_dtype,
806
+ is_rms_norm=is_rms_norm,
807
+ return_dropout_mask=return_dropout_mask,
808
+ out=out,
809
+ residual_out=residual_out,
810
+ )
811
  )
812
  ctx.save_for_backward(
813
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
 
820
  ctx.has_x1 = x1 is not None
821
  ctx.prenorm = prenorm
822
  ctx.x_dtype = x.dtype
 
823
  y = y.reshape(x_shape_og)
824
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
825
+ residual_out = (
826
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
827
+ )
828
+ dropout_mask = (
829
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
830
+ )
831
+ dropout_mask1 = (
832
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
833
+ )
834
  if not return_dropout_mask:
835
  if weight1 is None:
836
  return y if not prenorm else (y, residual_out)
 
854
  def backward(ctx, dy, *args):
855
  x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
856
  dy = dy.reshape(-1, dy.shape[-1])
857
+ if dy.stride(-1) != 1:
858
+ dy = dy.contiguous()
859
+ assert dy.shape == x.shape
860
  if weight1 is not None:
861
  dy1, args = args[0], args[1:]
862
  dy1 = dy1.reshape(-1, dy1.shape[-1])
863
+ if dy1.stride(-1) != 1:
864
+ dy1 = dy1.contiguous()
865
  assert dy1.shape == x.shape
866
  else:
867
  dy1 = None
868
  if ctx.prenorm:
869
  dresidual = args[0]
870
  dresidual = dresidual.reshape(-1, dresidual.shape[-1])
871
+ if dresidual.stride(-1) != 1:
872
+ dresidual = dresidual.contiguous()
873
  assert dresidual.shape == x.shape
874
  else:
875
  dresidual = None
876
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
877
  dy,
878
  x,
879
  weight,
 
890
  rowscale,
891
  ctx.has_residual,
892
  ctx.has_x1,
 
893
  ctx.is_rms_norm,
894
  x_dtype=ctx.x_dtype,
 
895
  )
896
  return (
897
  dx.reshape(ctx.x_shape_og),
 
910
  None,
911
  None,
912
  None,
 
 
913
  )
914
 
915
 
 
926
  rowscale=None,
927
  prenorm=False,
928
  residual_in_fp32=False,
 
929
  is_rms_norm=False,
930
  return_dropout_mask=False,
 
931
  out=None,
932
+ residual_out=None,
933
  ):
934
  return LayerNormFn.apply(
935
  x,
 
944
  rowscale,
945
  prenorm,
946
  residual_in_fp32,
 
947
  is_rms_norm,
948
  return_dropout_mask,
 
949
  out,
950
+ residual_out,
951
  )
952
 
953
 
 
964
  rowscale=None,
965
  prenorm=False,
966
  residual_in_fp32=False,
 
967
  return_dropout_mask=False,
 
968
  out=None,
969
+ residual_out=None,
970
  ):
971
  return LayerNormFn.apply(
972
  x,
 
981
  rowscale,
982
  prenorm,
983
  residual_in_fp32,
 
984
  True,
985
  return_dropout_mask,
 
986
  out,
987
+ residual_out,
988
  )
989
 
990
 
991
  class RMSNorm(torch.nn.Module):
992
 
993
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
 
994
  factory_kwargs = {"device": device, "dtype": dtype}
995
  super().__init__()
996
  self.eps = eps
 
998
  self.drop = torch.nn.Dropout(dropout_p)
999
  else:
1000
  self.drop = None
 
1001
  self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1002
  self.register_parameter("bias", None)
1003
  self.reset_parameters()
1004
 
1005
  def reset_parameters(self):
1006
+ torch.nn.init.ones_(self.weight)
 
 
 
1007
 
1008
  def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1009
  return rms_norm_fn(
 
1015
  dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1016
  prenorm=prenorm,
1017
  residual_in_fp32=residual_in_fp32,
 
1018
  )
1019
 
1020
 
1021
  class LayerNormLinearFn(torch.autograd.Function):
 
1022
  @staticmethod
1023
+ @custom_fwd(device_type="cuda")
1024
  def forward(
1025
  ctx,
1026
  x,
 
1036
  ):
1037
  x_shape_og = x.shape
1038
  # reshape input data into 2D tensor
1039
+ x = x.reshape(-1, x.shape[-1])
1040
+ if x.stride(-1) != 1:
1041
+ x = x.contiguous()
1042
  if residual is not None:
1043
  assert residual.shape == x_shape_og
1044
+ residual = residual.reshape(-1, residual.shape[-1])
1045
+ if residual.stride(-1) != 1:
1046
+ residual = residual.contiguous()
1047
  norm_weight = norm_weight.contiguous()
1048
+ if norm_bias is not None:
1049
+ norm_bias = norm_bias.contiguous()
1050
  residual_dtype = (
1051
  residual.dtype
1052
  if residual is not None
 
1058
  norm_bias,
1059
  eps,
1060
  residual,
1061
+ out_dtype=(
1062
+ None
1063
+ if not torch.is_autocast_enabled()
1064
+ else torch.get_autocast_gpu_dtype()
1065
+ ),
1066
  residual_dtype=residual_dtype,
1067
  is_rms_norm=is_rms_norm,
1068
  )
1069
  y = y.reshape(x_shape_og)
1070
+ dtype = (
1071
+ torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1072
+ )
1073
  linear_weight = linear_weight.to(dtype)
1074
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1075
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1076
  # We don't store y, will be recomputed in the backward pass to save memory
1077
+ ctx.save_for_backward(
1078
+ residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1079
+ )
1080
  ctx.x_shape_og = x_shape_og
1081
  ctx.eps = eps
1082
  ctx.is_rms_norm = is_rms_norm
 
1087
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1088
 
1089
  @staticmethod
1090
+ @custom_bwd(device_type="cuda")
1091
  def backward(ctx, dout, *args):
1092
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1093
  dout = dout.reshape(-1, dout.shape[-1])
1094
  dy = F.linear(dout, linear_weight.t())
1095
  dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1096
+ if dy.stride(-1) != 1:
1097
+ dy = dy.contiguous()
1098
  assert dy.shape == x.shape
1099
  if ctx.prenorm:
1100
  dresidual = args[0]
1101
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1102
+ if dresidual.stride(-1) != 1:
1103
+ dresidual = dresidual.contiguous()
1104
  assert dresidual.shape == x.shape
1105
  else:
1106
  dresidual = None
torch-ext/triton_layer_norm/layers.py CHANGED
@@ -1,46 +1,4 @@
1
- import torch
2
- from torch import nn
3
 
4
- from .layer_norm import rms_norm_fn
5
 
6
-
7
- class LlamaRMSNorm(nn.Module):
8
- """
9
- RMS Layer Norm for Llama models.
10
-
11
- Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
12
- `transformers`.
13
-
14
- Attributes:
15
- weight (`torch.Tensor`): The learnable scaling parameter.
16
- variance_epsilon (`float`): The epsilon value for numerical stability.
17
- """
18
- weight: torch.Tensor
19
- variance_epsilon: float
20
-
21
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
- """
23
- Apply RMS normalization to the input hidden states.
24
-
25
- Args:
26
- hidden_states (`torch.Tensor`):
27
- Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
28
- where the last dimension is the feature dimension to be normalized.
29
-
30
- Returns:
31
- `torch.Tensor`:
32
- The normalized tensor with the same shape as the input `hidden_states`.
33
- """
34
- return rms_norm_fn(
35
- hidden_states,
36
- self.weight,
37
- bias=None,
38
- residual=None,
39
- eps=self.variance_epsilon,
40
- dropout_p=0.0,
41
- prenorm=False,
42
- residual_in_fp32=False,
43
- )
44
-
45
-
46
- __all__ = ["LlamaRMSNorm"]
 
1
+ from .layer_norm import RMSNorm
 
2
 
 
3
 
4
+ __all__ = ["RMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/triton_layer_norm/utils/__init__.py DELETED
File without changes
torch-ext/triton_layer_norm/utils/library.py DELETED
@@ -1,66 +0,0 @@
1
- # Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py
2
- # The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema.
3
-
4
- from typing import Optional, Callable, Iterable, Union
5
-
6
- from torch.library import custom_op, CustomOpDef
7
- from torch._library.triton import set_wrap_triton_enabled
8
-
9
-
10
- def triton_op(
11
- name: str,
12
- fn: Optional[Callable] = None,
13
- /,
14
- *,
15
- mutates_args: Union[str, Iterable[str]],
16
- schema: Optional[str] = None,
17
- # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False,
18
- # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator
19
- # and so inductor can't trace inside.
20
- allow_decomposition=True,
21
- ) -> Callable:
22
- def dec(fn: Callable[..., object]) -> CustomOpDef:
23
- def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
24
- # Optimization: we're passing regular Tensors into the triton kernel, so
25
- # no need to go through HOP dispatch
26
- with set_wrap_triton_enabled(False):
27
- return fn(*args, **kwargs)
28
-
29
- result = custom_op(
30
- name,
31
- backend_fn,
32
- mutates_args=mutates_args,
33
- # This is the only difference with the PyTorch implementation
34
- schema=schema,
35
- )
36
- from torch._subclasses.functional_tensor import FunctionalTensorMode
37
-
38
- # We require that the user pass us a function that is make_fx traceable,
39
- # so we can just register it as the Fake/meta kernel.
40
- result.register_fake(fn)
41
-
42
- if allow_decomposition:
43
- # We decompose the operator when FunctionalTensorMode is active.
44
- # The goal is to decompose the operator in AOTDispatcher.
45
- # - With torch.compile, this means that the backend (usually Inductor)
46
- # can see a call to the triton kernel(s) and so it can directly optimize
47
- # them by inlining them into the lowering process.
48
- def functional_decomp( # type: ignore[no-untyped-def]
49
- mode, op, types, args, kwargs
50
- ):
51
- from torch.export._trace import custom_triton_ops_decomposition_disabled
52
-
53
- if custom_triton_ops_decomposition_disabled():
54
- return mode.__torch_dispatch__(op, types, args, kwargs)
55
- else:
56
- with mode:
57
- return fn(*args, **kwargs)
58
-
59
- result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
60
-
61
- return result
62
-
63
- if fn is None:
64
- return dec
65
- else:
66
- return dec(fn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/triton_layer_norm/utils/torch.py DELETED
@@ -1,21 +0,0 @@
1
- import torch
2
- from typing import Callable
3
-
4
-
5
- def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
- def decorator(*args, **kwargs):
7
- if cuda_amp_deprecated:
8
- kwargs["device_type"] = "cuda"
9
- return dec(*args, **kwargs)
10
- return decorator
11
-
12
-
13
- if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
- deprecated = True
15
- from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
- else:
17
- deprecated = False
18
- from torch.cuda.amp import custom_fwd, custom_bwd
19
-
20
- custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
- custom_bwd = custom_amp_decorator(custom_bwd, deprecated)