|
from typing import Optional |
|
|
|
import torch |
|
|
|
try: |
|
from ._ops import ops |
|
except ImportError as e: |
|
|
|
try: |
|
import _quantization |
|
|
|
ops = torch.ops._quantization |
|
except ImportError: |
|
raise e |
|
|
|
|
|
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: |
|
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) |
|
|
|
|
|
def cutlass_scaled_mm( |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
scale_a: torch.Tensor, |
|
scale_b: torch.Tensor, |
|
out_dtype: torch.dtype, |
|
bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 |
|
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 |
|
assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype |
|
|
|
m = a.shape[0] |
|
n = b.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device) |
|
|
|
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) |
|
|
|
return out |
|
|
|
|
|
def cutlass_scaled_mm_azp( |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
scale_a: torch.Tensor, |
|
scale_b: torch.Tensor, |
|
out_dtype: torch.dtype, |
|
azp_adj: torch.Tensor, |
|
azp: Optional[torch.Tensor] = None, |
|
bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
:param azp_adj: In the per-tensor case, this should include the azp. |
|
Always per-channel. |
|
:param azp: Only set in the per-token case. Per-token if set. |
|
""" |
|
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 |
|
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 |
|
assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype |
|
assert azp is None or azp.numel() == a.shape[0] |
|
|
|
m = a.shape[0] |
|
n = b.shape[1] |
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device) |
|
|
|
ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) |
|
return out |
|
|