| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						import triton | 
					
					
						
						| 
							 | 
						import triton.language as tl | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def srms_norm_fw(X, Y, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    row = tl.program_id(0) | 
					
					
						
						| 
							 | 
						    cols = tl.arange(0, BLOCK_SIZE_N) | 
					
					
						
						| 
							 | 
						    mask = cols < N | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x_ptrs = X + row * stride + cols | 
					
					
						
						| 
							 | 
						    x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    x_zm = tl.where(mask, x, 0.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    x_var = tl.sum(x_zm * x_zm, axis=0) / N | 
					
					
						
						| 
							 | 
						    rstd = 1.0 / tl.sqrt(x_var + eps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    y = x_zm * rstd | 
					
					
						
						| 
							 | 
						    tl.store(V + row, rstd) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    y_ptrs = Y + row * stride + cols | 
					
					
						
						| 
							 | 
						    tl.store(y_ptrs, y, mask=mask) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def srms_norm_bwd_dx_fused( | 
					
					
						
						| 
							 | 
						    DX, DY, | 
					
					
						
						| 
							 | 
						    X, V, | 
					
					
						
						| 
							 | 
						    stride, N, | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    BLOCK_SIZE_N: tl.constexpr, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    row = tl.program_id(0) | 
					
					
						
						| 
							 | 
						    cols = tl.arange(0, BLOCK_SIZE_N) | 
					
					
						
						| 
							 | 
						    mask = cols < N | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x_ptrs = X + row * stride + cols | 
					
					
						
						| 
							 | 
						    dy_ptrs = DY + row * stride + cols | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    x = tl.load(x_ptrs, mask=mask, other=0) | 
					
					
						
						| 
							 | 
						    dy = tl.load(dy_ptrs, mask=mask, other=0) | 
					
					
						
						| 
							 | 
						    rstd = tl.load(V + row) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    xhat = x * rstd | 
					
					
						
						| 
							 | 
						    wdy = dy | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    xhat = tl.where(mask, xhat, 0.) | 
					
					
						
						| 
							 | 
						    wdy = tl.where(mask, wdy, 0.) | 
					
					
						
						| 
							 | 
						    mean1 = tl.sum(xhat * wdy, axis=0) / N | 
					
					
						
						| 
							 | 
						    dx = (wdy - (xhat * mean1)) * rstd | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    mask = cols < N   | 
					
					
						
						| 
							 | 
						    dx_ptrs = DX + row * stride + cols | 
					
					
						
						| 
							 | 
						    tl.store(dx_ptrs, dx, mask=mask) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class _SrmsNorm(torch.autograd.Function): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def forward(ctx, x, eps): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if x.dtype == torch.float16: | 
					
					
						
						| 
							 | 
						            eps = max(eps, 1.6e-5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        y = torch.empty_like(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        x_arg = x.reshape(-1, x.shape[-1]) | 
					
					
						
						| 
							 | 
						        M, N = x_arg.shape | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        MAX_FUSED_SIZE = 65536 // x.element_size() | 
					
					
						
						| 
							 | 
						        BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | 
					
					
						
						| 
							 | 
						        if N > BLOCK_SIZE_N: | 
					
					
						
						| 
							 | 
						            raise RuntimeError( | 
					
					
						
						| 
							 | 
						                "This layer norm doesn't support feature dim >= 64KB.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if not x_arg.is_contiguous() or not y.is_contiguous(): | 
					
					
						
						| 
							 | 
						            x_arg = x_arg.contiguous() | 
					
					
						
						| 
							 | 
						            y = y.contiguous() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        srms_norm_fw[(M,)]( | 
					
					
						
						| 
							 | 
						            x_arg, y, rstd, | 
					
					
						
						| 
							 | 
						            x_arg.stride(0), | 
					
					
						
						| 
							 | 
						            N, | 
					
					
						
						| 
							 | 
						            eps, | 
					
					
						
						| 
							 | 
						            num_warps=num_warps, | 
					
					
						
						| 
							 | 
						            BLOCK_SIZE_N=BLOCK_SIZE_N, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        ctx.save_for_backward(x, rstd) | 
					
					
						
						| 
							 | 
						        ctx.BLOCK_SIZE_N = BLOCK_SIZE_N | 
					
					
						
						| 
							 | 
						        ctx.num_warps = num_warps | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return y.reshape_as(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def backward( | 
					
					
						
						| 
							 | 
						        ctx, dy | 
					
					
						
						| 
							 | 
						    ):   | 
					
					
						
						| 
							 | 
						        x, rstd = ctx.saved_tensors | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        x = x.reshape(-1, x.size(-1)) | 
					
					
						
						| 
							 | 
						        M, N = x.size() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        GROUP_SIZE_M = 32 | 
					
					
						
						| 
							 | 
						        if N <= 8192: | 
					
					
						
						| 
							 | 
						            GROUP_SIZE_M = 64 | 
					
					
						
						| 
							 | 
						        if N <= 4096: | 
					
					
						
						| 
							 | 
						            GROUP_SIZE_M = 96 | 
					
					
						
						| 
							 | 
						        if N <= 2048: | 
					
					
						
						| 
							 | 
						            GROUP_SIZE_M = 128 | 
					
					
						
						| 
							 | 
						        if N <= 1024: | 
					
					
						
						| 
							 | 
						            GROUP_SIZE_M = 256 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if dy.dtype == torch.float32: | 
					
					
						
						| 
							 | 
						            GROUP_SIZE_M = GROUP_SIZE_M // 2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        dy = dy.contiguous() | 
					
					
						
						| 
							 | 
						        dx = torch.empty_like(dy) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        assert ( | 
					
					
						
						| 
							 | 
						            dy.numel() == x.numel() | 
					
					
						
						| 
							 | 
						        ), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        srms_norm_bwd_dx_fused[(M,)]( | 
					
					
						
						| 
							 | 
						            dx, dy, x, | 
					
					
						
						| 
							 | 
						            rstd, | 
					
					
						
						| 
							 | 
						            x.stride(0), | 
					
					
						
						| 
							 | 
						            N, | 
					
					
						
						| 
							 | 
						            BLOCK_SIZE_N=ctx.BLOCK_SIZE_N, | 
					
					
						
						| 
							 | 
						            num_warps=num_warps | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        dx = dx.reshape_as(dy) | 
					
					
						
						| 
							 | 
						        return dx, None, None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class SimpleRMSNorm(torch.nn.Module): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, dim: int, eps: float = 1e-6): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.eps = eps | 
					
					
						
						| 
							 | 
						        self.dim = dim | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        return _SrmsNorm.apply(x, self.eps) | 
					
					
						
						| 
							 | 
						
 |