mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_COMMON_CUDA_UTILS_H_
25 #define MXNET_COMMON_CUDA_UTILS_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <dmlc/optional.h>
30 #include <mshadow/base.h>
31 #include <mxnet/libinfo.h>
32 
34 #ifdef __JETBRAINS_IDE__
35 #define __CUDACC__ 1
36 #define __host__
37 #define __device__
38 #define __global__
39 #define __forceinline__
40 #define __shared__
41 inline void __syncthreads() {}
42 inline void __threadfence_block() {}
43 template<class T> inline T __clz(const T val) { return val; }
44 struct __cuda_fake_struct { int x; int y; int z; };
45 extern __cuda_fake_struct blockDim;
46 extern __cuda_fake_struct threadIdx;
47 extern __cuda_fake_struct blockIdx;
48 #endif
49 
50 #define QUOTE(x) #x
51 #define QUOTEVALUE(x) QUOTE(x)
52 
53 #if MXNET_USE_CUDA
54 
55 #include <cuda_runtime.h>
56 #include <cublas_v2.h>
57 #include <curand.h>
58 
59 #include <vector>
60 
61 #define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \
62  static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \
63  QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \
64  QUOTEVALUE(min_version) " or later.")
65 
70 #ifdef __CUDACC__
71 inline __device__ bool __is_supported_cuda_architecture() {
72 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
73 #error "Fermi and earlier GPU architectures are not supported (architecture versions less than 3.0)"
74  return false;
75 #else
76  return true;
77 #endif // __CUDA_ARCH__ < 300
78 }
79 #endif // __CUDACC__
80 
85 #define CHECK_CUDA_ERROR(msg) \
86  { \
87  cudaError_t e = cudaGetLastError(); \
88  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
89  }
90 
97 #define CUDA_CALL(func) \
98  { \
99  cudaError_t e = (func); \
100  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
101  << "CUDA: " << cudaGetErrorString(e); \
102  }
103 
110 #define CUBLAS_CALL(func) \
111  { \
112  cublasStatus_t e = (func); \
113  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
114  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
115  }
116 
123 #define CUSOLVER_CALL(func) \
124  { \
125  cusolverStatus_t e = (func); \
126  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
127  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
128  }
129 
136 #define CURAND_CALL(func) \
137  { \
138  curandStatus_t e = (func); \
139  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
140  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
141  }
142 
149 #define NVRTC_CALL(x) \
150  { \
151  nvrtcResult result = x; \
152  CHECK_EQ(result, NVRTC_SUCCESS) \
153  << #x " failed with error " \
154  << nvrtcGetErrorString(result); \
155  }
156 
163 #define CUDA_DRIVER_CALL(func) \
164  { \
165  CUresult e = (func); \
166  if (e != CUDA_SUCCESS) { \
167  char const * err_msg = nullptr; \
168  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
169  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
170  } else { \
171  LOG(FATAL) << "CUDA Driver: " << err_msg; \
172  } \
173  } \
174  }
175 
176 
177 #if !defined(_MSC_VER)
178 #define CUDA_UNROLL _Pragma("unroll")
179 #define CUDA_NOUNROLL _Pragma("nounroll")
180 #else
181 #define CUDA_UNROLL
182 #define CUDA_NOUNROLL
183 #endif
184 
185 namespace mxnet {
186 namespace common {
188 namespace cuda {
192 template<typename DType>
193 struct CublasType;
194 
195 // With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own
196 // datatype cublasDataType_t. The older cudaDataType_t values could be
197 // included below, but since this class was introduced to support the cuBLAS v8
198 // call cublasGemmEx(), burdening the class with the legacy type values
199 // was not needed.
200 
201 template<>
202 struct CublasType<float> {
203  static const int kFlag = mshadow::kFloat32;
204 #if CUDA_VERSION >= 8000
205  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
206 #endif
207  typedef float ScaleType;
208  static const float one;
209  static const float zero;
210 };
211 template<>
212 struct CublasType<double> {
213  static const int kFlag = mshadow::kFloat64;
214 #if CUDA_VERSION >= 8000
215  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
216 #endif
217  typedef double ScaleType;
218  static const double one;
219  static const double zero;
220 };
221 template<>
222 struct CublasType<mshadow::half::half_t> {
223  static const int kFlag = mshadow::kFloat16;
224 #if CUDA_VERSION >= 8000
225  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
226 #endif
227  typedef float ScaleType;
228  static const mshadow::half::half_t one;
229  static const mshadow::half::half_t zero;
230 };
231 template<>
232 struct CublasType<uint8_t> {
233  static const int kFlag = mshadow::kUint8;
234 #if CUDA_VERSION >= 8000
235  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
236 #endif
237  typedef uint8_t ScaleType;
238  static const uint8_t one = 1;
239  static const uint8_t zero = 0;
240 };
241 template<>
242 struct CublasType<int32_t> {
243  static const int kFlag = mshadow::kInt32;
244 #if CUDA_VERSION >= 8000
245  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
246 #endif
247  typedef int32_t ScaleType;
248  static const int32_t one = 1;
249  static const int32_t zero = 0;
250 };
251 
257 inline const char* CublasGetErrorString(cublasStatus_t error) {
258  switch (error) {
259  case CUBLAS_STATUS_SUCCESS:
260  return "CUBLAS_STATUS_SUCCESS";
261  case CUBLAS_STATUS_NOT_INITIALIZED:
262  return "CUBLAS_STATUS_NOT_INITIALIZED";
263  case CUBLAS_STATUS_ALLOC_FAILED:
264  return "CUBLAS_STATUS_ALLOC_FAILED";
265  case CUBLAS_STATUS_INVALID_VALUE:
266  return "CUBLAS_STATUS_INVALID_VALUE";
267  case CUBLAS_STATUS_ARCH_MISMATCH:
268  return "CUBLAS_STATUS_ARCH_MISMATCH";
269  case CUBLAS_STATUS_MAPPING_ERROR:
270  return "CUBLAS_STATUS_MAPPING_ERROR";
271  case CUBLAS_STATUS_EXECUTION_FAILED:
272  return "CUBLAS_STATUS_EXECUTION_FAILED";
273  case CUBLAS_STATUS_INTERNAL_ERROR:
274  return "CUBLAS_STATUS_INTERNAL_ERROR";
275  case CUBLAS_STATUS_NOT_SUPPORTED:
276  return "CUBLAS_STATUS_NOT_SUPPORTED";
277  default:
278  break;
279  }
280  return "Unknown cuBLAS status";
281 }
282 
283 #if CUDA_VERSION >= 8000
284 
289 inline cublasOperation_t CublasTransposeOp(bool transpose) {
290  return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
291 }
292 #endif
293 
299 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
300  switch (error) {
301  case CUSOLVER_STATUS_SUCCESS:
302  return "CUSOLVER_STATUS_SUCCESS";
303  case CUSOLVER_STATUS_NOT_INITIALIZED:
304  return "CUSOLVER_STATUS_NOT_INITIALIZED";
305  case CUSOLVER_STATUS_ALLOC_FAILED:
306  return "CUSOLVER_STATUS_ALLOC_FAILED";
307  case CUSOLVER_STATUS_INVALID_VALUE:
308  return "CUSOLVER_STATUS_INVALID_VALUE";
309  case CUSOLVER_STATUS_ARCH_MISMATCH:
310  return "CUSOLVER_STATUS_ARCH_MISMATCH";
311  case CUSOLVER_STATUS_EXECUTION_FAILED:
312  return "CUSOLVER_STATUS_EXECUTION_FAILED";
313  case CUSOLVER_STATUS_INTERNAL_ERROR:
314  return "CUSOLVER_STATUS_INTERNAL_ERROR";
315  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
316  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
317  default:
318  break;
319  }
320  return "Unknown cuSOLVER status";
321 }
322 
328 inline const char* CurandGetErrorString(curandStatus_t status) {
329  switch (status) {
330  case CURAND_STATUS_SUCCESS:
331  return "CURAND_STATUS_SUCCESS";
332  case CURAND_STATUS_VERSION_MISMATCH:
333  return "CURAND_STATUS_VERSION_MISMATCH";
334  case CURAND_STATUS_NOT_INITIALIZED:
335  return "CURAND_STATUS_NOT_INITIALIZED";
336  case CURAND_STATUS_ALLOCATION_FAILED:
337  return "CURAND_STATUS_ALLOCATION_FAILED";
338  case CURAND_STATUS_TYPE_ERROR:
339  return "CURAND_STATUS_TYPE_ERROR";
340  case CURAND_STATUS_OUT_OF_RANGE:
341  return "CURAND_STATUS_OUT_OF_RANGE";
342  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
343  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
344  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
345  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
346  case CURAND_STATUS_LAUNCH_FAILURE:
347  return "CURAND_STATUS_LAUNCH_FAILURE";
348  case CURAND_STATUS_PREEXISTING_FAILURE:
349  return "CURAND_STATUS_PREEXISTING_FAILURE";
350  case CURAND_STATUS_INITIALIZATION_FAILED:
351  return "CURAND_STATUS_INITIALIZATION_FAILED";
352  case CURAND_STATUS_ARCH_MISMATCH:
353  return "CURAND_STATUS_ARCH_MISMATCH";
354  case CURAND_STATUS_INTERNAL_ERROR:
355  return "CURAND_STATUS_INTERNAL_ERROR";
356  }
357  return "Unknown cuRAND status";
358 }
359 
360 template <typename DType>
361 inline DType __device__ CudaMax(DType a, DType b) {
362  return a > b ? a : b;
363 }
364 
365 template <typename DType>
366 inline DType __device__ CudaMin(DType a, DType b) {
367  return a < b ? a : b;
368 }
369 
370 class DeviceStore {
371  public:
373  explicit DeviceStore(int requested_device = -1, bool restore = true) :
374  restore_device_(-1),
375  current_device_(requested_device),
376  restore_(restore) {
377  if (restore_)
378  CUDA_CALL(cudaGetDevice(&restore_device_));
379  if (requested_device != restore_device_) {
380  SetDevice(requested_device);
381  }
382  }
383 
385  if (restore_ &&
386  current_device_ != restore_device_ &&
387  current_device_ != -1 &&
388  restore_device_ != -1)
389  CUDA_CALL(cudaSetDevice(restore_device_));
390  }
391 
392  void SetDevice(int device) {
393  if (device != -1) {
394  CUDA_CALL(cudaSetDevice(device));
395  current_device_ = device;
396  }
397  }
398 
399  private:
400  int restore_device_;
401  int current_device_;
402  bool restore_;
403 };
404 
413 int get_load_type(size_t N);
414 
425 int get_rows_per_block(size_t row_size, int num_threads_per_block);
426 
427 } // namespace cuda
428 } // namespace common
429 } // namespace mxnet
430 
432 constexpr size_t kMaxNumGpus = 64;
433 
434 // The implementations below assume that accesses of 32-bit ints are inherently atomic and
435 // can be read/written by multiple threads without locks. The values held should be < 2^31.
436 
445 inline int cudaAttributeLookup(int device_id, std::vector<int32_t> *cached_values,
446  cudaDeviceAttr attr, const char *attr_name) {
447  if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) {
448  LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id;
449  } else if ((*cached_values)[device_id] < 0) {
450  int temp = -1;
451  CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id));
452  (*cached_values)[device_id] = static_cast<int32_t>(temp);
453  }
454  return (*cached_values)[device_id];
455 }
456 
462 inline int ComputeCapabilityMajor(int device_id) {
463  static std::vector<int32_t> capability_major(kMaxNumGpus, -1);
464  return cudaAttributeLookup(device_id, &capability_major,
465  cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor");
466 }
467 
473 inline int ComputeCapabilityMinor(int device_id) {
474  static std::vector<int32_t> capability_minor(kMaxNumGpus, -1);
475  return cudaAttributeLookup(device_id, &capability_minor,
476  cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor");
477 }
478 
484 inline int SMArch(int device_id) {
485  auto major = ComputeCapabilityMajor(device_id);
486  auto minor = ComputeCapabilityMinor(device_id);
487  return 10 * major + minor;
488 }
489 
495 inline int MultiprocessorCount(int device_id) {
496  static std::vector<int32_t> sm_counts(kMaxNumGpus, -1);
497  return cudaAttributeLookup(device_id, &sm_counts,
498  cudaDevAttrMultiProcessorCount, "MultiprocessorCount");
499 }
500 
506 inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
507  static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1);
508  return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor,
509  cudaDevAttrMaxSharedMemoryPerMultiprocessor,
510  "MaxSharedMemoryPerMultiprocessor");
511 }
512 
518 inline bool SupportsCooperativeLaunch(int device_id) {
519  static std::vector<int32_t> coop_launch(kMaxNumGpus, -1);
520  return cudaAttributeLookup(device_id, &coop_launch,
521  cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch");
522 }
523 
530 inline bool SupportsFloat16Compute(int device_id) {
531  if (device_id < 0) {
532  return false;
533  } else {
534  // Kepler and most Maxwell GPUs do not support fp16 compute
535  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
536  return (computeCapabilityMajor > 5) ||
537  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
538  }
539 }
540 
547 inline bool SupportsTensorCore(int device_id) {
548  // Volta (sm_70) supports TensorCore algos
549  return device_id >= 0 &&
550  ComputeCapabilityMajor(device_id) >=7;
551 }
552 
553 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
554 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
555 
560 inline bool GetEnvAllowTensorCore() {
561  // Since these statics are in the '.h' file, they will exist and will be set
562  // separately in each compilation unit. Not ideal, but cleaner than creating a
563  // cuda_utils.cc solely to have a single instance and initialization.
564  static bool allow_tensor_core = false;
565  static bool is_set = false;
566  if (!is_set) {
567  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
568  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
569  allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
570  dmlc::optional<bool>(default_value)).value();
571  is_set = true;
572  }
573  return allow_tensor_core;
574 }
575 
576 // The policy if the user hasn't set the environment variable
577 // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
578 #define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT false
579 
584  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be
585  // legal.
587  return dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION",
588  dmlc::optional<bool>(default_value))
589  .value();
590 }
591 
592 #if CUDA_VERSION >= 9000
593 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
594 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
595  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
596  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
597  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
598  return handle_math_mode;
599 }
600 #endif
601 
602 #endif // MXNET_USE_CUDA
603 
604 #if MXNET_USE_CUDNN
605 
606 #include <cudnn.h>
607 
608 // Creating CUDNN_VERSION_AS_STRING as follows avoids a static_assert error message that shows
609 // the formula for CUDNN_VERSION, i.e. "1000 * 7 + 100 * 6 + 0" rather than number "7600".
610 static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10,
611  "CUDNN_VERSION_AS_STRING macro assumptions violated.");
612 #if CUDNN_PATCHLEVEL >= 10
613 #define CUDNN_VERSION_AS_STRING QUOTEVALUE(CUDNN_MAJOR) \
614  QUOTEVALUE(CUDNN_MINOR) \
615  QUOTEVALUE(CUDNN_PATCHLEVEL)
616 #else
617 #define CUDNN_VERSION_AS_STRING QUOTEVALUE(CUDNN_MAJOR) \
618  QUOTEVALUE(CUDNN_MINOR) \
619  "0" QUOTEVALUE(CUDNN_PATCHLEVEL)
620 #endif
621 
622 #define STATIC_ASSERT_CUDNN_VERSION_GE(min_version) \
623  static_assert(CUDNN_VERSION >= min_version, "Compiled-against cuDNN version " \
624  CUDNN_VERSION_AS_STRING " is too old, please upgrade system to version " \
625  QUOTEVALUE(min_version) " or later.")
626 
627 #define CUDNN_CALL(func) \
628  { \
629  cudnnStatus_t e = (func); \
630  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
631  }
632 
640 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
641  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
642  int max_algos = 0;
643  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
644  return max_algos;
645 }
646 
654 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
655  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
656  int max_algos = 0;
657  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
658  return max_algos;
659 }
660 
668 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
669  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
670  int max_algos = 0;
671  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
672  return max_algos;
673 }
674 
675 #endif // MXNET_USE_CUDNN
676 
677 // Overload atomicAdd to work for floats on all architectures
678 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
679 // From CUDA Programming Guide
680 static inline __device__ void atomicAdd(double *address, double val) {
681  unsigned long long* address_as_ull = // NOLINT(*)
682  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
683  unsigned long long old = *address_as_ull; // NOLINT(*)
684  unsigned long long assumed; // NOLINT(*)
685 
686  do {
687  assumed = old;
688  old = atomicCAS(address_as_ull, assumed,
689  __double_as_longlong(val +
690  __longlong_as_double(assumed)));
691 
692  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
693  } while (assumed != old);
694 }
695 #endif
696 
697 // Overload atomicAdd for half precision
698 // Taken from:
699 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
700 #ifdef __CUDACC__
701 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
702  mshadow::half::half_t val) {
703  unsigned int *address_as_ui =
704  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
705  (reinterpret_cast<size_t>(address) & 2));
706  unsigned int old = *address_as_ui;
707  unsigned int assumed;
708 
709  do {
710  assumed = old;
711  mshadow::half::half_t hsum;
712  hsum.half_ =
713  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
714  hsum += val;
715  old = reinterpret_cast<size_t>(address) & 2
716  ? (old & 0xffff) | (hsum.half_ << 16)
717  : (old & 0xffff0000) | hsum.half_;
718  old = atomicCAS(address_as_ui, assumed, old);
719  } while (assumed != old);
720 }
721 
722 static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
723  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
724  unsigned int old = *address_as_ui;
725  unsigned int shift = (((size_t)address & 0x3) << 3);
726  unsigned int sum;
727  unsigned int assumed;
728 
729  do {
730  assumed = old;
731  sum = val + static_cast<uint8_t>((old >> shift) & 0xff);
732  old = (old & ~(0x000000ff << shift)) | (sum << shift);
733  old = atomicCAS(address_as_ui, assumed, old);
734  } while (assumed != old);
735 }
736 
737 static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
738  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
739  unsigned int old = *address_as_ui;
740  unsigned int shift = (((size_t)address & 0x3) << 3);
741  unsigned int sum;
742  unsigned int assumed;
743 
744  do {
745  assumed = old;
746  sum = val + static_cast<int8_t>((old >> shift) & 0xff);
747  old = (old & ~(0x000000ff << shift)) | (sum << shift);
748  old = atomicCAS(address_as_ui, assumed, old);
749  } while (assumed != old);
750 }
751 
752 // Overload atomicAdd to work for signed int64 on all architectures
753 static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
754  atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
755 }
756 
757 template <typename DType>
758 __device__ inline DType ldg(const DType* address) {
759 #if __CUDA_ARCH__ >= 350
760  return __ldg(address);
761 #else
762  return *address;
763 #endif
764 }
765 
766 template <typename OP, typename T>
767 __device__ inline T warp_reduce(T value, OP redfun) {
768  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
769  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
770  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
771  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
772  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
773  return value;
774 }
775 
776 template <typename OP>
777 __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
778  float v = static_cast<float>(value);
779  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
780  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
781  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
782  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
783  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
784  return mshadow::half::half_t(v);
785 }
786 
787 #endif // __CUDACC__
788 
789 #endif // MXNET_COMMON_CUDA_UTILS_H_
Definition: base.h:359
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:110
static const double zero
Definition: cuda_utils.h:219
Container to hold optional data.
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:462
Definition: cuda_utils.h:370
c++17 compatible optional class.
Definition: optional.h:43
bool GetEnvAllowTensorCoreConversion()
Returns global policy for TensorCore implicit type casting.
Definition: cuda_utils.h:583
namespace of mxnet
Definition: api_registry.h:33
int get_load_type(size_t N)
Get the largest datatype suitable to read requested number of bytes.
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:560
DeviceStore(int requested_device=-1, bool restore=true)
default constructor- only optionally restores previous device
Definition: cuda_utils.h:373
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:484
int32_t ScaleType
Definition: cuda_utils.h:247
static const float zero
Definition: cuda_utils.h:209
static const float one
Definition: cuda_utils.h:208
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:366
static const double one
Definition: cuda_utils.h:218
int MultiprocessorCount(int device_id)
Return the number of streaming multiprocessors of GPU device_id.
Definition: cuda_utils.h:495
static const mshadow::half::half_t zero
Definition: cuda_utils.h:229
void SetDevice(int device)
Definition: cuda_utils.h:392
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:530
constexpr size_t kMaxNumGpus
Maximum number of GPUs.
Definition: cuda_utils.h:432
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:361
Definition: base.h:360
void SetDevice(int devid)
set the device of current thread to work on
#define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT
Definition: cuda_utils.h:578
int cudaAttributeLookup(int device_id, std::vector< int32_t > *cached_values, cudaDeviceAttr attr, const char *attr_name)
Return an attribute GPU device_id.
Definition: cuda_utils.h:445
Definition: base.h:363
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:547
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:299
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:554
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:328
int MaxSharedMemoryPerMultiprocessor(int device_id)
Return the shared memory size in bytes of each of the GPU&#39;s streaming multiprocessors.
Definition: cuda_utils.h:506
Definition: base.h:362
~DeviceStore()
Definition: cuda_utils.h:384
uint8_t ScaleType
Definition: cuda_utils.h:237
Converts between C++ datatypes and enums/constants needed by cuBLAS.
Definition: cuda_utils.h:193
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:473
static const mshadow::half::half_t one
Definition: cuda_utils.h:228
overloaded + operator between half_t and bf16_t
Definition: base.h:334
float ScaleType
Definition: cuda_utils.h:207
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:76
Definition: base.h:361
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:97
double ScaleType
Definition: cuda_utils.h:217
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:257
int get_rows_per_block(size_t row_size, int num_threads_per_block)
Determine how many rows in a 2D matrix should a block of threads handle based on the row size and the...
Provide lightweight util to do parameter setup and checking.
bool SupportsCooperativeLaunch(int device_id)
Return whether the GPU device_id supports cooperative-group kernel launching.
Definition: cuda_utils.h:518
get features of the MXNet library at runtime