#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <cmath>
|
|
|
|
class bfloat16
|
|
{
|
|
uint16_t value;
|
|
|
|
union FloatBits
|
|
{
|
|
float f;
|
|
uint16_t s[2];
|
|
uint32_t i;
|
|
};
|
|
|
|
bfloat16(uint16_t x, bool) : value(x)
|
|
{}
|
|
|
|
public:
|
|
bfloat16() = default;
|
|
|
|
bfloat16(float x)
|
|
{
|
|
*this = x;
|
|
}
|
|
|
|
bfloat16& operator = (float x)
|
|
{
|
|
FloatBits b {x};
|
|
switch(std::fpclassify(x))
|
|
{
|
|
case FP_SUBNORMAL:
|
|
case FP_ZERO:
|
|
// sign preserving zero (denormal go to zero)
|
|
value = b.s[1];
|
|
value &= 0x8000;
|
|
break;
|
|
case FP_INFINITE:
|
|
value = b.s[1];
|
|
break;
|
|
case FP_NAN:
|
|
// truncate and set MSB of the mantissa force QNAN
|
|
value = b.s[1];
|
|
value |= 1 << 6;
|
|
break;
|
|
case FP_NORMAL:
|
|
// round to nearest even and truncate
|
|
unsigned int rounding_bias = 0x00007FFF + (b.s[1] & 0x1);
|
|
b.i += rounding_bias;
|
|
value = b.s[1];
|
|
break;
|
|
}
|
|
|
|
return *this;
|
|
}
|
|
|
|
operator float() const
|
|
{
|
|
FloatBits b {0};
|
|
b.s[1] = value;
|
|
return b.f;
|
|
}
|
|
|
|
inline bfloat16 operator - () const
|
|
{
|
|
return bfloat16(value ^ 0x8000, true);
|
|
}
|
|
};
|
|
|
|
inline bfloat16 operator + (bfloat16 a, bfloat16 b)
|
|
{
|
|
return float(a) + float(b);
|
|
}
|
|
|
|
inline bfloat16 operator - (bfloat16 a, bfloat16 b)
|
|
{
|
|
return float(a) + float(-b);
|
|
}
|
|
|
|
inline bfloat16 operator * (bfloat16 a, bfloat16 b)
|
|
{
|
|
return float(a) * float(b);
|
|
}
|
|
|
|
inline bfloat16 operator / (bfloat16 a, bfloat16 b)
|
|
{
|
|
return float(a) / float(b);
|
|
}
|
|
|
|
inline bfloat16& operator += (bfloat16& a, const bfloat16 b)
|
|
{
|
|
a = a + b;
|
|
return a;
|
|
}
|
|
|
|
inline bfloat16& operator -= (bfloat16& a, const bfloat16 b)
|
|
{
|
|
a = a - b;
|
|
return a;
|
|
}
|
|
|
|
inline bfloat16& operator *= (bfloat16& a, const bfloat16 b)
|
|
{
|
|
a = a * b;
|
|
return a;
|
|
}
|
|
|
|
inline bfloat16& operator /= (bfloat16& a, const bfloat16 b)
|
|
{
|
|
a = a / b;
|
|
return a;
|
|
}
|
|
|
|
// Comparison operators
|
|
inline bfloat16 operator == (const bfloat16 a, const bfloat16 b)
|
|
{
|
|
return float(a) == float(b);
|
|
}
|
|
|
|
inline bfloat16 operator != (const bfloat16 a, const bfloat16 b)
|
|
{
|
|
return !(a == b);
|
|
}
|
|
|
|
inline bfloat16 operator < (const bfloat16 a, const bfloat16 b)
|
|
{
|
|
return float(a) < float(b);
|
|
}
|
|
|
|
inline bfloat16 operator <= (const bfloat16 a, const bfloat16 b)
|
|
{
|
|
return float(a) <= float(b);
|
|
}
|
|
|
|
inline bfloat16 operator > (const bfloat16 a, const bfloat16 b)
|
|
{
|
|
return float(a) > float(b);
|
|
}
|
|
|
|
inline bfloat16 operator >= (const bfloat16 a, const bfloat16 b)
|
|
{
|
|
return float(a) >= float(b);
|
|
}
|