danieldk's picture
danieldk HF staff
Make `RMSNorm` layer pure, rename to `LlamaRMSNorm`
4a07eb0
import torch
from torch import nn
from .layer_norm import rms_norm_fn
class LlamaRMSNorm(nn.Module):
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return rms_norm_fn(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)
__all__ = ["LlamaRMSNorm"]