You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

144 lines
2.6 KiB

#pragma once
#include <cstdint>
#include <cmath>
class bfloat16
{
uint16_t value;
union FloatBits
{
float f;
uint16_t s[2];
uint32_t i;
};
bfloat16(uint16_t x, bool) : value(x)
{}
public:
bfloat16() = default;
bfloat16(float x)
{
*this = x;
}
bfloat16& operator = (float x)
{
FloatBits b {x};
switch(std::fpclassify(x))
{
case FP_SUBNORMAL:
case FP_ZERO:
// sign preserving zero (denormal go to zero)
value = b.s[1];
value &= 0x8000;
break;
case FP_INFINITE:
value = b.s[1];
break;
case FP_NAN:
// truncate and set MSB of the mantissa force QNAN
value = b.s[1];
value |= 1 << 6;
break;
case FP_NORMAL:
// round to nearest even and truncate
unsigned int rounding_bias = 0x00007FFF + (b.s[1] & 0x1);
b.i += rounding_bias;
value = b.s[1];
break;
}
return *this;
}
operator float() const
{
FloatBits b {0};
b.s[1] = value;
return b.f;
}
inline bfloat16 operator - () const
{
return bfloat16(value ^ 0x8000, true);
}
};
inline bfloat16 operator + (bfloat16 a, bfloat16 b)
{
return float(a) + float(b);
}
inline bfloat16 operator - (bfloat16 a, bfloat16 b)
{
return float(a) + float(-b);
}
inline bfloat16 operator * (bfloat16 a, bfloat16 b)
{
return float(a) * float(b);
}
inline bfloat16 operator / (bfloat16 a, bfloat16 b)
{
return float(a) / float(b);
}
inline bfloat16& operator += (bfloat16& a, const bfloat16 b)
{
a = a + b;
return a;
}
inline bfloat16& operator -= (bfloat16& a, const bfloat16 b)
{
a = a - b;
return a;
}
inline bfloat16& operator *= (bfloat16& a, const bfloat16 b)
{
a = a * b;
return a;
}
inline bfloat16& operator /= (bfloat16& a, const bfloat16 b)
{
a = a / b;
return a;
}
// Comparison operators
inline bfloat16 operator == (const bfloat16 a, const bfloat16 b)
{
return float(a) == float(b);
}
inline bfloat16 operator != (const bfloat16 a, const bfloat16 b)
{
return !(a == b);
}
inline bfloat16 operator < (const bfloat16 a, const bfloat16 b)
{
return float(a) < float(b);
}
inline bfloat16 operator <= (const bfloat16 a, const bfloat16 b)
{
return float(a) <= float(b);
}
inline bfloat16 operator > (const bfloat16 a, const bfloat16 b)
{
return float(a) > float(b);
}
inline bfloat16 operator >= (const bfloat16 a, const bfloat16 b)
{
return float(a) >= float(b);
}