26 #ifndef MSHADOW_BFLOAT_H_ 27 #define MSHADOW_BFLOAT_H_ 35 #define MSHADOW_BF16_OPERATOR_TYPE(RTYPE, ITYPE, OP) \ 36 MSHADOW_XINLINE RTYPE operator OP (ITYPE a, bf16_t b) { \ 37 return RTYPE(a OP float(b)); \ 39 MSHADOW_XINLINE RTYPE operator OP (bf16_t a, ITYPE b) { \ 40 return RTYPE(float(a) OP b); \ 43 #define MSHADOW_BF16_OPERATOR(RTYPE, OP) \ 44 MSHADOW_XINLINE RTYPE operator OP (bf16_t a, bf16_t b) { \ 45 return RTYPE(static_cast<float>(a) OP float(b)); \ 47 MSHADOW_BF16_OPERATOR_TYPE(float, float, OP) \ 48 MSHADOW_BF16_OPERATOR_TYPE(double, double, OP) \ 49 MSHADOW_BF16_OPERATOR_TYPE(float, int8_t, OP) \ 50 MSHADOW_BF16_OPERATOR_TYPE(float, uint8_t, OP) \ 51 MSHADOW_BF16_OPERATOR_TYPE(float, int32_t, OP) \ 52 MSHADOW_BF16_OPERATOR_TYPE(float, uint32_t, OP) \ 53 MSHADOW_BF16_OPERATOR_TYPE(float, int64_t, OP) \ 54 MSHADOW_BF16_OPERATOR_TYPE(float, uint64_t, OP) 56 #define MSHADOW_BF16_ASSIGNOP(AOP, OP) \ 57 template<typename T> \ 58 MSHADOW_XINLINE bf16_t operator AOP (const T& a) { \ 59 return *this = bf16_t(float(*this) OP float(a)); \ 61 template<typename T> \ 62 MSHADOW_XINLINE bf16_t operator AOP (const volatile T& a) volatile { \ 63 return *this = bf16_t(float(*this) OP float(a)); \ 66 #define MSHADOW_BF16_CONVERSIONOP(T) \ 67 MSHADOW_XINLINE operator T() const { \ 68 return T(BF16ToFloat(bf16_)); \ 70 MSHADOW_XINLINE operator T() const volatile { \ 71 return T(BF16ToFloat(bf16_)); \ 87 MSHADOW_XINLINE explicit bf16_t(
const double& value) { constructor(value); }
88 MSHADOW_XINLINE explicit bf16_t(
const int8_t& value) { constructor(value); }
89 MSHADOW_XINLINE explicit bf16_t(
const uint8_t& value) { constructor(value); }
90 MSHADOW_XINLINE explicit bf16_t(
const int32_t& value) { constructor(value); }
91 MSHADOW_XINLINE explicit bf16_t(
const uint32_t& value) { constructor(value); }
92 MSHADOW_XINLINE explicit bf16_t(
const int64_t& value) { constructor(value); }
93 MSHADOW_XINLINE explicit bf16_t(
const uint64_t& value) { constructor(value); }
107 return bf16_t(-
float(*
this));
117 return *
this = bf16_t(a);
127 return *
this = bf16_t(a);
138 return reinterpret_cast<const uint16_t*
>(&value)[1];
142 MSHADOW_XINLINE uint16_t FloatToBF16(
const volatile float& value)
const volatile {
143 return reinterpret_cast<const volatile uint16_t*
>(&value)[1];
148 reinterpret_cast<uint16_t*
>(&ret)[1] = value;
152 MSHADOW_XINLINE float BF16ToFloat(
const volatile uint16_t& value)
const volatile {
154 reinterpret_cast<uint16_t*
>(&ret)[1] = value;
160 bf16_ = FloatToBF16(
float(value));
181 #define MSHADOW_BF16_MIN mshadow::bfloat::bf16_t::Binary(0xFF7F); 182 #define MSHADOW_BF16_MAX mshadow::bfloat::bf16_t::Binary(0x7F7F); 185 #endif // MSHADOW_BFLOAT_H_ BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:101
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:93
#define MSHADOW_BF16_ASSIGNOP(AOP, OP)
Definition: bfloat.h:56
#define MSHADOW_XINLINE
Definition: base.h:230
#define MSHADOW_BF16_CONVERSIONOP(T)
Definition: bfloat.h:66
#define MSHADOW_BF16_OPERATOR(RTYPE, OP)
Definition: bfloat.h:43
class MSHADOW_ALIGNED(2) bf16_t
Definition: bfloat.h:74
overloaded + operator between half_t and bf16_t
Definition: base.h:334