mxnet
ndarray.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_NDARRAY_H_
25 #define MXNET_NDARRAY_H_
26 
27 #include <dmlc/base.h>
28 #include <dmlc/logging.h>
29 #include <dmlc/io.h>
30 #include <dmlc/type_traits.h>
31 #include <dmlc/registry.h>
32 #include <nnvm/node.h>
33 #include <vector>
34 #include <map>
35 #include <string>
36 #include <algorithm>
37 #include <memory>
38 #include <algorithm>
39 #if MXNET_USE_MKLDNN == 1
40 #include <mkldnn.hpp>
41 #endif
42 #include "./base.h"
43 #include "./storage.h"
44 #include "./engine.h"
45 // check c++11
46 #if DMLC_USE_CXX11 == 0
47 #error "cxx11 was required for ndarray module"
48 #endif
49 
50 namespace mxnet {
51 // enum for storage types
52 namespace csr {
54 }
55 
56 namespace rowsparse {
58 }
59 
61  kUndefinedStorage = -1, // undefined storage
62  kDefaultStorage, // dense
63  kRowSparseStorage, // row sparse
64  kCSRStorage, // csr
65 };
66 
68  kNormalErr, // normal
69  kCSRShapeErr, // shape mismatch for csr
70  kCSRIndPtrErr, // indptr error for csr
71  kCSRIdxErr, // idx error for csr
72  kRSPShapeErr, // shape mismatch for row sparse
73  kRSPIdxErr, // indices error for row sparse
74 };
75 
76 class MKLDNNMemory;
77 
81 class NDArray {
82  public:
85  : entry_(nullptr) {
86  }
94  NDArray(const mxnet::TShape &shape, Context ctx,
95  bool delay_alloc = false, int dtype = mshadow::default_type_flag)
96  : ptr_(std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype)),
97  shape_(shape),
98  dtype_(dtype),
99  storage_type_(kDefaultStorage),
100  entry_(nullptr) {
101  }
104  NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Context ctx,
105  bool delay_alloc = true, int dtype = mshadow::default_type_flag,
106  std::vector<int> aux_types = {}, mxnet::ShapeVector aux_shapes = {},
107  mxnet::TShape storage_shape = mxnet::TShape(mshadow::Shape1(0)));
114  explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag)
115  : ptr_(std::make_shared<Chunk>(mxnet::TShape(mshadow::Shape1(0)), ctx, true, dtype)),
116  shape_(),
117  dtype_(dtype),
118  storage_type_(kDefaultStorage),
119  entry_(nullptr) {
120  }
128  NDArray(const TBlob &data, int dev_id)
129  : ptr_(std::make_shared<Chunk>(data, dev_id)),
130  shape_(data.shape_),
131  dtype_(data.type_flag_),
132  storage_type_(kDefaultStorage),
133  entry_(nullptr) {
134  }
135 
144  NDArray(const TBlob &data, int dev_id, const std::function<void()>& deleter)
145  : ptr_(new Chunk(data, dev_id), [deleter](Chunk *p) {
146  deleter(); // call custom deleter
147  delete p; // delete Chunk object
148  }),
149  shape_(data.shape_),
150  dtype_(data.type_flag_), storage_type_(kDefaultStorage),
151  entry_(nullptr) {
152  }
153 
155  NDArray(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype)
156  : ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)),
157  shape_(shape),
158  dtype_(dtype),
159  storage_type_(kDefaultStorage),
160  entry_(nullptr) {
161  }
162 
173  NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape,
174  const TBlob &data, const std::vector<TBlob> &aux_data, int dev_id)
175  : ptr_(std::make_shared<Chunk>(stype, data, aux_data, dev_id)),
176  shape_(shape),
177  dtype_(data.type_flag_),
178  storage_type_(stype),
179  entry_(nullptr) {
180  }
185  void Init(const mxnet::TShape &shape) {
186  ptr_->Init(shape, this->dtype_);
187  this->shape_ = shape;
188  }
192  void SetShapeFromChunk();
193  /*
194  * This indicates whether an array is a view of another array (created by
195  * reshape or slice). If an array is a view and the data is stored in
196  * MKLDNN format, we need to convert the data to the default format when
197  * data in the view is accessed.
198  */
199  inline bool IsView() const {
200  // View only works on the default storage
201  if (storage_type() != kDefaultStorage)
202  return false;
203  // If the array reuses memory, its shape may be different from the storage
204  // shape. However, we shouldn't consider it as a view.
205  if (reuse_)
206  return false;
207  return byte_offset_ > 0 || shape() != ptr_->storage_shape;
208  }
209 
210  /* \brief Check whether the two arrays are the same array */
211  inline bool IsSame(const NDArray& other) const {
212  return ptr_ == other.ptr_ &&
213  shape_ == other.shape_ &&
214  byte_offset_ == other.byte_offset_ &&
215  dtype_ == other.dtype_;
216  }
217 
221  inline const mxnet::TShape& shape() const {
222  return shape_;
223  }
229  inline const mxnet::TShape &storage_shape() const {
230  CHECK(ptr_ != nullptr);
231  CHECK_NE(storage_type(), kDefaultStorage)
232  << "storage_shape() is not intended for kDefaultStorage.";
233  return ptr_->storage_shape;
234  }
235 
241  inline const mxnet::TShape& aux_shape(size_t index) const {
242  CHECK_NE(storage_type(), kDefaultStorage)
243  << "aux_shape() is not intended for kDefaultStorage.";
244  return ptr_->aux_shapes[index];
245  }
246 
247  /* \return the shapes of all aux data */
249  CHECK_NE(storage_type(), kDefaultStorage)
250  << "aux_shapes() is not intended for kDefaultStorage.";
251  return ptr_->aux_shapes;
252  }
253 
255  const std::vector<int>& aux_types() const {
256  CHECK_NE(storage_type(), kDefaultStorage)
257  << "aux_types() is not intended for kDefaultStorage.";
258  return ptr_->aux_types;
259  }
260 
268  inline void set_aux_shape(size_t index, const mxnet::TShape& shape) const {
269  CHECK_NE(storage_type(), kDefaultStorage)
270  << "set_aux_shape() is not intended for kDefaultStorage.";
271  ptr_->set_aux_shape(index, shape);
272  }
273 
277  inline const TBlob& data() const {
278  if (storage_type() == kDefaultStorage) CheckAndAlloc();
279  SetTBlob();
280  return tblob_;
281  }
285  NDArray grad() const;
286 
290  inline TBlob aux_data(size_t i) const {
291  auto stype = storage_type();
292  TBlob res;
293  auto shape = aux_shape(i);
294  auto type = aux_type(i);
295  MSHADOW_TYPE_SWITCH(type, DType, {
296  auto dptr = static_cast<DType*>(ptr_->aux_handles[i].dptr);
297  CHECK(stype == kRowSparseStorage || stype == kCSRStorage)
298  << "Unexpected storage type: " << stype;
299  res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type);
300  });
301  return res;
302  }
306  inline Context ctx() const {
307  CHECK(!is_none());
308  return ptr_->shandle.ctx;
309  }
313  inline int dtype() const {
314  return dtype_;
315  }
316  inline int aux_type(size_t i) const {
317  CHECK(!is_none());
318  return ptr_->aux_types[i];
319  }
320 
322  return storage_type_;
323  }
325  inline bool is_none() const {
326  return ptr_.get() == nullptr;
327  }
329  bool fresh_out_grad() const;
331  void set_fresh_out_grad(bool state) const;
336  inline bool storage_initialized() const {
337  if (is_none()) return false;
338  auto stype = storage_type();
339  CHECK_NE(stype, kDefaultStorage)
340  << "storage_initialized() is not intended for kDefaultStorage.";
341  if (stype == kRowSparseStorage) {
342  CHECK_EQ(aux_shape(rowsparse::kIdx)[0], storage_shape()[0])
343  << "inconsistent storage shape " << storage_shape()
344  << " vs. aux shape " << aux_shape(rowsparse::kIdx);
345  return aux_shape(rowsparse::kIdx).Size() != 0;
346  } else if (stype == kCSRStorage) {
347  CHECK_EQ(aux_shape(csr::kIdx)[0], storage_shape()[0])
348  << "inconsistent storage shape " << storage_shape()
349  << " vs. aux shape " << aux_shape(csr::kIdx);
350  return aux_shape(csr::kIdx).Size() != 0;
351  } else {
352  LOG(FATAL) << "Unknown storage type";
353  }
354  return true;
355  }
358  CHECK(!is_none());
359  CHECK_EQ(storage_type(), kDefaultStorage);
360  CheckAndAlloc();
361  return ptr_->shandle;
362  }
367  inline void WaitToRead() const {
368  if (is_none()) return;
369  Engine::Get()->WaitForVar(ptr_->var);
370  }
375  inline void WaitToWrite() const {
376  if (is_none()) return;
382  [](RunContext, Engine::CallbackOnComplete on_complete) {
383  on_complete();
384  }, Context{}, {}, {ptr_->var});
385  Engine::Get()->WaitForVar(ptr_->var);
386  }
388  inline Engine::VarHandle var() const {
389  return ptr_->var;
390  }
392  inline size_t byte_offset() const {
393  return byte_offset_;
394  }
396  inline size_t version() const {
397  return var()->version();
398  }
403  void Save(dmlc::Stream *strm) const;
409  bool LegacyLoad(dmlc::Stream *strm, const uint32_t magic);
415  bool Load(dmlc::Stream *strm);
421  NDArray &operator=(real_t scalar);
428  NDArray &operator+=(const NDArray &src);
435  NDArray &operator+=(const real_t &src);
442  NDArray &operator-=(const NDArray &src);
449  NDArray &operator-=(const real_t &src);
456  NDArray &operator*=(const NDArray &src);
463  NDArray &operator*=(const real_t &src);
470  NDArray &operator/=(const NDArray &src);
477  NDArray &operator/=(const real_t &src);
483  NDArray Copy(Context ctx) const;
494  void SyncCopyFromCPU(const void *data, size_t size) const;
495 
499  void SyncCopyFromNDArray(const NDArray &src, int i = -1, int j = -1);
500 
511  void SyncCopyToCPU(void *data, size_t size) const;
517  void SyncCheckFormat(const bool full_check) const;
524  NDArray Slice(index_t begin, index_t end) const;
531  NDArray SliceWithRecord(index_t begin, index_t end);
537  NDArray At(index_t idx) const;
543  NDArray AtWithRecord(index_t idx);
548  NDArray aux_ndarray(size_t i) const;
549 
554  NDArray data_ndarray() const;
555 
563  inline NDArray AsArray(const mxnet::TShape &shape, int dtype) const {
564  CHECK_EQ(storage_type(), kDefaultStorage)
565  << "AsArray is intended only for kDefaultStorage.";
566  CHECK_GE(ptr_->shandle.size,
567  shape.Size() * mshadow::mshadow_sizeof(dtype))
568  << "NDArray.AsArray: target memory size is bigger";
569  // We can't reuse memory in a view.
570  CHECK(!IsView());
571  NDArray ret = *this;
572  ret.shape_ = shape;
573  ret.dtype_ = dtype;
574  ret.reuse_ = true;
575  return ret;
576  }
577 
583  DLManagedTensor* ToDLPack() const;
584 
596  static NDArray FromDLPack(const DLManagedTensor* tensor, bool transient_handle);
597 
605  inline void SparseUpdateChunk(const NDArray &arr) const {
606  CHECK(shape_ == arr.shape_) << "ndarray shape is different from the target";
607  CHECK(dtype_ == arr.dtype_) << "ndarray dtype is different from the target";
608  auto stype = arr.storage_type();
609  CHECK(stype == kCSRStorage || stype == kRowSparseStorage)
610  << "Only to be used with CSR and RSP storage types";
611  // swap shandles between src and dst
612  Storage::Handle shandle_dst = arr.ptr_->shandle;
613  arr.ptr_->shandle = ptr_->shandle;
614  ptr_->shandle = shandle_dst;
615 
616  ptr_->storage_shape = arr.ptr_->storage_shape;
617  ptr_->storage_type = arr.ptr_->storage_type;
618  ptr_->ctx = arr.ptr_->ctx;
619 
620  // swap aux_handles between src and dst
621  size_t aux_idx = 0;
622  CHECK(ptr_->aux_handles.size() == arr.ptr_->aux_handles.size())
623  << "ndarray number of aux_handles is different from target";
624  for (auto &aux_handle : arr.ptr_->aux_handles) {
625  Storage::Handle aux_dst = ptr_->aux_handles[aux_idx];
626  ptr_->aux_handles[aux_idx] = aux_handle;
627  aux_handle = aux_dst;
628  aux_idx++;
629  }
630  ptr_->aux_types = arr.ptr_->aux_types;
631  ptr_->aux_shapes = arr.ptr_->aux_shapes;
632  }
633 
639  NDArray Reshape(const mxnet::TShape &shape) const;
645  NDArray ReshapeWithRecord(const mxnet::TShape &shape);
649  NDArray Detach() const {
650  NDArray ret(*this);
651  ret.entry_ = nnvm::NodeEntry(nullptr);
652  return ret;
653  }
654 
655  nnvm::Symbol get_autograd_symbol() const;
660  inline void CheckAndAlloc() const {
661  CHECK_EQ(storage_type(), kDefaultStorage);
662  ptr_->CheckAndAlloc();
663  }
664 
674  void ReshapeAndAlloc(const mxnet::TShape& shape) {
675  CHECK_EQ(storage_type(), kDefaultStorage);
676  CHECK(!is_none());
677  shape_ = shape;
678  ptr_->CheckAndAlloc(shape.Size() * mshadow::mshadow_sizeof(dtype_));
679  }
680 
681  /* !
682  * \brief Alloc memory for non-default storage
683  * aux_shape is only known at run time
684  */
685  inline void CheckAndAlloc(const mxnet::ShapeVector &aux_shapes) const {
686  CHECK_NE(storage_type(), kDefaultStorage)
687  << "CheckAndAlloc(aux_shapes) is not intended for kDefaultStorage";
688  ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_);
689  }
690  inline void CheckAndAllocData(const mxnet::TShape &storage_shape) const {
691  CHECK_NE(storage_type(), kDefaultStorage)
692  << "CheckAndAllocData is not intended for kDefaultStorage";
693  ptr_->CheckAndAllocData(storage_shape, dtype_);
694  }
695  inline void CheckAndAllocAuxData(size_t i, const mxnet::TShape &aux_shape) const {
696  CHECK_NE(storage_type(), kDefaultStorage)
697  << "CheckAndAllocAuxData is not intended for kDefaultStorage";
698  ptr_->CheckAndAllocAuxData(i, aux_shape);
699  }
700 
701 #if MXNET_USE_MKLDNN == 1
702  /*
703  * Create NDArray from mkldnn memory.
704  * mkldnn_mem The mkldnn memory to be managed.
705  */
706  explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
707  /*
708  * Create NDArray from mkldnn memory descriptor.
709  * mem_pd The mkldnn memory descriptor to be created.
710  */
711  explicit NDArray(const mkldnn::memory::desc &md);
712  /*
713  * Test if the data is stored in one of special MKLDNN format.
714  */
715  bool IsMKLDNNData() const {
716  return ptr_->IsMKLDNN();
717  }
718  /*
719  * Test if the data is stored in one of default MXNet formats.
720  */
721  bool IsDefaultData() const {
722  return ptr_->IsDefault();
723  }
724  /*
725  * All functions below return a raw pointer to mkldnn memory. Actually there
726  * is a shared pointer that hold the memory either in NDArray or in MKLDNN
727  * stream. As long as we call these functions inside an operator, the return
728  * memory is always valid.
729  */
730 
731  /*
732  * This function returns mkldnn::memory with the default primitive_desc.
733  */
734  const mkldnn::memory *GetMKLDNNData() const;
735  /*
736  * This function returns mkldnn::memory with the given primitive_desc
737  * as long as the array size meets the required size in the given primitive_desc.
738  */
739  const mkldnn::memory *GetMKLDNNData(const mkldnn::memory::desc &md) const;
740  /*
741  * This function returns mkldnn::memory with the given primitive_desc.
742  * The returned mkldnn::memory will have the same physical layout as
743  * the given primitive_desc.
744  */
745  const mkldnn::memory *GetMKLDNNDataReorder(
746  const mkldnn::memory::desc &md) const;
747 
748  /*
749  * This function copies data from mkldnn memory.
750  */
751  void CopyFrom(const mkldnn::memory &mem);
752  /*
753  * This function allocates memory for array and creates mkldnn memory
754  * with the specified format.
755  */
756  mkldnn::memory *CreateMKLDNNData(const mkldnn::memory::desc &md);
757 
758  /*
759  * These are the async version of the methods above.
760  * It changes the layout of this NDArray, but it happens after all accesses to
761  * the array are complete.
762  */
763  void Reorder2DefaultAsync() const;
764  void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md) const;
765 
766  /*
767  * This creates a new NDArray with the reordered data.
768  * It doesn't affect the data of the original NDArray.
769  */
770  NDArray Reorder2Default() const;
771 
772  /*
773  * This creates a new NDArray using f32 with the reordered data.
774  * It doesn't affect the data of the original NDArray.
775  */
776  NDArray Reorder2DefaultFloatFormat() const;
777 
778  void InvalidateMKLDNNData();
779 
780  /*
781  * This function is used inside operators to reshape an array.
782  * It doesn't change the layout of the original array and allocate memory from
783  * the temporary buffer. The returned array is only valid inside the current
784  * invocation of this operator.
785  * This is different from Reshape. Reshape will cause data in the array to be
786  * converted to the default layout and allocate memory from malloc directly,
787  * which can be expensive.
788  * It's used by FullyConnected right now.
789  */
790  NDArray MKLDNNDataReshape(const mxnet::TShape &shape) const;
791 
795  void UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc);
796 #endif
797 
804  static void Save(dmlc::Stream* fo,
805  const std::vector<NDArray>& data,
806  const std::vector<std::string>& names);
813  static void Load(dmlc::Stream* fi,
814  std::vector<NDArray>* data,
815  std::vector<std::string>* keys);
816 
817  private:
818  friend class Imperative;
820  // shandle is used to store the actual values in the NDArray
821  // aux_handles store the aux data(such as indices) if it's needed by non-default storage.
822  struct Chunk {
826  Storage::Handle shandle;
831  std::vector<Storage::Handle> aux_handles;
832 
833 #if MXNET_USE_MKLDNN == 1
834 
836  std::shared_ptr<MKLDNNMemory> mkl_mem_;
837 #endif
838 
839  Engine::VarHandle var;
845  bool static_data;
848  bool delay_alloc;
849  // the type of the storage. The storage_type is never kUndefinedStorage once the chunk
850  // is constructed.
851  NDArrayStorageType storage_type = kDefaultStorage;
853  std::vector<int> aux_types;
854  // context of data
855  Context ctx;
856  // The shape of the chunk data.
857  // This might not be the same shape as the NDArray, since the storage may be sparse.
858  // The default value for storage_shape is {0} when an empty non-default NDArray is created.
859  mxnet::TShape storage_shape;
860  // The shape of aux data. The default value for the shape depends on the type of storage.
861  // If aux_shapes[i].Size() is zero, aux data i is empty.
862  mxnet::ShapeVector aux_shapes;
864  std::shared_ptr<Storage> storage_ref_;
866  std::weak_ptr<Engine> engine_ref_;
867 
868 
870  Chunk() : static_data(true), delay_alloc(false),
871  storage_ref_(Storage::_GetSharedRef()),
872  engine_ref_(Engine::_GetSharedRef()) {}
873 
875  Chunk(mxnet::TShape shape, Context ctx_, bool delay_alloc_, int dtype)
876  : static_data(false), delay_alloc(true), ctx(ctx_),
877  storage_ref_(Storage::_GetSharedRef()),
878  engine_ref_(Engine::_GetSharedRef()) {
879  storage_shape = shape;
880  if (shape_is_known(storage_shape)) {
881  shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
882  }
883  var = Engine::Get()->NewVariable();
884  shandle.ctx = ctx_;
885  if (!delay_alloc_) {
886  this->CheckAndAlloc();
887  }
888  }
889 
890  Chunk(const TBlob &data, int dev_id)
891  : static_data(true), delay_alloc(false),
892  storage_ref_(Storage::_GetSharedRef()),
893  engine_ref_(Engine::_GetSharedRef()) {
894  CHECK(storage_type == kDefaultStorage);
895  var = Engine::Get()->NewVariable();
896  if (data.dev_mask() == cpu::kDevMask) {
897  ctx = Context::CPU();
898  } else {
899  CHECK_EQ(data.dev_mask(), gpu::kDevMask);
900  ctx = Context::GPU(dev_id);
901  }
902  // init shandle
903  shandle.ctx = ctx;
904  shandle.dptr = data.dptr_;
905  shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_);
906  storage_shape = data.shape_;
907  }
908 
909  Chunk(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype)
910  : static_data(false), delay_alloc(false),
911  storage_ref_(Storage::_GetSharedRef()),
912  engine_ref_(Engine::_GetSharedRef()) {
913  var = Engine::Get()->NewVariable();
914  ctx = Context::CPUShared(0);
915  shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
916  shandle.ctx = ctx;
917  shandle.shared_pid = shared_pid;
918  shandle.shared_id = shared_id;
919  Storage::Get()->Alloc(&shandle);
920  storage_shape = shape;
921  }
922  // Constructor for a non-default storage chunk
923  Chunk(NDArrayStorageType storage_type_, const mxnet::TShape &storage_shape_, Context ctx_,
924  bool delay_alloc_, int dtype, const std::vector<int> &aux_types_,
925  const mxnet::ShapeVector &aux_shapes_)
926  : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_),
927  aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_),
928  aux_shapes(aux_shapes_), storage_ref_(Storage::_GetSharedRef()),
929  engine_ref_(Engine::_GetSharedRef()) {
930  shandle.ctx = ctx;
931  var = Engine::Get()->NewVariable();
932  // aux_handles always reflect the correct number of aux data
933  for (size_t i = 0; i < aux_shapes.size(); i++) {
934  CheckAndAllocAuxData(i, aux_shapes[i]);
935  // this line is needed in case when aux_shapes[i].Size() = 0
936  // aux_handles[i] will not be updated and take only default value.
937  aux_handles[i].ctx = ctx;
938  }
939  if (!delay_alloc) {
940  CheckAndAllocData(storage_shape, dtype);
941  }
942  }
943 
944  Chunk(const NDArrayStorageType storage_type_, const TBlob &data,
945  const std::vector<TBlob> &aux_data, int dev_id)
946  : static_data(true), delay_alloc(false), storage_type(storage_type_),
947  storage_ref_(Storage::_GetSharedRef()), engine_ref_(Engine::_GetSharedRef()) {
948  using namespace mshadow;
949  CHECK_NE(storage_type, kDefaultStorage);
950  // init var
951  var = Engine::Get()->NewVariable();
952  // init ctx
953  if (data.dev_mask() == cpu::kDevMask) {
954  ctx = Context::CPU();
955  } else {
956  CHECK_EQ(data.dev_mask(), gpu::kDevMask);
957  ctx = Context::GPU(dev_id);
958  }
959  // init shandle
960  shandle.ctx = ctx;
961  shandle.dptr = data.dptr_;
962  shandle.size = data.shape_.Size() * mshadow_sizeof(data.type_flag_);
963  storage_shape = data.shape_;
964  // init aux handles
965  for (const auto &aux : aux_data) {
966  Storage::Handle aux_handle;
967  aux_handle.ctx = ctx;
968  aux_handle.dptr = aux.dptr_;
969  aux_handle.size = aux.shape_.Size() * mshadow_sizeof(aux.type_flag_);
970  aux_handles.push_back(aux_handle);
971  aux_types.emplace_back(aux.type_flag_);
972  aux_shapes.emplace_back(aux.shape_);
973  }
974  }
975 
977  inline void set_aux_shape(const size_t i, const mxnet::TShape& shape) {
978  aux_shapes[i] = shape;
979  if (storage_shape.ndim() >= 0) {
980  if (storage_type == kRowSparseStorage && i == rowsparse::kIdx) {
981  storage_shape[0] = shape[0];
982  } else if (storage_type == kCSRStorage && i == csr::kIdx) {
983  storage_shape[0] = shape[0];
984  }
985  }
986  }
987 
989  inline void CheckAndAlloc(void) {
990  if (delay_alloc) {
991  shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx);
992 #if MXNET_USE_MKLDNN == 1
993  mkl_mem_ = nullptr;
994 #endif
995  delay_alloc = false;
996  }
997  }
998 
1000  // size is the number of bytes
1001  void CheckAndAlloc(uint64_t dbytes) {
1002  CHECK_EQ(kDefaultStorage, storage_type)
1003  << "CheckAndAlloc(dbytes) is only intended for kDefaultStorage";
1004  dbytes = std::max(dbytes, static_cast<uint64_t>(shandle.size));
1005  if (delay_alloc) {
1006  shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
1007 #if MXNET_USE_MKLDNN == 1
1008  mkl_mem_ = nullptr;
1009 #endif
1010  delay_alloc = false;
1011  } else if (shandle.size < dbytes) {
1012  // free storage
1013  Storage::Get()->Free(shandle);
1014  // init storage
1015  shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
1016 #if MXNET_USE_MKLDNN == 1
1017  mkl_mem_ = nullptr;
1018 #endif
1019  }
1020  }
1022  void Init(const mxnet::TShape &shape, int dtype) {
1023  auto size = shape.Size();
1024  storage_shape = shape;
1025  shandle.size = size * mshadow::mshadow_sizeof(dtype);
1026  this->CheckAndAlloc();
1027  }
1028  inline void CheckAndAlloc(const mxnet::TShape &shape, const mxnet::ShapeVector &aux_shapes,
1029  int dtype) {
1030  // calculate size, perform allocation
1031  if (kRowSparseStorage == storage_type) {
1032  // For row sparse, aux_shape indicates the number of rows to allocate
1033  auto aux_shape = aux_shapes[rowsparse::kIdx];
1034  CheckAndAllocAuxData(rowsparse::kIdx, aux_shape);
1035  mxnet::TShape storage_shape(shape);
1036  storage_shape[0] = aux_shape[0];
1037  CheckAndAllocData(storage_shape, dtype);
1038  } else if (kCSRStorage == storage_type) {
1039  CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]);
1040  CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]);
1041  CheckAndAllocData(aux_shapes[csr::kIdx], dtype);
1042  } else {
1043  LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc";
1044  }
1045  }
1046  // create storage handle for data based on shape and dtype, assuming ctx is set
1047  // storage shape is also updated
1048  // if data is already allocated, try reuse the storage. Otherwise, free the current one
1049  // and allocate new storage
1050  void CheckAndAllocData(const mxnet::TShape &shape, int dtype);
1051 
1052 #if MXNET_USE_MKLDNN == 1
1053  // Have MKL memory reference to the data in the default storage
1054  // or create memory for MKLDNN.
1055  void SetMKLMem(const mxnet::TShape &shape, int dtype);
1056  // If the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
1057  // save the result in shandle.
1058  void Reorder2Default();
1059  // Reroder data to a specified layout.
1060  void MKLDNNDataReorder(const mkldnn::memory::desc &md);
1061  bool IsMKLDNN() const;
1062  bool IsDefault() const;
1063 #endif
1064 
1065  // create storage handle for aux data based on shape
1066  // this function assumes ctx, aux shapes and aux types are set
1067  // aux shape is also updated
1068  // if aux data is already allocated, try reuse the storage. Otherwise, free the current one
1069  // and allocate new storage
1070  inline void CheckAndAllocAuxData(size_t i, const mxnet::TShape &shape) {
1071  CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData";
1072  CHECK_NE(storage_type, kUndefinedStorage)
1073  << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData";
1074  CHECK_NE(storage_type, kDefaultStorage)
1075  << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData";
1076  if (aux_handles.size() <= i) {
1077  aux_handles.resize(i + 1);
1078  }
1079  size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]);
1080  if (aux_handles[i].size < aux_bytes) {
1081  // free storage
1082  Storage::Get()->Free(aux_handles[i]);
1083  // init aux storage
1084  aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);
1085  }
1086  // init shape
1087  set_aux_shape(i, shape);
1088  }
1090  ~Chunk();
1091  }; // struct Chunk
1092 
1093  void SetTBlob() const;
1094 
1096  std::shared_ptr<Chunk> ptr_{nullptr};
1098  mxnet::TShape shape_;
1100  size_t byte_offset_ = 0;
1102  int dtype_ = -1;
1104  bool reuse_ = false;
1106  NDArrayStorageType storage_type_ = kUndefinedStorage;
1108  nnvm::NodeEntry entry_;
1116  mutable TBlob tblob_;
1117 }; // class NDArray
1118 
1122 size_t num_aux_data(NDArrayStorageType stype);
1123 
1135 void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0);
1136 
1150 void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0, bool is_opr = false);
1151 
1158 void ElementwiseSum(const std::vector<NDArray> &source, NDArray *out, int priority = 0);
1159 
1166 NDArray operator+(const NDArray &lhs, const NDArray &rhs);
1173 NDArray operator+(const NDArray &lhs, const real_t &rhs);
1180 NDArray operator-(const NDArray &lhs, const NDArray &rhs);
1187 NDArray operator-(const NDArray &lhs, const real_t &rhs);
1194 NDArray operator*(const NDArray &lhs, const NDArray &rhs); \
1201 NDArray operator*(const NDArray &lhs, const real_t &rhs);
1208 NDArray operator/(const NDArray &lhs, const NDArray &rhs);
1215 NDArray operator/(const NDArray &lhs, const real_t &rhs);
1216 
1221 void RandomSeed(uint32_t seed);
1226 void RandomSeed(Context ctx, uint32_t seed);
1233 void SampleUniform(real_t begin, real_t end, NDArray *out);
1240 void SampleGaussian(real_t mu, real_t sigma, NDArray *out);
1247 void SampleGamma(real_t alpha, real_t beta, NDArray *out);
1253 void SampleExponential(real_t lambda, NDArray *out);
1259 void SamplePoisson(real_t lambda, NDArray *out);
1266 void SampleNegBinomial(int32_t k, real_t p, NDArray *out);
1273 void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out);
1274 
1275 
1276 //--------------------------------------------------------------
1277 // The following part are API Registration of NDArray functions.
1278 //--------------------------------------------------------------
1279 
1281 typedef std::function<void (NDArray **used_vars,
1282  real_t *scalars,
1283  NDArray **mutate_vars,
1284  int num_params,
1285  char **param_keys,
1286  char **param_vals)> NDArrayAPIFunction;
1302 };
1305  : public dmlc::FunctionRegEntryBase<NDArrayFunctionReg,
1306  NDArrayAPIFunction> {
1308  unsigned num_use_vars;
1312  unsigned num_scalars;
1319  : num_use_vars(0),
1320  num_mutate_vars(0),
1321  num_scalars(0),
1322  type_mask(0) {}
1329  inline NDArrayFunctionReg &set_function(void (*fsetvalue)(const real_t &rhs,
1330  NDArray *out)) {
1331  body = [fsetvalue] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1332  int num_params, char **param_keys, char **param_vals) {
1333  (*fsetvalue)(s[0], mutate_vars[0]);
1334  };
1335  num_mutate_vars = 1; num_scalars = 1;
1336  this->add_argument("src", "real_t", "Source input to the function.");
1337  return *this;
1338  }
1345  inline NDArrayFunctionReg &set_function(void(*fternary)(const NDArray &lhs,
1346  const NDArray &mhs,
1347  const NDArray &rhs,
1348  NDArray *out)) {
1349  body = [fternary](NDArray **used_vars,
1350  real_t *s, NDArray **mutate_vars,
1351  int num_params, char **param_keys, char **param_vals) {
1352  (*fternary)(*used_vars[0], *used_vars[1], *used_vars[2], mutate_vars[0]);
1353  };
1354  num_use_vars = 3; num_mutate_vars = 1;
1356  this->add_argument("lhs", "NDArray", "Left operand to the function.");
1357  this->add_argument("mhs", "NDArray", "Middle operand to the function.");
1358  this->add_argument("rhs", "NDArray", "Right operand to the function.");
1359  return *this;
1360  }
1367  inline NDArrayFunctionReg &set_function(void (*fbinary)(const NDArray &lhs,
1368  const NDArray &rhs,
1369  NDArray *out)) {
1370  body = [fbinary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1371  int num_params, char **param_keys, char **param_vals) {
1372  (*fbinary)(*used_vars[0], *used_vars[1], mutate_vars[0]);
1373  };
1374  num_use_vars = 2; num_mutate_vars = 1;
1376  this->add_argument("lhs", "NDArray", "Left operand to the function.");
1377  this->add_argument("rhs", "NDArray", "Right operand to the function.");
1378  return *this;
1379  }
1386  inline NDArrayFunctionReg &set_function(void (*fscalar)(const NDArray &lhs,
1387  const real_t &rhs,
1388  NDArray *out)) {
1389  body = [fscalar] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1390  int num_params, char **param_keys, char **param_vals) {
1391  (*fscalar)(*used_vars[0], s[0], mutate_vars[0]);
1392  };
1393  num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1;
1395  this->add_argument("lhs", "NDArray", "Left operand to the function.");
1396  this->add_argument("rhs", "real_t", "Right operand to the function.");
1397  return *this;
1398  }
1405  inline NDArrayFunctionReg &set_function(void (*funary)(const NDArray &src,
1406  NDArray *out)) {
1407  body = [funary] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1408  int num_params, char **param_keys, char **param_vals) {
1409  (*funary)(*used_vars[0], mutate_vars[0]);
1410  };
1411  num_use_vars = 1; num_mutate_vars = 1;
1413  this->add_argument("src", "NDArray", "Source input to the function.");
1414  return *this;
1415  }
1423  void (*fgeneric)(NDArray **used_vars,
1424  real_t *s,
1425  NDArray **mutate_vars,
1426  const std::map<std::string, std::string>& param)) {
1427  body = [fgeneric] (NDArray **used_vars, real_t *s, NDArray **mutate_vars,
1428  int num_params, char **param_keys, char **param_vals) {
1429  std::map<std::string, std::string> param;
1430  for (int i = 0; i < num_params; ++i) {
1431  param[param_keys[i]] = param_vals[i];
1432  }
1433  fgeneric(used_vars, s, mutate_vars, param);
1434  };
1435  return *this;
1436  }
1442  inline NDArrayFunctionReg &set_num_use_vars(unsigned n) {
1443  num_use_vars = n; return *this;
1444  }
1451  num_mutate_vars = n; return *this;
1452  }
1458  inline NDArrayFunctionReg &set_num_scalars(unsigned n) {
1459  num_scalars = n; return *this;
1460  }
1466  inline NDArrayFunctionReg &set_type_mask(int tmask) {
1467  type_mask = tmask; return *this;
1468  }
1469 }; // NDArrayFunctionReg
1470 
1482 #define MXNET_REGISTER_NDARRAY_FUN(name) \
1483  DMLC_REGISTRY_REGISTER(::mxnet::NDArrayFunctionReg, NDArrayFunctionReg, name)
1484 
1485 } // namespace mxnet
1486 
1487 namespace dmlc {
1489 DMLC_DECLARE_TRAITS(has_saveload, mxnet::NDArray, true);
1490 } // namespace dmlc
1491 #endif // MXNET_NDARRAY_H_
const mxnet::ShapeVector & aux_shapes() const
Definition: ndarray.h:248
Definition: ndarray.h:73
const int default_type_flag
type enum value for default real type
Definition: base.h:484
Definition: ndarray.h:62
NDArrayStorageType
Definition: ndarray.h:60
TBlob aux_data(size_t i) const
Definition: ndarray.h:290
Definition: ndarray.h:53
const std::vector< int > & aux_types() const
Definition: ndarray.h:255
NDArrayFunctionReg & set_num_mutate_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1450
NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, const TBlob &data, const std::vector< TBlob > &aux_data, int dev_id)
constructing a static NDArray of non-default storage that shares data with TBlob Use with caution: al...
Definition: ndarray.h:173
bool is_none() const
Definition: ndarray.h:325
void CheckAndAllocAuxData(size_t i, const mxnet::TShape &aux_shape) const
Definition: ndarray.h:695
NDArrayFormatErr
Definition: ndarray.h:67
mxnet::TShape shape_
shape of the tensor
Definition: tensor_blob.h:71
Common base class for function registry.
Definition: registry.h:151
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
void RandomSeed(uint32_t seed)
Seed all random number generator in mxnet.
Engine that schedules all the operations according to dependency.
size_t byte_offset() const
Definition: ndarray.h:392
Context ctx() const
Definition: ndarray.h:306
const mxnet::TShape & aux_shape(size_t index) const
get the shape of aux_data(index)
Definition: ndarray.h:241
void SparseUpdateChunk(const NDArray &arr) const
Update ndarray chunk storage handles using existing ndarray storage handles Also update the aux_handl...
Definition: ndarray.h:605
NDArrayFunctionReg()
constructor
Definition: ndarray.h:1318
namespace of mxnet
Definition: api_registry.h:33
Storage manager across multiple devices.
Definition: storage.h:35
NDArray operator*(const NDArray &lhs, const NDArray &rhs)
elementwise multiplication
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
virtual void Free(Handle handle)=0
Free storage.
NDArrayFunctionReg & set_num_use_vars(unsigned n)
set the number of mutate variables
Definition: ndarray.h:1442
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:96
static Context GPU(int32_t dev_id=-1)
int type_mask
information on how function should be called from API
Definition: ndarray.h:1314
NDArrayFunctionReg & set_function(void(*funary)(const NDArray &src, NDArray *out))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1405
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:73
Definition: optional.h:251
NDArrayFunctionReg & set_num_scalars(unsigned n)
set the number of scalar arguments
Definition: ndarray.h:1458
Definition: ndarray.h:71
unsigned num_mutate_vars
number of variable mutated by this function
Definition: ndarray.h:1310
execution time context. The information needed in runtime for actual execution.
Definition: base.h:349
interface of stream I/O for serialization
Definition: io.h:30
void * dptr
Pointer to the data.
Definition: storage.h:44
NDArrayFunctionReg & set_function(void(*fscalar)(const NDArray &lhs, const real_t &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1386
Graph node data structure.
base class of engine variables.
Definition: engine.h:43
Definition: ndarray.h:64
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)
macro to quickly declare traits information
Definition: type_traits.h:126
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:262
Context ctx
Context information about device and ID.
Definition: storage.h:52
NDArray()
default constructor
Definition: ndarray.h:84
unsigned num_use_vars
number of variable used by this function
Definition: ndarray.h:1308
int shared_id
Definition: storage.h:57
NDArrayFunctionReg & set_function(void(*fternary)(const NDArray &lhs, const NDArray &mhs, const NDArray &rhs, NDArray *out))
set the function body to a ternary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1345
int dtype() const
Definition: ndarray.h:313
Definition: ndarray.h:61
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:50
void Init(const mxnet::TShape &shape)
initialize the NDArray, assuming it is not assigned a meaningful shape before
Definition: ndarray.h:185
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:819
RowSparseAuxType
Definition: ndarray.h:57
Definition: ndarray.h:69
all the scalar should go before use_vars
Definition: ndarray.h:1292
bool storage_initialized() const
Returns true if a sparse ndarray&#39;s aux_data and storage are initialized Throws an exception if the in...
Definition: ndarray.h:336
void SampleExponential(real_t lambda, NDArray *out)
Sample exponential distribution for each elements of out.
Engine::VarHandle var() const
Definition: ndarray.h:388
void * dptr_
pointer to the data
Definition: tensor_blob.h:69
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
void CheckAndAlloc(const mxnet::ShapeVector &aux_shapes) const
Definition: ndarray.h:685
Definition: ndarray.h:57
whether this function allows the handles in the target to be empty NDArray that are not yet initializ...
Definition: ndarray.h:1301
size_t version() const
return var version of the NDArray
Definition: ndarray.h:396
Definition: ndarray.h:72
C Tensor object, manage memory of DLTensor. This data structure is intended to facilitate the borrowi...
Definition: dlpack.h:157
static Storage * Get()
namespace for dmlc
Definition: array_view.h:12
NDArrayStorageType storage_type() const
Definition: ndarray.h:321
virtual void WaitForVar(VarHandle var)=0
Wait for a variable.
void CopyFromTo(const NDArray &from, const NDArray *to, int priority=0)
issue an copy operation from one NDArray to another the two ndarray can sit on different devices this...
NDArray Detach() const
Return a copy of this NDArray without autograd history.
Definition: ndarray.h:649
CSRAuxType
Definition: ndarray.h:53
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
Definition: ndarray.h:53
Storage manager across multiple devices.
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr, bool wait=false)=0
Push an asynchronous operation to the engine.
Storage handle.
Definition: storage.h:40
static Context CPUShared(int32_t dev_id=0)
Definition: ndarray.h:63
size_t num_aux_data(NDArrayStorageType stype)
NDArrayFunctionReg & set_type_mask(int tmask)
set type mask
Definition: ndarray.h:1466
void WaitToRead() const
Block until all the pending write operations with respect to current NDArray are finished, and read can be performed.
Definition: ndarray.h:367
NDArray(const TBlob &data, int dev_id, const std::function< void()> &deleter)
constructing a static NDArray that shares data with TBlob which is with deleter Use with caution: all...
Definition: ndarray.h:144
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:206
an entry that represents output data from a node
Definition: node.h:51
Handle Alloc(size_t size, Context ctx)
Allocate a new contiguous memory for a given size.
Definition: storage.h:65
NDArray operator-(const NDArray &lhs, const NDArray &rhs)
elementwise subtraction
Definition: ndarray.h:70
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1479
NDArrayFunctionReg & set_function(void(*fsetvalue)(const real_t &rhs, NDArray *out))
set the function body to a NDArray setvalue function this will also auto set the parameters correctly...
Definition: ndarray.h:1329
NDArray operator+(const NDArray &lhs, const NDArray &rhs)
elementwise add
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
Registry entry for NDArrayFunction.
Definition: ndarray.h:1304
void CheckAndAllocData(const mxnet::TShape &storage_shape) const
Definition: ndarray.h:690
bool IsView() const
Definition: ndarray.h:199
NDArrayFunctionReg & set_function(void(*fbinary)(const NDArray &lhs, const NDArray &rhs, NDArray *out))
set the function body to a binary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1367
void WaitToWrite() const
Block until all the pending read/write operations with respect to current NDArray are finished...
Definition: ndarray.h:375
Storage::Handle storage_handle() const
get storage handle
Definition: ndarray.h:357
Dependency engine that schedules operations.
Definition: engine.h:116
void set_aux_shape(size_t index, const mxnet::TShape &shape) const
For a sparse operation on a csr matrix for example, the size of the column index array is an estimate...
Definition: ndarray.h:268
static Context CPU(int32_t dev_id=0)
const TBlob & data() const
Definition: ndarray.h:277
runtime functions for NDArray
Definition: imperative.h:50
size_t Size() const
Definition: tuple.h:520
int aux_type(size_t i) const
Definition: ndarray.h:316
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:72
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:437
void ReshapeAndAlloc(const mxnet::TShape &shape)
Allocate the space if the allocation has been delayed or the requested size is bigger than the availa...
Definition: ndarray.h:674
all the use_vars should go before scalar
Definition: ndarray.h:1290
void CheckAndAlloc() const
Allocate the space if it is delayed allocated. This is an internal function used by system that norma...
Definition: ndarray.h:660
const mxnet::TShape & storage_shape() const
Definition: ndarray.h:229
unsigned num_scalars
number of scalars used by this function
Definition: ndarray.h:1312
static Engine * Get()
NDArray(int shared_pid, int shared_id, const mxnet::TShape &shape, int dtype)
create ndarray from shared memory
Definition: ndarray.h:155
#define MSHADOW_TYPE_SWITCH(type, DType,...)
Definition: base.h:1074
NDArray(const mxnet::TShape &shape, Context ctx, bool delay_alloc=false, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray
Definition: ndarray.h:94
Definition: ndarray.h:68
bool shape_is_known(const TShape &x)
Definition: tuple.h:692
const mxnet::TShape & shape() const
Definition: ndarray.h:221
overloaded + operator between half_t and bf16_t
Definition: base.h:334
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:94
NDArray AsArray(const mxnet::TShape &shape, int dtype) const
Create a NDArray that shares memory with current one The new array must have smaller memory size than...
Definition: ndarray.h:563
size_t size
Size of the storage.
Definition: storage.h:48
bool IsSame(const NDArray &other) const
Definition: ndarray.h:211
void SampleGenNegBinomial(real_t mu, real_t alpha, NDArray *out)
Sample generalized negative binomial distribution for each elements of out.
Context information about the execution environment.
Definition: base.h:101
void SamplePoisson(real_t lambda, NDArray *out)
Sample Poisson distribution for each elements of out.
int ndim() const
Definition: tuple.h:217
ndarray interface
Definition: ndarray.h:81
NDArray(Context ctx, int dtype=mshadow::default_type_flag)
constructs a new dynamic NDArray whose shape is unknown, hence the NDArray is inherently lazily creat...
Definition: ndarray.h:114
NDArray(const TBlob &data, int dev_id)
constructing a static NDArray that shares data with TBlob Use with caution: allocate ONLY ONE NDArray...
Definition: ndarray.h:128
void ElementwiseSum(const std::vector< NDArray > &source, NDArray *out, int priority=0)
Perform elementwise sum over each data from source, store result into out.
std::function< void(NDArray **used_vars, real_t *scalars, NDArray **mutate_vars, int num_params, char **param_keys, char **param_vals)> NDArrayAPIFunction
definition of NDArray function
Definition: ndarray.h:1286
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
void SampleNegBinomial(int32_t k, real_t p, NDArray *out)
Sample negative binomial distribution for each elements of out.
NDArrayFunctionReg & set_function(void(*fgeneric)(NDArray **used_vars, real_t *s, NDArray **mutate_vars, const std::map< std::string, std::string > &param))
set the function body to a unary NDArray function this will also auto set the parameters correctly ...
Definition: ndarray.h:1422
type traits information header
int shared_pid
Id for IPC shared memory.
Definition: storage.h:56
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:65
void SampleGamma(real_t alpha, real_t beta, NDArray *out)
Sample gamma distribution for each elements of out.
NDArray operator/(const NDArray &lhs, const NDArray &rhs)
elementwise division
NDArrayFunctionTypeMask
mask information on how functions can be exposed
Definition: ndarray.h:1288