24 #ifndef MXNET_NDARRAY_H_ 25 #define MXNET_NDARRAY_H_ 28 #include <dmlc/logging.h> 39 #if MXNET_USE_MKLDNN == 1 46 #if DMLC_USE_CXX11 == 0 47 #error "cxx11 was required for ndarray module" 96 : ptr_(
std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype)),
129 : ptr_(
std::make_shared<Chunk>(data, dev_id)),
131 dtype_(data.type_flag_),
144 NDArray(
const TBlob &data,
int dev_id,
const std::function<
void()>& deleter)
145 : ptr_(new Chunk(data, dev_id), [deleter](Chunk *p) {
156 : ptr_(
std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)),
174 const TBlob &data,
const std::vector<TBlob> &aux_data,
int dev_id)
175 : ptr_(
std::make_shared<Chunk>(stype, data, aux_data, dev_id)),
177 dtype_(data.type_flag_),
178 storage_type_(stype),
186 ptr_->Init(shape, this->dtype_);
187 this->shape_ = shape;
192 void SetShapeFromChunk();
207 return byte_offset_ > 0 || shape() != ptr_->storage_shape;
212 return ptr_ == other.ptr_ &&
213 shape_ == other.shape_ &&
214 byte_offset_ == other.byte_offset_ &&
215 dtype_ == other.dtype_;
230 CHECK(ptr_ !=
nullptr);
232 <<
"storage_shape() is not intended for kDefaultStorage.";
233 return ptr_->storage_shape;
243 <<
"aux_shape() is not intended for kDefaultStorage.";
244 return ptr_->aux_shapes[index];
250 <<
"aux_shapes() is not intended for kDefaultStorage.";
251 return ptr_->aux_shapes;
257 <<
"aux_types() is not intended for kDefaultStorage.";
258 return ptr_->aux_types;
270 <<
"set_aux_shape() is not intended for kDefaultStorage.";
271 ptr_->set_aux_shape(index, shape);
291 auto stype = storage_type();
293 auto shape = aux_shape(i);
294 auto type = aux_type(i);
296 auto dptr =
static_cast<DType*
>(ptr_->aux_handles[i].dptr);
298 <<
"Unexpected storage type: " << stype;
299 res =
TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type);
308 return ptr_->shandle.ctx;
318 return ptr_->aux_types[i];
322 return storage_type_;
326 return ptr_.get() ==
nullptr;
329 bool fresh_out_grad()
const;
331 void set_fresh_out_grad(
bool state)
const;
337 if (is_none())
return false;
338 auto stype = storage_type();
340 <<
"storage_initialized() is not intended for kDefaultStorage.";
343 <<
"inconsistent storage shape " << storage_shape()
347 CHECK_EQ(aux_shape(
csr::kIdx)[0], storage_shape()[0])
348 <<
"inconsistent storage shape " << storage_shape()
349 <<
" vs. aux shape " << aux_shape(
csr::kIdx);
352 LOG(FATAL) <<
"Unknown storage type";
361 return ptr_->shandle;
368 if (is_none())
return;
376 if (is_none())
return;
384 },
Context{}, {}, {ptr_->var});
397 return var()->version();
409 bool LegacyLoad(
dmlc::Stream *strm,
const uint32_t magic);
494 void SyncCopyFromCPU(
const void *data,
size_t size)
const;
499 void SyncCopyFromNDArray(
const NDArray &src,
int i = -1,
int j = -1);
511 void SyncCopyToCPU(
void *data,
size_t size)
const;
517 void SyncCheckFormat(
const bool full_check)
const;
548 NDArray aux_ndarray(
size_t i)
const;
565 <<
"AsArray is intended only for kDefaultStorage.";
566 CHECK_GE(ptr_->shandle.size,
568 <<
"NDArray.AsArray: target memory size is bigger";
606 CHECK(shape_ == arr.shape_) <<
"ndarray shape is different from the target";
607 CHECK(dtype_ == arr.dtype_) <<
"ndarray dtype is different from the target";
610 <<
"Only to be used with CSR and RSP storage types";
613 arr.ptr_->shandle = ptr_->shandle;
614 ptr_->shandle = shandle_dst;
616 ptr_->storage_shape = arr.ptr_->storage_shape;
617 ptr_->storage_type = arr.ptr_->storage_type;
618 ptr_->ctx = arr.ptr_->ctx;
622 CHECK(ptr_->aux_handles.size() == arr.ptr_->aux_handles.size())
623 <<
"ndarray number of aux_handles is different from target";
624 for (
auto &aux_handle : arr.ptr_->aux_handles) {
626 ptr_->aux_handles[aux_idx] = aux_handle;
627 aux_handle = aux_dst;
630 ptr_->aux_types = arr.ptr_->aux_types;
631 ptr_->aux_shapes = arr.ptr_->aux_shapes;
662 ptr_->CheckAndAlloc();
687 <<
"CheckAndAlloc(aux_shapes) is not intended for kDefaultStorage";
688 ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_);
692 <<
"CheckAndAllocData is not intended for kDefaultStorage";
693 ptr_->CheckAndAllocData(storage_shape, dtype_);
697 <<
"CheckAndAllocAuxData is not intended for kDefaultStorage";
698 ptr_->CheckAndAllocAuxData(i, aux_shape);
701 #if MXNET_USE_MKLDNN == 1 706 explicit NDArray(
const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
711 explicit NDArray(
const mkldnn::memory::desc &md);
715 bool IsMKLDNNData()
const {
716 return ptr_->IsMKLDNN();
721 bool IsDefaultData()
const {
722 return ptr_->IsDefault();
734 const mkldnn::memory *GetMKLDNNData()
const;
739 const mkldnn::memory *GetMKLDNNData(
const mkldnn::memory::desc &md)
const;
745 const mkldnn::memory *GetMKLDNNDataReorder(
746 const mkldnn::memory::desc &md)
const;
751 void CopyFrom(
const mkldnn::memory &mem);
756 mkldnn::memory *CreateMKLDNNData(
const mkldnn::memory::desc &md);
763 void Reorder2DefaultAsync()
const;
764 void MKLDNNDataReorderAsync(
const mkldnn::memory::desc &md)
const;
770 NDArray Reorder2Default()
const;
776 NDArray Reorder2DefaultFloatFormat()
const;
778 void InvalidateMKLDNNData();
795 void UpdateMKLDNNMemDesc(
const mkldnn::memory::desc &desc);
805 const std::vector<NDArray>& data,
806 const std::vector<std::string>& names);
814 std::vector<NDArray>* data,
815 std::vector<std::string>* keys);
831 std::vector<Storage::Handle> aux_handles;
833 #if MXNET_USE_MKLDNN == 1 836 std::shared_ptr<MKLDNNMemory> mkl_mem_;
853 std::vector<int> aux_types;
864 std::shared_ptr<Storage> storage_ref_;
866 std::weak_ptr<Engine> engine_ref_;
870 Chunk() : static_data(true), delay_alloc(false),
871 storage_ref_(
Storage::_GetSharedRef()),
872 engine_ref_(
Engine::_GetSharedRef()) {}
876 : static_data(false), delay_alloc(true), ctx(ctx_),
877 storage_ref_(
Storage::_GetSharedRef()),
878 engine_ref_(
Engine::_GetSharedRef()) {
879 storage_shape = shape;
886 this->CheckAndAlloc();
890 Chunk(
const TBlob &data,
int dev_id)
891 : static_data(true), delay_alloc(false),
892 storage_ref_(
Storage::_GetSharedRef()),
893 engine_ref_(
Engine::_GetSharedRef()) {
906 storage_shape = data.
shape_;
909 Chunk(
int shared_pid,
int shared_id,
const mxnet::TShape& shape,
int dtype)
910 : static_data(false), delay_alloc(false),
911 storage_ref_(
Storage::_GetSharedRef()),
912 engine_ref_(
Engine::_GetSharedRef()) {
920 storage_shape = shape;
924 bool delay_alloc_,
int dtype,
const std::vector<int> &aux_types_,
926 : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
927 aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
928 aux_shapes(aux_shapes_), storage_ref_(
Storage::_GetSharedRef()),
929 engine_ref_(
Engine::_GetSharedRef()) {
933 for (
size_t i = 0; i < aux_shapes.size(); i++) {
934 CheckAndAllocAuxData(i, aux_shapes[i]);
937 aux_handles[i].ctx = ctx;
940 CheckAndAllocData(storage_shape, dtype);
945 const std::vector<TBlob> &aux_data,
int dev_id)
946 : static_data(true), delay_alloc(false), storage_type(storage_type_),
947 storage_ref_(
Storage::_GetSharedRef()), engine_ref_(
Engine::_GetSharedRef()) {
963 storage_shape = data.
shape_;
965 for (
const auto &aux : aux_data) {
967 aux_handle.
ctx = ctx;
968 aux_handle.
dptr = aux.dptr_;
970 aux_handles.push_back(aux_handle);
971 aux_types.emplace_back(aux.type_flag_);
972 aux_shapes.emplace_back(aux.shape_);
977 inline void set_aux_shape(
const size_t i,
const mxnet::TShape& shape) {
978 aux_shapes[i] = shape;
979 if (storage_shape.
ndim() >= 0) {
981 storage_shape[0] = shape[0];
983 storage_shape[0] = shape[0];
989 inline void CheckAndAlloc(
void) {
992 #if MXNET_USE_MKLDNN == 1 1001 void CheckAndAlloc(uint64_t dbytes) {
1003 <<
"CheckAndAlloc(dbytes) is only intended for kDefaultStorage";
1004 dbytes = std::max(dbytes, static_cast<uint64_t>(shandle.
size));
1007 #if MXNET_USE_MKLDNN == 1 1010 delay_alloc =
false;
1011 }
else if (shandle.
size < dbytes) {
1016 #if MXNET_USE_MKLDNN == 1 1023 auto size = shape.
Size();
1024 storage_shape = shape;
1026 this->CheckAndAlloc();
1036 storage_shape[0] = aux_shape[0];
1037 CheckAndAllocData(storage_shape, dtype);
1041 CheckAndAllocData(aux_shapes[
csr::kIdx], dtype);
1043 LOG(FATAL) <<
"Storage type " << storage_type <<
" not implemented for CheckAndAlloc";
1050 void CheckAndAllocData(
const mxnet::TShape &shape,
int dtype);
1052 #if MXNET_USE_MKLDNN == 1 1058 void Reorder2Default();
1060 void MKLDNNDataReorder(
const mkldnn::memory::desc &md);
1061 bool IsMKLDNN()
const;
1062 bool IsDefault()
const;
1070 inline void CheckAndAllocAuxData(
size_t i,
const mxnet::TShape &shape) {
1071 CHECK_EQ(shape.
ndim(), 1) <<
"shape must be 1D in CheckAndAllocAuxData";
1073 <<
"storage type cannot be kUndefinedStorage in CheckAndAllocAuxData";
1075 <<
"storage type cannot be kDefaultStorage in CheckAndAllocAuxData";
1076 if (aux_handles.size() <= i) {
1077 aux_handles.resize(i + 1);
1080 if (aux_handles[i].size < aux_bytes) {
1087 set_aux_shape(i, shape);
1093 void SetTBlob()
const;
1096 std::shared_ptr<Chunk> ptr_{
nullptr};
1100 size_t byte_offset_ = 0;
1104 bool reuse_ =
false;
1116 mutable TBlob tblob_;
1281 typedef std::function<void (
NDArray **used_vars,
1306 NDArrayAPIFunction> {
1332 int num_params,
char **param_keys,
char **param_vals) {
1333 (*fsetvalue)(s[0], mutate_vars[0]);
1335 num_mutate_vars = 1; num_scalars = 1;
1336 this->add_argument(
"src",
"real_t",
"Source input to the function.");
1349 body = [fternary](
NDArray **used_vars,
1351 int num_params,
char **param_keys,
char **param_vals) {
1352 (*fternary)(*used_vars[0], *used_vars[1], *used_vars[2], mutate_vars[0]);
1354 num_use_vars = 3; num_mutate_vars = 1;
1356 this->add_argument(
"lhs",
"NDArray",
"Left operand to the function.");
1357 this->add_argument(
"mhs",
"NDArray",
"Middle operand to the function.");
1358 this->add_argument(
"rhs",
"NDArray",
"Right operand to the function.");
1371 int num_params,
char **param_keys,
char **param_vals) {
1372 (*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]);
1374 num_use_vars = 2; num_mutate_vars = 1;
1376 this->add_argument(
"lhs",
"NDArray",
"Left operand to the function.");
1377 this->add_argument(
"rhs",
"NDArray",
"Right operand to the function.");
1390 int num_params,
char **param_keys,
char **param_vals) {
1391 (*fscalar)(*used_vars[0], s[0], mutate_vars[0]);
1393 num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
1395 this->add_argument(
"lhs",
"NDArray",
"Left operand to the function.");
1396 this->add_argument(
"rhs",
"real_t",
"Right operand to the function.");
1408 int num_params,
char **param_keys,
char **param_vals) {
1409 (*funary)(*used_vars[0], mutate_vars[0]);
1411 num_use_vars = 1; num_mutate_vars = 1;
1413 this->add_argument(
"src",
"NDArray",
"Source input to the function.");
1423 void (*fgeneric)(
NDArray **used_vars,
1426 const std::map<std::string, std::string>& param)) {
1428 int num_params,
char **param_keys,
char **param_vals) {
1429 std::map<std::string, std::string> param;
1430 for (
int i = 0; i < num_params; ++i) {
1431 param[param_keys[i]] = param_vals[i];
1433 fgeneric(used_vars, s, mutate_vars, param);
1443 num_use_vars = n;
return *
this;
1451 num_mutate_vars = n;
return *
this;
1459 num_scalars = n;
return *
this;
1467 type_mask = tmask;
return *
this;
1482 #define MXNET_REGISTER_NDARRAY_FUN(name) \ 1483 DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name) 1491 #endif // MXNET_NDARRAY_H_ const mxnet::ShapeVector & aux_shapes() const
Definition: ndarray.h:248
const int default_type_flag
type enum value for default real type
Definition: base.h:484
NDArrayStorageType
Definition: ndarray.h:60
TBlob aux_data(size_t i) const
Definition: ndarray.h:290
const std::vector< int > & aux_types() const
Definition: ndarray.h:255
NDArrayFunctionReg & set_num_mutate_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1450
NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, const TBlob &data, const std::vector< TBlob > &aux_data, int dev_id)
constructing a static NDArray of non-default storage that shares data with TBlob Use with caution: al...
Definition: ndarray.h:173
bool is_none() const
Definition: ndarray.h:325
void CheckAndAllocAuxData(size_t i, const mxnet::TShape &aux_shape) const
Definition: ndarray.h:695
NDArrayFormatErr
Definition: ndarray.h:67
mxnet::TShape shape_
shape of the tensor
Definition: tensor_blob.h:71
Common base class for function registry.
Definition: registry.h:151
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
void RandomSeed(uint32_t seed)
Seed all random number generator in mxnet.
Engine that schedules all the operations according to dependency.
size_t byte_offset() const
Definition: ndarray.h:392
Context ctx() const
Definition: ndarray.h:306
const mxnet::TShape & aux_shape(size_t index) const
get the shape of aux_data(index)
Definition: ndarray.h:241
void SparseUpdateChunk(const NDArray &arr) const
Update ndarray chunk storage handles using existing ndarray storage handles Also update the aux_handl...
Definition: ndarray.h:605
NDArrayFunctionReg()
constructor
Definition: ndarray.h:1318
namespace of mxnet
Definition: api_registry.h:33
Storage manager across multiple devices.
Definition: storage.h:35
NDArray operator*(const NDArray &lhs, const NDArray &rhs)
elementwise multiplication
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
virtual void Free(Handle handle)=0
Free storage.
NDArrayFunctionReg & set_num_use_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1442
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:96
static Context GPU(int32_t dev_id=-1)
int type_mask
information on how function should be called from API
Definition: ndarray.h:1314
NDArrayFunctionReg & set_function(void(*funary)(const NDArray &src, NDArray *out))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1405
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:73
Definition: optional.h:251
NDArrayFunctionReg & set_num_scalars(unsigned n)
set the number of scalar arguments
Definition: ndarray.h:1458
unsigned num_mutate_vars
number of variable mutated by this function
Definition: ndarray.h:1310
execution time context. The information needed in runtime for actual execution.
Definition: base.h:349
interface of stream I/O for serialization
Definition: io.h:30
void * dptr
Pointer to the data.
Definition: storage.h:44
NDArrayFunctionReg & set_function(void(*fscalar)(const NDArray &lhs, const real_t &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1386
Graph node data structure.
base class of engine variables.
Definition: engine.h:43
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)
macro to quickly declare traits information
Definition: type_traits.h:126
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:262
Context ctx
Context information about device and ID.
Definition: storage.h:52
NDArray()
default constructor
Definition: ndarray.h:84
unsigned num_use_vars
number of variable used by this function
Definition: ndarray.h:1308
int shared_id
Definition: storage.h:57
NDArrayFunctionReg & set_function(void(*fternary)(const NDArray &lhs, const NDArray &mhs, const NDArray &rhs, NDArray *out))
set the function body to a ternary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1345
int dtype() const
Definition: ndarray.h:313
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:50
void Init(const mxnet::TShape &shape)
initialize the NDArray, assuming it is not assigned a meaningful shape before
Definition: ndarray.h:185
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:819
RowSparseAuxType
Definition: ndarray.h:57
all the scalar should go before use_vars
Definition: ndarray.h:1292
bool storage_initialized() const
Returns true if a sparse ndarray's aux_data and storage are initialized Throws an exception if the in...
Definition: ndarray.h:336
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
Engine::VarHandle var() const
Definition: ndarray.h:388
void * dptr_
pointer to the data
Definition: tensor_blob.h:69
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
void CheckAndAlloc(const mxnet::ShapeVector &aux_shapes) const
Definition: ndarray.h:685
whether this function allows the handles in the target to be empty NDArray that are not yet initializ...
Definition: ndarray.h:1301
size_t version() const
return var version of the NDArray
Definition: ndarray.h:396
C Tensor object, manage memory of DLTensor. This data structure is intended to facilitate the borrowi...
Definition: dlpack.h:157
namespace for dmlc
Definition: array_view.h:12
NDArrayStorageType storage_type() const
Definition: ndarray.h:321
virtual void WaitForVar(VarHandle var)=0
Wait for a variable.
void CopyFromTo(const NDArray &from, const NDArray *to, int priority=0)
issue an copy operation from one NDArray to another the two ndarray can sit on different devices this...
NDArray Detach() const
Return a copy of this NDArray without autograd history.
Definition: ndarray.h:649
CSRAuxType
Definition: ndarray.h:53
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
Storage manager across multiple devices.
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr, bool wait=false)=0
Push an asynchronous operation to the engine.
Storage handle.
Definition: storage.h:40
static Context CPUShared(int32_t dev_id=0)
size_t num_aux_data(NDArrayStorageType stype)
NDArrayFunctionReg & set_type_mask(int tmask)
set type mask
Definition: ndarray.h:1466
void WaitToRead() const
Block until all the pending write operations with respect to current NDArray are finished, and read can be performed.
Definition: ndarray.h:367
NDArray(const TBlob &data, int dev_id, const std::function< void()> &deleter)
constructing a static NDArray that shares data with TBlob which is with deleter Use with caution: all...
Definition: ndarray.h:144
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:206
an entry that represents output data from a node
Definition: node.h:51
Handle Alloc(size_t size, Context ctx)
Allocate a new contiguous memory for a given size.
Definition: storage.h:65
NDArray operator-(const NDArray &lhs, const NDArray &rhs)
elementwise subtraction
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1479
NDArrayFunctionReg & set_function(void(*fsetvalue)(const real_t &rhs, NDArray *out))
set the function body to a NDArray setvalue function this will also auto set the parameters correctly...
Definition: ndarray.h:1329
NDArray operator+(const NDArray &lhs, const NDArray &rhs)
elementwise add
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
Registry entry for NDArrayFunction.
Definition: ndarray.h:1304
void CheckAndAllocData(const mxnet::TShape &storage_shape) const
Definition: ndarray.h:690
bool IsView() const
Definition: ndarray.h:199
NDArrayFunctionReg & set_function(void(*fbinary)(const NDArray &lhs, const NDArray &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1367
void WaitToWrite() const
Block until all the pending read/write operations with respect to current NDArray are finished...
Definition: ndarray.h:375
Storage::Handle storage_handle() const
get storage handle
Definition: ndarray.h:357
Dependency engine that schedules operations.
Definition: engine.h:116
void set_aux_shape(size_t index, const mxnet::TShape &shape) const
For a sparse operation on a csr matrix for example, the size of the column index array is an estimate...
Definition: ndarray.h:268
static Context CPU(int32_t dev_id=0)
const TBlob & data() const
Definition: ndarray.h:277
runtime functions for NDArray
Definition: imperative.h:50
size_t Size() const
Definition: tuple.h:520
int aux_type(size_t i) const
Definition: ndarray.h:316
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:72
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:437
void ReshapeAndAlloc(const mxnet::TShape &shape)
Allocate the space if the allocation has been delayed or the requested size is bigger than the availa...
Definition: ndarray.h:674
all the use_vars should go before scalar
Definition: ndarray.h:1290
void CheckAndAlloc() const
Allocate the space if it is delayed allocated. This is an internal function used by system that norma...
Definition: ndarray.h:660
const mxnet::TShape & storage_shape() const
Definition: ndarray.h:229
unsigned num_scalars
number of scalars used by this function
Definition: ndarray.h:1312
NDArray(int shared_pid, int shared_id, const mxnet::TShape &shape, int dtype)
create ndarray from shared memory
Definition: ndarray.h:155
#define MSHADOW_TYPE_SWITCH(type, DType,...)
Definition: base.h:1074
NDArray(const mxnet::TShape &shape, Context ctx, bool delay_alloc=false, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray
Definition: ndarray.h:94
bool shape_is_known(const TShape &x)
Definition: tuple.h:692
const mxnet::TShape & shape() const
Definition: ndarray.h:221
overloaded + operator between half_t and bf16_t
Definition: base.h:334
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:94
NDArray AsArray(const mxnet::TShape &shape, int dtype) const
Create a NDArray that shares memory with current one The new array must have smaller memory size than...
Definition: ndarray.h:563
size_t size
Size of the storage.
Definition: storage.h:48
bool IsSame(const NDArray &other) const
Definition: ndarray.h:211
void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out)
Sample generalized negative binomial distribution for each elements of out.
Context information about the execution environment.
Definition: base.h:101
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
int ndim() const
Definition: tuple.h:217
ndarray interface
Definition: ndarray.h:81
NDArray(Context ctx, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray whose shape is unknown, hence the NDArray is inherently lazily creat...
Definition: ndarray.h:114
NDArray(const TBlob &data, int dev_id)
constructing a static NDArray that shares data with TBlob Use with caution: allocate ONLY ONE NDArray...
Definition: ndarray.h:128
void ElementwiseSum(const std::vector< NDArray > &source, NDArray *out, int priority=0)
Perform elementwise sum over each data from source, store result into out.
std::function< void(NDArray **used_vars, real_t *scalars, NDArray **mutate_vars, int num_params, char **param_keys, char **param_vals)> NDArrayAPIFunction
definition of NDArray function
Definition: ndarray.h:1286
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
void SampleNegBinomial(int32_t k, real_t p, NDArray *out)
Sample negative binomial distribution for each elements of out.
NDArrayFunctionReg & set_function(void(*fgeneric)(NDArray **used_vars, real_t *s, NDArray **mutate_vars, const std::map< std::string, std::string > ¶m))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1422
type traits information header
int shared_pid
Id for IPC shared memory.
Definition: storage.h:56
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:65
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
NDArray operator/(const NDArray &lhs, const NDArray &rhs)
elementwise division
NDArrayFunctionTypeMask
mask information on how functions can be exposed
Definition: ndarray.h:1288