davidhd commited on
Commit
0d451f3
·
verified ·
1 Parent(s): 016629d

Update rmsnorm.py

Browse files

Modifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19

Files changed (1) hide show
  1. rmsnorm.py +11 -7
rmsnorm.py CHANGED
@@ -6,29 +6,33 @@ class RMSNorm(nn.Module):
6
  def __init__(self, dim: int, eps: float = 1e-6):
7
  """
8
  Initialize the RMSNorm normalization layer.
9
-
10
  Args:
11
  dim (int): The dimension of the input tensor.
12
  eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
-
14
  Attributes:
15
  eps (float): A small value added to the denominator for numerical stability.
16
  weight (nn.Parameter): Learnable scaling parameter.
17
-
18
  """
19
  super().__init__()
20
  self.eps = eps
21
  self.weight = nn.Parameter(torch.ones(dim))
22
 
 
 
 
23
  def forward(self, x):
24
  """
25
  Forward pass through the RMSNorm layer.
26
-
27
  Args:
28
  x (torch.Tensor): The input tensor.
29
-
30
  Returns:
31
  torch.Tensor: The output tensor after applying RMSNorm.
32
-
33
  """
34
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
 
 
6
  def __init__(self, dim: int, eps: float = 1e-6):
7
  """
8
  Initialize the RMSNorm normalization layer.
9
+
10
  Args:
11
  dim (int): The dimension of the input tensor.
12
  eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
13
+
14
  Attributes:
15
  eps (float): A small value added to the denominator for numerical stability.
16
  weight (nn.Parameter): Learnable scaling parameter.
17
+
18
  """
19
  super().__init__()
20
  self.eps = eps
21
  self.weight = nn.Parameter(torch.ones(dim))
22
 
23
+ def _norm(self, x):
24
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
25
+
26
  def forward(self, x):
27
  """
28
  Forward pass through the RMSNorm layer.
29
+
30
  Args:
31
  x (torch.Tensor): The input tensor.
32
+
33
  Returns:
34
  torch.Tensor: The output tensor after applying RMSNorm.
35
+
36
  """
37
+ output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
38
+ return output * self.weight