26 #ifndef MSHADOW_BASE_H_ 27 #define MSHADOW_BASE_H_ 29 #ifndef _CRT_SECURE_NO_WARNINGS 30 #define _CRT_SECURE_NO_WARNINGS 32 #ifndef _CRT_SECURE_NO_DEPRECATE 33 #define _CRT_SECURE_NO_DEPRECATE 49 typedef signed char int8_t;
51 typedef __int16 int16_t;
52 typedef __int32 int32_t;
53 typedef __int64 int64_t;
54 typedef unsigned char uint8_t;
55 typedef unsigned __int16 uint16_t;
56 typedef unsigned __int32 uint32_t;
57 typedef unsigned __int64 uint64_t;
67 #ifndef MSHADOW_STAND_ALONE 68 #define MSHADOW_STAND_ALONE 0 71 #ifndef MSHADOW_ALLOC_PAD 72 #define MSHADOW_ALLOC_PAD true 82 #ifndef MSHADOW_MIN_PAD_RATIO 83 #define MSHADOW_MIN_PAD_RATIO 2 86 #if MSHADOW_STAND_ALONE 87 #define MSHADOW_USE_CBLAS 0 88 #define MSHADOW_USE_MKL 0 89 #define MSHADOW_USE_CUDA 0 96 #ifndef MSHADOW_FORCE_STREAM 97 #define MSHADOW_FORCE_STREAM 1 101 #ifndef MSHADOW_USE_CBLAS 102 #define MSHADOW_USE_CBLAS 0 105 #ifndef MSHADOW_USE_MKL 106 #define MSHADOW_USE_MKL 1 109 #ifndef MSHADOW_USE_ARMPL 110 #define MSHADOW_USE_ARMPL 0 117 #ifndef MSHADOW_USE_CUDA 118 #define MSHADOW_USE_CUDA 1 124 #ifndef MSHADOW_USE_CUDNN 125 #define MSHADOW_USE_CUDNN 0 131 #ifndef MSHADOW_USE_CUSOLVER 132 #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA 139 #ifndef MSHADOW_OLD_CUDA 140 #define MSHADOW_OLD_CUDA 0 146 #ifndef MSHADOW_IN_CXX11 147 #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ 148 __cplusplus >= 201103L || defined(_MSC_VER)) 149 #define MSHADOW_IN_CXX11 1 151 #define MSHADOW_IN_CXX11 0 156 #ifndef MSHADOW_USE_SSE 157 #define MSHADOW_USE_SSE 1 161 #ifndef MSHADOW_USE_F16C 162 #if defined(_MSC_VER) || defined(__CUDACC__) 163 #define MSHADOW_USE_F16C 0 164 #elif defined(__clang__) && \ 165 ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1))) 166 #define MSHADOW_USE_F16C 0 168 #define MSHADOW_USE_F16C 1 173 #ifndef MSHADOW_USE_NVML 174 #define MSHADOW_USE_NVML 0 178 #undef MSHADOW_USE_SSE 179 #define MSHADOW_USE_SSE 0 182 #if MSHADOW_USE_CBLAS 184 #if MSHADOW_USE_ARMPL 185 #define armpl_singlecomplex_t float _Complex 186 #define armpl_doublecomplex_t double _Complex 190 #elif MSHADOW_USE_MKL 191 #include <mkl_blas.h> 192 #include <mkl_cblas.h> 194 #include <mkl_vsl_functions.h> 195 #include <mkl_version.h> 200 #include <cublas_v2.h> 204 #if MSHADOW_USE_CUDNN == 1 208 #if MSHADOW_USE_CUSOLVER == 1 209 #include <cusolverDn.h> 218 #ifdef MSHADOW_XINLINE 219 #error "MSHADOW_XINLINE must not be defined" 222 #define MSHADOW_FORCE_INLINE __forceinline 223 #pragma warning(disable : 4068) 225 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) 228 #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ 230 #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE 233 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE 235 #if defined(__GXX_EXPERIMENTAL_CXX0X) ||\ 236 defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L 237 #define MSHADOW_CONSTEXPR constexpr 239 #define MSHADOW_CONSTEXPR const 248 #ifndef MSHADOW_DEFAULT_DTYPE 249 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t 255 #ifndef MSHADOW_USE_GLOG 256 #define MSHADOW_USE_GLOG DMLC_USE_GLOG 257 #endif // MSHADOW_USE_GLOG 260 #define MSHADOW_THROW_EXCEPTION noexcept(false) 261 #define MSHADOW_NO_EXCEPTION noexcept(true) 263 #define MSHADOW_THROW_EXCEPTION 264 #define MSHADOW_NO_EXCEPTION 267 #if defined(_MSC_VER) 268 #define MSHADOW_ALIGNED(x) __declspec(align(x)) 270 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x))) 278 #define MSHADOW_CUDA_CALL(func) \ 280 cudaError_t e = (func); \ 281 if (e == cudaErrorCudartUnloading) { \ 282 throw dmlc::Error(cudaGetErrorString(e)); \ 284 CHECK(e == cudaSuccess) \ 285 << "CUDA: " << cudaGetErrorString(e); \ 292 #define MSHADOW_CATCH_ERROR(func) \ 296 } catch (const dmlc::Error &e) { \ 297 std::string what = e.what(); \ 298 if (what.find("driver shutting down") == std::string::npos) { \ 299 LOG(ERROR) << "Ignore CUDA Error " << what; \ 307 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \ 308 MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \ 309 return float(a) OP float(b); \ 311 MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \ 312 return float(a) OP float(b); \ 332 #include "./logging.h" 338 const float kPi = 3.1415926f;
340 #if MSHADOW_INT64_TENSOR_SIZE == 1 374 template<
typename DType>
380 static const int kLanes = 1;
382 #if (CUDA_VERSION >= 8000) 383 static const cudaDataType_t kCudaFlag = CUDA_R_32F;
385 #if MSHADOW_USE_CUDNN 386 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
387 typedef float ScaleType;
394 static const int kLanes = 1;
396 #if (CUDA_VERSION >= 8000) 397 static const cudaDataType_t kCudaFlag = CUDA_R_64F;
399 #if MSHADOW_USE_CUDNN 400 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
401 typedef double ScaleType;
408 static const int kLanes = 1;
410 #if (CUDA_VERSION >= 8000) 411 static const cudaDataType_t kCudaFlag = CUDA_R_16F;
413 #if MSHADOW_USE_CUDNN 414 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
415 typedef float ScaleType;
422 static const int kLanes = 2;
427 static const int kLanes = 1;
432 static const int kLanes = 1;
434 #if (CUDA_VERSION >= 8000) 435 static const cudaDataType_t kCudaFlag = CUDA_R_8U;
437 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) 439 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
440 typedef uint8_t ScaleType;
447 static const int kLanes = 1;
449 #if (CUDA_VERSION >= 8000) 450 static const cudaDataType_t kCudaFlag = CUDA_R_8I;
452 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) 453 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
454 typedef int8_t ScaleType;
461 static const int kLanes = 1;
463 #if (CUDA_VERSION >= 8000) 464 static const cudaDataType_t kCudaFlag = CUDA_R_32I;
466 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) 467 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
468 typedef int32_t ScaleType;
475 static const int kLanes = 1;
480 static const int kLanes = 1;
506 static const index_t kNdim = 4;
507 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 508 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
510 static const int kCudnnFlag = -1;
516 static const index_t kNdim = 4;
517 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 518 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
520 static const int kCudnnFlag = -1;
529 static const index_t kNdim = 5;
530 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 531 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
533 static const int kCudnnFlag = -1;
539 static const index_t kNdim = 5;
540 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 541 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
543 static const int kCudnnFlag = -1;
556 template<
typename DType>
564 template<
typename DType>
572 template<
typename DType>
580 template<
typename DType>
588 template<
typename DType>
600 template<
typename DType>
611 template<
typename DType>
616 inline static default_real_t
AlphaBLAS(
void) {
return 1.0f; }
618 inline static default_real_t
BetaBLAS(
void) {
return 0.0f; }
625 template<
typename DType>
630 inline static default_real_t
AlphaBLAS(
void) {
return 1.0f; }
632 inline static default_real_t
BetaBLAS(
void) {
return 1.0f; }
639 template<
typename DType>
644 inline static default_real_t
AlphaBLAS(
void) {
return -1.0f; }
646 inline static default_real_t
BetaBLAS(
void) {
return 1.0f; }
653 template<
typename DType>
663 template<
typename DType>
672 #ifndef __CUDA_ARCH__ 680 namespace isnan_typed {
681 template<
typename DType>
706 namespace isinf_typed {
707 template<
typename DType>
736 template<
typename DType>
793 template<
typename DType>
795 return MinValue<DType>();
810 return half::half_t::Binary(
818 template<
typename DType>
875 template<
typename DType>
877 return MaxValue<DType>();
900 template<
typename DType>
905 template<
typename DType>
907 DType y = src - residual;
912 residual = (t - dst) - y;
917 template<
typename DType>
919 Reduce(dst_val, src_val);
922 template<
typename DType>
923 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
924 DType t1 = dst_val + src_val;
929 DType e = t1 - dst_val;
930 DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
932 dst_residual = t2 - (dst_val - t1);
936 template<
typename DType>
939 template<
typename DType>
945 template<
typename DType>
952 template<
typename DType>
959 template<
typename DType>
968 template<
typename DType>
971 if (!(dst >= src)) dst = src;
975 template<
typename DType>
980 template<
typename DType>
982 Reduce(dst_val, src_val);
985 template<
typename DType>
986 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
987 Reduce(dst_val, src_val);
990 template<
typename DType>
993 template<
typename DType>
999 template<
typename DType>
1001 return redres == redsrc ? 1: 0;
1006 template<
typename DType>
1008 initv = limits::NegInfValue<DType>();
1013 template<
typename DType>
1015 SetInitValue(initv);
1021 template<
typename DType>
1024 if (!(dst <= src)) dst = src;
1028 template<
typename DType>
1033 template<
typename DType>
1035 Reduce(dst_val, src_val);
1038 template<
typename DType>
1039 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
1040 Reduce(dst_val, src_val);
1043 template<
typename DType>
1046 template<
typename DType>
1052 template<
typename DType>
1054 return redres == redsrc ? 1: 0;
1059 template<
typename DType>
1061 initv = limits::PosInfValue<DType>();
1066 template<
typename DType>
1068 SetInitValue(initv);
1074 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \ 1076 case mshadow::kFloat32: \ 1078 typedef float DType; \ 1082 case mshadow::kFloat64: \ 1084 typedef double DType; \ 1088 case mshadow::kFloat16: \ 1090 typedef mshadow::half::half_t DType; \ 1094 case mshadow::kBfloat16: \ 1096 typedef mshadow::bfloat::bf16_t DType; \ 1100 case mshadow::kUint8: \ 1102 typedef uint8_t DType; \ 1106 case mshadow::kInt8: \ 1108 typedef int8_t DType; \ 1112 case mshadow::kInt32: \ 1114 typedef int32_t DType; \ 1118 case mshadow::kInt64: \ 1120 typedef int64_t DType; \ 1125 LOG(FATAL) << "Unknown type enum " << type; \ 1128 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \ 1130 case mshadow::kFloat32: \ 1132 typedef float DType; \ 1136 case mshadow::kFloat64: \ 1138 typedef double DType; \ 1142 case mshadow::kFloat16: \ 1144 typedef mshadow::half::half_t DType; \ 1148 case mshadow::kUint8: \ 1150 typedef uint8_t DType; \ 1154 case mshadow::kInt8: \ 1156 typedef int8_t DType; \ 1160 case mshadow::kInt32: \ 1162 typedef int32_t DType; \ 1166 case mshadow::kInt64: \ 1168 typedef int64_t DType; \ 1173 LOG(FATAL) << "Unknown type enum " << type; \ 1177 #define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ 1179 case mshadow::kFloat32: \ 1181 typedef float DType; \ 1185 case mshadow::kFloat64: \ 1187 typedef double DType; \ 1191 case mshadow::kFloat16: \ 1193 typedef mshadow::half::half2_t DType; \ 1197 case mshadow::kUint8: \ 1199 typedef uint8_t DType; \ 1203 case mshadow::kInt32: \ 1205 typedef int32_t DType; \ 1209 case mshadow::kInt64: \ 1211 typedef int64_t DType; \ 1216 LOG(FATAL) << "Unknown type enum " << type; \ 1219 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \ 1221 case mshadow::kFloat32: \ 1223 typedef float DType; \ 1227 case mshadow::kFloat64: \ 1229 typedef double DType; \ 1234 LOG(FATAL) << "This operation only supports " \ 1235 "32-bit and 64-bit floating point"; \ 1238 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \ 1240 case mshadow::kFloat32: \ 1242 typedef float DType; \ 1246 case mshadow::kFloat64: \ 1248 typedef double DType; \ 1252 case mshadow::kFloat16: \ 1254 typedef mshadow::half::half_t DType; \ 1258 case mshadow::kUint8: \ 1259 LOG(FATAL) << "This operation only support " \ 1260 "floating point types not uint8"; \ 1262 case mshadow::kInt8: \ 1263 LOG(FATAL) << "This operation only support " \ 1264 "floating point types not int8"; \ 1266 case mshadow::kInt32: \ 1267 LOG(FATAL) << "This operation only support " \ 1268 "floating point types, not int32";\ 1270 case mshadow::kInt64: \ 1271 LOG(FATAL) << "This operation only support " \ 1272 "floating point types, not int64";\ 1275 LOG(FATAL) << "Unknown type enum " << type; \ 1279 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ 1281 case mshadow::kFloat32: \ 1283 typedef float DType$; \ 1284 typedef float DLargeType$; \ 1288 case mshadow::kFloat64: \ 1290 typedef double DType$; \ 1291 typedef double DLargeType$; \ 1295 case mshadow::kFloat16: \ 1297 typedef mshadow::half::half_t DType$; \ 1298 typedef float DLargeType$; \ 1302 case mshadow::kBfloat16: \ 1304 typedef mshadow::bfloat::bf16_t DType$; \ 1305 typedef float DLargeType$; \ 1309 case mshadow::kUint8: \ 1310 LOG(FATAL) << "This operation only support " \ 1311 "floating point types not uint8"; \ 1313 case mshadow::kInt8: \ 1314 LOG(FATAL) << "This operation only support " \ 1315 "floating point types not int8"; \ 1317 case mshadow::kInt32: \ 1318 LOG(FATAL) << "This operation only support " \ 1319 "floating point types, not int32";\ 1321 case mshadow::kInt64: \ 1322 LOG(FATAL) << "This operation only support " \ 1323 "floating point types, not int64";\ 1326 LOG(FATAL) << "Unknown type enum " << type$; \ 1329 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ 1331 case mshadow::kFloat32: \ 1333 typedef float DType$; \ 1334 typedef float DLargeType$; \ 1338 case mshadow::kFloat64: \ 1340 typedef double DType$; \ 1341 typedef double DLargeType$; \ 1345 case mshadow::kFloat16: \ 1347 typedef mshadow::half::half_t DType$; \ 1348 typedef float DLargeType$; \ 1352 case mshadow::kUint8: \ 1353 LOG(FATAL) << "This operation only support " \ 1354 "floating point types not uint8"; \ 1356 case mshadow::kInt8: \ 1357 LOG(FATAL) << "This operation only support " \ 1358 "floating point types not int8"; \ 1360 case mshadow::kInt32: \ 1361 LOG(FATAL) << "This operation only support " \ 1362 "floating point types, not int32";\ 1364 case mshadow::kInt64: \ 1365 LOG(FATAL) << "This operation only support " \ 1366 "floating point types, not int64";\ 1369 LOG(FATAL) << "Unknown type enum " << type$; \ 1372 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ 1374 case mshadow::kNCHW: \ 1376 const int Layout = kNCHW; \ 1380 case mshadow::kNHWC: \ 1382 const int Layout = kNHWC; \ 1386 case mshadow::kNCDHW: \ 1388 const int Layout = kNCDHW; \ 1392 case mshadow::kNDHWC: \ 1394 const int Layout = kNDHWC; \ 1399 LOG(FATAL) << "Unknown layout enum " << layout; \ 1406 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \ 1408 case mshadow::kInt64: \ 1410 typedef int64_t DType; \ 1415 LOG(FATAL) << "Unknown type enum " << type; \ 1418 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \ 1420 case mshadow::kFloat32: \ 1422 typedef float DType; \ 1426 case mshadow::kFloat64: \ 1428 typedef double DType; \ 1432 case mshadow::kFloat16: \ 1434 typedef mshadow::half::half_t DType; \ 1438 case mshadow::kBfloat16: \ 1440 typedef mshadow::bfloat::bf16_t DType; \ 1444 case mshadow::kUint8: \ 1446 typedef uint8_t DType; \ 1450 case mshadow::kInt8: \ 1452 typedef int8_t DType; \ 1456 case mshadow::kInt32: \ 1458 typedef int32_t DType; \ 1462 case mshadow::kInt64: \ 1464 typedef int64_t DType; \ 1468 case mshadow::kBool: \ 1470 typedef bool DType; \ 1475 LOG(FATAL) << "Unknown type enum " << type; \ 1495 return "unsigned char";
1505 LOG(FATAL) <<
"Unknown type enum " << dtype;
1511 #endif // MSHADOW_BASE_H_ static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1047
#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType,...)
Definition: base.h:1418
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1053
const int default_type_flag
type enum value for default real type
Definition: base.h:484
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:994
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:847
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:646
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:799
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:626
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:770
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:640
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:740
definition of vector float16, half2 type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:946
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:960
MSHADOW_XINLINE uint32_t MaxValue< uint32_t >(void)
maximum value of uint32_t
Definition: base.h:867
save to saver: =
Definition: base.h:609
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:760
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:708
definition of bfloat type.
definition of half (float16) type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1000
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:976
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:371
divide operator
Definition: base.h:578
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:794
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:654
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:881
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:953
op::right OPType
corresponding binary operator type
Definition: base.h:620
op::minus OPType
corresponding binary operator type
Definition: base.h:648
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:581
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val)
Definition: base.h:698
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1014
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1022
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:969
identity function that maps a real number to it self
Definition: base.h:598
op::mul OPType
corresponding binary operator type
Definition: base.h:658
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:612
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:644
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:940
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:765
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:991
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val)
Definition: base.h:724
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:780
#define MSHADOW_XINLINE
Definition: base.h:230
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:906
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:664
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:336
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:827
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:682
const int default_layout
default layout for 4d tensor
Definition: base.h:525
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1067
LayoutFlag
Definition: base.h:487
get rhs
Definition: base.h:586
#define MSHADOW_BF16_MAX
Definition: bfloat.h:182
std::string dtype_string(const int dtype)
Definition: base.h:1486
minus to saver: -=
Definition: base.h:637
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:852
int32_t index_t
type that will be used for index
Definition: base.h:343
multiply to saver: *=
Definition: base.h:651
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:565
const float kPi
pi
Definition: base.h:338
op::plus OPType
corresponding binary operator type
Definition: base.h:634
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:822
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:355
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:548
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1029
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:775
minimum reducer
Definition: base.h:1019
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:745
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:368
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
divide to saver: /=
Definition: base.h:661
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:557
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)
Definition: base.h:307
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1044
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:876
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1479
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:981
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:804
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:601
maximum reducer
Definition: base.h:966
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:923
TypeFlag
data type flag
Definition: base.h:358
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:857
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:573
plus operator
Definition: base.h:562
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:616
save to saver: +=
Definition: base.h:623
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1039
sum reducer
Definition: base.h:898
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:918
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:632
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1007
MSHADOW_XINLINE unsigned int MinValue< unsigned int >(void)
minimum value of unsigned int
Definition: base.h:785
mul operator
Definition: base.h:554
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:370
#define MSHADOW_HALF_MAX
Definition: half.h:369
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:618
overloaded + operator between half_t and bf16_t
Definition: base.h:334
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1060
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:862
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:901
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:842
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:589
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1034
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:986
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:937
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:630
index_t openmp_index_t
openmp index for linux
Definition: base.h:351
minus operator
Definition: base.h:570
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:886
op::div OPType
corresponding binary operator type
Definition: base.h:668
#define MSHADOW_BF16_MIN
overloaded + operator for bf16_t
Definition: bfloat.h:181