26 #ifndef MSHADOW_HALF2_H_ 27 #define MSHADOW_HALF2_H_ 29 #if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) 30 #define MSHADOW_CUDA_HALF2 1 31 #include <cuda_fp16.h> 33 #define MSHADOW_CUDA_HALF2 0 43 #define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \ 44 template<typename T> \ 45 MSHADOW_XINLINE half2_t operator AOP (const T& a) { \ 46 return *this = half2_t(*this OP a); \ 51 #if MSHADOW_CUDA_HALF2 59 #if MSHADOW_CUDA_HALF2 69 #if MSHADOW_CUDA_HALF2 70 half2_ = __half2half2(__int2half_rz(a));
72 half_t2[0] = (half_t)a;
73 half_t2[1] = (half_t)a;
82 #if MSHADOW_CUDA_HALF2 83 return half2_t(__hneg2(half2_));
85 return half2_t(-half_t2[0], -half_t2[1]);
90 #if MSHADOW_CUDA_HALF2 93 half_t2[0] = a.half_t2[0];
94 half_t2[1] = a.half_t2[1];
107 #if MSHADOW_CUDA_HALF2 108 return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_),
109 __high2float(a.half2_) + __high2float(b.half2_)));
111 return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]);
116 #if MSHADOW_CUDA_HALF2 117 return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_),
118 __high2float(a.half2_) - __high2float(b.half2_)));
120 return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]);
125 #if MSHADOW_CUDA_HALF2 126 return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_),
127 __high2float(a.half2_) * __high2float(b.half2_)));
129 return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]);
134 #if MSHADOW_CUDA_HALF2 135 return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_),
136 __high2float(a.half2_) / __high2float(b.half2_)));
138 return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]);
143 #if MSHADOW_CUDA_HALF2 144 return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)),
145 ::fmod(__high2float(a.half2_), __high2float(b.half2_))));
147 return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1]));
152 #if MSHADOW_CUDA_HALF2 153 return __hbeq2(a.half2_, b.half2_);
155 return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]);
161 #endif // MSHADOW_HALF2_H_ class MSHADOW_ALIGNED(2) half_t
Definition: half.h:112
#define MSHADOW_HALF2_ASSIGNOP(AOP, OP)
Definition: half2.h:43
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:106
#define MSHADOW_XINLINE
Definition: base.h:230
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:115
MSHADOW_XINLINE bool operator==(half2_t a, half2_t b)
overloaded == operator for half2_t
Definition: half2.h:151
MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b)
overloaded * operator for half2_t
Definition: half2.h:124
MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b)
overloaded / operator for half2_t
Definition: half2.h:133
overloaded + operator between half_t and bf16_t
Definition: base.h:334
MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b)
overloaded % operator for half2_t
Definition: half2.h:142