You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
4.5 KiB

1 year ago
  1. #pragma once
  2. #include <cstdint>
  3. #include <limits>
  4. uint16_t float2half(float val);
  5. float half2float(uint16_t val);
  6. class half
  7. {
  8. public:
  9. half() = default;
  10. half(float val) : _x(float2half(val))
  11. {
  12. }
  13. operator float() const
  14. {
  15. return half2float(_x);
  16. }
  17. inline half operator - () const
  18. {
  19. half result;
  20. result._x = _x ^ 0x8000;
  21. return result;
  22. }
  23. static half construct(const uint16_t half_data)
  24. {
  25. half result;
  26. result._x = half_data;
  27. return result;
  28. }
  29. private:
  30. uint16_t _x;
  31. friend bool isinf(const half a);
  32. friend bool isnan(const half a);
  33. friend half abs(const half a);
  34. };
  35. // Arithmetic
  36. inline half operator + (const half a, const half b)
  37. {
  38. return half(float(a) + float(b));
  39. }
  40. inline half operator - (const half a, const half b)
  41. {
  42. return half(float(a) - float(b));
  43. }
  44. inline half operator * (const half a, const half b)
  45. {
  46. return half(float(a) * float(b));
  47. }
  48. inline half operator / (const half a, const half b)
  49. {
  50. return half(float(a) / float(b));
  51. }
  52. inline half& operator += (half& a, const half b)
  53. {
  54. a = a + b;
  55. return a;
  56. }
  57. inline half& operator -= (half& a, const half b)
  58. {
  59. a = a - b;
  60. return a;
  61. }
  62. inline half& operator *= (half& a, const half b)
  63. {
  64. a = a * b;
  65. return a;
  66. }
  67. inline half& operator /= (half& a, const half b)
  68. {
  69. a = a / b;
  70. return a;
  71. }
  72. // Comparison operators
  73. inline half operator == (const half a, const half b)
  74. {
  75. return float(a) == float(b);
  76. }
  77. inline half operator != (const half a, const half b)
  78. {
  79. return !(a == b);
  80. }
  81. inline half operator < (const half a, const half b)
  82. {
  83. return float(a) < float(b);
  84. }
  85. inline half operator <= (const half a, const half b)
  86. {
  87. return float(a) <= float(b);
  88. }
  89. inline half operator > (const half a, const half b)
  90. {
  91. return float(a) > float(b);
  92. }
  93. inline half operator >= (const half a, const half b)
  94. {
  95. return float(a) >= float(b);
  96. }
  97. inline bool isinf(const half a)
  98. {
  99. return (a._x & 0x7FFF) == 0x7C00;
  100. }
  101. inline bool isnan(const half a)
  102. {
  103. return (a._x & 0x7fff) > 0x7c00;
  104. }
  105. inline bool isfinite(const half a)
  106. {
  107. return !isinf(a) && !isnan(a);
  108. }
  109. inline half abs(const half a)
  110. {
  111. half result;
  112. result._x = a._x & 0x7FFF;
  113. return result;
  114. }
  115. namespace detail
  116. {
  117. union uif
  118. {
  119. uint32_t u;
  120. float f;
  121. };
  122. }
  123. #ifdef __F16C__
  124. #include <immintrin.h>
  125. #endif
  126. inline uint16_t float2half(float val)
  127. {
  128. #ifdef __F16C__
  129. return _cvtss_sh(val, 0);
  130. #else
  131. detail::uif f;
  132. f.f = val;
  133. const detail::uif f32infty = { 255 << 23 };
  134. const detail::uif f16max = { (127 + 16) << 23 };
  135. const detail::uif denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
  136. unsigned int sign_mask = 0x80000000u;
  137. uint16_t o;
  138. o = static_cast<uint16_t>(0x0u);
  139. uint32_t sign = f.u & sign_mask;
  140. f.u ^= sign;
  141. // NOTE all the integer compares in this function can be safely
  142. // compiled into signed compares since all operands are below
  143. // 0x80000000. Important if you want fast straight SSE2 code
  144. // (since there's no unsigned PCMPGTD).
  145. if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
  146. o = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
  147. }
  148. else { // (De)normalized number or zero
  149. if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
  150. // use a magic value to align our 10 mantissa bits at the bottom of
  151. // the float. as long as FP addition is round-to-nearest-even this
  152. // just works.
  153. f.f += denorm_magic.f;
  154. // and one integer subtract of the bias later, we have our final float!
  155. o = static_cast<unsigned short>(f.u - denorm_magic.u);
  156. }
  157. else {
  158. unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
  159. // update exponent, rounding bias part 1
  160. f.u += ((unsigned int)(15 - 127) << 23) + 0xfff;
  161. // rounding bias part 2
  162. f.u += mant_odd;
  163. // take the bits!
  164. o = static_cast<unsigned short>(f.u >> 13);
  165. }
  166. }
  167. o |= static_cast<uint16_t>(sign >> 16);
  168. return o;
  169. #endif
  170. }
  171. inline float half2float(uint16_t _x)
  172. {
  173. #ifdef __F16C__
  174. return _cvtsh_ss(_x, 0);
  175. #else
  176. const detail::uif magic = { 113 << 23 };
  177. const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
  178. detail::uif o;
  179. o.u = (_x & 0x7fff) << 13; // exponent/mantissa bits
  180. unsigned int exp = shifted_exp & o.u; // just the exponent
  181. o.u += (127 - 15) << 23; // exponent adjust
  182. // handle exponent special cases
  183. if (exp == shifted_exp) { // Inf/NaN?
  184. o.u += (128 - 16) << 23; // extra exp adjust
  185. }
  186. else if (exp == 0) { // Zero/Denormal?
  187. o.u += 1 << 23; // extra exp adjust
  188. o.f -= magic.f; // renormalize
  189. }
  190. o.u |= (_x & 0x8000) << 16; // sign bit
  191. return o.f;
  192. #endif
  193. }