mxnet
random_generator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_RANDOM_GENERATOR_H_
25 #define MXNET_RANDOM_GENERATOR_H_
26 
27 #include <random>
28 #include <new>
29 #include "./base.h"
30 
31 #if MXNET_USE_CUDA
32 #include <curand_kernel.h>
33 #include <math.h>
34 #endif // MXNET_USE_CUDA
35 
36 namespace mxnet {
37 namespace common {
38 namespace random {
39 
40 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
42 
43 template<typename DType>
44 class RandGenerator<cpu, DType> {
45  public:
46  // at least how many random numbers should be generated by one CPU thread.
47  static const int kMinNumRandomPerThread;
48  // store how many global random states for CPU.
49  static const int kNumRandomStates;
50 
51  // implementation class for random number generator
52  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
53  class Impl {
54  public:
55  typedef typename std::conditional<std::is_floating_point<DType>::value,
56  DType, double>::type FType;
57  explicit Impl(RandGenerator<cpu, DType> *gen, int state_idx)
58  : engine_(gen->states_ + state_idx) {}
59 
60  Impl(const Impl &) = delete;
61  Impl &operator=(const Impl &) = delete;
62 
63  MSHADOW_XINLINE int rand() { return engine_->operator()(); }
64 
66  return static_cast<int64_t>(engine_->operator()() << 31) + engine_->operator()();
67  }
68 
70  typedef typename std::conditional<std::is_integral<DType>::value,
71  std::uniform_int_distribution<DType>,
72  std::uniform_real_distribution<FType>>::type GType;
73  GType dist_uniform;
74  return dist_uniform(*engine_);
75  }
76 
78  std::normal_distribution<FType> dist_normal;
79  return dist_normal(*engine_);
80  }
81 
82  private:
83  std::mt19937 *engine_;
84  }; // class RandGenerator<cpu, DType>::Impl
85 
87  inst->states_ = new std::mt19937[kNumRandomStates];
88  }
89 
90  static void FreeState(RandGenerator<cpu, DType> *inst) {
91  delete[] inst->states_;
92  }
93 
94  MSHADOW_XINLINE void Seed(mshadow::Stream<cpu> *, uint32_t seed) {
95  for (int i = 0; i < kNumRandomStates; ++i) (states_ + i)->seed(seed + i);
96  }
97 
98  // export global random states, used by c++ custom operator
100  return static_cast<void*>(states_);
101  }
102 
103  private:
104  std::mt19937 *states_;
105 }; // class RandGenerator<cpu, DType>
106 
107 template<typename DType>
109 
110 template<typename DType>
112 
113 #if MXNET_USE_CUDA
114 
115 template<typename DType>
116 class RandGenerator<gpu, DType> {
117  public:
118  // at least how many random numbers should be generated by one GPU thread.
119  static const int kMinNumRandomPerThread;
120  // store how many global random states for GPU.
121  static const int kNumRandomStates;
122 
123  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
124  // by using 1.0-curand_uniform().
125  // Needed as some samplers in sampler.h won't be able to deal with
126  // one of the boundary cases.
127  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
128  class Impl {
129  public:
130  Impl &operator=(const Impl &) = delete;
131  Impl(const Impl &) = delete;
132 
133  // Copy state to local memory for efficiency.
134  __device__ explicit Impl(RandGenerator<gpu, DType> *gen, int state_idx)
135  : global_gen_(gen),
136  global_state_idx_(state_idx),
137  state_(*(gen->states_ + state_idx)) {}
138 
139  __device__ ~Impl() {
140  // store the curand state back into global memory
141  global_gen_->states_[global_state_idx_] = state_;
142  }
143 
144  MSHADOW_FORCE_INLINE __device__ int rand() {
145  return curand(&state_);
146  }
147 
148  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
149  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
150  }
151 
152  MSHADOW_FORCE_INLINE __device__ float uniform() {
153  return static_cast<float>(1.0) - curand_uniform(&state_);
154  }
155 
156  MSHADOW_FORCE_INLINE __device__ float normal() {
157  return curand_normal(&state_);
158  }
159 
160  private:
161  RandGenerator<gpu, DType> *global_gen_;
162  int global_state_idx_;
163  curandStatePhilox4_32_10_t state_;
164  }; // class RandGenerator<gpu, DType>::Impl
165 
166  static void AllocState(RandGenerator<gpu, DType> *inst);
167 
168  static void FreeState(RandGenerator<gpu, DType> *inst);
169 
170  void Seed(mshadow::Stream<gpu> *s, uint32_t seed);
171 
172  // export global random states, used by c++ custom operator
173  void* GetStates();
174 
175  private:
176  curandStatePhilox4_32_10_t *states_;
177 }; // class RandGenerator<gpu, DType>
178 
179 template<>
180 class RandGenerator<gpu, double> {
181  public:
182  // uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
183  // by using 1.0-curand_uniform().
184  // Needed as some samplers in sampler.h won't be able to deal with
185  // one of the boundary cases.
186  // TODO(alexzai): move impl class to separate file - tracked in MXNET-948
187  class Impl {
188  public:
189  Impl &operator=(const Impl &) = delete;
190  Impl(const Impl &) = delete;
191 
192  // Copy state to local memory for efficiency.
193  __device__ explicit Impl(RandGenerator<gpu, double> *gen, int state_idx)
194  : global_gen_(gen),
195  global_state_idx_(state_idx),
196  state_(*(gen->states_ + state_idx)) {}
197 
198  __device__ ~Impl() {
199  // store the curand state back into global memory
200  global_gen_->states_[global_state_idx_] = state_;
201  }
202 
203  MSHADOW_FORCE_INLINE __device__ int rand() {
204  return curand(&state_);
205  }
206 
207  MSHADOW_FORCE_INLINE __device__ int64_t rand_int64() {
208  return static_cast<int64_t>(curand(&state_) << 31) + curand(&state_);
209  }
210 
211  MSHADOW_FORCE_INLINE __device__ double uniform() {
212  return static_cast<float>(1.0) - curand_uniform_double(&state_);
213  }
214 
215  MSHADOW_FORCE_INLINE __device__ double normal() {
216  return curand_normal_double(&state_);
217  }
218 
219  private:
220  RandGenerator<gpu, double> *global_gen_;
221  int global_state_idx_;
222  curandStatePhilox4_32_10_t state_;
223  }; // class RandGenerator<gpu, double>::Impl
224 
225  private:
226  curandStatePhilox4_32_10_t *states_;
227 }; // class RandGenerator<gpu, double>
228 
229 #endif // MXNET_USE_CUDA
230 
231 } // namespace random
232 } // namespace common
233 } // namespace mxnet
234 #endif // MXNET_RANDOM_GENERATOR_H_
MSHADOW_FORCE_INLINE __device__ float normal()
Definition: random_generator.h:156
static void AllocState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:86
Definition: random_generator.h:180
namespace of mxnet
Definition: api_registry.h:33
Definition: stream_gpu-inl.h:37
static const int kMinNumRandomPerThread
Definition: random_generator.h:119
__device__ Impl(RandGenerator< gpu, DType > *gen, int state_idx)
Definition: random_generator.h:134
MSHADOW_XINLINE int rand()
Definition: random_generator.h:63
MSHADOW_XINLINE FType normal()
Definition: random_generator.h:77
__device__ ~Impl()
Definition: random_generator.h:139
#define MSHADOW_FORCE_INLINE
Definition: base.h:225
static const int kNumRandomStates
Definition: random_generator.h:121
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:203
static void FreeState(RandGenerator< cpu, DType > *inst)
Definition: random_generator.h:90
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:207
device name CPU
Definition: tensor.h:39
Impl(RandGenerator< cpu, DType > *gen, int state_idx)
Definition: random_generator.h:57
device name GPU
Definition: tensor.h:46
#define MSHADOW_XINLINE
Definition: base.h:230
MSHADOW_FORCE_INLINE __device__ float uniform()
Definition: random_generator.h:152
std::conditional< std::is_floating_point< DType >::value, DType, double >::type FType
Definition: random_generator.h:56
static const int kNumRandomStates
Definition: random_generator.h:49
Definition: random_generator.h:116
MSHADOW_XINLINE int64_t rand_int64()
Definition: random_generator.h:65
MSHADOW_FORCE_INLINE __device__ int64_t rand_int64()
Definition: random_generator.h:148
static const int kMinNumRandomPerThread
Definition: random_generator.h:47
MSHADOW_XINLINE FType uniform()
Definition: random_generator.h:69
MSHADOW_XINLINE void * GetStates()
Definition: random_generator.h:99
MSHADOW_FORCE_INLINE __device__ int rand()
Definition: random_generator.h:144
__device__ ~Impl()
Definition: random_generator.h:198
Definition: random_generator.h:41
MSHADOW_FORCE_INLINE __device__ double normal()
Definition: random_generator.h:215
MSHADOW_FORCE_INLINE __device__ double uniform()
Definition: random_generator.h:211
__device__ Impl(RandGenerator< gpu, double > *gen, int state_idx)
Definition: random_generator.h:193
MSHADOW_XINLINE void Seed(mshadow::Stream< cpu > *, uint32_t seed)
Definition: random_generator.h:94
computaion stream structure, used for asynchronous computations
Definition: tensor.h:383