#pragma once

#include <cstdint>
#include <limits>

uint16_t float2half(float val);
float half2float(uint16_t val);

class half
{
public:
	half() = default;

	half(float val) : _x(float2half(val))
	{
	}

	operator float() const
	{
		return half2float(_x);
	}

	inline half operator - () const
	{
		half result;
		result._x = _x ^ 0x8000;
		return result;
	}

	static half construct(const uint16_t half_data)
	{
		half result;
		result._x = half_data;
		return result;
	}

private:
	uint16_t _x;

	friend bool isinf(const half a);
	friend bool isnan(const half a);
	friend half abs(const half a);
};

// Arithmetic
inline half operator + (const half a, const half b)
{
	return half(float(a) + float(b));
}

inline half operator - (const half a, const half b)
{
	return half(float(a) - float(b));
}

inline half operator * (const half a, const half b)
{
	return half(float(a) * float(b));
}

inline half operator / (const half a, const half b)
{
	return half(float(a) / float(b));
}

inline half& operator += (half& a, const half b)
{
	a = a + b;
	return a;
}

inline half& operator -= (half& a, const half b)
{
	a = a - b;
	return a;
}

inline half& operator *= (half& a, const half b)
{
	a = a * b;
	return a;
}

inline half& operator /= (half& a, const half b)
{
	a = a / b;
	return a;
}

// Comparison operators
inline half operator == (const half a, const half b)
{
	return float(a) == float(b);
}

inline half operator != (const half a, const half b)
{
	return !(a == b);
}

inline half operator < (const half a, const half b)
{
	return float(a) < float(b);
}

inline half operator <= (const half a, const half b)
{
	return float(a) <= float(b);
}

inline half operator > (const half a, const half b)
{
	return float(a) > float(b);
}

inline half operator >= (const half a, const half b)
{
	return float(a) >= float(b);
}

inline bool isinf(const half a)
{
	return (a._x & 0x7FFF) == 0x7C00;
}

inline bool isnan(const half a)
{
	return (a._x & 0x7fff) > 0x7c00;
}

inline bool isfinite(const half a)
{
	return !isinf(a) && !isnan(a);
}

inline half abs(const half a)
{
	half result;
	result._x = a._x & 0x7FFF;
	return result;
}

namespace detail
{
	union uif
	{
		uint32_t u;
		float f;
	};
}

#ifdef __F16C__
#include <immintrin.h>
#endif

inline uint16_t float2half(float val)
{
#ifdef __F16C__
	return _cvtss_sh(val, 0);
#else
	detail::uif f;
	f.f = val;

	const detail::uif f32infty = { 255 << 23 };
	const detail::uif f16max = { (127 + 16) << 23 };
	const detail::uif denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
	unsigned int sign_mask = 0x80000000u;
	uint16_t o;
	o = static_cast<uint16_t>(0x0u);

	uint32_t sign = f.u & sign_mask;
	f.u ^= sign;

	// NOTE all the integer compares in this function can be safely
	// compiled into signed compares since all operands are below
	// 0x80000000. Important if you want fast straight SSE2 code
	// (since there's no unsigned PCMPGTD).

	if (f.u >= f16max.u) {  // result is Inf or NaN (all exponent bits set)
		o = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
	}
	else {  // (De)normalized number or zero
		if (f.u < (113 << 23)) {  // resulting FP16 is subnormal or zero
								  // use a magic value to align our 10 mantissa bits at the bottom of
								  // the float. as long as FP addition is round-to-nearest-even this
								  // just works.
			f.f += denorm_magic.f;

			// and one integer subtract of the bias later, we have our final float!
			o = static_cast<unsigned short>(f.u - denorm_magic.u);
		}
		else {
			unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd

													 // update exponent, rounding bias part 1
			f.u += ((unsigned int)(15 - 127) << 23) + 0xfff;
			// rounding bias part 2
			f.u += mant_odd;
			// take the bits!
			o = static_cast<unsigned short>(f.u >> 13);
		}
	}

	o |= static_cast<uint16_t>(sign >> 16);
	return o;
#endif
}

inline float half2float(uint16_t _x)
{
#ifdef __F16C__
	return _cvtsh_ss(_x, 0);
#else
	const detail::uif magic = { 113 << 23 };
	const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
	detail::uif o;

	o.u = (_x & 0x7fff) << 13;             // exponent/mantissa bits
	unsigned int exp = shifted_exp & o.u;   // just the exponent
	o.u += (127 - 15) << 23;                // exponent adjust

											// handle exponent special cases
	if (exp == shifted_exp) {     // Inf/NaN?
		o.u += (128 - 16) << 23;    // extra exp adjust
	}
	else if (exp == 0) {        // Zero/Denormal?
		o.u += 1 << 23;             // extra exp adjust
		o.f -= magic.f;             // renormalize
	}

	o.u |= (_x & 0x8000) << 16;    // sign bit
	return o.f;
#endif
}