|
|
|
"""Tests for the triton_scaled_mm kernel |
|
|
|
Run `pytest tests/kernels/test_triton_scaled_mm.py`. |
|
""" |
|
from typing import Optional |
|
|
|
import pytest |
|
import torch |
|
|
|
from triton_scaled_mm import triton_scaled_mm |
|
|
|
device = "cuda" |
|
|
|
|
|
def scaled_mm_torch( |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
scale_a: torch.Tensor, |
|
scale_b: torch.Tensor, |
|
out_dtype: type[torch.dtype], |
|
bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
out = torch.mm(a.to(torch.float32), b.to(torch.float32)) |
|
out = scale_a * out |
|
out = scale_b.T * out |
|
out = out.to(out_dtype) |
|
if bias is not None: |
|
out = out + bias |
|
|
|
return out |
|
|
|
|
|
def get_8bit_types(): |
|
types = [torch.int8] |
|
minor, major = torch.cuda.get_device_capability() |
|
capability = major * 10 + minor |
|
supports_fp8 = capability >= 89 |
|
|
|
if supports_fp8 and torch.version.hip is not None: |
|
types.append(torch.float8_e4m3fnuz) |
|
elif supports_fp8 and torch.version.cuda is not None and torch.cuda.is_available(): |
|
types.append(torch.float8_e4m3fn) |
|
return types |
|
|
|
|
|
@pytest.mark.parametrize("M", [1, 33, 64, 512]) |
|
@pytest.mark.parametrize("N", [256, 971, 20486]) |
|
@pytest.mark.parametrize("K", [128, 496, 1024]) |
|
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) |
|
@pytest.mark.parametrize("in_dtype", get_8bit_types()) |
|
@pytest.mark.parametrize("use_scalar_scale_a", [True, False]) |
|
@pytest.mark.parametrize("use_scalar_scale_b", [True, False]) |
|
@pytest.mark.parametrize("use_bias", [True, False]) |
|
def test_scaled_mm( |
|
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias |
|
): |
|
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point() |
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_floating_point_type(in_dtype): |
|
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype) |
|
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype) |
|
else: |
|
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device) |
|
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device) |
|
|
|
if use_scalar_scale_a: |
|
scale_a = torch.rand((1, 1), device=device) |
|
else: |
|
scale_a = 0.25 * torch.rand((M, 1), device=device) |
|
|
|
if use_scalar_scale_b: |
|
scale_b = torch.rand((1, 1), device=device) |
|
else: |
|
scale_b = 0.25 * torch.rand((N, 1), device=device) |
|
|
|
bias = None |
|
if use_bias: |
|
bias = torch.rand((N,), device=device, dtype=out_dtype) |
|
|
|
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) |
|
|
|
a_cpu = a.cpu() |
|
b_cpu = b.cpu() |
|
scale_a_cpu = scale_a.cpu() |
|
scale_b_cpu = scale_b.cpu() |
|
bias_cpu = None if bias is None else bias.cpu() |
|
|
|
c_actual = scaled_mm_torch( |
|
a_cpu, b_cpu, scale_a_cpu, scale_b_cpu, out_dtype, bias_cpu |
|
) |
|
|
|
c_check_cpu = c_check.cpu() |
|
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1) |
|
|