|
|
- #pragma once
-
- #include <cstdint>
- #include <cmath>
-
- class bfloat16
- {
- uint16_t value;
-
- union FloatBits
- {
- float f;
- uint16_t s[2];
- uint32_t i;
- };
-
- bfloat16(uint16_t x, bool) : value(x)
- {}
-
- public:
- bfloat16() = default;
-
- bfloat16(float x)
- {
- *this = x;
- }
-
- bfloat16& operator = (float x)
- {
- FloatBits b {x};
- switch(std::fpclassify(x))
- {
- case FP_SUBNORMAL:
- case FP_ZERO:
- // sign preserving zero (denormal go to zero)
- value = b.s[1];
- value &= 0x8000;
- break;
- case FP_INFINITE:
- value = b.s[1];
- break;
- case FP_NAN:
- // truncate and set MSB of the mantissa force QNAN
- value = b.s[1];
- value |= 1 << 6;
- break;
- case FP_NORMAL:
- // round to nearest even and truncate
- unsigned int rounding_bias = 0x00007FFF + (b.s[1] & 0x1);
- b.i += rounding_bias;
- value = b.s[1];
- break;
- }
-
- return *this;
- }
-
- operator float() const
- {
- FloatBits b {0};
- b.s[1] = value;
- return b.f;
- }
-
- inline bfloat16 operator - () const
- {
- return bfloat16(value ^ 0x8000, true);
- }
- };
-
- inline bfloat16 operator + (bfloat16 a, bfloat16 b)
- {
- return float(a) + float(b);
- }
-
- inline bfloat16 operator - (bfloat16 a, bfloat16 b)
- {
- return float(a) + float(-b);
- }
-
- inline bfloat16 operator * (bfloat16 a, bfloat16 b)
- {
- return float(a) * float(b);
- }
-
- inline bfloat16 operator / (bfloat16 a, bfloat16 b)
- {
- return float(a) / float(b);
- }
-
- inline bfloat16& operator += (bfloat16& a, const bfloat16 b)
- {
- a = a + b;
- return a;
- }
-
- inline bfloat16& operator -= (bfloat16& a, const bfloat16 b)
- {
- a = a - b;
- return a;
- }
-
- inline bfloat16& operator *= (bfloat16& a, const bfloat16 b)
- {
- a = a * b;
- return a;
- }
-
- inline bfloat16& operator /= (bfloat16& a, const bfloat16 b)
- {
- a = a / b;
- return a;
- }
-
- // Comparison operators
- inline bfloat16 operator == (const bfloat16 a, const bfloat16 b)
- {
- return float(a) == float(b);
- }
-
- inline bfloat16 operator != (const bfloat16 a, const bfloat16 b)
- {
- return !(a == b);
- }
-
- inline bfloat16 operator < (const bfloat16 a, const bfloat16 b)
- {
- return float(a) < float(b);
- }
-
- inline bfloat16 operator <= (const bfloat16 a, const bfloat16 b)
- {
- return float(a) <= float(b);
- }
-
- inline bfloat16 operator > (const bfloat16 a, const bfloat16 b)
- {
- return float(a) > float(b);
- }
-
- inline bfloat16 operator >= (const bfloat16 a, const bfloat16 b)
- {
- return float(a) >= float(b);
- }
|