Spaces:
Running
Running
File size: 4,623 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (C) 2023, Tri Dao.
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange, repeat
from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('itype', [torch.float16])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
# @pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_selective_state_update(dim, dstate, has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 2
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
if has_z:
z = torch.randn_like(x)
else:
z = None
state_ref = state.detach().clone()
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('itype', [torch.float16])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize("tie_hdim", [False, True])
# @pytest.mark.parametrize('tie_hdim', [True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
# @pytest.mark.parametrize("ngroups", [2])
@pytest.mark.parametrize("dstate", [16, 32, 64])
# @pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 2
headdim = 64
nheads = dim // headdim
state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
if not tie_hdim:
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
if has_z:
z = torch.randn_like(x)
else:
z = None
state_ref = state.detach().clone()
state_og = state.detach().clone()
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|