Update rmsnorm.py
Browse filesModifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19
- rmsnorm.py +11 -7
rmsnorm.py
CHANGED
@@ -6,29 +6,33 @@ class RMSNorm(nn.Module):
|
|
6 |
def __init__(self, dim: int, eps: float = 1e-6):
|
7 |
"""
|
8 |
Initialize the RMSNorm normalization layer.
|
9 |
-
|
10 |
Args:
|
11 |
dim (int): The dimension of the input tensor.
|
12 |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
13 |
-
|
14 |
Attributes:
|
15 |
eps (float): A small value added to the denominator for numerical stability.
|
16 |
weight (nn.Parameter): Learnable scaling parameter.
|
17 |
-
|
18 |
"""
|
19 |
super().__init__()
|
20 |
self.eps = eps
|
21 |
self.weight = nn.Parameter(torch.ones(dim))
|
22 |
|
|
|
|
|
|
|
23 |
def forward(self, x):
|
24 |
"""
|
25 |
Forward pass through the RMSNorm layer.
|
26 |
-
|
27 |
Args:
|
28 |
x (torch.Tensor): The input tensor.
|
29 |
-
|
30 |
Returns:
|
31 |
torch.Tensor: The output tensor after applying RMSNorm.
|
32 |
-
|
33 |
"""
|
34 |
-
|
|
|
|
6 |
def __init__(self, dim: int, eps: float = 1e-6):
|
7 |
"""
|
8 |
Initialize the RMSNorm normalization layer.
|
9 |
+
|
10 |
Args:
|
11 |
dim (int): The dimension of the input tensor.
|
12 |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
13 |
+
|
14 |
Attributes:
|
15 |
eps (float): A small value added to the denominator for numerical stability.
|
16 |
weight (nn.Parameter): Learnable scaling parameter.
|
17 |
+
|
18 |
"""
|
19 |
super().__init__()
|
20 |
self.eps = eps
|
21 |
self.weight = nn.Parameter(torch.ones(dim))
|
22 |
|
23 |
+
def _norm(self, x):
|
24 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
25 |
+
|
26 |
def forward(self, x):
|
27 |
"""
|
28 |
Forward pass through the RMSNorm layer.
|
29 |
+
|
30 |
Args:
|
31 |
x (torch.Tensor): The input tensor.
|
32 |
+
|
33 |
Returns:
|
34 |
torch.Tensor: The output tensor after applying RMSNorm.
|
35 |
+
|
36 |
"""
|
37 |
+
output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
|
38 |
+
return output * self.weight
|