gen6scp's picture
Patched codes for ZeroGPU
d643072
|
raw
history blame
865 Bytes

a fast implementation of linear attention

64x64, fp16

# validate correctness
## fp16 vs fp32
python -m develop_triton_litemla attn_type=LiteMLA test_correctness=True
## triton fp16 vs fp32
python -m develop_triton_litemla attn_type=TritonLiteMLA test_correctness=True

# test performance
## fp16, forward
python -m develop_triton_litemla attn_type=LiteMLA
each step takes 10.81 ms
max memory allocated: 2.2984 GB

## triton fp16, forward
python -m develop_triton_litemla attn_type=TritonLiteMLA
each step takes 4.70 ms
max memory allocated: 1.6480 GB

## fp16, backward
python -m develop_triton_litemla attn_type=LiteMLA backward=True
each step takes 35.34 ms
max memory allocated: 3.4412 GB

## triton fp16, backward
python -m develop_triton_litemla attn_type=TritonLiteMLA backward=True
each step takes 14.25 ms
max memory allocated: 2.4704 GB