24 #ifndef MXNET_RANDOM_GENERATOR_H_ 25 #define MXNET_RANDOM_GENERATOR_H_ 32 #include <curand_kernel.h> 34 #endif // MXNET_USE_CUDA 40 template<
typename Device,
typename DType MSHADOW_DEFAULT_DTYPE>
43 template<
typename DType>
55 typedef typename std::conditional<std::is_floating_point<DType>::value,
58 : engine_(gen->states_ + state_idx) {}
60 Impl(
const Impl &) =
delete;
61 Impl &operator=(
const Impl &) =
delete;
66 return static_cast<int64_t
>(engine_->operator()() << 31) + engine_->operator()();
70 typedef typename std::conditional<std::is_integral<DType>::value,
71 std::uniform_int_distribution<DType>,
72 std::uniform_real_distribution<FType>>::type GType;
74 return dist_uniform(*engine_);
78 std::normal_distribution<FType> dist_normal;
79 return dist_normal(*engine_);
83 std::mt19937 *engine_;
87 inst->states_ =
new std::mt19937[kNumRandomStates];
91 delete[] inst->states_;
95 for (
int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
100 return static_cast<void*
>(states_);
104 std::mt19937 *states_;
107 template<
typename DType>
110 template<
typename DType>
115 template<
typename DType>
130 Impl &operator=(
const Impl &) =
delete;
131 Impl(
const Impl &) =
delete;
136 global_state_idx_(state_idx),
137 state_(*(gen->states_ + state_idx)) {}
141 global_gen_->states_[global_state_idx_] = state_;
145 return curand(&state_);
149 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
153 return static_cast<float>(1.0) - curand_uniform(&state_);
157 return curand_normal(&state_);
162 int global_state_idx_;
163 curandStatePhilox4_32_10_t state_;
176 curandStatePhilox4_32_10_t *states_;
189 Impl &operator=(
const Impl &) =
delete;
190 Impl(
const Impl &) =
delete;
195 global_state_idx_(state_idx),
196 state_(*(gen->states_ + state_idx)) {}
200 global_gen_->states_[global_state_idx_] = state_;
204 return curand(&state_);
208 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
212 return static_cast<float>(1.0) - curand_uniform_double(&state_);
216 return curand_normal_double(&state_);
221 int global_state_idx_;
222 curandStatePhilox4_32_10_t state_;
226 curandStatePhilox4_32_10_t *states_;
229 #endif // MXNET_USE_CUDA 234 #endif // MXNET_RANDOM_GENERATOR_H_ MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:156
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:86
Definition: random_generator.h:180
namespace of mxnet
Definition: api_registry.h:33
Definition: stream_gpu-inl.h:37
static const int kMinNumRandomPerThread
Definition: random_generator.h:119
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:134
MSHADOW_XINLINE int rand()
Definition: random_generator.h:63
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:77
__device__ ~Impl()
Definition: random_generator.h:139
#define MSHADOW_FORCE_INLINE
Definition: base.h:225
static const int kNumRandomStates
Definition: random_generator.h:121
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:203
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:90
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:207
device name CPU
Definition: tensor.h:39
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:57
device name GPU
Definition: tensor.h:46
#define MSHADOW_XINLINE
Definition: base.h:230
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:152
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:56
static const int kNumRandomStates
Definition: random_generator.h:49
Definition: random_generator.h:116
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:65
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:148
static const int kMinNumRandomPerThread
Definition: random_generator.h:47
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:69
Definition: random_generator.h:44
MSHADOW_XINLINE void * GetStates()
Definition: random_generator.h:99
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:144
__device__ ~Impl()
Definition: random_generator.h:198
Definition: random_generator.h:41
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:215
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:211
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:193
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:94
computaion stream structure, used for asynchronous computations
Definition: tensor.h:383