mxnet
random_generator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_RANDOM_GENERATOR_H_
26 #define MXNET_COMMON_RANDOM_GENERATOR_H_
27 
28 #include <mxnet/base.h>
29 #include <random>
30 #include <new>
31 
32 #if MXNET_USE_CUDA
33 #include <curand_kernel.h>
34 #include "../common/cuda_utils.h"
35 #endif // MXNET_USE_CUDA
36 
37 using namespace mshadow;
38 
39 namespace mxnet {
40 namespace common {
41 namespace random {
42 
43 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
45 
46 template<typename DType>
47 class RandGenerator<cpu, DType> {
48  public:
49  // at least how many random numbers should be generated by one CPU thread.
50  static const int kMinNumRandomPerThread;
51  // store how many global random states for CPU.
52  static const int kNumRandomStates;
53 
54  // implementation class for random number generator
55  class Impl {
56  public:
57  typedef typename std::conditional<std::is_floating_point<DType>::value,
58  DType, double>::type FType;
59 
60  explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
61  : engine_(gen->states_ + state_idx) {}
62 
63  Impl(const Impl &) = delete;
64  Impl &operator=(const Impl &) = delete;
65 
66  MSHADOW_XINLINE int rand() { return engine_->operator()(); }
67 
68  MSHADOW_XINLINE FType uniform() {
69  typedef typename std::conditional<std::is_integral<DType>::value,
70  std::uniform_int_distribution<DType>,
71  std::uniform_real_distribution<FType>>::type GType;
72  GType dist_uniform;
73  return dist_uniform(*engine_);
74  }
75 
76  MSHADOW_XINLINE FType normal() {
77  std::normal_distribution<FType> dist_normal;
78  return dist_normal(*engine_);
79  }
80 
81  private:
82  std::mt19937 *engine_;
83  };
84 
86  inst->states_ = new std::mt19937[kNumRandomStates];
87  }
88 
89  static void FreeState(RandGenerator<cpu, DType> *inst) {
90  delete[] inst->states_;
91  }
92 
93  MSHADOW_XINLINE void Seed(Stream<cpu> *, uint32_t seed) {
94  for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
95  }
96 
97  private:
98  std::mt19937 *states_;
99 }; // class RandGenerator<cpu, DType>
100 
101 template<typename DType>
103 
104 template<typename DType>
106 
107 #if MXNET_USE_CUDA
108 
109 template<typename DType>
110 class RandGenerator<gpu, DType> {
111  public:
112  // at least how many random numbers should be generated by one GPU thread.
113  static const int kMinNumRandomPerThread;
114  // store how many global random states for GPU.
115  static const int kNumRandomStates;
116 
117  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
118  // by using 1.0-curand_uniform().
119  // Needed as some samplers in sampler.h won't be able to deal with
120  // one of the boundary cases.
121  class Impl {
122  public:
123  Impl &operator=(const Impl &) = delete;
124  Impl(const Impl &) = delete;
125 
126  // Copy state to local memory for efficiency.
127  __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
128  : global_gen_(gen),
129  global_state_idx_(state_idx),
130  state_(*(gen->states_ + state_idx)) {}
131 
132  __device__ ~Impl() {
133  // store the curand state back into global memory
134  global_gen_->states_[global_state_idx_] = state_;
135  }
136 
137  MSHADOW_FORCE_INLINE __device__ int rand() {
138  return curand(&state_);
139  }
140 
141  MSHADOW_FORCE_INLINE __device__ float uniform() {
142  return static_cast<float>(1.0) - curand_uniform(&state_);
143  }
144 
145  MSHADOW_FORCE_INLINE __device__ float normal() {
146  return curand_normal(&state_);
147  }
148 
149  private:
150  RandGenerator<gpu, DType> *global_gen_;
151  int global_state_idx_;
152  curandStatePhilox4_32_10_t state_;
153  }; // class RandGenerator<gpu, DType>::Impl
154 
156  CUDA_CALL(cudaMalloc(&inst->states_,
157  kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
158  }
159 
161  CUDA_CALL(cudaFree(inst->states_));
162  }
163 
164  void Seed(Stream<gpu> *s, uint32_t seed);
165 
166  private:
167  curandStatePhilox4_32_10_t *states_;
168 }; // class RandGenerator<gpu, DType>
169 
170 template<>
171 class RandGenerator<gpu, double> {
172  public:
173  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
174  // by using 1.0-curand_uniform().
175  // Needed as some samplers in sampler.h won't be able to deal with
176  // one of the boundary cases.
177  class Impl {
178  public:
179  Impl &operator=(const Impl &) = delete;
180  Impl(const Impl &) = delete;
181 
182  // Copy state to local memory for efficiency.
183  __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
184  : global_gen_(gen),
185  global_state_idx_(state_idx),
186  state_(*(gen->states_ + state_idx)) {}
187 
188  __device__ ~Impl() {
189  // store the curand state back into global memory
190  global_gen_->states_[global_state_idx_] = state_;
191  }
192 
193  MSHADOW_FORCE_INLINE __device__ int rand() {
194  return curand(&state_);
195  }
196 
197  MSHADOW_FORCE_INLINE __device__ double uniform() {
198  return static_cast<float>(1.0) - curand_uniform_double(&state_);
199  }
200 
201  MSHADOW_FORCE_INLINE __device__ double normal() {
202  return curand_normal_double(&state_);
203  }
204 
205  private:
206  RandGenerator<gpu, double> *global_gen_;
207  int global_state_idx_;
208  curandStatePhilox4_32_10_t state_;
209  }; // class RandGenerator<gpu, double>::Impl
210 
211  private:
212  curandStatePhilox4_32_10_t *states_;
213 }; // class RandGenerator<gpu, double>
214 
215 #endif // MXNET_USE_CUDA
216 
217 } // namespace random
218 } // namespace common
219 } // namespace mxnet
220 #endif // MXNET_COMMON_RANDOM_GENERATOR_H_
static void FreeState(RandGenerator< gpu, DType > *inst)
Definition: random_generator.h:160
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:145
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:85
Definition: random_generator.h:171
namespace of mxnet
Definition: base.h:127
static const int kMinNumRandomPerThread
Definition: random_generator.h:113
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:127
MSHADOW_XINLINE int rand()
Definition: random_generator.h:66
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:76
__device__ ~Impl()
Definition: random_generator.h:132
static const int kNumRandomStates
Definition: random_generator.h:115
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:193
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:89
MSHADOW_XINLINE void Seed(Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:93
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:60
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:141
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:58
static const int kNumRandomStates
Definition: random_generator.h:52
Definition: random_generator.h:110
mshadow::gpu gpu
mxnet gpu
Definition: base.h:131
static const int kMinNumRandomPerThread
Definition: random_generator.h:50
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:68
mshadow::cpu cpu
mxnet cpu
Definition: base.h:129
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:137
__device__ ~Impl()
Definition: random_generator.h:188
static void AllocState(RandGenerator< gpu, DType > *inst)
Definition: random_generator.h:155
Definition: random_generator.h:44
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:201
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:187
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:197
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:183