25 #ifndef MXNET_RANDOM_GENERATOR_H_ 26 #define MXNET_RANDOM_GENERATOR_H_ 33 #include <curand_kernel.h> 35 #endif // MXNET_USE_CUDA 41 template<
typename Device,
typename DType MSHADOW_DEFAULT_DTYPE>
44 template<
typename DType>
56 typedef typename std::conditional<std::is_floating_point<DType>::value,
59 : engine_(gen->states_ + state_idx) {}
61 Impl(
const Impl &) =
delete;
62 Impl &operator=(
const Impl &) =
delete;
64 MSHADOW_XINLINE
int rand() {
return engine_->operator()(); }
67 return static_cast<int64_t
>(engine_->operator()() << 31) + engine_->operator()();
71 typedef typename std::conditional<std::is_integral<DType>::value,
72 std::uniform_int_distribution<DType>,
73 std::uniform_real_distribution<FType>>::type GType;
75 return dist_uniform(*engine_);
79 std::normal_distribution<FType> dist_normal;
80 return dist_normal(*engine_);
84 std::mt19937 *engine_;
88 inst->states_ =
new std::mt19937[kNumRandomStates];
92 delete[] inst->states_;
95 MSHADOW_XINLINE
void Seed(mshadow::Stream<cpu> *, uint32_t seed) {
96 for (
int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
100 std::mt19937 *states_;
103 template<
typename DType>
106 template<
typename DType>
111 template<
typename DType>
126 Impl &operator=(
const Impl &) =
delete;
127 Impl(
const Impl &) =
delete;
132 global_state_idx_(state_idx),
133 state_(*(gen->states_ + state_idx)) {}
137 global_gen_->states_[global_state_idx_] = state_;
140 MSHADOW_FORCE_INLINE __device__
int rand() {
141 return curand(&state_);
145 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
148 MSHADOW_FORCE_INLINE __device__
float uniform() {
149 return static_cast<float>(1.0) - curand_uniform(&state_);
152 MSHADOW_FORCE_INLINE __device__
float normal() {
153 return curand_normal(&state_);
158 int global_state_idx_;
159 curandStatePhilox4_32_10_t state_;
166 void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
169 curandStatePhilox4_32_10_t *states_;
182 Impl &operator=(
const Impl &) =
delete;
183 Impl(
const Impl &) =
delete;
188 global_state_idx_(state_idx),
189 state_(*(gen->states_ + state_idx)) {}
193 global_gen_->states_[global_state_idx_] = state_;
196 MSHADOW_FORCE_INLINE __device__
int rand() {
197 return curand(&state_);
201 return static_cast<int64_t
>(curand(&state_) << 31) + curand(&state_);
204 MSHADOW_FORCE_INLINE __device__
double uniform() {
205 return static_cast<float>(1.0) - curand_uniform_double(&state_);
208 MSHADOW_FORCE_INLINE __device__
double normal() {
209 return curand_normal_double(&state_);
214 int global_state_idx_;
215 curandStatePhilox4_32_10_t state_;
219 curandStatePhilox4_32_10_t *states_;
222 #endif // MXNET_USE_CUDA 227 #endif // MXNET_RANDOM_GENERATOR_H_ MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:152
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:87
Definition: random_generator.h:173
namespace of mxnet
Definition: base.h:89
static const int kMinNumRandomPerThread
Definition: random_generator.h:115
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:130
MSHADOW_XINLINE int rand()
Definition: random_generator.h:64
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:78
__device__ ~Impl()
Definition: random_generator.h:135
static const int kNumRandomStates
Definition: random_generator.h:117
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:196
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:91
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:200
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:58
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:148
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:57
static const int kNumRandomStates
Definition: random_generator.h:50
Definition: random_generator.h:112
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:66
mshadow::gpu gpu
mxnet gpu
Definition: base.h:93
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:144
static const int kMinNumRandomPerThread
Definition: random_generator.h:48
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:70
Definition: random_generator.h:45
mshadow::cpu cpu
mxnet cpu
Definition: base.h:91
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:140
__device__ ~Impl()
Definition: random_generator.h:191
Definition: random_generator.h:42
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:208
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:204
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:186
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:95