#pragma once #include #include class bfloat16 { uint16_t value; union FloatBits { float f; uint16_t s[2]; uint32_t i; }; bfloat16(uint16_t x, bool) : value(x) {} public: bfloat16() = default; bfloat16(float x) { *this = x; } bfloat16& operator = (float x) { FloatBits b {x}; switch(std::fpclassify(x)) { case FP_SUBNORMAL: case FP_ZERO: // sign preserving zero (denormal go to zero) value = b.s[1]; value &= 0x8000; break; case FP_INFINITE: value = b.s[1]; break; case FP_NAN: // truncate and set MSB of the mantissa force QNAN value = b.s[1]; value |= 1 << 6; break; case FP_NORMAL: // round to nearest even and truncate unsigned int rounding_bias = 0x00007FFF + (b.s[1] & 0x1); b.i += rounding_bias; value = b.s[1]; break; } return *this; } operator float() const { FloatBits b {0}; b.s[1] = value; return b.f; } inline bfloat16 operator - () const { return bfloat16(value ^ 0x8000, true); } }; inline bfloat16 operator + (bfloat16 a, bfloat16 b) { return float(a) + float(b); } inline bfloat16 operator - (bfloat16 a, bfloat16 b) { return float(a) + float(-b); } inline bfloat16 operator * (bfloat16 a, bfloat16 b) { return float(a) * float(b); } inline bfloat16 operator / (bfloat16 a, bfloat16 b) { return float(a) / float(b); } inline bfloat16& operator += (bfloat16& a, const bfloat16 b) { a = a + b; return a; } inline bfloat16& operator -= (bfloat16& a, const bfloat16 b) { a = a - b; return a; } inline bfloat16& operator *= (bfloat16& a, const bfloat16 b) { a = a * b; return a; } inline bfloat16& operator /= (bfloat16& a, const bfloat16 b) { a = a / b; return a; } // Comparison operators inline bfloat16 operator == (const bfloat16 a, const bfloat16 b) { return float(a) == float(b); } inline bfloat16 operator != (const bfloat16 a, const bfloat16 b) { return !(a == b); } inline bfloat16 operator < (const bfloat16 a, const bfloat16 b) { return float(a) < float(b); } inline bfloat16 operator <= (const bfloat16 a, const bfloat16 b) { return float(a) <= float(b); } inline bfloat16 operator > (const bfloat16 a, const bfloat16 b) { return float(a) > float(b); } inline bfloat16 operator >= (const bfloat16 a, const bfloat16 b) { return float(a) >= float(b); }