a fast implementation of linear attention
64x64, fp16
# validate correctness
## fp16 vs fp32
python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True
## triton fp16 vs fp32
python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True
# test performance
## fp16, forward
python -m develop_triton_litemla attn_type=LiteMLA
each step takes 10.81 ms
max memory allocated: 2.2984 GB
## triton fp16, forward
python -m develop_triton_litemla attn_type=TritonLiteMLA
each step takes 4.70 ms
max memory allocated: 1.6480 GB
## fp16, backward
python -m develop_triton_litemla attn_type=LiteMLA backward=True
each step takes 35.34 ms
max memory allocated: 3.4412 GB
## triton fp16, backward
python -m develop_triton_litemla attn_type=TritonLiteMLA backward=True
each step takes 14.25 ms
max memory allocated: 2.4704 GB