25 #ifndef MXNET_NDARRAY_H_ 26 #define MXNET_NDARRAY_H_ 29 #include <dmlc/logging.h> 40 #if MXNET_USE_MKLDNN == 1 47 #if DMLC_USE_CXX11 == 0 48 #error "cxx11 was required for ndarray module" 97 : ptr_(
std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype)),
130 : ptr_(
std::make_shared<Chunk>(data, dev_id)),
132 dtype_(data.type_flag_),
145 NDArray(
const TBlob &data,
int dev_id,
const std::function<
void()>& deleter)
146 : ptr_(new Chunk(data, dev_id), [deleter](Chunk *p) {
157 : ptr_(
std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)),
175 const TBlob &data,
const std::vector<TBlob> &aux_data,
int dev_id)
176 : ptr_(
std::make_shared<Chunk>(stype, data, aux_data, dev_id)),
178 dtype_(data.type_flag_),
179 storage_type_(stype),
187 ptr_->Init(shape, this->dtype_);
188 this->shape_ = shape;
193 void SetShapeFromChunk();
208 return byte_offset_ > 0 || shape() != ptr_->storage_shape;
213 return ptr_ == other.ptr_ &&
214 shape_ == other.shape_ &&
215 byte_offset_ == other.byte_offset_ &&
216 dtype_ == other.dtype_;
231 CHECK(ptr_ !=
nullptr);
233 <<
"storage_shape() is not intended for kDefaultStorage.";
234 return ptr_->storage_shape;
244 <<
"aux_shape() is not intended for kDefaultStorage.";
245 return ptr_->aux_shapes[index];
251 <<
"aux_shapes() is not intended for kDefaultStorage.";
252 return ptr_->aux_shapes;
258 <<
"aux_types() is not intended for kDefaultStorage.";
259 return ptr_->aux_types;
271 <<
"set_aux_shape() is not intended for kDefaultStorage.";
272 ptr_->set_aux_shape(index, shape);
292 auto stype = storage_type();
294 auto shape = aux_shape(i);
295 auto type = aux_type(i);
297 auto dptr =
static_cast<DType*
>(ptr_->aux_handles[i].dptr);
299 <<
"Unexpected storage type: " << stype;
300 res =
TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type);
309 return ptr_->shandle.ctx;
319 return ptr_->aux_types[i];
323 return storage_type_;
327 return ptr_.get() ==
nullptr;
330 bool fresh_out_grad()
const;
332 void set_fresh_out_grad(
bool state)
const;
338 if (is_none())
return false;
339 auto stype = storage_type();
341 <<
"storage_initialized() is not intended for kDefaultStorage.";
344 <<
"inconsistent storage shape " << storage_shape()
348 CHECK_EQ(aux_shape(
csr::kIdx)[0], storage_shape()[0])
349 <<
"inconsistent storage shape " << storage_shape()
350 <<
" vs. aux shape " << aux_shape(
csr::kIdx);
353 LOG(FATAL) <<
"Unknown storage type";
362 return ptr_->shandle;
369 if (is_none())
return;
377 if (is_none())
return;
385 },
Context{}, {}, {ptr_->var});
398 return var()->version();
410 bool LegacyLoad(
dmlc::Stream *strm,
const uint32_t magic);
495 void SyncCopyFromCPU(
const void *data,
size_t size)
const;
500 void SyncCopyFromNDArray(
const NDArray &src,
int i = -1,
int j = -1);
512 void SyncCopyToCPU(
void *data,
size_t size)
const;
518 void SyncCheckFormat(
const bool full_check)
const;
549 NDArray aux_ndarray(
size_t i)
const;
566 <<
"AsArray is intended only for kDefaultStorage.";
567 CHECK_GE(ptr_->shandle.size,
569 <<
"NDArray.AsArray: target memory size is bigger";
607 CHECK(shape_ == arr.shape_) <<
"ndarray shape is different from the target";
608 CHECK(dtype_ == arr.dtype_) <<
"ndarray dtype is different from the target";
611 <<
"Only to be used with CSR and RSP storage types";
614 arr.ptr_->shandle = ptr_->shandle;
615 ptr_->shandle = shandle_dst;
617 ptr_->storage_shape = arr.ptr_->storage_shape;
618 ptr_->storage_type = arr.ptr_->storage_type;
619 ptr_->ctx = arr.ptr_->ctx;
623 CHECK(ptr_->aux_handles.size() == arr.ptr_->aux_handles.size())
624 <<
"ndarray number of aux_handles is different from target";
625 for (
auto &aux_handle : arr.ptr_->aux_handles) {
627 ptr_->aux_handles[aux_idx] = aux_handle;
628 aux_handle = aux_dst;
631 ptr_->aux_types = arr.ptr_->aux_types;
632 ptr_->aux_shapes = arr.ptr_->aux_shapes;
663 ptr_->CheckAndAlloc();
688 <<
"CheckAndAlloc(aux_shapes) is not intended for kDefaultStorage";
689 ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_);
693 <<
"CheckAndAllocData is not intended for kDefaultStorage";
694 ptr_->CheckAndAllocData(storage_shape, dtype_);
698 <<
"CheckAndAllocAuxData is not intended for kDefaultStorage";
699 ptr_->CheckAndAllocAuxData(i, aux_shape);
702 #if MXNET_USE_MKLDNN == 1 707 explicit NDArray(
const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
712 explicit NDArray(
const mkldnn::memory::desc &md);
716 bool IsMKLDNNData()
const {
717 return ptr_->IsMKLDNN();
722 bool IsDefaultData()
const {
723 return ptr_->IsDefault();
735 const mkldnn::memory *GetMKLDNNData()
const;
740 const mkldnn::memory *GetMKLDNNData(
const mkldnn::memory::desc &md)
const;
746 const mkldnn::memory *GetMKLDNNDataReorder(
747 const mkldnn::memory::desc &md)
const;
752 void CopyFrom(
const mkldnn::memory &mem);
757 mkldnn::memory *CreateMKLDNNData(
const mkldnn::memory::desc &md);
764 void Reorder2DefaultAsync()
const;
765 void MKLDNNDataReorderAsync(
const mkldnn::memory::desc &md)
const;
771 NDArray Reorder2Default()
const;
777 NDArray Reorder2DefaultFloatFormat()
const;
779 void InvalidateMKLDNNData();
796 void UpdateMKLDNNMemDesc(
const mkldnn::memory::desc &desc);
806 const std::vector<NDArray>& data,
807 const std::vector<std::string>& names);
815 std::vector<NDArray>* data,
816 std::vector<std::string>* keys);
832 std::vector<Storage::Handle> aux_handles;
834 #if MXNET_USE_MKLDNN == 1 837 std::shared_ptr<MKLDNNMemory> mkl_mem_;
854 std::vector<int> aux_types;
865 std::shared_ptr<Storage> storage_ref_;
867 std::weak_ptr<Engine> engine_ref_;
871 Chunk() : static_data(true), delay_alloc(false),
872 storage_ref_(
Storage::_GetSharedRef()),
873 engine_ref_(
Engine::_GetSharedRef()) {}
877 : static_data(false), delay_alloc(true), ctx(ctx_),
878 storage_ref_(
Storage::_GetSharedRef()),
879 engine_ref_(
Engine::_GetSharedRef()) {
880 storage_shape = shape;
887 this->CheckAndAlloc();
891 Chunk(
const TBlob &data,
int dev_id)
892 : static_data(true), delay_alloc(false),
893 storage_ref_(
Storage::_GetSharedRef()),
894 engine_ref_(
Engine::_GetSharedRef()) {
907 storage_shape = data.
shape_;
910 Chunk(
int shared_pid,
int shared_id,
const mxnet::TShape& shape,
int dtype)
911 : static_data(false), delay_alloc(false),
912 storage_ref_(
Storage::_GetSharedRef()),
913 engine_ref_(
Engine::_GetSharedRef()) {
921 storage_shape = shape;
925 bool delay_alloc_,
int dtype,
const std::vector<int> &aux_types_,
927 : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
928 aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
929 aux_shapes(aux_shapes_), storage_ref_(
Storage::_GetSharedRef()),
930 engine_ref_(
Engine::_GetSharedRef()) {
934 for (
size_t i = 0; i < aux_shapes.size(); i++) {
935 CheckAndAllocAuxData(i, aux_shapes[i]);
938 aux_handles[i].ctx = ctx;
941 CheckAndAllocData(storage_shape, dtype);
946 const std::vector<TBlob> &aux_data,
int dev_id)
947 : static_data(true), delay_alloc(false), storage_type(storage_type_),
948 storage_ref_(
Storage::_GetSharedRef()), engine_ref_(
Engine::_GetSharedRef()) {
964 storage_shape = data.
shape_;
966 for (
const auto &aux : aux_data) {
968 aux_handle.
ctx = ctx;
969 aux_handle.
dptr = aux.dptr_;
971 aux_handles.push_back(aux_handle);
972 aux_types.emplace_back(aux.type_flag_);
973 aux_shapes.emplace_back(aux.shape_);
978 inline void set_aux_shape(
const size_t i,
const mxnet::TShape& shape) {
979 aux_shapes[i] = shape;
980 if (storage_shape.
ndim() >= 0) {
982 storage_shape[0] = shape[0];
984 storage_shape[0] = shape[0];
990 inline void CheckAndAlloc(
void) {
993 #if MXNET_USE_MKLDNN == 1 1002 void CheckAndAlloc(uint64_t dbytes) {
1004 <<
"CheckAndAlloc(dbytes) is only intended for kDefaultStorage";
1005 dbytes = std::max(dbytes, static_cast<uint64_t>(shandle.
size));
1008 #if MXNET_USE_MKLDNN == 1 1011 delay_alloc =
false;
1012 }
else if (shandle.
size < dbytes) {
1017 #if MXNET_USE_MKLDNN == 1 1024 auto size = shape.
Size();
1025 storage_shape = shape;
1027 this->CheckAndAlloc();
1037 storage_shape[0] = aux_shape[0];
1038 CheckAndAllocData(storage_shape, dtype);
1042 CheckAndAllocData(aux_shapes[
csr::kIdx], dtype);
1044 LOG(FATAL) <<
"Storage type " << storage_type <<
" not implemented for CheckAndAlloc";
1051 void CheckAndAllocData(
const mxnet::TShape &shape,
int dtype);
1053 #if MXNET_USE_MKLDNN == 1 1059 void Reorder2Default();
1061 void MKLDNNDataReorder(
const mkldnn::memory::desc &md);
1062 bool IsMKLDNN()
const;
1063 bool IsDefault()
const;
1071 inline void CheckAndAllocAuxData(
size_t i,
const mxnet::TShape &shape) {
1072 CHECK_EQ(shape.
ndim(), 1) <<
"shape must be 1D in CheckAndAllocAuxData";
1074 <<
"storage type cannot be kUndefinedStorage in CheckAndAllocAuxData";
1076 <<
"storage type cannot be kDefaultStorage in CheckAndAllocAuxData";
1077 if (aux_handles.size() <= i) {
1078 aux_handles.resize(i + 1);
1081 if (aux_handles[i].size < aux_bytes) {
1088 set_aux_shape(i, shape);
1094 void SetTBlob()
const;
1097 std::shared_ptr<Chunk> ptr_{
nullptr};
1101 size_t byte_offset_ = 0;
1105 bool reuse_ =
false;
1117 mutable TBlob tblob_;
1282 typedef std::function<void (
NDArray **used_vars,
1307 NDArrayAPIFunction> {
1333 int num_params,
char **param_keys,
char **param_vals) {
1334 (*fsetvalue)(s[0], mutate_vars[0]);
1336 num_mutate_vars = 1; num_scalars = 1;
1337 this->add_argument(
"src",
"real_t",
"Source input to the function.");
1350 body = [fternary](
NDArray **used_vars,
1352 int num_params,
char **param_keys,
char **param_vals) {
1353 (*fternary)(*used_vars[0], *used_vars[1], *used_vars[2], mutate_vars[0]);
1355 num_use_vars = 3; num_mutate_vars = 1;
1357 this->add_argument(
"lhs",
"NDArray",
"Left operand to the function.");
1358 this->add_argument(
"mhs",
"NDArray",
"Middle operand to the function.");
1359 this->add_argument(
"rhs",
"NDArray",
"Right operand to the function.");
1372 int num_params,
char **param_keys,
char **param_vals) {
1373 (*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]);
1375 num_use_vars = 2; num_mutate_vars = 1;
1377 this->add_argument(
"lhs",
"NDArray",
"Left operand to the function.");
1378 this->add_argument(
"rhs",
"NDArray",
"Right operand to the function.");
1391 int num_params,
char **param_keys,
char **param_vals) {
1392 (*fscalar)(*used_vars[0], s[0], mutate_vars[0]);
1394 num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
1396 this->add_argument(
"lhs",
"NDArray",
"Left operand to the function.");
1397 this->add_argument(
"rhs",
"real_t",
"Right operand to the function.");
1409 int num_params,
char **param_keys,
char **param_vals) {
1410 (*funary)(*used_vars[0], mutate_vars[0]);
1412 num_use_vars = 1; num_mutate_vars = 1;
1414 this->add_argument(
"src",
"NDArray",
"Source input to the function.");
1424 void (*fgeneric)(
NDArray **used_vars,
1427 const std::map<std::string, std::string>& param)) {
1429 int num_params,
char **param_keys,
char **param_vals) {
1430 std::map<std::string, std::string> param;
1431 for (
int i = 0; i < num_params; ++i) {
1432 param[param_keys[i]] = param_vals[i];
1434 fgeneric(used_vars, s, mutate_vars, param);
1444 num_use_vars = n;
return *
this;
1452 num_mutate_vars = n;
return *
this;
1460 num_scalars = n;
return *
this;
1468 type_mask = tmask;
return *
this;
1483 #define MXNET_REGISTER_NDARRAY_FUN(name) \ 1484 DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name) 1492 #endif // MXNET_NDARRAY_H_ const mxnet::ShapeVector & aux_shapes() const
Definition: ndarray.h:249
const int default_type_flag
type enum value for default real type
Definition: base.h:477
NDArrayStorageType
Definition: ndarray.h:61
TBlob aux_data(size_t i) const
Definition: ndarray.h:291
const std::vector< int > & aux_types() const
Definition: ndarray.h:256
NDArrayFunctionReg & set_num_mutate_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1451
NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, const TBlob &data, const std::vector< TBlob > &aux_data, int dev_id)
constructing a static NDArray of non-default storage that shares data with TBlob Use with caution: al...
Definition: ndarray.h:174
bool is_none() const
Definition: ndarray.h:326
void CheckAndAllocAuxData(size_t i, const mxnet::TShape &aux_shape) const
Definition: ndarray.h:696
NDArrayFormatErr
Definition: ndarray.h:68
mxnet::TShape shape_
shape of the tensor
Definition: tensor_blob.h:72
Common base class for function registry.
Definition: registry.h:151
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:104
void RandomSeed(uint32_t seed)
Seed all random number generator in mxnet.
Engine that schedules all the operations according to dependency.
size_t byte_offset() const
Definition: ndarray.h:393
Context ctx() const
Definition: ndarray.h:307
const mxnet::TShape & aux_shape(size_t index) const
get the shape of aux_data(index)
Definition: ndarray.h:242
void SparseUpdateChunk(const NDArray &arr) const
Update ndarray chunk storage handles using existing ndarray storage handles Also update the aux_handl...
Definition: ndarray.h:606
NDArrayFunctionReg()
constructor
Definition: ndarray.h:1319
namespace of mxnet
Definition: api_registry.h:33
Storage manager across multiple devices.
Definition: storage.h:36
NDArray operator*(const NDArray &lhs, const NDArray &rhs)
elementwise multiplication
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:146
virtual void Free(Handle handle)=0
Free storage.
NDArrayFunctionReg & set_num_use_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1443
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:97
static Context GPU(int32_t dev_id=-1)
int type_mask
information on how function should be called from API
Definition: ndarray.h:1315
NDArrayFunctionReg & set_function(void(*funary)(const NDArray &src, NDArray *out))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1406
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:74
Definition: optional.h:251
NDArrayFunctionReg & set_num_scalars(unsigned n)
set the number of scalar arguments
Definition: ndarray.h:1459
unsigned num_mutate_vars
number of variable mutated by this function
Definition: ndarray.h:1311
execution time context. The information needed in runtime for actual execution.
Definition: base.h:350
interface of stream I/O for serialization
Definition: io.h:30
void * dptr
Pointer to the data.
Definition: storage.h:45
NDArrayFunctionReg & set_function(void(*fscalar)(const NDArray &lhs, const real_t &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1387
Graph node data structure.
base class of engine variables.
Definition: engine.h:44
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)
macro to quickly declare traits information
Definition: type_traits.h:126
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:263
Context ctx
Context information about device and ID.
Definition: storage.h:53
NDArray()
default constructor
Definition: ndarray.h:85
unsigned num_use_vars
number of variable used by this function
Definition: ndarray.h:1309
int shared_id
Definition: storage.h:58
NDArrayFunctionReg & set_function(void(*fternary)(const NDArray &lhs, const NDArray &mhs, const NDArray &rhs, NDArray *out))
set the function body to a ternary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1346
int dtype() const
Definition: ndarray.h:314
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:51
void Init(const mxnet::TShape &shape)
initialize the NDArray, assuming it is not assigned a meaningful shape before
Definition: ndarray.h:186
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:820
RowSparseAuxType
Definition: ndarray.h:58
all the scalar should go before use_vars
Definition: ndarray.h:1293
bool storage_initialized() const
Returns true if a sparse ndarray's aux_data and storage are initialized Throws an exception if the in...
Definition: ndarray.h:337
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
Engine::VarHandle var() const
Definition: ndarray.h:389
void * dptr_
pointer to the data
Definition: tensor_blob.h:70
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
void CheckAndAlloc(const mxnet::ShapeVector &aux_shapes) const
Definition: ndarray.h:686
whether this function allows the handles in the target to be empty NDArray that are not yet initializ...
Definition: ndarray.h:1302
size_t version() const
return var version of the NDArray
Definition: ndarray.h:397
C Tensor object, manage memory of DLTensor. This data structure is intended to facilitate the borrowi...
Definition: dlpack.h:157
namespace for dmlc
Definition: array_view.h:12
NDArrayStorageType storage_type() const
Definition: ndarray.h:322
virtual void WaitForVar(VarHandle var)=0
Wait for a variable.
void CopyFromTo(const NDArray &from, const NDArray *to, int priority=0)
issue an copy operation from one NDArray to another the two ndarray can sit on different devices this...
NDArray Detach() const
Return a copy of this NDArray without autograd history.
Definition: ndarray.h:650
CSRAuxType
Definition: ndarray.h:54
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
Storage manager across multiple devices.
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr, bool wait=false)=0
Push an asynchronous operation to the engine.
Storage handle.
Definition: storage.h:41
static Context CPUShared(int32_t dev_id=0)
size_t num_aux_data(NDArrayStorageType stype)
NDArrayFunctionReg & set_type_mask(int tmask)
set type mask
Definition: ndarray.h:1467
void WaitToRead() const
Block until all the pending write operations with respect to current NDArray are finished, and read can be performed.
Definition: ndarray.h:368
NDArray(const TBlob &data, int dev_id, const std::function< void()> &deleter)
constructing a static NDArray that shares data with TBlob which is with deleter Use with caution: all...
Definition: ndarray.h:145
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:207
an entry that represents output data from a node
Definition: node.h:51
Handle Alloc(size_t size, Context ctx)
Allocate a new contiguous memory for a given size.
Definition: storage.h:66
NDArray operator-(const NDArray &lhs, const NDArray &rhs)
elementwise subtraction
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1472
NDArrayFunctionReg & set_function(void(*fsetvalue)(const real_t &rhs, NDArray *out))
set the function body to a NDArray setvalue function this will also auto set the parameters correctly...
Definition: ndarray.h:1330
NDArray operator+(const NDArray &lhs, const NDArray &rhs)
elementwise add
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:44
Registry entry for NDArrayFunction.
Definition: ndarray.h:1305
void CheckAndAllocData(const mxnet::TShape &storage_shape) const
Definition: ndarray.h:691
bool IsView() const
Definition: ndarray.h:200
NDArrayFunctionReg & set_function(void(*fbinary)(const NDArray &lhs, const NDArray &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1368
void WaitToWrite() const
Block until all the pending read/write operations with respect to current NDArray are finished...
Definition: ndarray.h:376
Storage::Handle storage_handle() const
get storage handle
Definition: ndarray.h:358
Dependency engine that schedules operations.
Definition: engine.h:117
void set_aux_shape(size_t index, const mxnet::TShape &shape) const
For a sparse operation on a csr matrix for example, the size of the column index array is an estimate...
Definition: ndarray.h:269
static Context CPU(int32_t dev_id=0)
const TBlob & data() const
Definition: ndarray.h:278
runtime functions for NDArray
Definition: imperative.h:50
size_t Size() const
Definition: tuple.h:521
int aux_type(size_t i) const
Definition: ndarray.h:317
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:73
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:438
void ReshapeAndAlloc(const mxnet::TShape &shape)
Allocate the space if the allocation has been delayed or the requested size is bigger than the availa...
Definition: ndarray.h:675
all the use_vars should go before scalar
Definition: ndarray.h:1291
void CheckAndAlloc() const
Allocate the space if it is delayed allocated. This is an internal function used by system that norma...
Definition: ndarray.h:661
const mxnet::TShape & storage_shape() const
Definition: ndarray.h:230
unsigned num_scalars
number of scalars used by this function
Definition: ndarray.h:1313
NDArray(int shared_pid, int shared_id, const mxnet::TShape &shape, int dtype)
create ndarray from shared memory
Definition: ndarray.h:156
#define MSHADOW_TYPE_SWITCH(type, DType,...)
Definition: base.h:1067
NDArray(const mxnet::TShape &shape, Context ctx, bool delay_alloc=false, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray
Definition: ndarray.h:95
bool shape_is_known(const TShape &x)
Definition: tuple.h:693
const mxnet::TShape & shape() const
Definition: ndarray.h:222
overloaded + operator between half_t and bf16_t
Definition: base.h:327
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:95
NDArray AsArray(const mxnet::TShape &shape, int dtype) const
Create a NDArray that shares memory with current one The new array must have smaller memory size than...
Definition: ndarray.h:564
size_t size
Size of the storage.
Definition: storage.h:49
bool IsSame(const NDArray &other) const
Definition: ndarray.h:212
void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out)
Sample generalized negative binomial distribution for each elements of out.
Context information about the execution environment.
Definition: base.h:102
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
int ndim() const
Definition: tuple.h:218
ndarray interface
Definition: ndarray.h:82
NDArray(Context ctx, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray whose shape is unknown, hence the NDArray is inherently lazily creat...
Definition: ndarray.h:115
NDArray(const TBlob &data, int dev_id)
constructing a static NDArray that shares data with TBlob Use with caution: allocate ONLY ONE NDArray...
Definition: ndarray.h:129
void ElementwiseSum(const std::vector< NDArray > &source, NDArray *out, int priority=0)
Perform elementwise sum over each data from source, store result into out.
std::function< void(NDArray **used_vars, real_t *scalars, NDArray **mutate_vars, int num_params, char **param_keys, char **param_vals)> NDArrayAPIFunction
definition of NDArray function
Definition: ndarray.h:1287
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
void SampleNegBinomial(int32_t k, real_t p, NDArray *out)
Sample negative binomial distribution for each elements of out.
NDArrayFunctionReg & set_function(void(*fgeneric)(NDArray **used_vars, real_t *s, NDArray **mutate_vars, const std::map< std::string, std::string > ¶m))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1423
type traits information header
int shared_pid
Id for IPC shared memory.
Definition: storage.h:57
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
NDArray operator/(const NDArray &lhs, const NDArray &rhs)
elementwise division
NDArrayFunctionTypeMask
mask information on how functions can be exposed
Definition: ndarray.h:1289