Spaces:
Running
Running
# Copyright (C) 2023, Tri Dao. | |
import math | |
import torch | |
import torch.nn.functional as F | |
import pytest | |
from einops import rearrange, repeat | |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref | |
# @pytest.mark.parametrize('itype', [torch.float16]) | |
# @pytest.mark.parametrize('has_z', [True]) | |
# @pytest.mark.parametrize("dstate", [16]) | |
# @pytest.mark.parametrize("dim", [2048]) | |
def test_selective_state_update(dim, dstate, has_z, itype): | |
device = "cuda" | |
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) | |
if itype == torch.bfloat16: | |
rtol, atol = 1e-2, 5e-2 | |
if torch.version.hip: | |
atol *= 2 | |
# set seed | |
torch.random.manual_seed(0) | |
batch_size = 2 | |
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) | |
x = torch.randn(batch_size, dim, device=device, dtype=itype) | |
dt = torch.randn(batch_size, dim, device=device, dtype=itype) | |
dt_bias = torch.rand(dim, device=device) - 4.0 | |
A = -torch.rand(dim, dstate, device=device) - 1.0 | |
B = torch.randn(batch_size, dstate, device=device) | |
C = torch.randn(batch_size, dstate, device=device) | |
D = torch.randn(dim, device=device) | |
if has_z: | |
z = torch.randn_like(x) | |
else: | |
z = None | |
state_ref = state.detach().clone() | |
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) | |
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) | |
print(f"Output max diff: {(out - out_ref).abs().max().item()}") | |
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") | |
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) | |
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) | |
# @pytest.mark.parametrize('itype', [torch.float16]) | |
# @pytest.mark.parametrize('has_z', [True]) | |
# @pytest.mark.parametrize('tie_hdim', [True]) | |
# @pytest.mark.parametrize("ngroups", [2]) | |
# @pytest.mark.parametrize("dstate", [16]) | |
# @pytest.mark.parametrize("dim", [2048]) | |
def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype): | |
device = "cuda" | |
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) | |
if itype == torch.bfloat16: | |
rtol, atol = 1e-2, 1e-1 | |
# set seed | |
torch.random.manual_seed(0) | |
batch_size = 2 | |
headdim = 64 | |
nheads = dim // headdim | |
state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device) | |
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) | |
if not tie_hdim: | |
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) | |
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 | |
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 | |
D = torch.randn(nheads, headdim, device=device) | |
else: | |
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim) | |
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) | |
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate) | |
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) | |
B = torch.randn(batch_size, ngroups, dstate, device=device) | |
C = torch.randn(batch_size, ngroups, dstate, device=device) | |
if has_z: | |
z = torch.randn_like(x) | |
else: | |
z = None | |
state_ref = state.detach().clone() | |
state_og = state.detach().clone() | |
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) | |
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) | |
print(f"Output max diff: {(out - out_ref).abs().max().item()}") | |
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") | |
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) | |
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) | |