mxnet
random.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_RANDOM_H_
27 #define MSHADOW_RANDOM_H_
28 
29 #include <cstdlib>
30 #include <algorithm>
31 #include <random>
32 #include "./base.h"
33 #include "./tensor.h"
34 #include "./tensor_container.h"
35 
36 #if MSHADOW_IN_CXX11
37 #include <random> // use cxx11 random by default
38 #endif
39 
40 #if _MSC_VER
41 #define rand_r(x) rand()
42 #endif
43 
44 
45 namespace mshadow {
51 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
52 class Random {};
53 
55 template<typename DType>
56 class Random<cpu, DType> {
57  public:
62  explicit Random(int seed) {
63  this->Seed(seed);
64  buffer_.Resize(Shape1(kRandBufferSize));
65  }
66  ~Random(void) {
67  }
72  inline void Seed(int seed) {
73 #if MSHADOW_IN_CXX11
74  rnd_engine_.seed(seed);
75 #endif
76  this->rseed_ = static_cast<unsigned>(seed);
77  }
82  inline unsigned GetSeed() const {
83  return rseed_;
84  }
89  inline void set_stream(Stream<cpu> *stream) {
90  }
91 
92 // These samplers are only avail in C++11.
93 #if MSHADOW_IN_CXX11
94 
99  inline unsigned GetRandInt() {
100  return rnd_engine_();
101  }
102 
106  inline void GetRandInt(const Tensor<cpu, 1, unsigned>& dst) {
107  std::generate_n(dst.dptr_, dst.size(0), [&](){ return rnd_engine_(); });
108  }
109 
116  template<int dim, class Sampler>
117  inline void SampleDistribution(Tensor<cpu, dim, DType> *dst, Sampler sampler) {
118  if (dst->CheckContiguous()) {
119  std::generate_n(dst->dptr_, dst->shape_.Size(), sampler);
120  } else {
121  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
122  for (index_t i = 0; i < mat.size(0); ++i) {
123  std::generate_n(mat[i].dptr_, mat.size(1), sampler);
124  }
125  }
126  }
127 
135  template<int dim, typename PType>
136  inline void SampleUniform(Tensor<cpu, dim, DType> *dst,
137  PType a = 0.0f , PType b = 1.0f ) {
138  // Ensure that half_t is handled correctly.
139  typedef typename std::conditional<std::is_floating_point<DType>::value,
140  DType, double>::type FType;
141  typedef typename std::conditional<std::is_integral<DType>::value,
142  std::uniform_int_distribution<DType>,
143  std::uniform_real_distribution<FType>>::type GType;
144  GType dist_uniform(a, b);
145  SampleDistribution(dst, [&](){ return dist_uniform(rnd_engine_);});
146  }
147 
155  template<int dim, typename PType>
156  inline void SampleGaussian(Tensor<cpu, dim, DType> *dst,
157  PType mu = 0.0f, PType sigma = 1.0f ) {
158  if (sigma <= 0) {
159  *dst = mu; return;
160  }
161  typedef typename std::conditional<std::is_floating_point<DType>::value,
162  DType, double>::type GType;
163  std::normal_distribution<GType> dist_normal(mu, sigma);
164  SampleDistribution(dst, [&](){ return dist_normal(rnd_engine_);});
165  }
166 
174  template<int dim, typename PType>
175  inline void SampleGamma(Tensor<cpu, dim, DType> *dst,
176  PType alpha, PType beta) {
177  typedef typename std::conditional<std::is_floating_point<DType>::value,
178  DType, double>::type GType;
179  std::gamma_distribution<GType> dist_gamma(alpha, beta);
180  SampleDistribution(dst, [&](){ return dist_gamma(rnd_engine_);});
181  }
182 
189  template<int dim, typename PType>
190  inline void SampleExponential(Tensor<cpu, dim, DType> *dst, PType lambda ) {
191  typedef typename std::conditional<std::is_floating_point<DType>::value,
192  DType, double>::type GType;
193  std::exponential_distribution<GType> dist_exp(lambda);
194  SampleDistribution(dst, [&](){ return dist_exp(rnd_engine_);});
195  }
196 
203  template<int dim, typename PType>
204  inline void SamplePoisson(Tensor<cpu, dim, DType> *dst, PType lambda) {
205  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
206  std::poisson_distribution<GType> dist_poisson(lambda);
207  SampleDistribution(dst, [&](){ return static_cast<DType>(dist_poisson(rnd_engine_));});
208  }
209 
217  template<int dim, typename PType1, typename PType2>
218  inline void SampleNegativeBinomial(Tensor<cpu, dim, DType> *dst, PType1 k, PType2 p) {
219  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
220  std::negative_binomial_distribution<GType> dist_negbinomial(k, p);
221  SampleDistribution(dst, [&](){ return static_cast<DType>(dist_negbinomial(rnd_engine_));});
222  }
223 
232  template<int dim, typename PType>
233  inline void SampleGeneralizedNegativeBinomial(Tensor<cpu, dim, DType> *dst,
234  PType mu, PType alpha) {
235  if (alpha == PType(0)) {
236  SamplePoisson(dst, mu); // limit of Poisson
237  } else {
238  PType r(PType(1) / alpha);
239  PType beta = mu * alpha;
240  std::gamma_distribution<> dist_gamma(r, beta);
241  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
242  SampleDistribution(dst,
243  [&](){ std::poisson_distribution<GType> dist_poisson(dist_gamma(rnd_engine_));
244  return static_cast<DType>(dist_poisson(rnd_engine_));});
245  }
246  }
247 #endif
248 
260  template<int dim>
261  inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
263  buffer_.Resize(Shape1(shape.Size()));
264  this->SampleGaussian(&buffer_, 0.0f, 1.0f);
265  return expr::reshape(buffer_, shape);
266  }
278  template<int dim>
279  inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
281  buffer_.Resize(Shape1(shape.Size()));
282  this->SampleUniform(&buffer_, 0.0f, 1.0f);
283  return expr::reshape(buffer_, shape);
284  }
285 
286  std::mt19937 &GetRndEngine() {
287  return rnd_engine_;
288  }
289 
290  private:
291 #if MSHADOW_IN_CXX11
292 
293  std::mt19937 rnd_engine_;
295  unsigned rseed_;
296 
297 #else
298 
300  unsigned rseed_;
301  // functions
302  template<int dim>
303  inline void SampleUniform(Tensor<cpu, dim, DType> *dst,
304  DType a = 0.0f, DType b = 1.0f) {
305  if (dst->CheckContiguous()) {
306  this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b);
307  } else {
308  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
309  for (index_t i = 0; i < mat.size(0); ++i) {
310  this->GenUniform(mat[i].dptr_, mat.size(1), a, b);
311  }
312  }
313  }
314  template<int dim>
315  inline void SampleGaussian(Tensor<cpu, dim, DType> *dst,
316  DType mu = 0.0f, DType sigma = 1.0f) {
317  if (sigma <= 0.0f) {
318  *dst = mu; return;
319  }
320  if (dst->CheckContiguous()) {
321  this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
322  } else {
323  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
324  for (index_t i = 0; i < mat.size(0); ++i) {
325  this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma);
326  }
327  }
328  }
329  inline void GenUniform(float *dptr, index_t size, float a, float b) {
330  for (index_t j = 0; j < size; ++j) {
331  dptr[j] = static_cast<float>(RandNext()) * (b - a) + a;
332  }
333  }
334  inline void GenUniform(double *dptr, index_t size, double a, double b) {
335  for (index_t j = 0; j < size; ++j) {
336  dptr[j] = static_cast<double>(RandNext()) * (b - a) + a;
337  }
338  }
339  inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) {
340  this->GenGaussianX(dptr, size, mu, sigma);
341  }
342  inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) {
343  this->GenGaussianX(dptr, size, mu, sigma);
344  }
345  inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) {
346  DType g1 = 0.0f, g2 = 0.0f;
347  for (index_t j = 0; j < size; ++j) {
348  if ((j & 1) == 0) {
349  this->SampleNormal2D(&g1, &g2);
350  dptr[j] = mu + g1 * sigma;
351  } else {
352  dptr[j] = mu + g2 * sigma;
353  }
354  }
355  }
357  inline DType RandNext(void) {
358  return static_cast<DType>(rand_r(&rseed_)) /
359  (static_cast<DType>(RAND_MAX) + 1.0f);
360  }
362  inline DType RandNext2(void) {
363  return (static_cast<DType>(rand_r(&rseed_)) + 1.0f) /
364  (static_cast<DType>(RAND_MAX) + 2.0f);
365  }
371  inline void SampleNormal2D(DType *xx_, DType *yy_) {
372  DType &xx = *xx_, &yy = *yy_;
373  DType x, y, s;
374  do {
375  x = 2.0f * RandNext2() - 1.0f;
376  y = 2.0f * RandNext2() - 1.0f;
377  s = x * x + y * y;
378  } while (s >= 1.0f || s == 0.0f);
379  DType t = std::sqrt(-2.0f * std::log(s) / s);
380  xx = x * t; yy = y * t;
381  }
382 #endif
383 
385 }; // class Random<cpu, DType>
386 
387 // only allow GPU PRNG when cuda is enabled
388 #if MSHADOW_USE_CUDA
389 
390 template<typename DType>
391 class Random<gpu, DType> {
392  public:
397  explicit Random(int seed) : gen_(NULL) {
398  this->Seed(seed);
399  buffer_.Resize(Shape1(kRandBufferSize));
400  }
402  DeleteGenerator();
403  }
408  inline void set_stream(Stream<gpu> *stream) {
409  curandStatus_t status;
410  status = curandSetStream(gen_, Stream<gpu>::GetStream(stream));
411 
412  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed";
413  }
418  inline void Seed(int seed) {
419  // Create a new rng, either initially or if the RNG type can't reset its offset.
420  if (gen_ == NULL || (curandSetGeneratorOffset(gen_, 0ULL) != CURAND_STATUS_SUCCESS))
421  CreateGenerator();
422  // Now set the seed.
423  curandStatus_t status;
424  status = curandSetPseudoRandomGeneratorSeed(gen_, static_cast<uint64_t>(seed));
425  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed.";
426  }
430  inline void GetRandInt(const Tensor<gpu, 1, unsigned>& dst) {
431  curandStatus_t status = curandGenerate(gen_, dst.dptr_, dst.size(0));
432  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen rand ints failed.";
433  }
441  template<int dim>
442  inline void SampleUniform(Tensor<gpu, dim, DType> *dst,
443  DType a = 0.0f, DType b = 1.0f);
444 
452  template<int dim>
453  inline void SampleGaussian(Tensor<gpu, dim, DType> *dst,
454  DType mu = 0.0f, DType sigma = 1.0f);
468  template<int dim>
469  inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
470  gaussian(Shape<dim> shape, DType mu = 0.0f, DType sigma = 1.0f);
482  template<int dim>
483  inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
484  uniform(Shape<dim> shape);
485 
486  private:
487  inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) {
488  curandStatus_t status;
489  status = curandGenerateNormal(gen_, dptr, size, mu, sigma);
490  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed."
491  << " size = " << size
492  << ",mu = " << mu
493  << ",sigma = " << sigma;
494  }
495  inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) {
496  curandStatus_t status;
497  status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma);
498  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed."
499  << " size = " << size
500  << ",mu = " << mu
501  << ",sigma = " << sigma;
502  }
503  inline void GenUniform(float *dptr, size_t size) {
504  curandStatus_t status;
505  status = curandGenerateUniform(gen_, dptr, size);
506  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed."
507  << " size = " << size;
508  }
509  inline void GenUniform(double *dptr, size_t size) {
510  curandStatus_t status;
511  status = curandGenerateUniformDouble(gen_, dptr, size);
512  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed."
513  << " size = " << size;
514  }
515  inline void CreateGenerator() {
516  if (gen_ != NULL)
517  DeleteGenerator();
518  curandStatus_t status;
519  status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT);
520  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Cannot create CURAND Generator";
521  }
522  inline void DeleteGenerator() {
523  if (gen_ != NULL) {
524  curandStatus_t status;
525  status = curandDestroyGenerator(gen_);
526  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed";
527  gen_ = NULL;
528  }
529  }
531  curandGenerator_t gen_;
534 }; // class Random<gpu, DType>
535 #endif // MSHADOW_USE_CUDA
536 
537 #ifdef __CUDACC__
538 // implementations that depends on cuda kernels
539 template<typename DType>
540 template<int dim>
542  Tensor<gpu, dim, DType> *dst, DType a, DType b) {
543  if (a == 0.0f && b == 1.0f) {
544  if (dst->CheckContiguous()) {
545  this->GenUniform(dst->dptr_, dst->shape_.Size());
546  } else {
547  *dst = this->uniform(dst->shape_);
548  }
549  } else {
550  *dst = this->uniform(dst->shape_) * (b - a) + a;
551  }
552 }
553 template<typename DType>
554 template<int dim>
556  Tensor<gpu, dim, DType> *dst, DType mu, DType sigma) {
557  // We need to check whether the shape size is even since CuRand supports only normal distribution
558  // generation of even number of elements.
559  if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) {
560  this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
561  } else {
562  *dst = this->gaussian(dst->shape_, mu, sigma);
563  }
564 }
565 
566 template<typename DType>
567 template<int dim>
568 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
569 Random<gpu, DType>::gaussian(Shape<dim> shape, DType mu, DType sigma) {
570  size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1;
571  // allocate alligned size
572  buffer_.Resize(Shape1(aligned_sz));
573  buffer_.Resize(Shape1(shape.Size()));
574  this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma);
575  return expr::reshape(buffer_, shape);
576 }
577 
578 template<typename DType>
579 template<int dim>
580 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
582  buffer_.Resize(Shape1(shape.Size()));
583  this->GenUniform(buffer_.dptr_, buffer_.size(0));
584  return expr::reshape(buffer_, shape);
585 }
586 #endif // __CUDACC__
587 } // namespace mshadow
588 #endif // MSHADOW_RANDOM_H_
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > gaussian(Shape< dim > shape)
return a temporal expression storing standard gaussian random variables the temporal tensor is only v...
Definition: random.h:262
random number generator
Definition: random.h:52
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:418
DType * dptr_
pointer to the data
Definition: tensor.h:434
unsigned GetSeed() const
get random seed used in random generator
Definition: random.h:82
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:72
Definition: stream_gpu-inl.h:37
~Random(void) MSHADOW_THROW_EXCEPTION
Definition: random.h:401
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:436
~Random(void)
Definition: random.h:66
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > uniform(Shape< dim > shape)
return a temporal expression storing standard uniform [0,1) the temporal tensor is only valid before ...
Definition: random.h:280
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
device name CPU
Definition: tensor.h:39
device name GPU
Definition: tensor.h:46
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:336
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
int32_t index_t
type that will be used for index
Definition: base.h:343
ReshapeExp< SrcExp, DType, dimdst, ExpInfo< SrcExp >::kDim > reshape(const Exp< SrcExp, DType, etype > &src, Shape< dimdst > oshape)
a expression that reshapes a tensor to another shape
Definition: reshape.h:66
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:519
tensor container that does memory allocation and resize like STL, use it to save the lines of FreeSpa...
Definition: tensor_container.h:40
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:505
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:491
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:206
void GetRandInt(const Tensor< gpu, 1, unsigned > &dst)
get a set of random integers
Definition: random.h:430
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
reshape the content to another shape input: Tensor<Device,dimsrc>: ishape output: Tensor<Device...
Definition: reshape.h:39
std::mt19937 & GetRndEngine()
Definition: random.h:286
Random(int seed)
constructor of random engine
Definition: random.h:397
tensor container that does memory allocation and resize like STL
void set_stream(Stream< cpu > *stream)
set the stream of computation
Definition: random.h:89
Random(int seed)
constructor of random engine
Definition: random.h:62
overloaded + operator between half_t and bf16_t
Definition: base.h:334
#define MSHADOW_THROW_EXCEPTION
Definition: base.h:260
general tensor
Definition: tensor.h:420
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
void set_stream(Stream< gpu > *stream)
set the stream of computation
Definition: random.h:408
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:144
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
computaion stream structure, used for asynchronous computations
Definition: tensor.h:383