File size: 524 Bytes
4a07eb0 02bea52 4a07eb0 02bea52 4a07eb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
from torch import nn
from .layer_norm import rms_norm_fn
class LlamaRMSNorm(nn.Module):
weight: torch.Tensor
variance_epsilon: float
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return rms_norm_fn(
hidden_states,
self.weight,
bias=None,
residual=None,
eps=self.variance_epsilon,
dropout_p=0.0,
prenorm=False,
residual_in_fp32=False,
)
__all__ = ["LlamaRMSNorm"]
|