mxnet
random_generator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_RANDOM_GENERATOR_H_
26 #define MXNET_RANDOM_GENERATOR_H_
27 
28 #include <random>
29 #include <new>
30 #include "./base.h"
31 
32 #if MXNET_USE_CUDA
33 #include <curand_kernel.h>
34 #include <math.h>
35 #endif // MXNET_USE_CUDA
36 
37 namespace mxnet {
38 namespace common {
39 namespace random {
40 
41 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
43 
44 template<typename DType>
45 class RandGenerator<cpu, DType> {
46  public:
47  // at least how many random numbers should be generated by one CPU thread.
48  static const int kMinNumRandomPerThread;
49  // store how many global random states for CPU.
50  static const int kNumRandomStates;
51 
52  // implementation class for random number generator
53  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
54  class Impl {
55  public:
56  typedef typename std::conditional<std::is_floating_point<DType>::value,
57  DType, double>::type FType;
58  explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
59  : engine_(gen->states_ + state_idx) {}
60 
61  Impl(const Impl &) = delete;
62  Impl &operator=(const Impl &) = delete;
63 
64  MSHADOW_XINLINE int rand() { return engine_->operator()(); }
65 
66  MSHADOW_XINLINE int64_t rand_int64() {
67  return static_cast<int64_t>(engine_->operator()() << 31) + engine_->operator()();
68  }
69 
70  MSHADOW_XINLINE FType uniform() {
71  typedef typename std::conditional<std::is_integral<DType>::value,
72  std::uniform_int_distribution<DType>,
73  std::uniform_real_distribution<FType>>::type GType;
74  GType dist_uniform;
75  return dist_uniform(*engine_);
76  }
77 
78  MSHADOW_XINLINE FType normal() {
79  std::normal_distribution<FType> dist_normal;
80  return dist_normal(*engine_);
81  }
82 
83  private:
84  std::mt19937 *engine_;
85  }; // class RandGenerator<cpu, DType>::Impl
86 
88  inst->states_ = new std::mt19937[kNumRandomStates];
89  }
90 
91  static void FreeState(RandGenerator<cpu, DType> *inst) {
92  delete[] inst->states_;
93  }
94 
95  MSHADOW_XINLINE void Seed(mshadow::Stream<cpu> *, uint32_t seed) {
96  for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
97  }
98 
99  private:
100  std::mt19937 *states_;
101 }; // class RandGenerator<cpu, DType>
102 
103 template<typename DType>
105 
106 template<typename DType>
108 
109 #if MXNET_USE_CUDA
110 
111 template<typename DType>
112 class RandGenerator<gpu, DType> {
113  public:
114  // at least how many random numbers should be generated by one GPU thread.
115  static const int kMinNumRandomPerThread;
116  // store how many global random states for GPU.
117  static const int kNumRandomStates;
118 
119  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
120  // by using 1.0-curand_uniform().
121  // Needed as some samplers in sampler.h won't be able to deal with
122  // one of the boundary cases.
123  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
124  class Impl {
125  public:
126  Impl &operator=(const Impl &) = delete;
127  Impl(const Impl &) = delete;
128 
129  // Copy state to local memory for efficiency.
130  __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
131  : global_gen_(gen),
132  global_state_idx_(state_idx),
133  state_(*(gen->states_ + state_idx)) {}
134 
135  __device__ ~Impl() {
136  // store the curand state back into global memory
137  global_gen_->states_[global_state_idx_] = state_;
138  }
139 
140  MSHADOW_FORCE_INLINE __device__ int rand() {
141  return curand(&state_);
142  }
143 
144  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
145  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
146  }
147 
148  MSHADOW_FORCE_INLINE __device__ float uniform() {
149  return static_cast<float>(1.0) - curand_uniform(&state_);
150  }
151 
152  MSHADOW_FORCE_INLINE __device__ float normal() {
153  return curand_normal(&state_);
154  }
155 
156  private:
157  RandGenerator<gpu, DType> *global_gen_;
158  int global_state_idx_;
159  curandStatePhilox4_32_10_t state_;
160  }; // class RandGenerator<gpu, DType>::Impl
161 
162  static void AllocState(RandGenerator<gpu, DType> *inst);
163 
164  static void FreeState(RandGenerator<gpu, DType> *inst);
165 
166  void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
167 
168  private:
169  curandStatePhilox4_32_10_t *states_;
170 }; // class RandGenerator<gpu, DType>
171 
172 template<>
173 class RandGenerator<gpu, double> {
174  public:
175  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
176  // by using 1.0-curand_uniform().
177  // Needed as some samplers in sampler.h won't be able to deal with
178  // one of the boundary cases.
179  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
180  class Impl {
181  public:
182  Impl &operator=(const Impl &) = delete;
183  Impl(const Impl &) = delete;
184 
185  // Copy state to local memory for efficiency.
186  __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
187  : global_gen_(gen),
188  global_state_idx_(state_idx),
189  state_(*(gen->states_ + state_idx)) {}
190 
191  __device__ ~Impl() {
192  // store the curand state back into global memory
193  global_gen_->states_[global_state_idx_] = state_;
194  }
195 
196  MSHADOW_FORCE_INLINE __device__ int rand() {
197  return curand(&state_);
198  }
199 
200  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
201  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
202  }
203 
204  MSHADOW_FORCE_INLINE __device__ double uniform() {
205  return static_cast<float>(1.0) - curand_uniform_double(&state_);
206  }
207 
208  MSHADOW_FORCE_INLINE __device__ double normal() {
209  return curand_normal_double(&state_);
210  }
211 
212  private:
213  RandGenerator<gpu, double> *global_gen_;
214  int global_state_idx_;
215  curandStatePhilox4_32_10_t state_;
216  }; // class RandGenerator<gpu, double>::Impl
217 
218  private:
219  curandStatePhilox4_32_10_t *states_;
220 }; // class RandGenerator<gpu, double>
221 
222 #endif // MXNET_USE_CUDA
223 
224 } // namespace random
225 } // namespace common
226 } // namespace mxnet
227 #endif // MXNET_RANDOM_GENERATOR_H_
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:152
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:87
Definition: random_generator.h:173
namespace of mxnet
Definition: base.h:89
static const int kMinNumRandomPerThread
Definition: random_generator.h:115
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:130
MSHADOW_XINLINE int rand()
Definition: random_generator.h:64
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:78
__device__ ~Impl()
Definition: random_generator.h:135
static const int kNumRandomStates
Definition: random_generator.h:117
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:196
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:91
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:200
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:58
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:148
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:57
static const int kNumRandomStates
Definition: random_generator.h:50
Definition: random_generator.h:112
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:66
mshadow::gpu gpu
mxnet gpu
Definition: base.h:93
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:144
static const int kMinNumRandomPerThread
Definition: random_generator.h:48
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:70
mshadow::cpu cpu
mxnet cpu
Definition: base.h:91
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:140
__device__ ~Impl()
Definition: random_generator.h:191
Definition: random_generator.h:42
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:208
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:204
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:186
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:95