Move `RMSNorm` to `layers`
Browse files
torch-ext/triton_layer_norm/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
-
from .layer_norm import
|
2 |
|
3 |
-
|
|
|
|
|
|
1 |
+
from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
|
2 |
|
3 |
+
from . import layers
|
4 |
+
|
5 |
+
__all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
|
torch-ext/triton_layer_norm/layers.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .layer_norm import RMSNorm
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ["RMSNorm"]
|