Spaces:
Running
Running
// Defines the bloat16 type (brain floating-point). This representation uses | |
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. | |
namespace c10 { | |
namespace detail { | |
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { | |
float res = 0; | |
uint32_t tmp = src; | |
tmp <<= 16; | |
float* tempRes; | |
// We should be using memcpy in order to respect the strict aliasing rule | |
// but it fails in the HIP environment. | |
tempRes = reinterpret_cast<float*>(&tmp); | |
res = *tempRes; | |
std::memcpy(&res, &tmp, sizeof(tmp)); | |
return res; | |
} | |
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { | |
uint32_t res = 0; | |
// We should be using memcpy in order to respect the strict aliasing rule | |
// but it fails in the HIP environment. | |
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); | |
res = *tempRes; | |
std::memcpy(&res, &src, sizeof(res)); | |
return res >> 16; | |
} | |
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { | |
if (src != src) { | |
if (isnan(src)) { | |
if (std::isnan(src)) { | |
return UINT16_C(0x7FC0); | |
} else { | |
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) | |
union { | |
uint32_t U32; | |
float F32; | |
}; | |
F32 = src; | |
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); | |
return static_cast<uint16_t>((U32 + rounding_bias) >> 16); | |
} | |
} | |
} // namespace detail | |
struct alignas(2) BFloat16 { | |
uint16_t x; | |
// HIP wants __host__ __device__ tag, CUDA does not | |
C10_HOST_DEVICE BFloat16() = default; | |
BFloat16() = default; | |
struct from_bits_t {}; | |
static constexpr C10_HOST_DEVICE from_bits_t from_bits() { | |
return from_bits_t(); | |
} | |
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) | |
: x(bits){}; | |
inline C10_HOST_DEVICE BFloat16(float value); | |
inline C10_HOST_DEVICE operator float() const; | |
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); | |
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; | |
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); | |
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; | |
}; | |
} // namespace c10 | |