mxnet
random.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MSHADOW_RANDOM_H_
28 #define MSHADOW_RANDOM_H_
29 
30 #include <cstdlib>
31 #include <algorithm>
32 #include <random>
33 #include "./base.h"
34 #include "./tensor.h"
35 #include "./tensor_container.h"
36 
37 #if MSHADOW_IN_CXX11
38 #include <random> // use cxx11 random by default
39 #endif
40 
41 #if _MSC_VER
42 #define rand_r(x) rand()
43 #endif
44 
45 
46 namespace mshadow {
52 template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
53 class Random {};
54 
56 template<typename DType>
57 class Random<cpu, DType> {
58  public:
63  explicit Random(int seed) {
64  this->Seed(seed);
65  buffer_.Resize(Shape1(kRandBufferSize));
66  }
67  ~Random(void) {
68  }
73  inline void Seed(int seed) {
74 #if MSHADOW_IN_CXX11
75  rnd_engine_.seed(seed);
76 #endif
77  this->rseed_ = static_cast<unsigned>(seed);
78  }
83  inline unsigned GetSeed() const {
84  return rseed_;
85  }
90  inline void set_stream(Stream<cpu> *stream) {
91  }
92 
93 // These samplers are only avail in C++11.
94 #if MSHADOW_IN_CXX11
95 
100  inline unsigned GetRandInt() {
101  return rnd_engine_();
102  }
103 
107  inline void GetRandInt(const Tensor<cpu, 1, unsigned>& dst) {
108  std::generate_n(dst.dptr_, dst.size(0), [&](){ return rnd_engine_(); });
109  }
110 
117  template<int dim, class Sampler>
118  inline void SampleDistribution(Tensor<cpu, dim, DType> *dst, Sampler sampler) {
119  if (dst->CheckContiguous()) {
120  std::generate_n(dst->dptr_, dst->shape_.Size(), sampler);
121  } else {
122  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
123  for (index_t i = 0; i < mat.size(0); ++i) {
124  std::generate_n(mat[i].dptr_, mat.size(1), sampler);
125  }
126  }
127  }
128 
136  template<int dim, typename PType>
137  inline void SampleUniform(Tensor<cpu, dim, DType> *dst,
138  PType a = 0.0f , PType b = 1.0f ) {
139  // Ensure that half_t is handled correctly.
140  typedef typename std::conditional<std::is_floating_point<DType>::value,
141  DType, double>::type FType;
142  typedef typename std::conditional<std::is_integral<DType>::value,
143  std::uniform_int_distribution<DType>,
144  std::uniform_real_distribution<FType>>::type GType;
145  GType dist_uniform(a, b);
146  SampleDistribution(dst, [&](){ return dist_uniform(rnd_engine_);});
147  }
148 
156  template<int dim, typename PType>
157  inline void SampleGaussian(Tensor<cpu, dim, DType> *dst,
158  PType mu = 0.0f, PType sigma = 1.0f ) {
159  if (sigma <= 0) {
160  *dst = mu; return;
161  }
162  typedef typename std::conditional<std::is_floating_point<DType>::value,
163  DType, double>::type GType;
164  std::normal_distribution<GType> dist_normal(mu, sigma);
165  SampleDistribution(dst, [&](){ return dist_normal(rnd_engine_);});
166  }
167 
175  template<int dim, typename PType>
176  inline void SampleGamma(Tensor<cpu, dim, DType> *dst,
177  PType alpha, PType beta) {
178  typedef typename std::conditional<std::is_floating_point<DType>::value,
179  DType, double>::type GType;
180  std::gamma_distribution<GType> dist_gamma(alpha, beta);
181  SampleDistribution(dst, [&](){ return dist_gamma(rnd_engine_);});
182  }
183 
190  template<int dim, typename PType>
191  inline void SampleExponential(Tensor<cpu, dim, DType> *dst, PType lambda ) {
192  typedef typename std::conditional<std::is_floating_point<DType>::value,
193  DType, double>::type GType;
194  std::exponential_distribution<GType> dist_exp(lambda);
195  SampleDistribution(dst, [&](){ return dist_exp(rnd_engine_);});
196  }
197 
204  template<int dim, typename PType>
205  inline void SamplePoisson(Tensor<cpu, dim, DType> *dst, PType lambda) {
206  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
207  std::poisson_distribution<GType> dist_poisson(lambda);
208  SampleDistribution(dst, [&](){ return static_cast<DType>(dist_poisson(rnd_engine_));});
209  }
210 
218  template<int dim, typename PType1, typename PType2>
219  inline void SampleNegativeBinomial(Tensor<cpu, dim, DType> *dst, PType1 k, PType2 p) {
220  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
221  std::negative_binomial_distribution<GType> dist_negbinomial(k, p);
222  SampleDistribution(dst, [&](){ return static_cast<DType>(dist_negbinomial(rnd_engine_));});
223  }
224 
233  template<int dim, typename PType>
234  inline void SampleGeneralizedNegativeBinomial(Tensor<cpu, dim, DType> *dst,
235  PType mu, PType alpha) {
236  if (alpha == PType(0)) {
237  SamplePoisson(dst, mu); // limit of Poisson
238  } else {
239  PType r(PType(1) / alpha);
240  PType beta = mu * alpha;
241  std::gamma_distribution<> dist_gamma(r, beta);
242  typedef typename std::conditional<std::is_integral<DType>::value, DType, int>::type GType;
243  SampleDistribution(dst,
244  [&](){ std::poisson_distribution<GType> dist_poisson(dist_gamma(rnd_engine_));
245  return static_cast<DType>(dist_poisson(rnd_engine_));});
246  }
247  }
248 #endif
249 
261  template<int dim>
262  inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
264  buffer_.Resize(Shape1(shape.Size()));
265  this->SampleGaussian(&buffer_, 0.0f, 1.0f);
266  return expr::reshape(buffer_, shape);
267  }
279  template<int dim>
280  inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
282  buffer_.Resize(Shape1(shape.Size()));
283  this->SampleUniform(&buffer_, 0.0f, 1.0f);
284  return expr::reshape(buffer_, shape);
285  }
286 
287  std::mt19937 &GetRndEngine() {
288  return rnd_engine_;
289  }
290 
291  private:
292 #if MSHADOW_IN_CXX11
293 
294  std::mt19937 rnd_engine_;
296  unsigned rseed_;
297 
298 #else
299 
301  unsigned rseed_;
302  // functions
303  template<int dim>
304  inline void SampleUniform(Tensor<cpu, dim, DType> *dst,
305  DType a = 0.0f, DType b = 1.0f) {
306  if (dst->CheckContiguous()) {
307  this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b);
308  } else {
309  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
310  for (index_t i = 0; i < mat.size(0); ++i) {
311  this->GenUniform(mat[i].dptr_, mat.size(1), a, b);
312  }
313  }
314  }
315  template<int dim>
316  inline void SampleGaussian(Tensor<cpu, dim, DType> *dst,
317  DType mu = 0.0f, DType sigma = 1.0f) {
318  if (sigma <= 0.0f) {
319  *dst = mu; return;
320  }
321  if (dst->CheckContiguous()) {
322  this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
323  } else {
324  Tensor<cpu, 2, DType> mat = dst->FlatTo2D();
325  for (index_t i = 0; i < mat.size(0); ++i) {
326  this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma);
327  }
328  }
329  }
330  inline void GenUniform(float *dptr, index_t size, float a, float b) {
331  for (index_t j = 0; j < size; ++j) {
332  dptr[j] = static_cast<float>(RandNext()) * (b - a) + a;
333  }
334  }
335  inline void GenUniform(double *dptr, index_t size, double a, double b) {
336  for (index_t j = 0; j < size; ++j) {
337  dptr[j] = static_cast<double>(RandNext()) * (b - a) + a;
338  }
339  }
340  inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) {
341  this->GenGaussianX(dptr, size, mu, sigma);
342  }
343  inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) {
344  this->GenGaussianX(dptr, size, mu, sigma);
345  }
346  inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) {
347  DType g1 = 0.0f, g2 = 0.0f;
348  for (index_t j = 0; j < size; ++j) {
349  if ((j & 1) == 0) {
350  this->SampleNormal2D(&g1, &g2);
351  dptr[j] = mu + g1 * sigma;
352  } else {
353  dptr[j] = mu + g2 * sigma;
354  }
355  }
356  }
358  inline DType RandNext(void) {
359  return static_cast<DType>(rand_r(&rseed_)) /
360  (static_cast<DType>(RAND_MAX) + 1.0f);
361  }
363  inline DType RandNext2(void) {
364  return (static_cast<DType>(rand_r(&rseed_)) + 1.0f) /
365  (static_cast<DType>(RAND_MAX) + 2.0f);
366  }
372  inline void SampleNormal2D(DType *xx_, DType *yy_) {
373  DType &xx = *xx_, &yy = *yy_;
374  DType x, y, s;
375  do {
376  x = 2.0f * RandNext2() - 1.0f;
377  y = 2.0f * RandNext2() - 1.0f;
378  s = x * x + y * y;
379  } while (s >= 1.0f || s == 0.0f);
380  DType t = std::sqrt(-2.0f * std::log(s) / s);
381  xx = x * t; yy = y * t;
382  }
383 #endif
384 
386 }; // class Random<cpu, DType>
387 
388 // only allow GPU PRNG when cuda is enabled
389 #if MSHADOW_USE_CUDA
390 
391 template<typename DType>
392 class Random<gpu, DType> {
393  public:
398  explicit Random(int seed) : gen_(NULL) {
399  this->Seed(seed);
400  buffer_.Resize(Shape1(kRandBufferSize));
401  }
403  DeleteGenerator();
404  }
409  inline void set_stream(Stream<gpu> *stream) {
410  curandStatus_t status;
411  status = curandSetStream(gen_, Stream<gpu>::GetStream(stream));
412 
413  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed";
414  }
419  inline void Seed(int seed) {
420  // Create a new rng, either initially or if the RNG type can't reset its offset.
421  if (gen_ == NULL || (curandSetGeneratorOffset(gen_, 0ULL) != CURAND_STATUS_SUCCESS))
422  CreateGenerator();
423  // Now set the seed.
424  curandStatus_t status;
425  status = curandSetPseudoRandomGeneratorSeed(gen_, static_cast<uint64_t>(seed));
426  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed.";
427  }
431  inline void GetRandInt(const Tensor<gpu, 1, unsigned>& dst) {
432  curandStatus_t status = curandGenerate(gen_, dst.dptr_, dst.size(0));
433  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen rand ints failed.";
434  }
442  template<int dim>
443  inline void SampleUniform(Tensor<gpu, dim, DType> *dst,
444  DType a = 0.0f, DType b = 1.0f);
445 
453  template<int dim>
454  inline void SampleGaussian(Tensor<gpu, dim, DType> *dst,
455  DType mu = 0.0f, DType sigma = 1.0f);
469  template<int dim>
470  inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
471  gaussian(Shape<dim> shape, DType mu = 0.0f, DType sigma = 1.0f);
483  template<int dim>
484  inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
485  uniform(Shape<dim> shape);
486 
487  private:
488  inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) {
489  curandStatus_t status;
490  status = curandGenerateNormal(gen_, dptr, size, mu, sigma);
491  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed."
492  << " size = " << size
493  << ",mu = " << mu
494  << ",sigma = " << sigma;
495  }
496  inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) {
497  curandStatus_t status;
498  status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma);
499  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed."
500  << " size = " << size
501  << ",mu = " << mu
502  << ",sigma = " << sigma;
503  }
504  inline void GenUniform(float *dptr, size_t size) {
505  curandStatus_t status;
506  status = curandGenerateUniform(gen_, dptr, size);
507  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed."
508  << " size = " << size;
509  }
510  inline void GenUniform(double *dptr, size_t size) {
511  curandStatus_t status;
512  status = curandGenerateUniformDouble(gen_, dptr, size);
513  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed."
514  << " size = " << size;
515  }
516  inline void CreateGenerator() {
517  if (gen_ != NULL)
518  DeleteGenerator();
519  curandStatus_t status;
520  status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT);
521  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Cannot create CURAND Generator";
522  }
523  inline void DeleteGenerator() {
524  if (gen_ != NULL) {
525  curandStatus_t status;
526  status = curandDestroyGenerator(gen_);
527  CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed";
528  gen_ = NULL;
529  }
530  }
532  curandGenerator_t gen_;
535 }; // class Random<gpu, DType>
536 #endif // MSHADOW_USE_CUDA
537 
538 #ifdef __CUDACC__
539 // implementations that depends on cuda kernels
540 template<typename DType>
541 template<int dim>
543  Tensor<gpu, dim, DType> *dst, DType a, DType b) {
544  if (a == 0.0f && b == 1.0f) {
545  if (dst->CheckContiguous()) {
546  this->GenUniform(dst->dptr_, dst->shape_.Size());
547  } else {
548  *dst = this->uniform(dst->shape_);
549  }
550  } else {
551  *dst = this->uniform(dst->shape_) * (b - a) + a;
552  }
553 }
554 template<typename DType>
555 template<int dim>
557  Tensor<gpu, dim, DType> *dst, DType mu, DType sigma) {
558  // We need to check whether the shape size is even since CuRand supports only normal distribution
559  // generation of even number of elements.
560  if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) {
561  this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma);
562  } else {
563  *dst = this->gaussian(dst->shape_, mu, sigma);
564  }
565 }
566 
567 template<typename DType>
568 template<int dim>
569 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
570 Random<gpu, DType>::gaussian(Shape<dim> shape, DType mu, DType sigma) {
571  size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1;
572  // allocate alligned size
573  buffer_.Resize(Shape1(aligned_sz));
574  buffer_.Resize(Shape1(shape.Size()));
575  this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma);
576  return expr::reshape(buffer_, shape);
577 }
578 
579 template<typename DType>
580 template<int dim>
581 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
583  buffer_.Resize(Shape1(shape.Size()));
584  this->GenUniform(buffer_.dptr_, buffer_.size(0));
585  return expr::reshape(buffer_, shape);
586 }
587 #endif // __CUDACC__
588 } // namespace mshadow
589 #endif // MSHADOW_RANDOM_H_
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > gaussian(Shape< dim > shape)
return a temporal expression storing standard gaussian random variables the temporal tensor is only v...
Definition: random.h:263
random number generator
Definition: random.h:53
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:419
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:145
DType * dptr_
pointer to the data
Definition: tensor.h:435
unsigned GetSeed() const
get random seed used in random generator
Definition: random.h:83
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:73
Definition: stream_gpu-inl.h:38
~Random(void) MSHADOW_THROW_EXCEPTION
Definition: random.h:402
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
~Random(void)
Definition: random.h:67
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > uniform(Shape< dim > shape)
return a temporal expression storing standard uniform [0,1) the temporal tensor is only valid before ...
Definition: random.h:281
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
device name CPU
Definition: tensor.h:40
device name GPU
Definition: tensor.h:47
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:520
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:329
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
int32_t index_t
type that will be used for index
Definition: base.h:336
ReshapeExp< SrcExp, DType, dimdst, ExpInfo< SrcExp >::kDim > reshape(const Exp< SrcExp, DType, etype > &src, Shape< dimdst > oshape)
a expression that reshapes a tensor to another shape
Definition: reshape.h:67
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
tensor container that does memory allocation and resize like STL, use it to save the lines of FreeSpa...
Definition: tensor_container.h:41
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:207
void GetRandInt(const Tensor< gpu, 1, unsigned > &dst)
get a set of random integers
Definition: random.h:431
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:492
reshape the content to another shape input: Tensor<Device,dimsrc>: ishape output: Tensor<Device...
Definition: reshape.h:40
std::mt19937 & GetRndEngine()
Definition: random.h:287
Random(int seed)
constructor of random engine
Definition: random.h:398
tensor container that does memory allocation and resize like STL
void set_stream(Stream< cpu > *stream)
set the stream of computation
Definition: random.h:90
Random(int seed)
constructor of random engine
Definition: random.h:63
overloaded + operator between half_t and bf16_t
Definition: base.h:327
#define MSHADOW_THROW_EXCEPTION
Definition: base.h:253
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:506
general tensor
Definition: tensor.h:421
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
void set_stream(Stream< gpu > *stream)
set the stream of computation
Definition: random.h:409
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384