import torch from torch import nn from .layer_norm import rms_norm_fn class LlamaRMSNorm(nn.Module): weight: torch.Tensor variance_epsilon: float def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return rms_norm_fn( hidden_states, self.weight, bias=None, residual=None, eps=self.variance_epsilon, dropout_p=0.0, prenorm=False, residual_in_fp32=False, ) __all__ = ["LlamaRMSNorm"]