danieldk HF staff commited on
Commit
4a07eb0
·
1 Parent(s): 1e347f0

Make `RMSNorm` layer pure, rename to `LlamaRMSNorm`

Browse files
build/torch-universal/triton_layer_norm/layers.py CHANGED
@@ -1,4 +1,24 @@
1
- from .layer_norm import RMSNorm
 
2
 
 
3
 
4
- __all__ = ["RMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
 
4
+ from .layer_norm import rms_norm_fn
5
 
6
+
7
+ class LlamaRMSNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return rms_norm_fn(
13
+ hidden_states,
14
+ self.weight,
15
+ bias=None,
16
+ residual=None,
17
+ eps=self.variance_epsilon,
18
+ dropout_p=0.0,
19
+ prenorm=False,
20
+ residual_in_fp32=False,
21
+ )
22
+
23
+
24
+ __all__ = ["LlamaRMSNorm"]
torch-ext/triton_layer_norm/layers.py CHANGED
@@ -1,4 +1,24 @@
1
- from .layer_norm import RMSNorm
 
2
 
 
3
 
4
- __all__ = ["RMSNorm"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
 
4
+ from .layer_norm import rms_norm_fn
5
 
6
+
7
+ class LlamaRMSNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return rms_norm_fn(
13
+ hidden_states,
14
+ self.weight,
15
+ bias=None,
16
+ residual=None,
17
+ eps=self.variance_epsilon,
18
+ dropout_p=0.0,
19
+ prenorm=False,
20
+ residual_in_fp32=False,
21
+ )
22
+
23
+
24
+ __all__ = ["LlamaRMSNorm"]