27 #ifndef MSHADOW_RANDOM_H_ 28 #define MSHADOW_RANDOM_H_ 42 #define rand_r(x) rand() 52 template<
typename Device,
typename DType MSHADOW_DEFAULT_DTYPE>
56 template<
typename DType>
73 inline void Seed(
int seed) {
75 rnd_engine_.seed(seed);
77 this->rseed_ =
static_cast<unsigned>(seed);
100 inline unsigned GetRandInt() {
101 return rnd_engine_();
108 std::generate_n(dst.
dptr_, dst.
size(0), [&](){
return rnd_engine_(); });
117 template<
int dim,
class Sampler>
120 std::generate_n(dst->
dptr_, dst->
shape_.Size(), sampler);
124 std::generate_n(mat[i].dptr_, mat.
size(1), sampler);
136 template<
int dim,
typename PType>
138 PType a = 0.0f , PType b = 1.0f ) {
140 typedef typename std::conditional<std::is_floating_point<DType>::value,
141 DType,
double>::type FType;
142 typedef typename std::conditional<std::is_integral<DType>::value,
143 std::uniform_int_distribution<DType>,
144 std::uniform_real_distribution<FType>>::type GType;
145 GType dist_uniform(a, b);
146 SampleDistribution(dst, [&](){
return dist_uniform(rnd_engine_);});
156 template<
int dim,
typename PType>
158 PType mu = 0.0f, PType sigma = 1.0f ) {
162 typedef typename std::conditional<std::is_floating_point<DType>::value,
163 DType,
double>::type GType;
164 std::normal_distribution<GType> dist_normal(mu, sigma);
165 SampleDistribution(dst, [&](){
return dist_normal(rnd_engine_);});
175 template<
int dim,
typename PType>
177 PType alpha, PType beta) {
178 typedef typename std::conditional<std::is_floating_point<DType>::value,
179 DType,
double>::type GType;
180 std::gamma_distribution<GType> dist_gamma(alpha, beta);
181 SampleDistribution(dst, [&](){
return dist_gamma(rnd_engine_);});
190 template<
int dim,
typename PType>
192 typedef typename std::conditional<std::is_floating_point<DType>::value,
193 DType,
double>::type GType;
194 std::exponential_distribution<GType> dist_exp(lambda);
195 SampleDistribution(dst, [&](){
return dist_exp(rnd_engine_);});
204 template<
int dim,
typename PType>
206 typedef typename std::conditional<std::is_integral<DType>::value, DType,
int>::type GType;
207 std::poisson_distribution<GType> dist_poisson(lambda);
208 SampleDistribution(dst, [&](){
return static_cast<DType
>(dist_poisson(rnd_engine_));});
218 template<
int dim,
typename PType1,
typename PType2>
220 typedef typename std::conditional<std::is_integral<DType>::value, DType,
int>::type GType;
221 std::negative_binomial_distribution<GType> dist_negbinomial(k, p);
222 SampleDistribution(dst, [&](){
return static_cast<DType
>(dist_negbinomial(rnd_engine_));});
233 template<
int dim,
typename PType>
235 PType mu, PType alpha) {
236 if (alpha == PType(0)) {
239 PType r(PType(1) / alpha);
240 PType beta = mu * alpha;
241 std::gamma_distribution<> dist_gamma(r, beta);
242 typedef typename std::conditional<std::is_integral<DType>::value, DType,
int>::type GType;
243 SampleDistribution(dst,
244 [&](){ std::poisson_distribution<GType> dist_poisson(dist_gamma(rnd_engine_));
245 return static_cast<DType
>(dist_poisson(rnd_engine_));});
280 inline expr::ReshapeExp<Tensor<cpu, 1, DType>, DType, dim, 1>
294 std::mt19937 rnd_engine_;
305 DType a = 0.0f, DType b = 1.0f) {
307 this->GenUniform(dst->
dptr_, dst->
shape_.Size(), a, b);
311 this->GenUniform(mat[i].dptr_, mat.
size(1), a, b);
317 DType mu = 0.0f, DType sigma = 1.0f) {
322 this->GenGaussian(dst->
dptr_, dst->
shape_.Size(), mu, sigma);
326 this->GenGaussian(mat[i].dptr_, mat.
size(1), mu, sigma);
330 inline void GenUniform(
float *dptr,
index_t size,
float a,
float b) {
331 for (
index_t j = 0; j < size; ++j) {
332 dptr[j] =
static_cast<float>(RandNext()) * (b - a) + a;
335 inline void GenUniform(
double *dptr,
index_t size,
double a,
double b) {
336 for (
index_t j = 0; j < size; ++j) {
337 dptr[j] =
static_cast<double>(RandNext()) * (b - a) + a;
340 inline void GenGaussian(
float *dptr,
index_t size,
float mu,
float sigma) {
341 this->GenGaussianX(dptr, size, mu, sigma);
343 inline void GenGaussian(
double *dptr,
index_t size,
double mu,
double sigma) {
344 this->GenGaussianX(dptr, size, mu, sigma);
346 inline void GenGaussianX(DType *dptr,
index_t size, DType mu, DType sigma) {
347 DType g1 = 0.0f, g2 = 0.0f;
348 for (
index_t j = 0; j < size; ++j) {
350 this->SampleNormal2D(&g1, &g2);
351 dptr[j] = mu + g1 * sigma;
353 dptr[j] = mu + g2 * sigma;
358 inline DType RandNext(
void) {
359 return static_cast<DType
>(rand_r(&rseed_)) /
360 (static_cast<DType>(RAND_MAX) + 1.0f);
363 inline DType RandNext2(
void) {
364 return (static_cast<DType>(rand_r(&rseed_)) + 1.0f) /
365 (
static_cast<DType
>(RAND_MAX) + 2.0f);
372 inline void SampleNormal2D(DType *xx_, DType *yy_) {
373 DType &xx = *xx_, &yy = *yy_;
376 x = 2.0f * RandNext2() - 1.0f;
377 y = 2.0f * RandNext2() - 1.0f;
379 }
while (s >= 1.0f || s == 0.0f);
380 DType t = std::sqrt(-2.0f * std::log(s) / s);
381 xx = x * t; yy = y * t;
391 template<
typename DType>
410 curandStatus_t status;
413 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"set_stream CURAND failed";
421 if (gen_ == NULL || (curandSetGeneratorOffset(gen_, 0ULL) != CURAND_STATUS_SUCCESS))
424 curandStatus_t status;
425 status = curandSetPseudoRandomGeneratorSeed(gen_, static_cast<uint64_t>(seed));
426 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"Set CURAND seed failed.";
432 curandStatus_t status = curandGenerate(gen_, dst.
dptr_, dst.
size(0));
433 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"CURAND Gen rand ints failed.";
444 DType a = 0.0f, DType b = 1.0f);
455 DType mu = 0.0f, DType sigma = 1.0f);
471 gaussian(
Shape<dim> shape, DType mu = 0.0f, DType sigma = 1.0f);
484 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
488 inline void GenGaussian(
float *dptr,
size_t size,
float mu,
float sigma) {
489 curandStatus_t status;
490 status = curandGenerateNormal(gen_, dptr, size, mu, sigma);
491 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"CURAND Gen Normal float failed." 492 <<
" size = " << size
494 <<
",sigma = " << sigma;
496 inline void GenGaussian(
double *dptr,
size_t size,
double mu,
double sigma) {
497 curandStatus_t status;
498 status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma);
499 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"CURAND Gen Normal double failed." 500 <<
" size = " << size
502 <<
",sigma = " << sigma;
504 inline void GenUniform(
float *dptr,
size_t size) {
505 curandStatus_t status;
506 status = curandGenerateUniform(gen_, dptr, size);
507 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"CURAND Gen Uniform float failed." 508 <<
" size = " << size;
510 inline void GenUniform(
double *dptr,
size_t size) {
511 curandStatus_t status;
512 status = curandGenerateUniformDouble(gen_, dptr, size);
513 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"CURAND Gen Uniform double failed." 514 <<
" size = " << size;
516 inline void CreateGenerator() {
519 curandStatus_t status;
520 status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT);
521 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"Cannot create CURAND Generator";
523 inline void DeleteGenerator() {
525 curandStatus_t status;
526 status = curandDestroyGenerator(gen_);
527 CHECK_EQ(status, CURAND_STATUS_SUCCESS) <<
"Destory CURAND Gen failed";
532 curandGenerator_t gen_;
536 #endif // MSHADOW_USE_CUDA 540 template<
typename DType>
544 if (a == 0.0f && b == 1.0f) {
548 *dst = this->uniform(dst->
shape_);
551 *dst = this->uniform(dst->
shape_) * (b - a) + a;
554 template<
typename DType>
561 this->GenGaussian(dst->
dptr_, dst->
shape_.Size(), mu, sigma);
563 *dst = this->gaussian(dst->
shape_, mu, sigma);
567 template<
typename DType>
571 size_t aligned_sz = ((shape.
Size() + 1UL) >> 1) << 1;
573 buffer_.Resize(
Shape1(aligned_sz));
575 this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma);
579 template<
typename DType>
581 inline expr::ReshapeExp<Tensor<gpu, 1, DType>, DType, dim, 1>
584 this->GenUniform(buffer_.dptr_, buffer_.size(0));
589 #endif // MSHADOW_RANDOM_H_ expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > gaussian(Shape< dim > shape)
return a temporal expression storing standard gaussian random variables the temporal tensor is only v...
Definition: random.h:263
random number generator
Definition: random.h:53
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:419
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:145
DType * dptr_
pointer to the data
Definition: tensor.h:435
unsigned GetSeed() const
get random seed used in random generator
Definition: random.h:83
void Seed(int seed)
seed random number generator using this seed
Definition: random.h:73
Definition: stream_gpu-inl.h:38
~Random(void) MSHADOW_THROW_EXCEPTION
Definition: random.h:402
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
~Random(void)
Definition: random.h:67
expr::ReshapeExp< Tensor< cpu, 1, DType >, DType, dim, 1 > uniform(Shape< dim > shape)
return a temporal expression storing standard uniform [0,1) the temporal tensor is only valid before ...
Definition: random.h:281
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
device name CPU
Definition: tensor.h:40
device name GPU
Definition: tensor.h:47
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:520
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:329
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
int32_t index_t
type that will be used for index
Definition: base.h:336
ReshapeExp< SrcExp, DType, dimdst, ExpInfo< SrcExp >::kDim > reshape(const Exp< SrcExp, DType, etype > &src, Shape< dimdst > oshape)
a expression that reshapes a tensor to another shape
Definition: reshape.h:67
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
tensor container that does memory allocation and resize like STL, use it to save the lines of FreeSpa...
Definition: tensor_container.h:41
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:207
void GetRandInt(const Tensor< gpu, 1, unsigned > &dst)
get a set of random integers
Definition: random.h:431
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:492
reshape the content to another shape input: Tensor<Device,dimsrc>: ishape output: Tensor<Device...
Definition: reshape.h:40
std::mt19937 & GetRndEngine()
Definition: random.h:287
Random(int seed)
constructor of random engine
Definition: random.h:398
tensor container that does memory allocation and resize like STL
void set_stream(Stream< cpu > *stream)
set the stream of computation
Definition: random.h:90
Random(int seed)
constructor of random engine
Definition: random.h:63
overloaded + operator between half_t and bf16_t
Definition: base.h:327
#define MSHADOW_THROW_EXCEPTION
Definition: base.h:253
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:506
general tensor
Definition: tensor.h:421
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
void set_stream(Stream< gpu > *stream)
set the stream of computation
Definition: random.h:409
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384