mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_CUDA_UTILS_H_
26 #define MXNET_COMMON_CUDA_UTILS_H_
27 
28 #include <dmlc/logging.h>
29 #include <dmlc/parameter.h>
30 #include <dmlc/optional.h>
31 #include <mshadow/base.h>
32 
34 #ifdef __JETBRAINS_IDE__
35 #define __CUDACC__ 1
36 #define __host__
37 #define __device__
38 #define __global__
39 #define __forceinline__
40 #define __shared__
41 inline void __syncthreads() {}
42 inline void __threadfence_block() {}
43 template<class T> inline T __clz(const T val) { return val; }
44 struct __cuda_fake_struct { int x; int y; int z; };
45 extern __cuda_fake_struct blockDim;
46 extern __cuda_fake_struct threadIdx;
47 extern __cuda_fake_struct blockIdx;
48 #endif
49 
50 #if MXNET_USE_CUDA
51 
52 #include <cuda_runtime.h>
53 #include <cublas_v2.h>
54 #include <curand.h>
55 
60 #ifdef __CUDACC__
61 inline __device__ bool __is_supported_cuda_architecture() {
62 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
63 #error "Fermi and earlier GPU architectures are not supported (architecture versions less than 3.0)"
64  return false;
65 #else
66  return true;
67 #endif // __CUDA_ARCH__ < 300
68 }
69 #endif // __CUDACC__
70 
75 #define CHECK_CUDA_ERROR(msg) \
76  { \
77  cudaError_t e = cudaGetLastError(); \
78  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
79  }
80 
87 #define CUDA_CALL(func) \
88  { \
89  cudaError_t e = (func); \
90  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
91  << "CUDA: " << cudaGetErrorString(e); \
92  }
93 
100 #define CUBLAS_CALL(func) \
101  { \
102  cublasStatus_t e = (func); \
103  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
104  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
105  }
106 
113 #define CUSOLVER_CALL(func) \
114  { \
115  cusolverStatus_t e = (func); \
116  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
117  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
118  }
119 
126 #define CURAND_CALL(func) \
127  { \
128  curandStatus_t e = (func); \
129  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
130  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
131  }
132 
139 #define NVRTC_CALL(x) \
140  { \
141  nvrtcResult result = x; \
142  CHECK_EQ(result, NVRTC_SUCCESS) \
143  << #x " failed with error " \
144  << nvrtcGetErrorString(result); \
145  }
146 
153 #define CUDA_DRIVER_CALL(func) \
154  { \
155  CUresult e = (func); \
156  if (e != CUDA_SUCCESS) { \
157  char const * err_msg = nullptr; \
158  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
159  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
160  } else { \
161  LOG(FATAL) << "CUDA Driver: " << err_msg; \
162  } \
163  } \
164  }
165 
166 
167 #if !defined(_MSC_VER)
168 #define CUDA_UNROLL _Pragma("unroll")
169 #define CUDA_NOUNROLL _Pragma("nounroll")
170 #else
171 #define CUDA_UNROLL
172 #define CUDA_NOUNROLL
173 #endif
174 
175 namespace mxnet {
176 namespace common {
178 namespace cuda {
184 inline const char* CublasGetErrorString(cublasStatus_t error) {
185  switch (error) {
186  case CUBLAS_STATUS_SUCCESS:
187  return "CUBLAS_STATUS_SUCCESS";
188  case CUBLAS_STATUS_NOT_INITIALIZED:
189  return "CUBLAS_STATUS_NOT_INITIALIZED";
190  case CUBLAS_STATUS_ALLOC_FAILED:
191  return "CUBLAS_STATUS_ALLOC_FAILED";
192  case CUBLAS_STATUS_INVALID_VALUE:
193  return "CUBLAS_STATUS_INVALID_VALUE";
194  case CUBLAS_STATUS_ARCH_MISMATCH:
195  return "CUBLAS_STATUS_ARCH_MISMATCH";
196  case CUBLAS_STATUS_MAPPING_ERROR:
197  return "CUBLAS_STATUS_MAPPING_ERROR";
198  case CUBLAS_STATUS_EXECUTION_FAILED:
199  return "CUBLAS_STATUS_EXECUTION_FAILED";
200  case CUBLAS_STATUS_INTERNAL_ERROR:
201  return "CUBLAS_STATUS_INTERNAL_ERROR";
202  case CUBLAS_STATUS_NOT_SUPPORTED:
203  return "CUBLAS_STATUS_NOT_SUPPORTED";
204  default:
205  break;
206  }
207  return "Unknown cuBLAS status";
208 }
209 
215 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
216  switch (error) {
217  case CUSOLVER_STATUS_SUCCESS:
218  return "CUSOLVER_STATUS_SUCCESS";
219  case CUSOLVER_STATUS_NOT_INITIALIZED:
220  return "CUSOLVER_STATUS_NOT_INITIALIZED";
221  case CUSOLVER_STATUS_ALLOC_FAILED:
222  return "CUSOLVER_STATUS_ALLOC_FAILED";
223  case CUSOLVER_STATUS_INVALID_VALUE:
224  return "CUSOLVER_STATUS_INVALID_VALUE";
225  case CUSOLVER_STATUS_ARCH_MISMATCH:
226  return "CUSOLVER_STATUS_ARCH_MISMATCH";
227  case CUSOLVER_STATUS_EXECUTION_FAILED:
228  return "CUSOLVER_STATUS_EXECUTION_FAILED";
229  case CUSOLVER_STATUS_INTERNAL_ERROR:
230  return "CUSOLVER_STATUS_INTERNAL_ERROR";
231  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
232  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
233  default:
234  break;
235  }
236  return "Unknown cuSOLVER status";
237 }
238 
244 inline const char* CurandGetErrorString(curandStatus_t status) {
245  switch (status) {
246  case CURAND_STATUS_SUCCESS:
247  return "CURAND_STATUS_SUCCESS";
248  case CURAND_STATUS_VERSION_MISMATCH:
249  return "CURAND_STATUS_VERSION_MISMATCH";
250  case CURAND_STATUS_NOT_INITIALIZED:
251  return "CURAND_STATUS_NOT_INITIALIZED";
252  case CURAND_STATUS_ALLOCATION_FAILED:
253  return "CURAND_STATUS_ALLOCATION_FAILED";
254  case CURAND_STATUS_TYPE_ERROR:
255  return "CURAND_STATUS_TYPE_ERROR";
256  case CURAND_STATUS_OUT_OF_RANGE:
257  return "CURAND_STATUS_OUT_OF_RANGE";
258  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
259  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
260  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
261  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
262  case CURAND_STATUS_LAUNCH_FAILURE:
263  return "CURAND_STATUS_LAUNCH_FAILURE";
264  case CURAND_STATUS_PREEXISTING_FAILURE:
265  return "CURAND_STATUS_PREEXISTING_FAILURE";
266  case CURAND_STATUS_INITIALIZATION_FAILED:
267  return "CURAND_STATUS_INITIALIZATION_FAILED";
268  case CURAND_STATUS_ARCH_MISMATCH:
269  return "CURAND_STATUS_ARCH_MISMATCH";
270  case CURAND_STATUS_INTERNAL_ERROR:
271  return "CURAND_STATUS_INTERNAL_ERROR";
272  }
273  return "Unknown cuRAND status";
274 }
275 
276 template <typename DType>
277 inline DType __device__ CudaMax(DType a, DType b) {
278  return a > b ? a : b;
279 }
280 
281 template <typename DType>
282 inline DType __device__ CudaMin(DType a, DType b) {
283  return a < b ? a : b;
284 }
285 
286 class DeviceStore {
287  public:
289  explicit DeviceStore(int requested_device = -1, bool restore = true) :
290  restore_device_(-1),
291  current_device_(requested_device),
292  restore_(restore) {
293  if (restore_)
294  CUDA_CALL(cudaGetDevice(&restore_device_));
295  if (requested_device != restore_device_) {
296  SetDevice(requested_device);
297  }
298  }
299 
301  if (restore_ &&
302  current_device_ != restore_device_ &&
303  current_device_ != -1 &&
304  restore_device_ != -1)
305  CUDA_CALL(cudaSetDevice(restore_device_));
306  }
307 
308  void SetDevice(int device) {
309  if (device != -1) {
310  CUDA_CALL(cudaSetDevice(device));
311  current_device_ = device;
312  }
313  }
314 
315  private:
316  int restore_device_;
317  int current_device_;
318  bool restore_;
319 };
320 
321 } // namespace cuda
322 } // namespace common
323 } // namespace mxnet
324 
330 inline int ComputeCapabilityMajor(int device_id) {
331  int major = 0;
332  CUDA_CALL(cudaDeviceGetAttribute(&major,
333  cudaDevAttrComputeCapabilityMajor, device_id));
334  return major;
335 }
336 
342 inline int ComputeCapabilityMinor(int device_id) {
343  int minor = 0;
344  CUDA_CALL(cudaDeviceGetAttribute(&minor,
345  cudaDevAttrComputeCapabilityMinor, device_id));
346  return minor;
347 }
348 
354 inline int SMArch(int device_id) {
355  auto major = ComputeCapabilityMajor(device_id);
356  auto minor = ComputeCapabilityMinor(device_id);
357  return 10 * major + minor;
358 }
359 
366 inline bool SupportsFloat16Compute(int device_id) {
367  if (device_id < 0) {
368  return false;
369  } else {
370  // Kepler and most Maxwell GPUs do not support fp16 compute
371  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
372  return (computeCapabilityMajor > 5) ||
373  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
374  }
375 }
376 
383 inline bool SupportsTensorCore(int device_id) {
384  // Volta (sm_70) supports TensorCore algos
385  return device_id >= 0 &&
386  ComputeCapabilityMajor(device_id) >=7;
387 }
388 
389 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
390 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
391 
396 inline bool GetEnvAllowTensorCore() {
397  // Since these statics are in the '.h' file, they will exist and will be set
398  // separately in each compilation unit. Not ideal, but cleaner than creating a
399  // cuda_utils.cc solely to have a single instance and initialization.
400  static bool allow_tensor_core = false;
401  static bool is_set = false;
402  if (!is_set) {
403  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
404  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
405  allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
406  dmlc::optional<bool>(default_value)).value();
407  is_set = true;
408  }
409  return allow_tensor_core;
410 }
411 
412 // The policy if the user hasn't set the environment variable
413 // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
414 #define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT false
415 
420  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be
421  // legal.
423  return dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION",
424  dmlc::optional<bool>(default_value))
425  .value();
426 }
427 
428 #if CUDA_VERSION >= 9000
429 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
430 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
431  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
432  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
433  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
434  return handle_math_mode;
435 }
436 #endif
437 
438 #endif // MXNET_USE_CUDA
439 
440 #if MXNET_USE_CUDNN
441 
442 #include <cudnn.h>
443 
444 #define CUDNN_CALL(func) \
445  { \
446  cudnnStatus_t e = (func); \
447  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
448  }
449 
457 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
458 #if CUDNN_MAJOR >= 7
459  int max_algos = 0;
460  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
461  return max_algos;
462 #else
463  return 10;
464 #endif
465 }
466 
474 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
475 #if CUDNN_MAJOR >= 7
476  int max_algos = 0;
477  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
478  return max_algos;
479 #else
480  return 10;
481 #endif
482 }
483 
491 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
492 #if CUDNN_MAJOR >= 7
493  int max_algos = 0;
494  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
495  return max_algos;
496 #else
497  return 10;
498 #endif
499 }
500 
501 #endif // MXNET_USE_CUDNN
502 
503 // Overload atomicAdd to work for floats on all architectures
504 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
505 // From CUDA Programming Guide
506 static inline __device__ void atomicAdd(double *address, double val) {
507  unsigned long long* address_as_ull = // NOLINT(*)
508  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
509  unsigned long long old = *address_as_ull; // NOLINT(*)
510  unsigned long long assumed; // NOLINT(*)
511 
512  do {
513  assumed = old;
514  old = atomicCAS(address_as_ull, assumed,
515  __double_as_longlong(val +
516  __longlong_as_double(assumed)));
517 
518  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
519  } while (assumed != old);
520 }
521 #endif
522 
523 // Overload atomicAdd for half precision
524 // Taken from:
525 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
526 #if defined(__CUDA_ARCH__)
527 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
528  mshadow::half::half_t val) {
529  unsigned int *address_as_ui =
530  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
531  (reinterpret_cast<size_t>(address) & 2));
532  unsigned int old = *address_as_ui;
533  unsigned int assumed;
534 
535  do {
536  assumed = old;
537  mshadow::half::half_t hsum;
538  hsum.half_ =
539  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
540  hsum += val;
541  old = reinterpret_cast<size_t>(address) & 2
542  ? (old & 0xffff) | (hsum.half_ << 16)
543  : (old & 0xffff0000) | hsum.half_;
544  old = atomicCAS(address_as_ui, assumed, old);
545  } while (assumed != old);
546 }
547 
548 static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
549  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
550  unsigned int old = *address_as_ui;
551  unsigned int shift = (((size_t)address & 0x3) << 3);
552  unsigned int sum;
553  unsigned int assumed;
554 
555  do {
556  assumed = old;
557  sum = val + static_cast<uint8_t>((old >> shift) & 0xff);
558  old = (old & ~(0x000000ff << shift)) | (sum << shift);
559  old = atomicCAS(address_as_ui, assumed, old);
560  } while (assumed != old);
561 }
562 
563 static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
564  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
565  unsigned int old = *address_as_ui;
566  unsigned int shift = (((size_t)address & 0x3) << 3);
567  unsigned int sum;
568  unsigned int assumed;
569 
570  do {
571  assumed = old;
572  sum = val + static_cast<int8_t>((old >> shift) & 0xff);
573  old = (old & ~(0x000000ff << shift)) | (sum << shift);
574  old = atomicCAS(address_as_ui, assumed, old);
575  } while (assumed != old);
576 }
577 
578 // Overload atomicAdd to work for signed int64 on all architectures
579 static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
580  atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
581 }
582 
583 template <typename DType>
584 __device__ inline DType ldg(const DType* address) {
585 #if __CUDA_ARCH__ >= 350
586  return __ldg(address);
587 #else
588  return *address;
589 #endif
590 }
591 #endif
592 
593 #endif // MXNET_COMMON_CUDA_UTILS_H_
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:100
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:330
Definition: cuda_utils.h:286
bool GetEnvAllowTensorCoreConversion()
Returns global policy for TensorCore implicit type casting.
Definition: cuda_utils.h:419
namespace of mxnet
Definition: base.h:89
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:396
DeviceStore(int requested_device=-1, bool restore=true)
default constructor- only optionally restores previous device
Definition: cuda_utils.h:289
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:354
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:282
void SetDevice(int device)
Definition: cuda_utils.h:308
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:366
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:277
#define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT
Definition: cuda_utils.h:414
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:383
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:215
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:390
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:244
~DeviceStore()
Definition: cuda_utils.h:300
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:342
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:87
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:184
Symbol sum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Computes the sum of array elements over given axes.
Definition: op.h:3912