mxnet
half2.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_HALF2_H_
27 #define MSHADOW_HALF2_H_
28 
29 #if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
30  #define MSHADOW_CUDA_HALF2 1
31  #include <cuda_fp16.h>
32 #else
33  #define MSHADOW_CUDA_HALF2 0
34 #endif
35 
36 #include<math.h>
37 
39 namespace mshadow {
40 /* \brief name space for host/device portable half-precision floats */
41 namespace half {
42 
43 #define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \
44  template<typename T> \
45  MSHADOW_XINLINE half2_t operator AOP (const T& a) { \
46  return *this = half2_t(*this OP a); /* NOLINT(*)*/ \
47  } \
48 
49 class MSHADOW_ALIGNED(4) half2_t {
50  public:
51 #if MSHADOW_CUDA_HALF2
52  half2 half2_;
53 #else
54  half_t half_t2[2];
55 #endif
56 
57  MSHADOW_XINLINE half2_t() {}
58 
59 #if MSHADOW_CUDA_HALF2
60  MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {}
61 #else
62  MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) {
63  half_t2[0] = a;
64  half_t2[1] = b;
65  }
66 #endif
67 
68  MSHADOW_XINLINE explicit half2_t(int a) {
69 #if MSHADOW_CUDA_HALF2
70  half2_ = __half2half2(__int2half_rz(a));
71 #else
72  half_t2[0] = (half_t)a;
73  half_t2[1] = (half_t)a;
74 #endif
75  }
76 
77  MSHADOW_XINLINE half2_t operator+() {
78  return *this;
79  }
80 
81  MSHADOW_XINLINE half2_t operator-() {
82 #if MSHADOW_CUDA_HALF2
83  return half2_t(__hneg2(half2_));
84 #else
85  return half2_t(-half_t2[0], -half_t2[1]);
86 #endif
87  }
88 
89  MSHADOW_XINLINE half2_t operator=(const half2_t& a) {
90 #if MSHADOW_CUDA_HALF2
91  half2_ = a.half2_;
92 #else
93  half_t2[0] = a.half_t2[0];
94  half_t2[1] = a.half_t2[1];
95 #endif
96  return a;
97  }
98 
103 };
104 
106 MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) {
107 #if MSHADOW_CUDA_HALF2
108  return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_),
109  __high2float(a.half2_) + __high2float(b.half2_)));
110 #else
111  return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]);
112 #endif
113 }
115 MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) {
116 #if MSHADOW_CUDA_HALF2
117  return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_),
118  __high2float(a.half2_) - __high2float(b.half2_)));
119 #else
120  return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]);
121 #endif
122 }
124 MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) {
125 #if MSHADOW_CUDA_HALF2
126  return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_),
127  __high2float(a.half2_) * __high2float(b.half2_)));
128 #else
129  return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]);
130 #endif
131 }
133 MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) {
134 #if MSHADOW_CUDA_HALF2
135  return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_),
136  __high2float(a.half2_) / __high2float(b.half2_)));
137 #else
138  return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]);
139 #endif
140 }
142 MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) {
143 #if MSHADOW_CUDA_HALF2
144  return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)),
145  ::fmod(__high2float(a.half2_), __high2float(b.half2_))));
146 #else
147  return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1]));
148 #endif
149 }
151 MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) {
152 #if MSHADOW_CUDA_HALF2
153  return __hbeq2(a.half2_, b.half2_);
154 #else
155  return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]);
156 #endif
157 }
158 
159 } // namespace half
160 } // namespace mshadow
161 #endif // MSHADOW_HALF2_H_
class MSHADOW_ALIGNED(2) half_t
Definition: half.h:112
#define MSHADOW_HALF2_ASSIGNOP(AOP, OP)
Definition: half2.h:43
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:106
#define MSHADOW_XINLINE
Definition: base.h:230
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:115
MSHADOW_XINLINE bool operator==(half2_t a, half2_t b)
overloaded == operator for half2_t
Definition: half2.h:151
MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b)
overloaded * operator for half2_t
Definition: half2.h:124
MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b)
overloaded / operator for half2_t
Definition: half2.h:133
overloaded + operator between half_t and bf16_t
Definition: base.h:334
MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b)
overloaded % operator for half2_t
Definition: half2.h:142