25 #ifndef MXNET_COMMON_RANDOM_GENERATOR_H_ 26 #define MXNET_COMMON_RANDOM_GENERATOR_H_ 33 #include <curand_kernel.h> 34 #include "../common/cuda_utils.h" 35 #endif // MXNET_USE_CUDA 41 template<
typename Device,
typename DType MSHADOW_DEFAULT_DTYPE>
44 template<
typename DType>
55 typedef typename std::conditional<std::is_floating_point<DType>::value,
59 : engine_(gen->states_ + state_idx) {}
61 Impl(
const Impl &) =
delete;
62 Impl &operator=(
const Impl &) =
delete;
64 MSHADOW_XINLINE
int rand() {
return engine_->operator()(); }
67 typedef typename std::conditional<std::is_integral<DType>::value,
68 std::uniform_int_distribution<DType>,
69 std::uniform_real_distribution<FType>>::type GType;
71 return dist_uniform(*engine_);
75 std::normal_distribution<FType> dist_normal;
76 return dist_normal(*engine_);
80 std::mt19937 *engine_;
84 inst->states_ =
new std::mt19937[kNumRandomStates];
88 delete[] inst->states_;
91 MSHADOW_XINLINE
void Seed(mshadow::Stream<cpu> *, uint32_t seed) {
92 for (
int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
96 std::mt19937 *states_;
99 template<
typename DType>
102 template<
typename DType>
107 template<
typename DType>
121 Impl &operator=(
const Impl &) =
delete;
122 Impl(
const Impl &) =
delete;
127 global_state_idx_(state_idx),
128 state_(*(gen->states_ + state_idx)) {}
132 global_gen_->states_[global_state_idx_] = state_;
135 MSHADOW_FORCE_INLINE __device__
int rand() {
136 return curand(&state_);
139 MSHADOW_FORCE_INLINE __device__
float uniform() {
140 return static_cast<float>(1.0) - curand_uniform(&state_);
143 MSHADOW_FORCE_INLINE __device__
float normal() {
144 return curand_normal(&state_);
149 int global_state_idx_;
150 curandStatePhilox4_32_10_t state_;
155 kNumRandomStates *
sizeof(curandStatePhilox4_32_10_t)));
162 void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
165 curandStatePhilox4_32_10_t *states_;
177 Impl &operator=(
const Impl &) =
delete;
178 Impl(
const Impl &) =
delete;
183 global_state_idx_(state_idx),
184 state_(*(gen->states_ + state_idx)) {}
188 global_gen_->states_[global_state_idx_] = state_;
191 MSHADOW_FORCE_INLINE __device__
int rand() {
192 return curand(&state_);
195 MSHADOW_FORCE_INLINE __device__
double uniform() {
196 return static_cast<float>(1.0) - curand_uniform_double(&state_);
199 MSHADOW_FORCE_INLINE __device__
double normal() {
200 return curand_normal_double(&state_);
205 int global_state_idx_;
206 curandStatePhilox4_32_10_t state_;
210 curandStatePhilox4_32_10_t *states_;
213 #endif // MXNET_USE_CUDA 218 #endif // MXNET_COMMON_RANDOM_GENERATOR_H_ static void FreeState(RandGenerator< gpu, DType > *inst)
Definition: random_generator.h:158
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:143
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:83
Definition: random_generator.h:169
namespace of mxnet
Definition: base.h:118
static const int kMinNumRandomPerThread
Definition: random_generator.h:111
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:125
MSHADOW_XINLINE int rand()
Definition: random_generator.h:64
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:74
__device__ ~Impl()
Definition: random_generator.h:130
static const int kNumRandomStates
Definition: random_generator.h:113
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:191
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:87
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:58
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:139
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:56
static const int kNumRandomStates
Definition: random_generator.h:50
Definition: random_generator.h:108
mshadow::gpu gpu
mxnet gpu
Definition: base.h:122
static const int kMinNumRandomPerThread
Definition: random_generator.h:48
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:66
Definition: random_generator.h:45
mshadow::cpu cpu
mxnet cpu
Definition: base.h:120
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:135
__device__ ~Impl()
Definition: random_generator.h:186
static void AllocState(RandGenerator< gpu, DType > *inst)
Definition: random_generator.h:153
Definition: random_generator.h:42
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:199
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:202
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:195
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:181
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:91