24 #ifndef MXNET_COMMON_CUDA_UTILS_H_ 25 #define MXNET_COMMON_CUDA_UTILS_H_ 27 #include <dmlc/logging.h> 28 #include <dmlc/parameter.h> 29 #include <dmlc/optional.h> 30 #include <mshadow/base.h> 33 #ifdef __JETBRAINS_IDE__ 38 #define __forceinline__ 40 inline void __syncthreads() {}
41 inline void __threadfence_block() {}
42 template<
class T>
inline T __clz(
const T val) {
return val; }
43 struct __cuda_fake_struct {
int x;
int y;
int z; };
44 extern __cuda_fake_struct blockDim;
45 extern __cuda_fake_struct threadIdx;
46 extern __cuda_fake_struct blockIdx;
51 #include <cuda_runtime.h> 52 #include <cublas_v2.h> 66 case CUBLAS_STATUS_SUCCESS:
67 return "CUBLAS_STATUS_SUCCESS";
68 case CUBLAS_STATUS_NOT_INITIALIZED:
69 return "CUBLAS_STATUS_NOT_INITIALIZED";
70 case CUBLAS_STATUS_ALLOC_FAILED:
71 return "CUBLAS_STATUS_ALLOC_FAILED";
72 case CUBLAS_STATUS_INVALID_VALUE:
73 return "CUBLAS_STATUS_INVALID_VALUE";
74 case CUBLAS_STATUS_ARCH_MISMATCH:
75 return "CUBLAS_STATUS_ARCH_MISMATCH";
76 case CUBLAS_STATUS_MAPPING_ERROR:
77 return "CUBLAS_STATUS_MAPPING_ERROR";
78 case CUBLAS_STATUS_EXECUTION_FAILED:
79 return "CUBLAS_STATUS_EXECUTION_FAILED";
80 case CUBLAS_STATUS_INTERNAL_ERROR:
81 return "CUBLAS_STATUS_INTERNAL_ERROR";
82 case CUBLAS_STATUS_NOT_SUPPORTED:
83 return "CUBLAS_STATUS_NOT_SUPPORTED";
87 return "Unknown cuBLAS status";
97 case CUSOLVER_STATUS_SUCCESS:
98 return "CUSOLVER_STATUS_SUCCESS";
99 case CUSOLVER_STATUS_NOT_INITIALIZED:
100 return "CUSOLVER_STATUS_NOT_INITIALIZED";
101 case CUSOLVER_STATUS_ALLOC_FAILED:
102 return "CUSOLVER_STATUS_ALLOC_FAILED";
103 case CUSOLVER_STATUS_INVALID_VALUE:
104 return "CUSOLVER_STATUS_INVALID_VALUE";
105 case CUSOLVER_STATUS_ARCH_MISMATCH:
106 return "CUSOLVER_STATUS_ARCH_MISMATCH";
107 case CUSOLVER_STATUS_EXECUTION_FAILED:
108 return "CUSOLVER_STATUS_EXECUTION_FAILED";
109 case CUSOLVER_STATUS_INTERNAL_ERROR:
110 return "CUSOLVER_STATUS_INTERNAL_ERROR";
111 case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
112 return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
116 return "Unknown cuSOLVER status";
126 case CURAND_STATUS_SUCCESS:
127 return "CURAND_STATUS_SUCCESS";
128 case CURAND_STATUS_VERSION_MISMATCH:
129 return "CURAND_STATUS_VERSION_MISMATCH";
130 case CURAND_STATUS_NOT_INITIALIZED:
131 return "CURAND_STATUS_NOT_INITIALIZED";
132 case CURAND_STATUS_ALLOCATION_FAILED:
133 return "CURAND_STATUS_ALLOCATION_FAILED";
134 case CURAND_STATUS_TYPE_ERROR:
135 return "CURAND_STATUS_TYPE_ERROR";
136 case CURAND_STATUS_OUT_OF_RANGE:
137 return "CURAND_STATUS_OUT_OF_RANGE";
138 case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
139 return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
140 case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
141 return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
142 case CURAND_STATUS_LAUNCH_FAILURE:
143 return "CURAND_STATUS_LAUNCH_FAILURE";
144 case CURAND_STATUS_PREEXISTING_FAILURE:
145 return "CURAND_STATUS_PREEXISTING_FAILURE";
146 case CURAND_STATUS_INITIALIZATION_FAILED:
147 return "CURAND_STATUS_INITIALIZATION_FAILED";
148 case CURAND_STATUS_ARCH_MISMATCH:
149 return "CURAND_STATUS_ARCH_MISMATCH";
150 case CURAND_STATUS_INTERNAL_ERROR:
151 return "CURAND_STATUS_INTERNAL_ERROR";
153 return "Unknown cuRAND status";
156 template <
typename DType>
157 inline DType __device__
CudaMax(DType a, DType b) {
158 return a > b ? a : b;
161 template <
typename DType>
162 inline DType __device__
CudaMin(DType a, DType b) {
163 return a < b ? a : b;
174 #define CHECK_CUDA_ERROR(msg) \ 176 cudaError_t e = cudaGetLastError(); \ 177 CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ 186 #define CUDA_CALL(func) \ 188 cudaError_t e = (func); \ 189 CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ 190 << "CUDA: " << cudaGetErrorString(e); \ 199 #define CUBLAS_CALL(func) \ 201 cublasStatus_t e = (func); \ 202 CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ 203 << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \ 212 #define CUSOLVER_CALL(func) \ 214 cusolverStatus_t e = (func); \ 215 CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ 216 << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \ 225 #define CURAND_CALL(func) \ 227 curandStatus_t e = (func); \ 228 CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ 229 << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \ 238 #define NVRTC_CALL(x) \ 240 nvrtcResult result = x; \ 241 CHECK_EQ(result, NVRTC_SUCCESS) \ 242 << #x " failed with error " \ 243 << nvrtcGetErrorString(result); \ 252 #define CUDA_DRIVER_CALL(func) \ 254 CUresult e = (func); \ 255 if (e != CUDA_SUCCESS) { \ 256 char const * err_msg = nullptr; \ 257 if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ 258 LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ 260 LOG(FATAL) << "CUDA Driver: " << err_msg; \ 266 #if !defined(_MSC_VER) 267 #define CUDA_UNROLL _Pragma("unroll") 268 #define CUDA_NOUNROLL _Pragma("nounroll") 271 #define CUDA_NOUNROLL 282 cudaDevAttrComputeCapabilityMajor, device_id));
294 cudaDevAttrComputeCapabilityMinor, device_id));
306 return 10 * major + minor;
321 return (computeCapabilityMajor > 5) ||
334 return device_id >= 0 &&
339 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true 349 static bool allow_tensor_core =
false;
350 static bool is_set =
false;
354 allow_tensor_core = dmlc::GetEnv(
"MXNET_CUDA_ALLOW_TENSOR_CORE",
355 dmlc::optional<bool>(default_value)).value();
358 return allow_tensor_core;
361 #if CUDA_VERSION >= 9000 363 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
364 auto handle_math_mode = CUBLAS_DEFAULT_MATH;
365 CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
366 CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
367 return handle_math_mode;
371 #endif // MXNET_USE_CUDA 377 #define CUDNN_CALL(func) \ 379 cudnnStatus_t e = (func); \ 380 CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ 390 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
393 CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
407 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
410 CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
424 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
427 CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
434 #endif // MXNET_USE_CUDNN 437 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 439 static inline __device__
void atomicAdd(
double *address,
double val) {
440 unsigned long long* address_as_ull =
441 reinterpret_cast<unsigned long long*
>(address);
442 unsigned long long old = *address_as_ull;
443 unsigned long long assumed;
447 old = atomicCAS(address_as_ull, assumed,
448 __double_as_longlong(val +
449 __longlong_as_double(assumed)));
452 }
while (assumed != old);
459 #if defined(__CUDA_ARCH__) 460 static inline __device__
void atomicAdd(mshadow::half::half_t *address,
461 mshadow::half::half_t val) {
462 unsigned int *address_as_ui =
463 reinterpret_cast<unsigned int *
>(
reinterpret_cast<char *
>(address) -
464 (reinterpret_cast<size_t>(address) & 2));
465 unsigned int old = *address_as_ui;
466 unsigned int assumed;
470 mshadow::half::half_t hsum;
472 reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
474 old =
reinterpret_cast<size_t>(address) & 2
475 ? (old & 0xffff) | (hsum.half_ << 16)
476 : (old & 0xffff0000) | hsum.half_;
477 old = atomicCAS(address_as_ui, assumed, old);
478 }
while (assumed != old);
481 template <
typename DType>
482 __device__
inline DType ldg(
const DType* address) {
483 #if __CUDA_ARCH__ >= 350 484 return __ldg(address);
491 #endif // MXNET_COMMON_CUDA_UTILS_H_ #define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:199
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu's cuda compute architecture.
Definition: cuda_utils.h:279
namespace of mxnet
Definition: base.h:126
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:345
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:303
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:162
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu's architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:315
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:157
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu's architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:332
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:95
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:339
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:124
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu's cuda compute architecture.
Definition: cuda_utils.h:291
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:186
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:64