triton-scaled-mm / tests /test_triton_scaled_mm.py
danieldk's picture
danieldk HF staff
Tests: actually gate FP8 use based on capability
9baccbe
# SPDX-License-Identifier: Apache-2.0
"""Tests for the triton_scaled_mm kernel
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
from typing import Optional
import pytest
import torch
from triton_scaled_mm import triton_scaled_mm
device = "cuda"
def scaled_mm_torch(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
out = scale_b.T * out
out = out.to(out_dtype)
if bias is not None:
out = out + bias
return out
def get_8bit_types():
types = [torch.int8]
minor, major = torch.cuda.get_device_capability()
capability = major * 10 + minor
supports_fp8 = capability >= 89
if supports_fp8 and torch.version.hip is not None:
types.append(torch.float8_e4m3fnuz)
elif supports_fp8 and torch.version.cuda is not None and torch.cuda.is_available():
types.append(torch.float8_e4m3fn)
return types
@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("in_dtype", get_8bit_types())
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_scaled_mm(
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
torch.manual_seed(0)
# NOTE: There are cases, where if the matrix is large enough, an output
# like 65504.4 can be produced, and can easily turn into inf when
# multiplied when using float16/bfloat16. This means one function, e.g.,
# testing function, and another function, e.g. golden function, can
# produce a non-inf value while the other produces an inf value, and
# will cause assert_close/allclose to fail, even though if overflow
# wouldn't have occurred, the values would have been "close."
#
# So, the values here are kept small enough to avoid this situation.
if is_floating_point_type(in_dtype):
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
else:
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
if use_scalar_scale_a:
scale_a = torch.rand((1, 1), device=device)
else:
scale_a = 0.25 * torch.rand((M, 1), device=device)
if use_scalar_scale_b:
scale_b = torch.rand((1, 1), device=device)
else:
scale_b = 0.25 * torch.rand((N, 1), device=device)
bias = None
if use_bias:
bias = torch.rand((N,), device=device, dtype=out_dtype)
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
a_cpu = a.cpu()
b_cpu = b.cpu()
scale_a_cpu = scale_a.cpu()
scale_b_cpu = scale_b.cpu()
bias_cpu = None if bias is None else bias.cpu()
c_actual = scaled_mm_torch(
a_cpu, b_cpu, scale_a_cpu, scale_b_cpu, out_dtype, bias_cpu
)
c_check_cpu = c_check.cpu()
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1)