|
import torch |
|
|
|
from typing import Tuple, Optional, Union |
|
|
|
|
|
def is_hip() -> bool: |
|
return torch.version.hip is not None |
|
|
|
|
|
def scaled_fp8_quant( |
|
input: torch.Tensor, |
|
scale: Optional[torch.Tensor] = None, |
|
num_token_padding: Optional[int] = None, |
|
scale_ub: Optional[torch.Tensor] = None, |
|
use_per_token_if_dynamic: bool = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Quantize input tensor to FP8 and return quantized tensor and scale. |
|
|
|
This function supports both static and dynamic quantization: If you |
|
provide the scale, it will use static scaling and if you omit it, |
|
the scale will be determined dynamically. The function also allows |
|
optional padding of the output tensors for downstream kernels that |
|
will benefit from padding. |
|
|
|
Args: |
|
input: The input tensor to be quantized to FP8 |
|
scale: Optional scaling factor for the FP8 quantization |
|
scale_ub: Optional upper bound for scaling factor in dynamic |
|
per token case |
|
num_token_padding: If specified, pad the first dimension |
|
of the output to at least this value. |
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token |
|
in the dynamic quantization case. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and |
|
scaling factor. |
|
""" |
|
|
|
assert input.ndim == 2 |
|
shape: Union[Tuple[int, int], torch.Size] = input.shape |
|
|
|
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn |
|
if num_token_padding: |
|
shape = (max(num_token_padding, input.shape[0]), shape[1]) |
|
output = torch.empty(shape, device=input.device, dtype=out_dtype) |
|
|
|
if scale is None: |
|
if use_per_token_if_dynamic: |
|
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) |
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant( |
|
output, input, scale, scale_ub |
|
) |
|
else: |
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32) |
|
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) |
|
else: |
|
|
|
assert scale.numel() == 1 or num_token_padding is None |
|
torch.ops._C.static_scaled_fp8_quant(output, input, scale) |
|
|
|
return output, scale |
|
|