medmekk HF Staff commited on
Commit
bf4f478
·
verified ·
1 Parent(s): f70965d

Upload custom kernels

Browse files
build/torch-universal/liger_kernels/__init__.py CHANGED
@@ -6,7 +6,7 @@ from .group_norm import LigerGroupNormFunction
6
  from .kl_div import LigerKLDivLossFunction
7
  from .layer_norm import LigerLayerNormFunction
8
  from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
- from .rms_norm import LigerRMSNormFunction
10
  from .jsd import LigerJSDFunction
11
  from .rope import LigerRopeFunction
12
  from .swiglu import LigerSiLUMulFunction
@@ -22,6 +22,7 @@ __all__ = [
22
  "LigerLayerNormFunction",
23
  "LigerQwen2VLMRopeFunction",
24
  "LigerRMSNormFunction",
 
25
  "LigerJSDFunction",
26
  "LigerRopeFunction",
27
  "LigerSiLUMulFunction",
 
6
  from .kl_div import LigerKLDivLossFunction
7
  from .layer_norm import LigerLayerNormFunction
8
  from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
+ from .rms_norm import LigerRMSNormFunction, LigerRMSNorm
10
  from .jsd import LigerJSDFunction
11
  from .rope import LigerRopeFunction
12
  from .swiglu import LigerSiLUMulFunction
 
22
  "LigerLayerNormFunction",
23
  "LigerQwen2VLMRopeFunction",
24
  "LigerRMSNormFunction",
25
+ "LigerRMSNorm",
26
  "LigerJSDFunction",
27
  "LigerRopeFunction",
28
  "LigerSiLUMulFunction",
build/torch-universal/liger_kernels/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._liger_kernels_20250505101012
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_liger_kernels_20250505101012::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._liger_kernels_20250507090511
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_liger_kernels_20250507090511::{op_name}"
build/torch-universal/liger_kernels/rms_norm.py CHANGED
@@ -362,4 +362,44 @@ class LigerRMSNormFunction(torch.autograd.Function):
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
- return dX, dW, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
+ return dX, dW, None, None, None, None
366
+
367
+
368
+ class LigerRMSNorm(torch.nn.Module):
369
+ """
370
+ RMSNorm module that uses the optimized LigerRMSNormFunction.
371
+
372
+ Args:
373
+ hidden_size (int): The size of the hidden dimension.
374
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
375
+ offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
376
+ casting_mode (str, optional): The casting mode to use. Defaults to "llama".
377
+ in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
378
+ """
379
+
380
+ def __init__(self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", in_place=True):
381
+ super().__init__()
382
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
383
+ self.variance_epsilon = eps
384
+ self.offset = offset
385
+ self.casting_mode = casting_mode
386
+ self.in_place = in_place
387
+
388
+ def forward(self, hidden_states):
389
+ """
390
+ Apply RMS normalization to the input tensor.
391
+
392
+ Args:
393
+ hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
394
+
395
+ Returns:
396
+ torch.Tensor: Normalized tensor of the same shape as input
397
+ """
398
+ return LigerRMSNormFunction.apply(
399
+ hidden_states,
400
+ self.weight,
401
+ self.variance_epsilon,
402
+ self.offset,
403
+ self.casting_mode,
404
+ self.in_place
405
+ )
torch-ext/liger_kernels/__init__.py CHANGED
@@ -6,7 +6,7 @@ from .group_norm import LigerGroupNormFunction
6
  from .kl_div import LigerKLDivLossFunction
7
  from .layer_norm import LigerLayerNormFunction
8
  from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
- from .rms_norm import LigerRMSNormFunction
10
  from .jsd import LigerJSDFunction
11
  from .rope import LigerRopeFunction
12
  from .swiglu import LigerSiLUMulFunction
@@ -22,6 +22,7 @@ __all__ = [
22
  "LigerLayerNormFunction",
23
  "LigerQwen2VLMRopeFunction",
24
  "LigerRMSNormFunction",
 
25
  "LigerJSDFunction",
26
  "LigerRopeFunction",
27
  "LigerSiLUMulFunction",
 
6
  from .kl_div import LigerKLDivLossFunction
7
  from .layer_norm import LigerLayerNormFunction
8
  from .qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
+ from .rms_norm import LigerRMSNormFunction, LigerRMSNorm
10
  from .jsd import LigerJSDFunction
11
  from .rope import LigerRopeFunction
12
  from .swiglu import LigerSiLUMulFunction
 
22
  "LigerLayerNormFunction",
23
  "LigerQwen2VLMRopeFunction",
24
  "LigerRMSNormFunction",
25
+ "LigerRMSNorm",
26
  "LigerJSDFunction",
27
  "LigerRopeFunction",
28
  "LigerSiLUMulFunction",
torch-ext/liger_kernels/rms_norm.py CHANGED
@@ -362,4 +362,44 @@ class LigerRMSNormFunction(torch.autograd.Function):
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
- return dX, dW, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  ctx.num_warps,
363
  ctx.in_place,
364
  )
365
+ return dX, dW, None, None, None, None
366
+
367
+
368
+ class LigerRMSNorm(torch.nn.Module):
369
+ """
370
+ RMSNorm module that uses the optimized LigerRMSNormFunction.
371
+
372
+ Args:
373
+ hidden_size (int): The size of the hidden dimension.
374
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
375
+ offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0.
376
+ casting_mode (str, optional): The casting mode to use. Defaults to "llama".
377
+ in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True.
378
+ """
379
+
380
+ def __init__(self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", in_place=True):
381
+ super().__init__()
382
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
383
+ self.variance_epsilon = eps
384
+ self.offset = offset
385
+ self.casting_mode = casting_mode
386
+ self.in_place = in_place
387
+
388
+ def forward(self, hidden_states):
389
+ """
390
+ Apply RMS normalization to the input tensor.
391
+
392
+ Args:
393
+ hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H)
394
+
395
+ Returns:
396
+ torch.Tensor: Normalized tensor of the same shape as input
397
+ """
398
+ return LigerRMSNormFunction.apply(
399
+ hidden_states,
400
+ self.weight,
401
+ self.variance_epsilon,
402
+ self.offset,
403
+ self.casting_mode,
404
+ self.in_place
405
+ )