kernel
moe / ext-torch /fp8.py
danieldk's picture
danieldk HF Staff
Add MoE kernels from vLLM
29e93ec
raw
history blame
2.48 kB
import torch
from typing import Tuple, Optional, Union
def is_hip() -> bool:
return torch.version.hip is not None
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert input.ndim == 2
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub
)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
assert scale.numel() == 1 or num_token_padding is None
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale