danieldk's picture
danieldk HF Staff
Add Triton-based layer norm from flash-attention
0f75957
raw
history blame
161 Bytes
from .layer_norm import RMSNorm, layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
__all__ = ["RMSNorm", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]