mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_COMMON_CUDA_UTILS_H_
25 #define MXNET_COMMON_CUDA_UTILS_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <dmlc/optional.h>
30 #include <mshadow/base.h>
31 
33 #ifdef __JETBRAINS_IDE__
34 #define __CUDACC__ 1
35 #define __host__
36 #define __device__
37 #define __global__
38 #define __forceinline__
39 #define __shared__
40 inline void __syncthreads() {}
41 inline void __threadfence_block() {}
42 template<class T> inline T __clz(const T val) { return val; }
43 struct __cuda_fake_struct { int x; int y; int z; };
44 extern __cuda_fake_struct blockDim;
45 extern __cuda_fake_struct threadIdx;
46 extern __cuda_fake_struct blockIdx;
47 #endif
48 
49 #if MXNET_USE_CUDA
50 
51 #include <cuda_runtime.h>
52 #include <cublas_v2.h>
53 #include <curand.h>
54 
55 namespace mxnet {
56 namespace common {
58 namespace cuda {
64 inline const char* CublasGetErrorString(cublasStatus_t error) {
65  switch (error) {
66  case CUBLAS_STATUS_SUCCESS:
67  return "CUBLAS_STATUS_SUCCESS";
68  case CUBLAS_STATUS_NOT_INITIALIZED:
69  return "CUBLAS_STATUS_NOT_INITIALIZED";
70  case CUBLAS_STATUS_ALLOC_FAILED:
71  return "CUBLAS_STATUS_ALLOC_FAILED";
72  case CUBLAS_STATUS_INVALID_VALUE:
73  return "CUBLAS_STATUS_INVALID_VALUE";
74  case CUBLAS_STATUS_ARCH_MISMATCH:
75  return "CUBLAS_STATUS_ARCH_MISMATCH";
76  case CUBLAS_STATUS_MAPPING_ERROR:
77  return "CUBLAS_STATUS_MAPPING_ERROR";
78  case CUBLAS_STATUS_EXECUTION_FAILED:
79  return "CUBLAS_STATUS_EXECUTION_FAILED";
80  case CUBLAS_STATUS_INTERNAL_ERROR:
81  return "CUBLAS_STATUS_INTERNAL_ERROR";
82  case CUBLAS_STATUS_NOT_SUPPORTED:
83  return "CUBLAS_STATUS_NOT_SUPPORTED";
84  default:
85  break;
86  }
87  return "Unknown cuBLAS status";
88 }
89 
95 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
96  switch (error) {
97  case CUSOLVER_STATUS_SUCCESS:
98  return "CUSOLVER_STATUS_SUCCESS";
99  case CUSOLVER_STATUS_NOT_INITIALIZED:
100  return "CUSOLVER_STATUS_NOT_INITIALIZED";
101  case CUSOLVER_STATUS_ALLOC_FAILED:
102  return "CUSOLVER_STATUS_ALLOC_FAILED";
103  case CUSOLVER_STATUS_INVALID_VALUE:
104  return "CUSOLVER_STATUS_INVALID_VALUE";
105  case CUSOLVER_STATUS_ARCH_MISMATCH:
106  return "CUSOLVER_STATUS_ARCH_MISMATCH";
107  case CUSOLVER_STATUS_EXECUTION_FAILED:
108  return "CUSOLVER_STATUS_EXECUTION_FAILED";
109  case CUSOLVER_STATUS_INTERNAL_ERROR:
110  return "CUSOLVER_STATUS_INTERNAL_ERROR";
111  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
112  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
113  default:
114  break;
115  }
116  return "Unknown cuSOLVER status";
117 }
118 
124 inline const char* CurandGetErrorString(curandStatus_t status) {
125  switch (status) {
126  case CURAND_STATUS_SUCCESS:
127  return "CURAND_STATUS_SUCCESS";
128  case CURAND_STATUS_VERSION_MISMATCH:
129  return "CURAND_STATUS_VERSION_MISMATCH";
130  case CURAND_STATUS_NOT_INITIALIZED:
131  return "CURAND_STATUS_NOT_INITIALIZED";
132  case CURAND_STATUS_ALLOCATION_FAILED:
133  return "CURAND_STATUS_ALLOCATION_FAILED";
134  case CURAND_STATUS_TYPE_ERROR:
135  return "CURAND_STATUS_TYPE_ERROR";
136  case CURAND_STATUS_OUT_OF_RANGE:
137  return "CURAND_STATUS_OUT_OF_RANGE";
138  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
139  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
140  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
141  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
142  case CURAND_STATUS_LAUNCH_FAILURE:
143  return "CURAND_STATUS_LAUNCH_FAILURE";
144  case CURAND_STATUS_PREEXISTING_FAILURE:
145  return "CURAND_STATUS_PREEXISTING_FAILURE";
146  case CURAND_STATUS_INITIALIZATION_FAILED:
147  return "CURAND_STATUS_INITIALIZATION_FAILED";
148  case CURAND_STATUS_ARCH_MISMATCH:
149  return "CURAND_STATUS_ARCH_MISMATCH";
150  case CURAND_STATUS_INTERNAL_ERROR:
151  return "CURAND_STATUS_INTERNAL_ERROR";
152  }
153  return "Unknown cuRAND status";
154 }
155 
156 template <typename DType>
157 inline DType __device__ CudaMax(DType a, DType b) {
158  return a > b ? a : b;
159 }
160 
161 template <typename DType>
162 inline DType __device__ CudaMin(DType a, DType b) {
163  return a < b ? a : b;
164 }
165 
166 } // namespace cuda
167 } // namespace common
168 } // namespace mxnet
169 
174 #define CHECK_CUDA_ERROR(msg) \
175  { \
176  cudaError_t e = cudaGetLastError(); \
177  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
178  }
179 
186 #define CUDA_CALL(func) \
187  { \
188  cudaError_t e = (func); \
189  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
190  << "CUDA: " << cudaGetErrorString(e); \
191  }
192 
199 #define CUBLAS_CALL(func) \
200  { \
201  cublasStatus_t e = (func); \
202  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
203  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
204  }
205 
212 #define CUSOLVER_CALL(func) \
213  { \
214  cusolverStatus_t e = (func); \
215  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
216  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
217  }
218 
225 #define CURAND_CALL(func) \
226  { \
227  curandStatus_t e = (func); \
228  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
229  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
230  }
231 
238 #define NVRTC_CALL(x) \
239  { \
240  nvrtcResult result = x; \
241  CHECK_EQ(result, NVRTC_SUCCESS) \
242  << #x " failed with error " \
243  << nvrtcGetErrorString(result); \
244  }
245 
252 #define CUDA_DRIVER_CALL(func) \
253  { \
254  CUresult e = (func); \
255  if (e != CUDA_SUCCESS) { \
256  char const * err_msg = nullptr; \
257  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
258  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
259  } else { \
260  LOG(FATAL) << "CUDA Driver: " << err_msg; \
261  } \
262  } \
263  }
264 
265 
266 #if !defined(_MSC_VER)
267 #define CUDA_UNROLL _Pragma("unroll")
268 #define CUDA_NOUNROLL _Pragma("nounroll")
269 #else
270 #define CUDA_UNROLL
271 #define CUDA_NOUNROLL
272 #endif
273 
279 inline int ComputeCapabilityMajor(int device_id) {
280  int major = 0;
281  CUDA_CALL(cudaDeviceGetAttribute(&major,
282  cudaDevAttrComputeCapabilityMajor, device_id));
283  return major;
284 }
285 
291 inline int ComputeCapabilityMinor(int device_id) {
292  int minor = 0;
293  CUDA_CALL(cudaDeviceGetAttribute(&minor,
294  cudaDevAttrComputeCapabilityMinor, device_id));
295  return minor;
296 }
297 
303 inline int SMArch(int device_id) {
304  auto major = ComputeCapabilityMajor(device_id);
305  auto minor = ComputeCapabilityMinor(device_id);
306  return 10 * major + minor;
307 }
308 
315 inline bool SupportsFloat16Compute(int device_id) {
316  if (device_id < 0) {
317  return false;
318  } else {
319  // Kepler and most Maxwell GPUs do not support fp16 compute
320  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
321  return (computeCapabilityMajor > 5) ||
322  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
323  }
324 }
325 
332 inline bool SupportsTensorCore(int device_id) {
333  // Volta (sm_70) supports TensorCore algos
334  return device_id >= 0 &&
335  ComputeCapabilityMajor(device_id) >=7;
336 }
337 
338 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
339 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
340 
345 inline bool GetEnvAllowTensorCore() {
346  // Since these statics are in the '.h' file, they will exist and will be set
347  // separately in each compilation unit. Not ideal, but cleaner than creating a
348  // cuda_utils.cc solely to have a single instance and initialization.
349  static bool allow_tensor_core = false;
350  static bool is_set = false;
351  if (!is_set) {
352  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
353  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
354  allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
355  dmlc::optional<bool>(default_value)).value();
356  is_set = true;
357  }
358  return allow_tensor_core;
359 }
360 
361 #if CUDA_VERSION >= 9000
362 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
363 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
364  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
365  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
366  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
367  return handle_math_mode;
368 }
369 #endif
370 
371 #endif // MXNET_USE_CUDA
372 
373 #if MXNET_USE_CUDNN
374 
375 #include <cudnn.h>
376 
377 #define CUDNN_CALL(func) \
378  { \
379  cudnnStatus_t e = (func); \
380  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
381  }
382 
390 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
391 #if CUDNN_MAJOR >= 7
392  int max_algos = 0;
393  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
394  return max_algos;
395 #else
396  return 10;
397 #endif
398 }
399 
407 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
408 #if CUDNN_MAJOR >= 7
409  int max_algos = 0;
410  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
411  return max_algos;
412 #else
413  return 10;
414 #endif
415 }
416 
424 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
425 #if CUDNN_MAJOR >= 7
426  int max_algos = 0;
427  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
428  return max_algos;
429 #else
430  return 10;
431 #endif
432 }
433 
434 #endif // MXNET_USE_CUDNN
435 
436 // Overload atomicAdd to work for floats on all architectures
437 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
438 // From CUDA Programming Guide
439 static inline __device__ void atomicAdd(double *address, double val) {
440  unsigned long long* address_as_ull = // NOLINT(*)
441  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
442  unsigned long long old = *address_as_ull; // NOLINT(*)
443  unsigned long long assumed; // NOLINT(*)
444 
445  do {
446  assumed = old;
447  old = atomicCAS(address_as_ull, assumed,
448  __double_as_longlong(val +
449  __longlong_as_double(assumed)));
450 
451  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
452  } while (assumed != old);
453 }
454 #endif
455 
456 // Overload atomicAdd for half precision
457 // Taken from:
458 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
459 #if defined(__CUDA_ARCH__)
460 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
461  mshadow::half::half_t val) {
462  unsigned int *address_as_ui =
463  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
464  (reinterpret_cast<size_t>(address) & 2));
465  unsigned int old = *address_as_ui;
466  unsigned int assumed;
467 
468  do {
469  assumed = old;
470  mshadow::half::half_t hsum;
471  hsum.half_ =
472  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
473  hsum += val;
474  old = reinterpret_cast<size_t>(address) & 2
475  ? (old & 0xffff) | (hsum.half_ << 16)
476  : (old & 0xffff0000) | hsum.half_;
477  old = atomicCAS(address_as_ui, assumed, old);
478  } while (assumed != old);
479 }
480 
481 template <typename DType>
482 __device__ inline DType ldg(const DType* address) {
483 #if __CUDA_ARCH__ >= 350
484  return __ldg(address);
485 #else
486  return *address;
487 #endif
488 }
489 #endif
490 
491 #endif // MXNET_COMMON_CUDA_UTILS_H_
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:199
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:279
namespace of mxnet
Definition: base.h:126
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:345
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:303
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:162
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:315
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:157
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:332
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:95
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:339
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:124
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:291
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:186
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:64