mxnet
random_generator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_RANDOM_GENERATOR_H_
26 #define MXNET_COMMON_RANDOM_GENERATOR_H_
27 
28 #include <mxnet/base.h>
29 #include <random>
30 #include <new>
31 
32 #if MXNET_USE_CUDA
33 #include <curand_kernel.h>
34 #include "../common/cuda_utils.h"
35 #endif // MXNET_USE_CUDA
36 
37 namespace mxnet {
38 namespace common {
39 namespace random {
40 
41 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
43 
44 template<typename DType>
45 class RandGenerator<cpu, DType> {
46  public:
47  // at least how many random numbers should be generated by one CPU thread.
48  static const int kMinNumRandomPerThread;
49  // store how many global random states for CPU.
50  static const int kNumRandomStates;
51 
52  // implementation class for random number generator
53  class Impl {
54  public:
55  typedef typename std::conditional<std::is_floating_point<DType>::value,
56  DType, double>::type FType;
57 
58  explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
59  : engine_(gen->states_ + state_idx) {}
60 
61  Impl(const Impl &) = delete;
62  Impl &operator=(const Impl &) = delete;
63 
64  MSHADOW_XINLINE int rand() { return engine_->operator()(); }
65 
66  MSHADOW_XINLINE FType uniform() {
67  typedef typename std::conditional<std::is_integral<DType>::value,
68  std::uniform_int_distribution<DType>,
69  std::uniform_real_distribution<FType>>::type GType;
70  GType dist_uniform;
71  return dist_uniform(*engine_);
72  }
73 
74  MSHADOW_XINLINE FType normal() {
75  std::normal_distribution<FType> dist_normal;
76  return dist_normal(*engine_);
77  }
78 
79  private:
80  std::mt19937 *engine_;
81  };
82 
84  inst->states_ = new std::mt19937[kNumRandomStates];
85  }
86 
87  static void FreeState(RandGenerator<cpu, DType> *inst) {
88  delete[] inst->states_;
89  }
90 
91  MSHADOW_XINLINE void Seed(mshadow::Stream<cpu> *, uint32_t seed) {
92  for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
93  }
94 
95  private:
96  std::mt19937 *states_;
97 }; // class RandGenerator<cpu, DType>
98 
99 template<typename DType>
101 
102 template<typename DType>
104 
105 #if MXNET_USE_CUDA
106 
107 template<typename DType>
108 class RandGenerator<gpu, DType> {
109  public:
110  // at least how many random numbers should be generated by one GPU thread.
111  static const int kMinNumRandomPerThread;
112  // store how many global random states for GPU.
113  static const int kNumRandomStates;
114 
115  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
116  // by using 1.0-curand_uniform().
117  // Needed as some samplers in sampler.h won't be able to deal with
118  // one of the boundary cases.
119  class Impl {
120  public:
121  Impl &operator=(const Impl &) = delete;
122  Impl(const Impl &) = delete;
123 
124  // Copy state to local memory for efficiency.
125  __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
126  : global_gen_(gen),
127  global_state_idx_(state_idx),
128  state_(*(gen->states_ + state_idx)) {}
129 
130  __device__ ~Impl() {
131  // store the curand state back into global memory
132  global_gen_->states_[global_state_idx_] = state_;
133  }
134 
135  MSHADOW_FORCE_INLINE __device__ int rand() {
136  return curand(&state_);
137  }
138 
139  MSHADOW_FORCE_INLINE __device__ float uniform() {
140  return static_cast<float>(1.0) - curand_uniform(&state_);
141  }
142 
143  MSHADOW_FORCE_INLINE __device__ float normal() {
144  return curand_normal(&state_);
145  }
146 
147  private:
148  RandGenerator<gpu, DType> *global_gen_;
149  int global_state_idx_;
150  curandStatePhilox4_32_10_t state_;
151  }; // class RandGenerator<gpu, DType>::Impl
152 
154  CUDA_CALL(cudaMalloc(&inst->states_,
155  kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
156  }
157 
159  CUDA_CALL(cudaFree(inst->states_));
160  }
161 
162  void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
163 
164  private:
165  curandStatePhilox4_32_10_t *states_;
166 }; // class RandGenerator<gpu, DType>
167 
168 template<>
169 class RandGenerator<gpu, double> {
170  public:
171  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
172  // by using 1.0-curand_uniform().
173  // Needed as some samplers in sampler.h won't be able to deal with
174  // one of the boundary cases.
175  class Impl {
176  public:
177  Impl &operator=(const Impl &) = delete;
178  Impl(const Impl &) = delete;
179 
180  // Copy state to local memory for efficiency.
181  __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
182  : global_gen_(gen),
183  global_state_idx_(state_idx),
184  state_(*(gen->states_ + state_idx)) {}
185 
186  __device__ ~Impl() {
187  // store the curand state back into global memory
188  global_gen_->states_[global_state_idx_] = state_;
189  }
190 
191  MSHADOW_FORCE_INLINE __device__ int rand() {
192  return curand(&state_);
193  }
194 
195  MSHADOW_FORCE_INLINE __device__ double uniform() {
196  return static_cast<float>(1.0) - curand_uniform_double(&state_);
197  }
198 
199  MSHADOW_FORCE_INLINE __device__ double normal() {
200  return curand_normal_double(&state_);
201  }
202 
203  private:
204  RandGenerator<gpu, double> *global_gen_;
205  int global_state_idx_;
206  curandStatePhilox4_32_10_t state_;
207  }; // class RandGenerator<gpu, double>::Impl
208 
209  private:
210  curandStatePhilox4_32_10_t *states_;
211 }; // class RandGenerator<gpu, double>
212 
213 #endif // MXNET_USE_CUDA
214 
215 } // namespace random
216 } // namespace common
217 } // namespace mxnet
218 #endif // MXNET_COMMON_RANDOM_GENERATOR_H_
static void FreeState(RandGenerator< gpu, DType > *inst)
Definition: random_generator.h:158
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:143
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:83
Definition: random_generator.h:169
namespace of mxnet
Definition: base.h:127
static const int kMinNumRandomPerThread
Definition: random_generator.h:111
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:125
MSHADOW_XINLINE int rand()
Definition: random_generator.h:64
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:74
__device__ ~Impl()
Definition: random_generator.h:130
static const int kNumRandomStates
Definition: random_generator.h:113
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:191
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:87
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:58
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:139
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:56
static const int kNumRandomStates
Definition: random_generator.h:50
Definition: random_generator.h:108
mshadow::gpu gpu
mxnet gpu
Definition: base.h:131
static const int kMinNumRandomPerThread
Definition: random_generator.h:48
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:66
mshadow::cpu cpu
mxnet cpu
Definition: base.h:129
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:135
__device__ ~Impl()
Definition: random_generator.h:186
static void AllocState(RandGenerator< gpu, DType > *inst)
Definition: random_generator.h:153
Definition: random_generator.h:42
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:199
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:187
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:195
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:181
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:91