Dakerqi commited on
Commit
a3e0ef3
·
verified ·
1 Parent(s): ded609e

Delete components.py

Browse files
Files changed (1) hide show
  1. components.py +0 -54
components.py DELETED
@@ -1,54 +0,0 @@
1
- import warnings
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
- try:
7
- from apex.normalization import FusedRMSNorm as RMSNorm
8
- except ImportError:
9
- warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10
-
11
- class RMSNorm(torch.nn.Module):
12
- def __init__(self, dim: int, eps: float = 1e-6):
13
- """
14
- Initialize the RMSNorm normalization layer.
15
-
16
- Args:
17
- dim (int): The dimension of the input tensor.
18
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19
-
20
- Attributes:
21
- eps (float): A small value added to the denominator for numerical stability.
22
- weight (nn.Parameter): Learnable scaling parameter.
23
-
24
- """
25
- super().__init__()
26
- self.eps = eps
27
- self.weight = nn.Parameter(torch.ones(dim))
28
-
29
- def _norm(self, x):
30
- """
31
- Apply the RMSNorm normalization to the input tensor.
32
-
33
- Args:
34
- x (torch.Tensor): The input tensor.
35
-
36
- Returns:
37
- torch.Tensor: The normalized tensor.
38
-
39
- """
40
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
-
42
- def forward(self, x):
43
- """
44
- Forward pass through the RMSNorm layer.
45
-
46
- Args:
47
- x (torch.Tensor): The input tensor.
48
-
49
- Returns:
50
- torch.Tensor: The output tensor after applying RMSNorm.
51
-
52
- """
53
- output = self._norm(x.float()).type_as(x)
54
- return output * self.weight