mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_CUDA_UTILS_H_
26 #define MXNET_COMMON_CUDA_UTILS_H_
27 
28 #include <dmlc/logging.h>
29 #include <dmlc/parameter.h>
30 #include <dmlc/optional.h>
31 #include <mshadow/base.h>
32 
34 #ifdef __JETBRAINS_IDE__
35 #define __CUDACC__ 1
36 #define __host__
37 #define __device__
38 #define __global__
39 #define __forceinline__
40 #define __shared__
41 inline void __syncthreads() {}
42 inline void __threadfence_block() {}
43 template<class T> inline T __clz(const T val) { return val; }
44 struct __cuda_fake_struct { int x; int y; int z; };
45 extern __cuda_fake_struct blockDim;
46 extern __cuda_fake_struct threadIdx;
47 extern __cuda_fake_struct blockIdx;
48 #endif
49 
50 #if MXNET_USE_CUDA
51 
52 #include <cuda_runtime.h>
53 #include <cublas_v2.h>
54 #include <curand.h>
55 
60 #ifdef __CUDACC__
61 inline __device__ bool __is_supported_cuda_architecture() {
62 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
63 #error "Fermi and earlier GPU architectures are not supported (architecture versions less than 3.0)"
64  return false;
65 #else
66  return true;
67 #endif // __CUDA_ARCH__ < 300
68 }
69 #endif // __CUDACC__
70 
71 namespace mxnet {
72 namespace common {
74 namespace cuda {
80 inline const char* CublasGetErrorString(cublasStatus_t error) {
81  switch (error) {
82  case CUBLAS_STATUS_SUCCESS:
83  return "CUBLAS_STATUS_SUCCESS";
84  case CUBLAS_STATUS_NOT_INITIALIZED:
85  return "CUBLAS_STATUS_NOT_INITIALIZED";
86  case CUBLAS_STATUS_ALLOC_FAILED:
87  return "CUBLAS_STATUS_ALLOC_FAILED";
88  case CUBLAS_STATUS_INVALID_VALUE:
89  return "CUBLAS_STATUS_INVALID_VALUE";
90  case CUBLAS_STATUS_ARCH_MISMATCH:
91  return "CUBLAS_STATUS_ARCH_MISMATCH";
92  case CUBLAS_STATUS_MAPPING_ERROR:
93  return "CUBLAS_STATUS_MAPPING_ERROR";
94  case CUBLAS_STATUS_EXECUTION_FAILED:
95  return "CUBLAS_STATUS_EXECUTION_FAILED";
96  case CUBLAS_STATUS_INTERNAL_ERROR:
97  return "CUBLAS_STATUS_INTERNAL_ERROR";
98  case CUBLAS_STATUS_NOT_SUPPORTED:
99  return "CUBLAS_STATUS_NOT_SUPPORTED";
100  default:
101  break;
102  }
103  return "Unknown cuBLAS status";
104 }
105 
111 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
112  switch (error) {
113  case CUSOLVER_STATUS_SUCCESS:
114  return "CUSOLVER_STATUS_SUCCESS";
115  case CUSOLVER_STATUS_NOT_INITIALIZED:
116  return "CUSOLVER_STATUS_NOT_INITIALIZED";
117  case CUSOLVER_STATUS_ALLOC_FAILED:
118  return "CUSOLVER_STATUS_ALLOC_FAILED";
119  case CUSOLVER_STATUS_INVALID_VALUE:
120  return "CUSOLVER_STATUS_INVALID_VALUE";
121  case CUSOLVER_STATUS_ARCH_MISMATCH:
122  return "CUSOLVER_STATUS_ARCH_MISMATCH";
123  case CUSOLVER_STATUS_EXECUTION_FAILED:
124  return "CUSOLVER_STATUS_EXECUTION_FAILED";
125  case CUSOLVER_STATUS_INTERNAL_ERROR:
126  return "CUSOLVER_STATUS_INTERNAL_ERROR";
127  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
128  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
129  default:
130  break;
131  }
132  return "Unknown cuSOLVER status";
133 }
134 
140 inline const char* CurandGetErrorString(curandStatus_t status) {
141  switch (status) {
142  case CURAND_STATUS_SUCCESS:
143  return "CURAND_STATUS_SUCCESS";
144  case CURAND_STATUS_VERSION_MISMATCH:
145  return "CURAND_STATUS_VERSION_MISMATCH";
146  case CURAND_STATUS_NOT_INITIALIZED:
147  return "CURAND_STATUS_NOT_INITIALIZED";
148  case CURAND_STATUS_ALLOCATION_FAILED:
149  return "CURAND_STATUS_ALLOCATION_FAILED";
150  case CURAND_STATUS_TYPE_ERROR:
151  return "CURAND_STATUS_TYPE_ERROR";
152  case CURAND_STATUS_OUT_OF_RANGE:
153  return "CURAND_STATUS_OUT_OF_RANGE";
154  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
155  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
156  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
157  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
158  case CURAND_STATUS_LAUNCH_FAILURE:
159  return "CURAND_STATUS_LAUNCH_FAILURE";
160  case CURAND_STATUS_PREEXISTING_FAILURE:
161  return "CURAND_STATUS_PREEXISTING_FAILURE";
162  case CURAND_STATUS_INITIALIZATION_FAILED:
163  return "CURAND_STATUS_INITIALIZATION_FAILED";
164  case CURAND_STATUS_ARCH_MISMATCH:
165  return "CURAND_STATUS_ARCH_MISMATCH";
166  case CURAND_STATUS_INTERNAL_ERROR:
167  return "CURAND_STATUS_INTERNAL_ERROR";
168  }
169  return "Unknown cuRAND status";
170 }
171 
172 template <typename DType>
173 inline DType __device__ CudaMax(DType a, DType b) {
174  return a > b ? a : b;
175 }
176 
177 template <typename DType>
178 inline DType __device__ CudaMin(DType a, DType b) {
179  return a < b ? a : b;
180 }
181 
182 } // namespace cuda
183 } // namespace common
184 } // namespace mxnet
185 
190 #define CHECK_CUDA_ERROR(msg) \
191  { \
192  cudaError_t e = cudaGetLastError(); \
193  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
194  }
195 
202 #define CUDA_CALL(func) \
203  { \
204  cudaError_t e = (func); \
205  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
206  << "CUDA: " << cudaGetErrorString(e); \
207  }
208 
215 #define CUBLAS_CALL(func) \
216  { \
217  cublasStatus_t e = (func); \
218  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
219  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
220  }
221 
228 #define CUSOLVER_CALL(func) \
229  { \
230  cusolverStatus_t e = (func); \
231  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
232  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
233  }
234 
241 #define CURAND_CALL(func) \
242  { \
243  curandStatus_t e = (func); \
244  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
245  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
246  }
247 
254 #define NVRTC_CALL(x) \
255  { \
256  nvrtcResult result = x; \
257  CHECK_EQ(result, NVRTC_SUCCESS) \
258  << #x " failed with error " \
259  << nvrtcGetErrorString(result); \
260  }
261 
268 #define CUDA_DRIVER_CALL(func) \
269  { \
270  CUresult e = (func); \
271  if (e != CUDA_SUCCESS) { \
272  char const * err_msg = nullptr; \
273  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
274  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
275  } else { \
276  LOG(FATAL) << "CUDA Driver: " << err_msg; \
277  } \
278  } \
279  }
280 
281 
282 #if !defined(_MSC_VER)
283 #define CUDA_UNROLL _Pragma("unroll")
284 #define CUDA_NOUNROLL _Pragma("nounroll")
285 #else
286 #define CUDA_UNROLL
287 #define CUDA_NOUNROLL
288 #endif
289 
295 inline int ComputeCapabilityMajor(int device_id) {
296  int major = 0;
297  CUDA_CALL(cudaDeviceGetAttribute(&major,
298  cudaDevAttrComputeCapabilityMajor, device_id));
299  return major;
300 }
301 
307 inline int ComputeCapabilityMinor(int device_id) {
308  int minor = 0;
309  CUDA_CALL(cudaDeviceGetAttribute(&minor,
310  cudaDevAttrComputeCapabilityMinor, device_id));
311  return minor;
312 }
313 
319 inline int SMArch(int device_id) {
320  auto major = ComputeCapabilityMajor(device_id);
321  auto minor = ComputeCapabilityMinor(device_id);
322  return 10 * major + minor;
323 }
324 
331 inline bool SupportsFloat16Compute(int device_id) {
332  if (device_id < 0) {
333  return false;
334  } else {
335  // Kepler and most Maxwell GPUs do not support fp16 compute
336  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
337  return (computeCapabilityMajor > 5) ||
338  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
339  }
340 }
341 
348 inline bool SupportsTensorCore(int device_id) {
349  // Volta (sm_70) supports TensorCore algos
350  return device_id >= 0 &&
351  ComputeCapabilityMajor(device_id) >=7;
352 }
353 
354 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
355 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
356 
361 inline bool GetEnvAllowTensorCore() {
362  // Since these statics are in the '.h' file, they will exist and will be set
363  // separately in each compilation unit. Not ideal, but cleaner than creating a
364  // cuda_utils.cc solely to have a single instance and initialization.
365  static bool allow_tensor_core = false;
366  static bool is_set = false;
367  if (!is_set) {
368  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
369  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
370  allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
371  dmlc::optional<bool>(default_value)).value();
372  is_set = true;
373  }
374  return allow_tensor_core;
375 }
376 
377 #if CUDA_VERSION >= 9000
378 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
379 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
380  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
381  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
382  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
383  return handle_math_mode;
384 }
385 #endif
386 
387 #endif // MXNET_USE_CUDA
388 
389 #if MXNET_USE_CUDNN
390 
391 #include <cudnn.h>
392 
393 #define CUDNN_CALL(func) \
394  { \
395  cudnnStatus_t e = (func); \
396  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
397  }
398 
406 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
407 #if CUDNN_MAJOR >= 7
408  int max_algos = 0;
409  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
410  return max_algos;
411 #else
412  return 10;
413 #endif
414 }
415 
423 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
424 #if CUDNN_MAJOR >= 7
425  int max_algos = 0;
426  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
427  return max_algos;
428 #else
429  return 10;
430 #endif
431 }
432 
440 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
441 #if CUDNN_MAJOR >= 7
442  int max_algos = 0;
443  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
444  return max_algos;
445 #else
446  return 10;
447 #endif
448 }
449 
450 #endif // MXNET_USE_CUDNN
451 
452 // Overload atomicAdd to work for floats on all architectures
453 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
454 // From CUDA Programming Guide
455 static inline __device__ void atomicAdd(double *address, double val) {
456  unsigned long long* address_as_ull = // NOLINT(*)
457  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
458  unsigned long long old = *address_as_ull; // NOLINT(*)
459  unsigned long long assumed; // NOLINT(*)
460 
461  do {
462  assumed = old;
463  old = atomicCAS(address_as_ull, assumed,
464  __double_as_longlong(val +
465  __longlong_as_double(assumed)));
466 
467  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
468  } while (assumed != old);
469 }
470 #endif
471 
472 // Overload atomicAdd for half precision
473 // Taken from:
474 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
475 #if defined(__CUDA_ARCH__)
476 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
477  mshadow::half::half_t val) {
478  unsigned int *address_as_ui =
479  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
480  (reinterpret_cast<size_t>(address) & 2));
481  unsigned int old = *address_as_ui;
482  unsigned int assumed;
483 
484  do {
485  assumed = old;
486  mshadow::half::half_t hsum;
487  hsum.half_ =
488  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
489  hsum += val;
490  old = reinterpret_cast<size_t>(address) & 2
491  ? (old & 0xffff) | (hsum.half_ << 16)
492  : (old & 0xffff0000) | hsum.half_;
493  old = atomicCAS(address_as_ui, assumed, old);
494  } while (assumed != old);
495 }
496 
497 // Overload atomicAdd to work for signed int64 on all architectures
498 static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
499  atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
500 }
501 
502 template <typename DType>
503 __device__ inline DType ldg(const DType* address) {
504 #if __CUDA_ARCH__ >= 350
505  return __ldg(address);
506 #else
507  return *address;
508 #endif
509 }
510 #endif
511 
512 #endif // MXNET_COMMON_CUDA_UTILS_H_
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:215
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:295
namespace of mxnet
Definition: base.h:118
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:361
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:319
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:178
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:331
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:173
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:348
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:111
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:355
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:140
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:307
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:202
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:80