mxnet
base.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_BASE_H_
27 #define MSHADOW_BASE_H_
28 #ifdef _MSC_VER
29 #ifndef _CRT_SECURE_NO_WARNINGS
30 #define _CRT_SECURE_NO_WARNINGS
31 #endif
32 #ifndef _CRT_SECURE_NO_DEPRECATE
33 #define _CRT_SECURE_NO_DEPRECATE
34 #endif
35 #ifndef NOMINMAX
36 #define NOMINMAX
37 #endif
38 #endif
39 #include <cmath>
40 #include <cstdio>
41 #include <cfloat>
42 #include <climits>
43 #include <algorithm>
44 #include <functional>
45 #include <sstream>
46 #include <string>
47 
48 #ifdef _MSC_VER
49 typedef signed char int8_t;
51 typedef __int16 int16_t;
52 typedef __int32 int32_t;
53 typedef __int64 int64_t;
54 typedef unsigned char uint8_t;
55 typedef unsigned __int16 uint16_t;
56 typedef unsigned __int32 uint32_t;
57 typedef unsigned __int64 uint64_t;
59 #else
60 #include <inttypes.h>
61 #endif
62 // macro defintiions
67 #ifndef MSHADOW_STAND_ALONE
68 #define MSHADOW_STAND_ALONE 0
69 #endif
70 
71 #ifndef MSHADOW_ALLOC_PAD
72 #define MSHADOW_ALLOC_PAD true
73 #endif
74 
82 #ifndef MSHADOW_MIN_PAD_RATIO
83  #define MSHADOW_MIN_PAD_RATIO 2
84 #endif
85 
86 #if MSHADOW_STAND_ALONE
87  #define MSHADOW_USE_CBLAS 0
88  #define MSHADOW_USE_MKL 0
89  #define MSHADOW_USE_CUDA 0
90 #endif
91 
96 #ifndef MSHADOW_FORCE_STREAM
97 #define MSHADOW_FORCE_STREAM 1
98 #endif
99 
101 #ifndef MSHADOW_USE_CBLAS
102  #define MSHADOW_USE_CBLAS 0
103 #endif
104 
105 #ifndef MSHADOW_USE_MKL
106  #define MSHADOW_USE_MKL 1
107 #endif
108 
109 #ifndef MSHADOW_USE_ARMPL
110  #define MSHADOW_USE_ARMPL 0
111 #endif
112 
117 #ifndef MSHADOW_USE_CUDA
118  #define MSHADOW_USE_CUDA 1
119 #endif
120 
124 #ifndef MSHADOW_USE_CUDNN
125  #define MSHADOW_USE_CUDNN 0
126 #endif
127 
131 #ifndef MSHADOW_USE_CUSOLVER
132  #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA
133 #endif
134 
139 #ifndef MSHADOW_OLD_CUDA
140 #define MSHADOW_OLD_CUDA 0
141 #endif
142 
146 #ifndef MSHADOW_IN_CXX11
147  #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
148  __cplusplus >= 201103L || defined(_MSC_VER))
149  #define MSHADOW_IN_CXX11 1
150  #else
151  #define MSHADOW_IN_CXX11 0
152  #endif
153 #endif
154 
156 #ifndef MSHADOW_USE_SSE
157  #define MSHADOW_USE_SSE 1
158 #endif
159 
161 #ifndef MSHADOW_USE_F16C
162  #if defined(_MSC_VER) || defined(__CUDACC__)
163  #define MSHADOW_USE_F16C 0
164  #elif defined(__clang__) && \
165  ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1)))
166  #define MSHADOW_USE_F16C 0
167  #else
168  #define MSHADOW_USE_F16C 1
169  #endif
170 #endif
171 
173 #ifndef MSHADOW_USE_NVML
174  #define MSHADOW_USE_NVML 0
175 #endif
176 // SSE is conflict with cudacc
177 #ifdef __CUDACC__
178  #undef MSHADOW_USE_SSE
179  #define MSHADOW_USE_SSE 0
180 #endif
181 
182 #if MSHADOW_USE_CBLAS
183 extern "C" {
184  #if MSHADOW_USE_ARMPL
185  #define armpl_singlecomplex_t float _Complex
186  #define armpl_doublecomplex_t double _Complex
187  #endif
188  #include <cblas.h>
189 }
190 #elif MSHADOW_USE_MKL
191  #include <mkl_blas.h>
192  #include <mkl_cblas.h>
193  #include <mkl_vsl.h>
194  #include <mkl_vsl_functions.h>
195  #include <mkl_version.h>
196 #endif
197 
198 #if MSHADOW_USE_CUDA
199  #include <cuda.h>
200  #include <cublas_v2.h>
201  #include <curand.h>
202 #endif
203 
204 #if MSHADOW_USE_CUDNN == 1
205  #include <cudnn.h>
206 #endif
207 
208 #if MSHADOW_USE_CUSOLVER == 1
209  #include <cusolverDn.h>
210 #endif
211 
212 #if MSHADOW_USE_NVML
213  #include <nvml.h>
214 #endif
215 
216 // --------------------------------
217 // MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code
218 #ifdef MSHADOW_XINLINE
219  #error "MSHADOW_XINLINE must not be defined"
220 #endif
221 #ifdef _MSC_VER
222 #define MSHADOW_FORCE_INLINE __forceinline
223 #pragma warning(disable : 4068)
224 #else
225 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline))
226 #endif
227 #ifdef __CUDACC__
228  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__
229 #else
230  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE
231 #endif
232 
233 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE
234 
235 #if defined(__GXX_EXPERIMENTAL_CXX0X) ||\
236  defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
237  #define MSHADOW_CONSTEXPR constexpr
238 #else
239  #define MSHADOW_CONSTEXPR const
240 #endif
241 
248 #ifndef MSHADOW_DEFAULT_DTYPE
249 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t
250 #endif
251 
255 #ifndef MSHADOW_USE_GLOG
256 #define MSHADOW_USE_GLOG DMLC_USE_GLOG
257 #endif // MSHADOW_USE_GLOG
258 
259 #if DMLC_USE_CXX11
260 #define MSHADOW_THROW_EXCEPTION noexcept(false)
261 #define MSHADOW_NO_EXCEPTION noexcept(true)
262 #else
263 #define MSHADOW_THROW_EXCEPTION
264 #define MSHADOW_NO_EXCEPTION
265 #endif
266 
267 #if defined(_MSC_VER)
268 #define MSHADOW_ALIGNED(x) __declspec(align(x))
269 #else
270 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x)))
271 #endif
272 
278 #define MSHADOW_CUDA_CALL(func) \
279  { \
280  cudaError_t e = (func); \
281  if (e == cudaErrorCudartUnloading) { \
282  throw dmlc::Error(cudaGetErrorString(e)); \
283  } \
284  CHECK(e == cudaSuccess) \
285  << "CUDA: " << cudaGetErrorString(e); \
286  }
287 
292 #define MSHADOW_CATCH_ERROR(func) \
293  { \
294  try { \
295  (func); \
296  } catch (const dmlc::Error &e) { \
297  std::string what = e.what(); \
298  if (what.find("driver shutting down") == std::string::npos) { \
299  LOG(ERROR) << "Ignore CUDA Error " << what; \
300  } \
301  } \
302  }
303 
304 #include "./half.h"
305 #include "./half2.h"
306 #include "./bfloat.h"
307 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
308  MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
309  return float(a) OP float(b); /* NOLINT(*) */ \
310  } \
311  MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \
312  return float(a) OP float(b); /* NOLINT(*) */ \
313  }
314 
316 MSHADOW_HALF_BF_OPERATOR(float, +)
318 MSHADOW_HALF_BF_OPERATOR(float, -)
320 MSHADOW_HALF_BF_OPERATOR(float, *)
322 MSHADOW_HALF_BF_OPERATOR(float, /)
328 MSHADOW_HALF_BF_OPERATOR(bool, >=)
330 MSHADOW_HALF_BF_OPERATOR(bool, <=)
331 
332 #include "./logging.h"
333 
334 namespace mshadow {
336 const unsigned kRandBufferSize = 1000000;
338 const float kPi = 3.1415926f;
340 #if MSHADOW_INT64_TENSOR_SIZE == 1
341  typedef int64_t index_t;
342 #else
343  typedef int32_t index_t;
344 #endif
345 
346 #ifdef _WIN32
347 
348  typedef int64_t openmp_index_t;
349 #else
350 
351  typedef index_t openmp_index_t;
352 #endif
353 
355 typedef float default_real_t;
356 
358 enum TypeFlag {
359  kFloat32 = 0,
360  kFloat64 = 1,
361  kFloat16 = 2,
362  kUint8 = 3,
363  kInt32 = 4,
364  kInt8 = 5,
365  kInt64 = 6,
366  kBool = 7,
367  kInt16 = 8,
368  kUint16 = 9,
369  kUint32 = 10,
370  kUint64 = 11,
372 };
373 
374 template<typename DType>
375 struct DataType;
376 
377 template<>
378 struct DataType<float> {
379  static const int kFlag = kFloat32;
380  static const int kLanes = 1;
381 #if MSHADOW_USE_CUDA
382 #if (CUDA_VERSION >= 8000)
383  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
384 #endif
385 #if MSHADOW_USE_CUDNN
386  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
387  typedef float ScaleType;
388 #endif
389 #endif
390 };
391 template<>
392 struct DataType<double> {
393  static const int kFlag = kFloat64;
394  static const int kLanes = 1;
395 #if MSHADOW_USE_CUDA
396 #if (CUDA_VERSION >= 8000)
397  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
398 #endif
399 #if MSHADOW_USE_CUDNN
400  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
401  typedef double ScaleType;
402 #endif
403 #endif
404 };
405 template<>
406 struct DataType<half::half_t> {
407  static const int kFlag = kFloat16;
408  static const int kLanes = 1;
409 #if MSHADOW_USE_CUDA
410 #if (CUDA_VERSION >= 8000)
411  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
412 #endif
413 #if MSHADOW_USE_CUDNN
414  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
415  typedef float ScaleType;
416 #endif
417 #endif
418 };
419 template<>
420 struct DataType<half::half2_t> {
421  static const int kFlag = kFloat16;
422  static const int kLanes = 2;
423 };
424 template<>
425 struct DataType<bfloat::bf16_t> {
426  static const int kFlag = kBfloat16;
427  static const int kLanes = 1;
428 };
429 template<>
430 struct DataType<uint8_t> {
431  static const int kFlag = kUint8;
432  static const int kLanes = 1;
433 #if MSHADOW_USE_CUDA
434 #if (CUDA_VERSION >= 8000)
435  static const cudaDataType_t kCudaFlag = CUDA_R_8U;
436 #endif
437 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
438  // no uint8 in cudnn for now
439  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
440  typedef uint8_t ScaleType;
441 #endif
442 #endif
443 };
444 template<>
445 struct DataType<int8_t> {
446  static const int kFlag = kInt8;
447  static const int kLanes = 1;
448 #if MSHADOW_USE_CUDA
449 #if (CUDA_VERSION >= 8000)
450  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
451 #endif
452 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
453  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
454  typedef int8_t ScaleType;
455 #endif
456 #endif
457 };
458 template<>
459 struct DataType<int32_t> {
460  static const int kFlag = kInt32;
461  static const int kLanes = 1;
462 #if MSHADOW_USE_CUDA
463 #if (CUDA_VERSION >= 8000)
464  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
465 #endif
466 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
467  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
468  typedef int32_t ScaleType;
469 #endif
470 #endif
471 };
472 template<>
473 struct DataType<int64_t> {
474  static const int kFlag = kInt64;
475  static const int kLanes = 1;
476 };
477 template<>
478 struct DataType<bool> {
479  static const int kFlag = kBool;
480  static const int kLanes = 1;
481 };
482 
485 
488  kNCHW = 0,
491 
492  kNCW = 1 << 3,
495 
496  kNCDHW = 1 << 5,
499 };
500 
501 template<int layout>
502 struct LayoutType;
503 
504 template<>
505 struct LayoutType<kNCHW> {
506  static const index_t kNdim = 4;
507 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
508  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
509 #else
510  static const int kCudnnFlag = -1;
511 #endif
512 };
513 
514 template<>
515 struct LayoutType<kNHWC> {
516  static const index_t kNdim = 4;
517 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
518  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
519 #else
520  static const int kCudnnFlag = -1;
521 #endif
522 };
523 
525 const int default_layout = kNCHW;
526 
527 template<>
529  static const index_t kNdim = 5;
530 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
531  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
532 #else
533  static const int kCudnnFlag = -1;
534 #endif
535 };
536 
537 template<>
539  static const index_t kNdim = 5;
540 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
541  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
542 #else
543  static const int kCudnnFlag = -1;
544 #endif
545 };
546 
549 
551 namespace op {
552 // binary operator
554 struct mul{
556  template<typename DType>
557  MSHADOW_XINLINE static DType Map(DType a, DType b) {
558  return a * b;
559  }
560 };
562 struct plus {
564  template<typename DType>
565  MSHADOW_XINLINE static DType Map(DType a, DType b) {
566  return a + b;
567  }
568 };
570 struct minus {
572  template<typename DType>
573  MSHADOW_XINLINE static DType Map(DType a, DType b) {
574  return a - b;
575  }
576 };
578 struct div {
580  template<typename DType>
581  MSHADOW_XINLINE static DType Map(DType a, DType b) {
582  return a / b;
583  }
584 };
586 struct right {
588  template<typename DType>
589  MSHADOW_XINLINE static DType Map(DType a, DType b) {
590  return b;
591  }
592 };
593 // unary operator/ function: example
594 // these operators can be defined by user,
595 // in the same style as binary and unary operator
596 // to use, simply write F<op::identity>( src )
598 struct identity{
600  template<typename DType>
601  MSHADOW_XINLINE static DType Map(DType a) {
602  return a;
603  }
604 };
605 } // namespace op
607 namespace sv {
609 struct saveto {
611  template<typename DType>
612  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
613  a = b;
614  }
616  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
618  inline static default_real_t BetaBLAS(void) { return 0.0f; }
620  typedef op::right OPType;
621 };
623 struct plusto {
625  template<typename DType>
626  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
627  a += b;
628  }
630  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
632  inline static default_real_t BetaBLAS(void) { return 1.0f; }
634  typedef op::plus OPType;
635 };
637 struct minusto {
639  template<typename DType>
640  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
641  a -= b;
642  }
644  inline static default_real_t AlphaBLAS(void) { return -1.0f; }
646  inline static default_real_t BetaBLAS(void) { return 1.0f; }
648  typedef op::minus OPType;
649 };
651 struct multo {
653  template<typename DType>
654  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
655  a *= b;
656  }
658  typedef op::mul OPType;
659 };
661 struct divto {
663  template<typename DType>
664  MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*)
665  a /= b;
666  }
668  typedef op::div OPType;
669 };
670 } // namespace sv
671 
672 #ifndef __CUDA_ARCH__
673 using std::isnan;
674 using std::isinf;
675 #endif
676 
680 namespace isnan_typed {
681  template<typename DType>
682  MSHADOW_XINLINE bool IsNan(volatile DType val) {
683  return false;
684  }
685  template<>
686  MSHADOW_XINLINE bool IsNan(volatile float val) {
687  return isnan(val);
688  }
689  template<>
690  MSHADOW_XINLINE bool IsNan(volatile double val) {
691  return isnan(val);
692  }
693  template<>
694  MSHADOW_XINLINE bool IsNan(volatile long double val) {
695  return isnan(val);
696  }
697  template<>
698  MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
699  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
700  }
701 } // namespace isnan_typed
702 
706 namespace isinf_typed {
707  template<typename DType>
708  MSHADOW_XINLINE bool IsInf(volatile DType val) {
709  return false;
710  }
711  template<>
712  MSHADOW_XINLINE bool IsInf(volatile float val) {
713  return isinf(val);
714  }
715  template<>
716  MSHADOW_XINLINE bool IsInf(volatile double val) {
717  return isinf(val);
718  }
719  template<>
720  MSHADOW_XINLINE bool IsInf(volatile long double val) {
721  return isinf(val);
722  }
723  template<>
724  MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
725  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == MSHADOW_HALF_EXPONENT_BITS;
726  }
727 } // namespace isinf_typed
728 
730 namespace red {
731 namespace limits {
736 template<typename DType>
737 MSHADOW_XINLINE DType MinValue(void);
739 template<>
741  return -FLT_MAX;
742 }
744 template<>
746  return -DBL_MAX;
747 }
749 template<>
750 MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
751  return MSHADOW_HALF_MIN;
752 }
754 template<>
755 MSHADOW_XINLINE bfloat::bf16_t MinValue<bfloat::bf16_t>(void) {
756  return MSHADOW_BF16_MIN;
757 }
759 template<>
761  return 0;
762 }
764 template<>
766  return SCHAR_MIN;
767 }
769 template<>
771  return INT_MIN;
772 }
774 template<>
776  return LLONG_MIN;
777 }
779 template<>
781  return false;
782 }
784 template<>
786  return 0;
787 }
788 
793 template<typename DType>
795  return MinValue<DType>();
796 }
798 template<>
800  return -HUGE_VALF;
801 }
803 template<>
805  return -HUGE_VAL;
806 }
808 template<>
809 MSHADOW_XINLINE half::half_t NegInfValue<half::half_t>(void) {
810  return half::half_t::Binary(
812 }
813 
818 template<typename DType>
819 MSHADOW_XINLINE DType MaxValue(void);
821 template<>
823  return FLT_MAX;
824 }
826 template<>
828  return DBL_MAX;
829 }
831 template<>
832 MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) {
833  return MSHADOW_HALF_MAX;
834 }
836 template<>
837 MSHADOW_XINLINE bfloat::bf16_t MaxValue<bfloat::bf16_t>(void) {
838  return MSHADOW_BF16_MAX;
839 }
841 template<>
843  return UCHAR_MAX;
844 }
846 template<>
848  return SCHAR_MAX;
849 }
851 template<>
853  return INT_MAX;
854 }
856 template<>
858  return LLONG_MAX;
859 }
861 template<>
863  return true;
864 }
866 template<>
868  return -1;
869 }
870 
875 template<typename DType>
877  return MaxValue<DType>();
878 }
880 template<>
882  return HUGE_VALF;
883 }
885 template<>
887  return HUGE_VAL;
888 }
890 template<>
891 MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
892  return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
893 }
894 
895 } // namespace limits
896 
898 struct sum {
900  template<typename DType>
901  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
902  dst += src;
903  }
905  template<typename DType>
906  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
907  DType y = src - residual;
908  DType t = dst + y;
909  if (isinf_typed::IsInf(t)) {
910  residual = 0;
911  } else {
912  residual = (t - dst) - y;
913  }
914  dst = t;
915  }
917  template<typename DType>
918  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
919  Reduce(dst_val, src_val);
920  }
922  template<typename DType>
923  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
924  DType t1 = dst_val + src_val;
925  if (isinf_typed::IsInf(t1)) {
926  dst_val = t1;
927  dst_residual = 0;
928  } else {
929  DType e = t1 - dst_val;
930  DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
931  dst_val = t1 + t2;
932  dst_residual = t2 - (dst_val - t1);
933  }
934  }
936  template<typename DType>
937  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
939  template<typename DType>
940  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
945  template<typename DType>
946  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
947  return 1;
948  }
952  template<typename DType>
953  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
954  initv = 0;
955  }
959  template<typename DType>
960  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
961  SetInitValue(initv);
962  residual = 0;
963  }
964 };
966 struct maximum {
968  template<typename DType>
969  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
970  if (!isnan_typed::IsNan(dst)) {
971  if (!(dst >= src)) dst = src;
972  }
973  }
975  template<typename DType>
976  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
977  Reduce(dst, src);
978  }
980  template<typename DType>
981  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
982  Reduce(dst_val, src_val);
983  }
985  template<typename DType>
986  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
987  Reduce(dst_val, src_val);
988  }
990  template<typename DType>
991  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
993  template<typename DType>
994  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
999  template<typename DType>
1000  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1001  return redres == redsrc ? 1: 0;
1002  }
1006  template<typename DType>
1007  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1008  initv = limits::NegInfValue<DType>();
1009  }
1013  template<typename DType>
1014  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1015  SetInitValue(initv);
1016  }
1017 };
1019 struct minimum {
1021  template<typename DType>
1022  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1023  if (!isnan_typed::IsNan(dst)) {
1024  if (!(dst <= src)) dst = src;
1025  }
1026  }
1028  template<typename DType>
1029  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
1030  Reduce(dst, src);
1031  }
1033  template<typename DType>
1034  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1035  Reduce(dst_val, src_val);
1036  }
1038  template<typename DType>
1039  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1040  Reduce(dst_val, src_val);
1041  }
1043  template<typename DType>
1044  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1046  template<typename DType>
1047  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1052  template<typename DType>
1053  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1054  return redres == redsrc ? 1: 0;
1055  }
1059  template<typename DType>
1060  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1061  initv = limits::PosInfValue<DType>();
1062  }
1066  template<typename DType>
1067  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1068  SetInitValue(initv);
1069  }
1070 };
1071 } // namespace red
1072 
1073 #ifndef __NVCC__
1074 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1075  switch (type) { \
1076  case mshadow::kFloat32: \
1077  { \
1078  typedef float DType; \
1079  {__VA_ARGS__} \
1080  } \
1081  break; \
1082  case mshadow::kFloat64: \
1083  { \
1084  typedef double DType; \
1085  {__VA_ARGS__} \
1086  } \
1087  break; \
1088  case mshadow::kFloat16: \
1089  { \
1090  typedef mshadow::half::half_t DType; \
1091  {__VA_ARGS__} \
1092  } \
1093  break; \
1094  case mshadow::kBfloat16: \
1095  { \
1096  typedef mshadow::bfloat::bf16_t DType; \
1097  {__VA_ARGS__} \
1098  } \
1099  break; \
1100  case mshadow::kUint8: \
1101  { \
1102  typedef uint8_t DType; \
1103  {__VA_ARGS__} \
1104  } \
1105  break; \
1106  case mshadow::kInt8: \
1107  { \
1108  typedef int8_t DType; \
1109  {__VA_ARGS__} \
1110  } \
1111  break; \
1112  case mshadow::kInt32: \
1113  { \
1114  typedef int32_t DType; \
1115  {__VA_ARGS__} \
1116  } \
1117  break; \
1118  case mshadow::kInt64: \
1119  { \
1120  typedef int64_t DType; \
1121  {__VA_ARGS__} \
1122  } \
1123  break; \
1124  default: \
1125  LOG(FATAL) << "Unknown type enum " << type; \
1126  }
1127 #else
1128 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1129  switch (type) { \
1130  case mshadow::kFloat32: \
1131  { \
1132  typedef float DType; \
1133  {__VA_ARGS__} \
1134  } \
1135  break; \
1136  case mshadow::kFloat64: \
1137  { \
1138  typedef double DType; \
1139  {__VA_ARGS__} \
1140  } \
1141  break; \
1142  case mshadow::kFloat16: \
1143  { \
1144  typedef mshadow::half::half_t DType; \
1145  {__VA_ARGS__} \
1146  } \
1147  break; \
1148  case mshadow::kUint8: \
1149  { \
1150  typedef uint8_t DType; \
1151  {__VA_ARGS__} \
1152  } \
1153  break; \
1154  case mshadow::kInt8: \
1155  { \
1156  typedef int8_t DType; \
1157  {__VA_ARGS__} \
1158  } \
1159  break; \
1160  case mshadow::kInt32: \
1161  { \
1162  typedef int32_t DType; \
1163  {__VA_ARGS__} \
1164  } \
1165  break; \
1166  case mshadow::kInt64: \
1167  { \
1168  typedef int64_t DType; \
1169  {__VA_ARGS__} \
1170  } \
1171  break; \
1172  default: \
1173  LOG(FATAL) << "Unknown type enum " << type; \
1174  }
1175 #endif
1176 
1177 #define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
1178  switch (type) { \
1179  case mshadow::kFloat32: \
1180  { \
1181  typedef float DType; \
1182  {__VA_ARGS__} \
1183  } \
1184  break; \
1185  case mshadow::kFloat64: \
1186  { \
1187  typedef double DType; \
1188  {__VA_ARGS__} \
1189  } \
1190  break; \
1191  case mshadow::kFloat16: \
1192  { \
1193  typedef mshadow::half::half2_t DType; \
1194  {__VA_ARGS__} \
1195  } \
1196  break; \
1197  case mshadow::kUint8: \
1198  { \
1199  typedef uint8_t DType; \
1200  {__VA_ARGS__} \
1201  } \
1202  break; \
1203  case mshadow::kInt32: \
1204  { \
1205  typedef int32_t DType; \
1206  {__VA_ARGS__} \
1207  } \
1208  break; \
1209  case mshadow::kInt64: \
1210  { \
1211  typedef int64_t DType; \
1212  {__VA_ARGS__} \
1213  } \
1214  break; \
1215  default: \
1216  LOG(FATAL) << "Unknown type enum " << type; \
1217  }
1218 
1219 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
1220  switch (type) { \
1221  case mshadow::kFloat32: \
1222  { \
1223  typedef float DType; \
1224  {__VA_ARGS__} \
1225  } \
1226  break; \
1227  case mshadow::kFloat64: \
1228  { \
1229  typedef double DType; \
1230  {__VA_ARGS__} \
1231  } \
1232  break; \
1233  default: \
1234  LOG(FATAL) << "This operation only supports " \
1235  "32-bit and 64-bit floating point"; \
1236  }
1237 
1238 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \
1239  switch (type) { \
1240  case mshadow::kFloat32: \
1241  { \
1242  typedef float DType; \
1243  {__VA_ARGS__} \
1244  } \
1245  break; \
1246  case mshadow::kFloat64: \
1247  { \
1248  typedef double DType; \
1249  {__VA_ARGS__} \
1250  } \
1251  break; \
1252  case mshadow::kFloat16: \
1253  { \
1254  typedef mshadow::half::half_t DType; \
1255  {__VA_ARGS__} \
1256  } \
1257  break; \
1258  case mshadow::kUint8: \
1259  LOG(FATAL) << "This operation only support " \
1260  "floating point types not uint8"; \
1261  break; \
1262  case mshadow::kInt8: \
1263  LOG(FATAL) << "This operation only support " \
1264  "floating point types not int8"; \
1265  break; \
1266  case mshadow::kInt32: \
1267  LOG(FATAL) << "This operation only support " \
1268  "floating point types, not int32";\
1269  break; \
1270  case mshadow::kInt64: \
1271  LOG(FATAL) << "This operation only support " \
1272  "floating point types, not int64";\
1273  break; \
1274  default: \
1275  LOG(FATAL) << "Unknown type enum " << type; \
1276  }
1277 
1278 #ifndef __NVCC__
1279 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1280  switch (type$) { \
1281  case mshadow::kFloat32: \
1282  { \
1283  typedef float DType$; \
1284  typedef float DLargeType$; \
1285  {__VA_ARGS__} \
1286  } \
1287  break; \
1288  case mshadow::kFloat64: \
1289  { \
1290  typedef double DType$; \
1291  typedef double DLargeType$; \
1292  {__VA_ARGS__} \
1293  } \
1294  break; \
1295  case mshadow::kFloat16: \
1296  { \
1297  typedef mshadow::half::half_t DType$; \
1298  typedef float DLargeType$; \
1299  {__VA_ARGS__} \
1300  } \
1301  break; \
1302  case mshadow::kBfloat16: \
1303  { \
1304  typedef mshadow::bfloat::bf16_t DType$; \
1305  typedef float DLargeType$; \
1306  {__VA_ARGS__} \
1307  } \
1308  break; \
1309  case mshadow::kUint8: \
1310  LOG(FATAL) << "This operation only support " \
1311  "floating point types not uint8"; \
1312  break; \
1313  case mshadow::kInt8: \
1314  LOG(FATAL) << "This operation only support " \
1315  "floating point types not int8"; \
1316  break; \
1317  case mshadow::kInt32: \
1318  LOG(FATAL) << "This operation only support " \
1319  "floating point types, not int32";\
1320  break; \
1321  case mshadow::kInt64: \
1322  LOG(FATAL) << "This operation only support " \
1323  "floating point types, not int64";\
1324  break; \
1325  default: \
1326  LOG(FATAL) << "Unknown type enum " << type$; \
1327  }
1328 #else
1329 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1330  switch (type$) { \
1331  case mshadow::kFloat32: \
1332  { \
1333  typedef float DType$; \
1334  typedef float DLargeType$; \
1335  {__VA_ARGS__} \
1336  } \
1337  break; \
1338  case mshadow::kFloat64: \
1339  { \
1340  typedef double DType$; \
1341  typedef double DLargeType$; \
1342  {__VA_ARGS__} \
1343  } \
1344  break; \
1345  case mshadow::kFloat16: \
1346  { \
1347  typedef mshadow::half::half_t DType$; \
1348  typedef float DLargeType$; \
1349  {__VA_ARGS__} \
1350  } \
1351  break; \
1352  case mshadow::kUint8: \
1353  LOG(FATAL) << "This operation only support " \
1354  "floating point types not uint8"; \
1355  break; \
1356  case mshadow::kInt8: \
1357  LOG(FATAL) << "This operation only support " \
1358  "floating point types not int8"; \
1359  break; \
1360  case mshadow::kInt32: \
1361  LOG(FATAL) << "This operation only support " \
1362  "floating point types, not int32";\
1363  break; \
1364  case mshadow::kInt64: \
1365  LOG(FATAL) << "This operation only support " \
1366  "floating point types, not int64";\
1367  break; \
1368  default: \
1369  LOG(FATAL) << "Unknown type enum " << type$; \
1370  }
1371 #endif
1372 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
1373  switch (layout) { \
1374  case mshadow::kNCHW: \
1375  { \
1376  const int Layout = kNCHW; \
1377  {__VA_ARGS__} \
1378  } \
1379  break; \
1380  case mshadow::kNHWC: \
1381  { \
1382  const int Layout = kNHWC; \
1383  {__VA_ARGS__} \
1384  } \
1385  break; \
1386  case mshadow::kNCDHW: \
1387  { \
1388  const int Layout = kNCDHW; \
1389  {__VA_ARGS__} \
1390  } \
1391  break; \
1392  case mshadow::kNDHWC: \
1393  { \
1394  const int Layout = kNDHWC; \
1395  {__VA_ARGS__} \
1396  } \
1397  break; \
1398  default: \
1399  LOG(FATAL) << "Unknown layout enum " << layout; \
1400  }
1401 
1406 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \
1407  switch (type) { \
1408  case mshadow::kInt64: \
1409  { \
1410  typedef int64_t DType; \
1411  {__VA_ARGS__} \
1412  } \
1413  break; \
1414  default: \
1415  LOG(FATAL) << "Unknown type enum " << type; \
1416  }
1417 
1418 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
1419  switch (type) { \
1420  case mshadow::kFloat32: \
1421  { \
1422  typedef float DType; \
1423  {__VA_ARGS__} \
1424  } \
1425  break; \
1426  case mshadow::kFloat64: \
1427  { \
1428  typedef double DType; \
1429  {__VA_ARGS__} \
1430  } \
1431  break; \
1432  case mshadow::kFloat16: \
1433  { \
1434  typedef mshadow::half::half_t DType; \
1435  {__VA_ARGS__} \
1436  } \
1437  break; \
1438  case mshadow::kBfloat16: \
1439  { \
1440  typedef mshadow::bfloat::bf16_t DType; \
1441  {__VA_ARGS__} \
1442  } \
1443  break; \
1444  case mshadow::kUint8: \
1445  { \
1446  typedef uint8_t DType; \
1447  {__VA_ARGS__} \
1448  } \
1449  break; \
1450  case mshadow::kInt8: \
1451  { \
1452  typedef int8_t DType; \
1453  {__VA_ARGS__} \
1454  } \
1455  break; \
1456  case mshadow::kInt32: \
1457  { \
1458  typedef int32_t DType; \
1459  {__VA_ARGS__} \
1460  } \
1461  break; \
1462  case mshadow::kInt64: \
1463  { \
1464  typedef int64_t DType; \
1465  {__VA_ARGS__} \
1466  } \
1467  break; \
1468  case mshadow::kBool: \
1469  { \
1470  typedef bool DType; \
1471  {__VA_ARGS__} \
1472  } \
1473  break; \
1474  default: \
1475  LOG(FATAL) << "Unknown type enum " << type; \
1476  }
1477 
1479 inline size_t mshadow_sizeof(int type) {
1480  int size = 0;
1481  MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType););
1482  return size;
1483 }
1484 
1485 /*/ \brief get string with the type name from type enum */
1486 inline std::string dtype_string(const int dtype) {
1487  switch (dtype) {
1488  case mshadow::kFloat32:
1489  return "float";
1490  case mshadow::kFloat64:
1491  return "double";
1492  case mshadow::kFloat16:
1493  return "half";
1494  case mshadow::kUint8:
1495  return "unsigned char";
1496  case mshadow::kInt8:
1497  return "char";
1498  case mshadow::kInt32:
1499  return "int";
1500  case mshadow::kInt64:
1501  return "long long";
1502  case mshadow::kBool:
1503  return "bool";
1504  default:
1505  LOG(FATAL) << "Unknown type enum " << dtype;
1506  }
1507  return "unknown";
1508 }
1509 
1510 } // namespace mshadow
1511 #endif // MSHADOW_BASE_H_
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1047
#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType,...)
Definition: base.h:1418
Definition: base.h:359
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1053
const int default_type_flag
type enum value for default real type
Definition: base.h:484
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:994
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
Definition: base.h:496
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:847
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:646
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:799
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:626
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:770
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:640
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:740
definition of vector float16, half2 type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:946
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:960
MSHADOW_XINLINE uint32_t MaxValue< uint32_t >(void)
maximum value of uint32_t
Definition: base.h:867
save to saver: =
Definition: base.h:609
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:760
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:708
definition of bfloat type.
definition of half (float16) type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1000
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:976
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:371
divide operator
Definition: base.h:578
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:794
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:654
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:881
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:953
op::right OPType
corresponding binary operator type
Definition: base.h:620
op::minus OPType
corresponding binary operator type
Definition: base.h:648
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:581
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val)
Definition: base.h:698
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1014
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1022
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:969
Definition: base.h:488
Definition: base.h:367
identity function that maps a real number to it self
Definition: base.h:598
op::mul OPType
corresponding binary operator type
Definition: base.h:658
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:612
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:644
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:940
Definition: base.h:502
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:765
Definition: base.h:489
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:991
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val)
Definition: base.h:724
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:780
#define MSHADOW_XINLINE
Definition: base.h:230
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:906
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:664
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:336
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:827
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:682
const int default_layout
default layout for 4d tensor
Definition: base.h:525
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1067
LayoutFlag
Definition: base.h:487
get rhs
Definition: base.h:586
#define MSHADOW_BF16_MAX
Definition: bfloat.h:182
std::string dtype_string(const int dtype)
Definition: base.h:1486
minus to saver: -=
Definition: base.h:637
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:852
int32_t index_t
type that will be used for index
Definition: base.h:343
multiply to saver: *=
Definition: base.h:651
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:565
const float kPi
pi
Definition: base.h:338
Definition: base.h:364
op::plus OPType
corresponding binary operator type
Definition: base.h:634
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:822
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:355
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:548
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1029
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:775
minimum reducer
Definition: base.h:1019
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:745
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:368
Definition: base.h:366
Definition: base.h:360
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
divide to saver: /=
Definition: base.h:661
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:557
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)
Definition: base.h:307
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1044
Definition: base.h:370
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:876
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1479
Definition: base.h:363
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:981
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:804
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:601
Definition: base.h:497
Definition: base.h:369
maximum reducer
Definition: base.h:966
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:923
TypeFlag
data type flag
Definition: base.h:358
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:857
Definition: base.h:492
Definition: base.h:498
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:573
plus operator
Definition: base.h:562
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:616
Definition: base.h:362
save to saver: +=
Definition: base.h:623
Definition: base.h:494
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1039
sum reducer
Definition: base.h:898
Definition: base.h:493
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:918
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:632
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1007
MSHADOW_XINLINE unsigned int MinValue< unsigned int >(void)
minimum value of unsigned int
Definition: base.h:785
mul operator
Definition: base.h:554
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:370
#define MSHADOW_HALF_MAX
Definition: half.h:369
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:618
Definition: base.h:375
overloaded + operator between half_t and bf16_t
Definition: base.h:334
Definition: base.h:368
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1060
Definition: base.h:361
Definition: base.h:371
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:862
Definition: base.h:365
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:901
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:842
Definition: base.h:490
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:589
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1034
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:986
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:937
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:630
index_t openmp_index_t
openmp index for linux
Definition: base.h:351
minus operator
Definition: base.h:570
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:886
op::div OPType
corresponding binary operator type
Definition: base.h:668
#define MSHADOW_BF16_MIN
overloaded + operator for bf16_t
Definition: bfloat.h:181