danieldk HF staff commited on
Commit
02bea52
·
1 Parent(s): 5fdbebc

Move `RMSNorm` to `layers`

Browse files
torch-ext/triton_layer_norm/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
- from .layer_norm import RMSNorm, layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
- __all__ = ["RMSNorm", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
1
+ from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
+ from . import layers
4
+
5
+ __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
torch-ext/triton_layer_norm/layers.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .layer_norm import RMSNorm
2
+
3
+
4
+ __all__ = ["RMSNorm"]