You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
2.6 KiB

1 year ago
  1. #pragma once
  2. #include <cstdint>
  3. #include <cmath>
  4. class bfloat16
  5. {
  6. uint16_t value;
  7. union FloatBits
  8. {
  9. float f;
  10. uint16_t s[2];
  11. uint32_t i;
  12. };
  13. bfloat16(uint16_t x, bool) : value(x)
  14. {}
  15. public:
  16. bfloat16() = default;
  17. bfloat16(float x)
  18. {
  19. *this = x;
  20. }
  21. bfloat16& operator = (float x)
  22. {
  23. FloatBits b {x};
  24. switch(std::fpclassify(x))
  25. {
  26. case FP_SUBNORMAL:
  27. case FP_ZERO:
  28. // sign preserving zero (denormal go to zero)
  29. value = b.s[1];
  30. value &= 0x8000;
  31. break;
  32. case FP_INFINITE:
  33. value = b.s[1];
  34. break;
  35. case FP_NAN:
  36. // truncate and set MSB of the mantissa force QNAN
  37. value = b.s[1];
  38. value |= 1 << 6;
  39. break;
  40. case FP_NORMAL:
  41. // round to nearest even and truncate
  42. unsigned int rounding_bias = 0x00007FFF + (b.s[1] & 0x1);
  43. b.i += rounding_bias;
  44. value = b.s[1];
  45. break;
  46. }
  47. return *this;
  48. }
  49. operator float() const
  50. {
  51. FloatBits b {0};
  52. b.s[1] = value;
  53. return b.f;
  54. }
  55. inline bfloat16 operator - () const
  56. {
  57. return bfloat16(value ^ 0x8000, true);
  58. }
  59. };
  60. inline bfloat16 operator + (bfloat16 a, bfloat16 b)
  61. {
  62. return float(a) + float(b);
  63. }
  64. inline bfloat16 operator - (bfloat16 a, bfloat16 b)
  65. {
  66. return float(a) + float(-b);
  67. }
  68. inline bfloat16 operator * (bfloat16 a, bfloat16 b)
  69. {
  70. return float(a) * float(b);
  71. }
  72. inline bfloat16 operator / (bfloat16 a, bfloat16 b)
  73. {
  74. return float(a) / float(b);
  75. }
  76. inline bfloat16& operator += (bfloat16& a, const bfloat16 b)
  77. {
  78. a = a + b;
  79. return a;
  80. }
  81. inline bfloat16& operator -= (bfloat16& a, const bfloat16 b)
  82. {
  83. a = a - b;
  84. return a;
  85. }
  86. inline bfloat16& operator *= (bfloat16& a, const bfloat16 b)
  87. {
  88. a = a * b;
  89. return a;
  90. }
  91. inline bfloat16& operator /= (bfloat16& a, const bfloat16 b)
  92. {
  93. a = a / b;
  94. return a;
  95. }
  96. // Comparison operators
  97. inline bfloat16 operator == (const bfloat16 a, const bfloat16 b)
  98. {
  99. return float(a) == float(b);
  100. }
  101. inline bfloat16 operator != (const bfloat16 a, const bfloat16 b)
  102. {
  103. return !(a == b);
  104. }
  105. inline bfloat16 operator < (const bfloat16 a, const bfloat16 b)
  106. {
  107. return float(a) < float(b);
  108. }
  109. inline bfloat16 operator <= (const bfloat16 a, const bfloat16 b)
  110. {
  111. return float(a) <= float(b);
  112. }
  113. inline bfloat16 operator > (const bfloat16 a, const bfloat16 b)
  114. {
  115. return float(a) > float(b);
  116. }
  117. inline bfloat16 operator >= (const bfloat16 a, const bfloat16 b)
  118. {
  119. return float(a) >= float(b);
  120. }