mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_CUDA_UTILS_H_
26 #define MXNET_COMMON_CUDA_UTILS_H_
27 
28 #include <dmlc/logging.h>
29 #include <dmlc/parameter.h>
30 #include <dmlc/optional.h>
31 #include <mshadow/base.h>
32 
34 #ifdef __JETBRAINS_IDE__
35 #define __CUDACC__ 1
36 #define __host__
37 #define __device__
38 #define __global__
39 #define __forceinline__
40 #define __shared__
41 inline void __syncthreads() {}
42 inline void __threadfence_block() {}
43 template<class T> inline T __clz(const T val) { return val; }
44 struct __cuda_fake_struct { int x; int y; int z; };
45 extern __cuda_fake_struct blockDim;
46 extern __cuda_fake_struct threadIdx;
47 extern __cuda_fake_struct blockIdx;
48 #endif
49 
50 #if MXNET_USE_CUDA
51 
52 #include <cuda_runtime.h>
53 #include <cublas_v2.h>
54 #include <curand.h>
55 
56 namespace mxnet {
57 namespace common {
59 namespace cuda {
65 inline const char* CublasGetErrorString(cublasStatus_t error) {
66  switch (error) {
67  case CUBLAS_STATUS_SUCCESS:
68  return "CUBLAS_STATUS_SUCCESS";
69  case CUBLAS_STATUS_NOT_INITIALIZED:
70  return "CUBLAS_STATUS_NOT_INITIALIZED";
71  case CUBLAS_STATUS_ALLOC_FAILED:
72  return "CUBLAS_STATUS_ALLOC_FAILED";
73  case CUBLAS_STATUS_INVALID_VALUE:
74  return "CUBLAS_STATUS_INVALID_VALUE";
75  case CUBLAS_STATUS_ARCH_MISMATCH:
76  return "CUBLAS_STATUS_ARCH_MISMATCH";
77  case CUBLAS_STATUS_MAPPING_ERROR:
78  return "CUBLAS_STATUS_MAPPING_ERROR";
79  case CUBLAS_STATUS_EXECUTION_FAILED:
80  return "CUBLAS_STATUS_EXECUTION_FAILED";
81  case CUBLAS_STATUS_INTERNAL_ERROR:
82  return "CUBLAS_STATUS_INTERNAL_ERROR";
83  case CUBLAS_STATUS_NOT_SUPPORTED:
84  return "CUBLAS_STATUS_NOT_SUPPORTED";
85  default:
86  break;
87  }
88  return "Unknown cuBLAS status";
89 }
90 
96 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
97  switch (error) {
98  case CUSOLVER_STATUS_SUCCESS:
99  return "CUSOLVER_STATUS_SUCCESS";
100  case CUSOLVER_STATUS_NOT_INITIALIZED:
101  return "CUSOLVER_STATUS_NOT_INITIALIZED";
102  case CUSOLVER_STATUS_ALLOC_FAILED:
103  return "CUSOLVER_STATUS_ALLOC_FAILED";
104  case CUSOLVER_STATUS_INVALID_VALUE:
105  return "CUSOLVER_STATUS_INVALID_VALUE";
106  case CUSOLVER_STATUS_ARCH_MISMATCH:
107  return "CUSOLVER_STATUS_ARCH_MISMATCH";
108  case CUSOLVER_STATUS_EXECUTION_FAILED:
109  return "CUSOLVER_STATUS_EXECUTION_FAILED";
110  case CUSOLVER_STATUS_INTERNAL_ERROR:
111  return "CUSOLVER_STATUS_INTERNAL_ERROR";
112  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
113  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
114  default:
115  break;
116  }
117  return "Unknown cuSOLVER status";
118 }
119 
125 inline const char* CurandGetErrorString(curandStatus_t status) {
126  switch (status) {
127  case CURAND_STATUS_SUCCESS:
128  return "CURAND_STATUS_SUCCESS";
129  case CURAND_STATUS_VERSION_MISMATCH:
130  return "CURAND_STATUS_VERSION_MISMATCH";
131  case CURAND_STATUS_NOT_INITIALIZED:
132  return "CURAND_STATUS_NOT_INITIALIZED";
133  case CURAND_STATUS_ALLOCATION_FAILED:
134  return "CURAND_STATUS_ALLOCATION_FAILED";
135  case CURAND_STATUS_TYPE_ERROR:
136  return "CURAND_STATUS_TYPE_ERROR";
137  case CURAND_STATUS_OUT_OF_RANGE:
138  return "CURAND_STATUS_OUT_OF_RANGE";
139  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
140  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
141  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
142  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
143  case CURAND_STATUS_LAUNCH_FAILURE:
144  return "CURAND_STATUS_LAUNCH_FAILURE";
145  case CURAND_STATUS_PREEXISTING_FAILURE:
146  return "CURAND_STATUS_PREEXISTING_FAILURE";
147  case CURAND_STATUS_INITIALIZATION_FAILED:
148  return "CURAND_STATUS_INITIALIZATION_FAILED";
149  case CURAND_STATUS_ARCH_MISMATCH:
150  return "CURAND_STATUS_ARCH_MISMATCH";
151  case CURAND_STATUS_INTERNAL_ERROR:
152  return "CURAND_STATUS_INTERNAL_ERROR";
153  }
154  return "Unknown cuRAND status";
155 }
156 
157 template <typename DType>
158 inline DType __device__ CudaMax(DType a, DType b) {
159  return a > b ? a : b;
160 }
161 
162 template <typename DType>
163 inline DType __device__ CudaMin(DType a, DType b) {
164  return a < b ? a : b;
165 }
166 
167 } // namespace cuda
168 } // namespace common
169 } // namespace mxnet
170 
175 #define CHECK_CUDA_ERROR(msg) \
176  { \
177  cudaError_t e = cudaGetLastError(); \
178  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
179  }
180 
187 #define CUDA_CALL(func) \
188  { \
189  cudaError_t e = (func); \
190  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
191  << "CUDA: " << cudaGetErrorString(e); \
192  }
193 
200 #define CUBLAS_CALL(func) \
201  { \
202  cublasStatus_t e = (func); \
203  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
204  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
205  }
206 
213 #define CUSOLVER_CALL(func) \
214  { \
215  cusolverStatus_t e = (func); \
216  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
217  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
218  }
219 
226 #define CURAND_CALL(func) \
227  { \
228  curandStatus_t e = (func); \
229  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
230  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
231  }
232 
239 #define NVRTC_CALL(x) \
240  { \
241  nvrtcResult result = x; \
242  CHECK_EQ(result, NVRTC_SUCCESS) \
243  << #x " failed with error " \
244  << nvrtcGetErrorString(result); \
245  }
246 
253 #define CUDA_DRIVER_CALL(func) \
254  { \
255  CUresult e = (func); \
256  if (e != CUDA_SUCCESS) { \
257  char const * err_msg = nullptr; \
258  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
259  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
260  } else { \
261  LOG(FATAL) << "CUDA Driver: " << err_msg; \
262  } \
263  } \
264  }
265 
266 
267 #if !defined(_MSC_VER)
268 #define CUDA_UNROLL _Pragma("unroll")
269 #define CUDA_NOUNROLL _Pragma("nounroll")
270 #else
271 #define CUDA_UNROLL
272 #define CUDA_NOUNROLL
273 #endif
274 
280 inline int ComputeCapabilityMajor(int device_id) {
281  int major = 0;
282  CUDA_CALL(cudaDeviceGetAttribute(&major,
283  cudaDevAttrComputeCapabilityMajor, device_id));
284  return major;
285 }
286 
292 inline int ComputeCapabilityMinor(int device_id) {
293  int minor = 0;
294  CUDA_CALL(cudaDeviceGetAttribute(&minor,
295  cudaDevAttrComputeCapabilityMinor, device_id));
296  return minor;
297 }
298 
304 inline int SMArch(int device_id) {
305  auto major = ComputeCapabilityMajor(device_id);
306  auto minor = ComputeCapabilityMinor(device_id);
307  return 10 * major + minor;
308 }
309 
316 inline bool SupportsFloat16Compute(int device_id) {
317  if (device_id < 0) {
318  return false;
319  } else {
320  // Kepler and most Maxwell GPUs do not support fp16 compute
321  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
322  return (computeCapabilityMajor > 5) ||
323  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
324  }
325 }
326 
333 inline bool SupportsTensorCore(int device_id) {
334  // Volta (sm_70) supports TensorCore algos
335  return device_id >= 0 &&
336  ComputeCapabilityMajor(device_id) >=7;
337 }
338 
339 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
340 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
341 
346 inline bool GetEnvAllowTensorCore() {
347  // Since these statics are in the '.h' file, they will exist and will be set
348  // separately in each compilation unit. Not ideal, but cleaner than creating a
349  // cuda_utils.cc solely to have a single instance and initialization.
350  static bool allow_tensor_core = false;
351  static bool is_set = false;
352  if (!is_set) {
353  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
354  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
355  allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
356  dmlc::optional<bool>(default_value)).value();
357  is_set = true;
358  }
359  return allow_tensor_core;
360 }
361 
362 #if CUDA_VERSION >= 9000
363 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
364 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
365  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
366  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
367  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
368  return handle_math_mode;
369 }
370 #endif
371 
372 #endif // MXNET_USE_CUDA
373 
374 #if MXNET_USE_CUDNN
375 
376 #include <cudnn.h>
377 
378 #define CUDNN_CALL(func) \
379  { \
380  cudnnStatus_t e = (func); \
381  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
382  }
383 
391 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
392 #if CUDNN_MAJOR >= 7
393  int max_algos = 0;
394  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
395  return max_algos;
396 #else
397  return 10;
398 #endif
399 }
400 
408 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
409 #if CUDNN_MAJOR >= 7
410  int max_algos = 0;
411  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
412  return max_algos;
413 #else
414  return 10;
415 #endif
416 }
417 
425 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
426 #if CUDNN_MAJOR >= 7
427  int max_algos = 0;
428  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
429  return max_algos;
430 #else
431  return 10;
432 #endif
433 }
434 
435 #endif // MXNET_USE_CUDNN
436 
437 // Overload atomicAdd to work for floats on all architectures
438 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
439 // From CUDA Programming Guide
440 static inline __device__ void atomicAdd(double *address, double val) {
441  unsigned long long* address_as_ull = // NOLINT(*)
442  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
443  unsigned long long old = *address_as_ull; // NOLINT(*)
444  unsigned long long assumed; // NOLINT(*)
445 
446  do {
447  assumed = old;
448  old = atomicCAS(address_as_ull, assumed,
449  __double_as_longlong(val +
450  __longlong_as_double(assumed)));
451 
452  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
453  } while (assumed != old);
454 }
455 #endif
456 
457 // Overload atomicAdd for half precision
458 // Taken from:
459 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
460 #if defined(__CUDA_ARCH__)
461 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
462  mshadow::half::half_t val) {
463  unsigned int *address_as_ui =
464  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
465  (reinterpret_cast<size_t>(address) & 2));
466  unsigned int old = *address_as_ui;
467  unsigned int assumed;
468 
469  do {
470  assumed = old;
471  mshadow::half::half_t hsum;
472  hsum.half_ =
473  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
474  hsum += val;
475  old = reinterpret_cast<size_t>(address) & 2
476  ? (old & 0xffff) | (hsum.half_ << 16)
477  : (old & 0xffff0000) | hsum.half_;
478  old = atomicCAS(address_as_ui, assumed, old);
479  } while (assumed != old);
480 }
481 
482 // Overload atomicAdd to work for signed int64 on all architectures
483 static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
484  atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
485 }
486 
487 template <typename DType>
488 __device__ inline DType ldg(const DType* address) {
489 #if __CUDA_ARCH__ >= 350
490  return __ldg(address);
491 #else
492  return *address;
493 #endif
494 }
495 #endif
496 
497 #endif // MXNET_COMMON_CUDA_UTILS_H_
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:200
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:280
namespace of mxnet
Definition: base.h:127
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:346
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:304
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:163
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:316
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:158
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:333
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:96
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:340
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:125
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:292
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:187
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:65