|
#pragma once |
|
|
|
#ifdef __HIPCC__ |
|
#include <hip/hip_runtime.h> |
|
#else |
|
#include <type_traits> |
|
#include <stdint.h> |
|
#include <math.h> |
|
#include <iostream> |
|
#endif |
|
|
|
#include "hip_float8_impl.h" |
|
|
|
struct alignas(1) hip_fp8 { |
|
struct from_bits_t {}; |
|
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { |
|
return from_bits_t(); |
|
} |
|
uint8_t data; |
|
|
|
hip_fp8() = default; |
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; |
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; |
|
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) |
|
: data(v) {} |
|
|
|
#ifdef __HIP__MI300__ |
|
|
|
explicit HIP_FP8_DEVICE hip_fp8(float v) |
|
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {} |
|
|
|
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) |
|
: hip_fp8(static_cast<float>(v)) {} |
|
|
|
|
|
explicit HIP_FP8_HOST |
|
#else |
|
|
|
explicit HIP_FP8_HOST_DEVICE |
|
#endif |
|
hip_fp8(float v) { |
|
data = hip_fp8_impl::to_float8<4, 3, float, true , |
|
true >(v); |
|
} |
|
|
|
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) |
|
: hip_fp8(static_cast<float>(v)) {} |
|
|
|
#ifdef __HIP__MI300__ |
|
|
|
explicit inline HIP_FP8_DEVICE operator float() const { |
|
float fval; |
|
uint32_t i32val = static_cast<uint32_t>(data); |
|
|
|
|
|
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" |
|
: "=v"(fval) |
|
: "v"(i32val)); |
|
|
|
return fval; |
|
} |
|
|
|
explicit inline HIP_FP8_HOST operator float() const |
|
#else |
|
explicit inline HIP_FP8_HOST_DEVICE operator float() const |
|
#endif |
|
{ |
|
return hip_fp8_impl::from_float8<4, 3, float, true >( |
|
data); |
|
} |
|
}; |
|
|
|
namespace std { |
|
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } |
|
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } |
|
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } |
|
} |
|
|
|
|
|
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { |
|
return os << float(f8); |
|
} |
|
|
|
|
|
|
|
|
|
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { |
|
return (fa + float(b)); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { |
|
return (float(a) + fb); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { |
|
return hip_fp8(float(a) + float(b)); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { |
|
return a = hip_fp8(float(a) + float(b)); |
|
} |
|
|
|
|
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { |
|
return float(a) * float(b); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { |
|
return (a * float(b)); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { |
|
return (float(a) * b); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { |
|
return ((float)a * float(b)); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { |
|
return ((float)a * float(b)); |
|
} |
|
|
|
|
|
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { |
|
return (a.data == b.data); |
|
} |
|
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { |
|
return (a.data != b.data); |
|
} |
|
|
|
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { |
|
return static_cast<float>(a) >= static_cast<float>(b); |
|
} |
|
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { |
|
return static_cast<float>(a) > static_cast<float>(b); |
|
} |
|
|