|
#pragma once |
|
|
|
|
|
* Quantization utilities including: |
|
* Adjusted maximum values for qtypes. |
|
* Minimum scaling factors for qtypes. |
|
*/ |
|
|
|
#include <cmath> |
|
#include <torch/types.h> |
|
|
|
#ifndef USE_ROCM |
|
#include <c10/util/Float8_e4m3fn.h> |
|
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE |
|
#else |
|
#include <ATen/hip/HIPContext.h> |
|
#include <c10/util/Float8_e4m3fn.h> |
|
#include <c10/util/Float8_e4m3fnuz.h> |
|
|
|
#define MAYBE_HOST_DEVICE |
|
#endif |
|
|
|
template <typename T, |
|
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || |
|
std::is_same_v<T, c10::Float8_e4m3fnuz> || |
|
std::is_same_v<T, int8_t>>> |
|
struct quant_type_max { |
|
static constexpr T val() { return std::numeric_limits<T>::max(); } |
|
}; |
|
|
|
|
|
|
|
template <> |
|
struct quant_type_max<c10::Float8_e4m3fnuz> { |
|
static constexpr c10::Float8_e4m3fnuz val() { |
|
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); |
|
} |
|
}; |
|
|
|
template <typename T> |
|
MAYBE_HOST_DEVICE static constexpr T quant_type_max_v = |
|
quant_type_max<T>::val(); |
|
|
|
template <typename T, |
|
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> || |
|
std::is_same_v<T, c10::Float8_e4m3fnuz> || |
|
std::is_same_v<T, int8_t>>> |
|
struct min_scaling_factor { |
|
C10_DEVICE C10_ALWAYS_INLINE static float val() { |
|
return 1.0f / (quant_type_max_v<T> * 512.0f); |
|
} |
|
}; |
|
|
|
template <> |
|
struct min_scaling_factor<int8_t> { |
|
C10_DEVICE C10_ALWAYS_INLINE static float val() { |
|
return std::numeric_limits<float>::epsilon(); |
|
} |
|
}; |