mxnet
tensor_blob.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 #ifndef MXNET_TENSOR_BLOB_H_
29 #define MXNET_TENSOR_BLOB_H_
30 
31 #include <dmlc/logging.h>
32 #include <dmlc/json.h>
33 #include <dlpack/dlpack.h>
34 #include <vector>
35 #include <iostream>
36 #include <utility>
37 #include <algorithm>
38 #include "./base.h"
39 
40 namespace mxnet {
41 
42 // redefine DLPack enumeration to be backward compatible.
43 constexpr const int kCPU = kDLCPU;
44 constexpr const int kGPU = kDLGPU;
45 // extension type code under TVM function.
46 // Currently NNVM reserved 16 to 19 type code from TVM
47 // 16, 17, 18 is used by NNVM compiler already.
48 // Pick code 19 for MXNet NDArray
49 constexpr const int kTVMNDArrayTypeCode = 19;
50 
51 /* Forward declaration for friend declaration in TBlob */
52 class NDArray;
53 
66 class TBlob {
67  friend class NDArray;
68  public:
70  void *dptr_;
75 
77  TBlob(void)
78  : dptr_(NULL),
79  type_flag_(mshadow::DataType<real_t>::kFlag) {
80  SetDLTensor(cpu::kDevMask, 0);
81  }
89  template<typename DType>
90  TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id = -1)
91  : dptr_(dptr), shape_(shape),
92  type_flag_(mshadow::DataType<DType>::kFlag) {
93  SetDLTensor(dev_mask, dev_id);
94  }
103  TBlob(void *dptr, const TShape &shape, int dev_mask, int type_flag, int dev_id = -1)
104  : dptr_(dptr), shape_(shape), type_flag_(type_flag) {
105  SetDLTensor(dev_mask, dev_id);
106  }
114  template<typename Device, int dim, typename DType>
115  TBlob(const mshadow::Tensor<Device, dim, DType> &src) { // NOLINT(*)
116  *this = src;
117  }
126  template<typename Device, int dim, typename DType>
127  inline TBlob &operator=(const mshadow::Tensor<Device, dim, DType> &src) {
128  dptr_ = src.dptr_;
129  shape_ = src.shape_;
130  type_flag_ = mshadow::DataType<DType>::kFlag;
131  SetDLTensor(Device::kDevMask, -1);
132  return *this;
133  }
137  inline bool CheckContiguous(void) const {
138  return true;
139  }
145  inline TBlob reshape(const TShape& shape) const {
146  CHECK_EQ(this->shape_.Size(), shape.Size()) << "Shape size mismatch "
147  << this->shape_.Size() << " v.s. " << shape.Size();
148  TBlob ret(this->dptr_, shape, this->dev_mask(), this->type_flag_, this->dev_id());
149  return ret;
150  }
158  template<typename Device, typename DType>
159  inline mshadow::Tensor<Device, 2, DType> FlatTo2D(
160  mshadow::Stream<Device> *stream = NULL) const {
161  CHECK(Device::kDevMask == this->dev_mask())
162  << "TBlob.get: device type do not match specified type";
163  CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
164  << "TBlob.get_with_shape: data type do not match specified type."
165  << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
166  return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
167  shape_.FlatTo2D(),
168  shape_[shape_.ndim() - 1],
169  stream);
170  }
178  template<typename Device, typename DType>
179  inline mshadow::Tensor<Device, 1, DType> FlatTo1D(
180  mshadow::Stream<Device> *stream = NULL) const {
181  return this->get_with_shape<Device, 1, DType>(
182  mshadow::Shape1(shape_.Size()), stream);
183  }
185  inline int ndim(void) const {
186  return shape_.ndim();
187  }
193  inline index_t size(index_t idx) const {
194  return shape_[idx];
195  }
197  inline index_t Size(void) const {
198  return shape_.Size();
199  }
201  template<typename DType>
202  inline DType* dptr() const {
203  CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
204  << "TBlob.get_with_shape: data type do not match specified type."
205  << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
206  return static_cast<DType*>(dptr_);
207  }
209  inline int dev_mask() const {
210  return dltensor_.ctx.device_type;
211  }
213  inline int dev_id() const {
214  return dltensor_.ctx.device_id;
215  }
220  inline const DLTensor& dltensor() const {
221  return dltensor_;
222  }
223 
233  template<typename Device, int dim, typename DType>
234  inline mshadow::Tensor<Device, dim, DType> get(mshadow::Stream<Device> *stream = NULL) const {
235  CHECK(Device::kDevMask == this->dev_mask())
236  << "TBlob.get: device type do not match specified type";
237  return mshadow::Tensor<Device, dim, DType>(dptr<DType>(),
238  shape_.get<dim>(), shape_[shape_.ndim() - 1], stream);
239  }
250  template<typename Device, int dim, typename DType>
251  inline mshadow::Tensor<Device, dim, DType> get_with_shape(
252  const mshadow::Shape<dim> &shape,
253  mshadow::Stream<Device> *stream = NULL) const {
254  CHECK(Device::kDevMask == this->dev_mask())
255  << "TBlob.get: device type do not match specified type";
256  CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous";
257  CHECK_EQ(this->shape_.Size(), shape.Size())
258  << "TBlob.get_with_shape: new and old shape do not match total elements";
259  return mshadow::Tensor<Device, dim, DType>(dptr<DType>(), shape,
260  shape[dim - 1], stream);
261  }
271  template<typename Device, typename DType>
272  inline mshadow::Tensor<Device, 3, DType> FlatTo3D(
273  int axis, mshadow::Stream<Device> *stream = NULL) const {
274  return this->get_with_shape<Device, 3, DType>(
275  this->shape_.FlatTo3D(axis), stream);
276  }
287  template<typename Device, typename DType>
288  inline mshadow::Tensor<Device, 3, DType> FlatTo3D(
289  int axis_begin, int axis_end,
290  mshadow::Stream<Device> *stream = NULL) const {
291  return this->get_with_shape<Device, 3, DType>(
292  this->shape_.FlatTo3D(axis_begin, axis_end), stream);
293  }
303  template<typename Device, int dim, typename DType>
304  inline mshadow::Tensor<Device, dim, DType> FlatToKD(
305  mshadow::Stream<Device> *stream = NULL) const {
306  mshadow::Shape<dim> shape;
307  shape[0] = 1;
308  // Pad higher dimensions in case dim > ndim()
309  for (int i = 0; i < dim - ndim(); ++i) {
310  shape[i] = 1;
311  }
312  // Collapse higher dimensions in case dim < ndim()
313  for (int i = 0; i < ndim() - dim + 1; ++i) {
314  shape[0] *= shape_[i];
315  }
316  // Preserve lower dimensions.
317  for (int i = std::max(0, ndim() - dim + 1); i < ndim(); ++i) {
318  shape[i - ndim() + dim] = shape_[i];
319  }
320  return this->get_with_shape<Device, dim, DType>(shape, stream);
321  }
322 
323  private:
324  static DLDataType DTypeTransform(int type_flag) {
325  switch (type_flag) {
326  case mshadow::kFloat32: return DLDataType{kDLFloat, 32, 1};
327  case mshadow::kFloat64: return DLDataType{kDLFloat, 64, 1};
328  case mshadow::kFloat16: return DLDataType{kDLFloat, 16, 1};
329  case mshadow::kUint8: return DLDataType{kDLUInt, 8, 1};
330  case mshadow::kInt32: return DLDataType{kDLInt, 32, 1};
331  case mshadow::kInt8: return DLDataType{kDLInt, 8, 1};
332  case mshadow::kInt64: return DLDataType{kDLInt, 64, 1};
333  default: {
334  LOG(FATAL) << "Unknown type_flag=" << type_flag;
335  return DLDataType();
336  }
337  }
338  }
339 
340  inline void SetDLTensor(int dev_mask, int dev_id) {
341  dltensor_.data = dptr_;
342  dltensor_.ctx = DLContext{static_cast<DLDeviceType>(dev_mask), dev_id};
343  dltensor_.ndim = shape_.ndim();
344  dltensor_.dtype = DTypeTransform(type_flag_);
345  dltensor_.shape = shape_.data();
346  dltensor_.strides = NULL;
347  dltensor_.byte_offset = 0;
348  }
349 
350  private:
352  DLTensor dltensor_;
353 };
354 } // namespace mxnet
355 
356 namespace dmlc {
357 // Add a few patches to support TShape in dmlc/parameter.
358 DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
359 DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<int>, "Shape(tuple)");
360 DMLC_DECLARE_TYPE_NAME(nnvm::Tuple<dmlc::optional<int>>, "Shape(tuple)");
361 
362 namespace parameter {
363 
364 template<>
365 class FieldEntry<mxnet::TShape>
366  : public FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> {
367  public:
368  FieldEntry() : enforce_nonzero_(false), expect_ndim_(0) {}
369  // parent class
370  typedef FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> Parent;
371 
372  virtual void Check(void *head) const {
373  Parent::Check(head);
374  mxnet::TShape &v = this->Get(head);
375  if (expect_ndim_ != 0 && v.ndim() != expect_ndim_) {
376  std::ostringstream os;
377  os << "value " << v << "for Parameter " << this->key_
378  << " has wrong dimensions, expected dimension=" << expect_ndim_;
379  throw dmlc::ParamError(os.str());
380  }
381  if (enforce_nonzero_) {
382  for (mxnet::index_t i = 0; i < v.ndim(); ++i) {
383  if (v[i] == 0U) {
384  std::ostringstream os;
385  os << "value " << v << "for Parameter " << this->key_
386  << " is invalid, the input shape must be nonzero in all dimensions";
387  throw dmlc::ParamError(os.str());
388  }
389  }
390  }
391  }
393  this->enforce_nonzero_ = true;
394  return this->self();
395  }
397  expect_ndim_ = ndim;
398  return this->self();
399  }
400 
401  private:
402  // whether all the entries need to be nonzero
403  bool enforce_nonzero_;
404  // expected number of dimension, default = 0 means no restriction.
405  mxnet::index_t expect_ndim_;
406 };
407 
408 } // namespace parameter
409 } // namespace dmlc
410 
411 #endif // MXNET_TENSOR_BLOB_H_
TBlob & operator=(const mshadow::Tensor< Device, dim, DType > &src)
assignment from tensor
Definition: tensor_blob.h:127
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple< dmlc::optional< int >>,"Shape(tuple)")
constexpr const int kTVMNDArrayTypeCode
Definition: tensor_blob.h:49
TShape shape_
shape of the tensor
Definition: tensor_blob.h:72
FieldEntry< mxnet::TShape > & set_expect_ndim(mxnet::index_t ndim)
Definition: tensor_blob.h:396
namespace of mxnet
Definition: base.h:118
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:126
TBlob(void)
default constructor, default copy assign will work
Definition: tensor_blob.h:77
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:74
mshadow::Tensor< Device, dim, DType > get_with_shape(const mshadow::Shape< dim > &shape, mshadow::Stream< Device > *stream=NULL) const
fetch a tensor in given shape If size do not match the stored size, an error will be issued ...
Definition: tensor_blob.h:251
nnvm::TShape TShape
Shape data structure used to record shape information.
Definition: base.h:128
FieldEntry< mxnet::TShape > & enforce_nonzero()
Definition: tensor_blob.h:392
FieldEntryBase< FieldEntry< mxnet::TShape >, mxnet::TShape > Parent
Definition: tensor_blob.h:370
mshadow::Tensor< Device, 1, DType > FlatTo1D(mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 1 dimension, collapse all the dimensions together.
Definition: tensor_blob.h:179
constexpr const int kGPU
Definition: tensor_blob.h:44
Symbol max(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2735
index_t size(index_t idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor_blob.h:193
void * dptr_
pointer to the data
Definition: tensor_blob.h:70
Definition: ndarray.h:1368
DType * dptr() const
get pointer in dtype
Definition: tensor_blob.h:202
int ndim(void) const
return number of dimension of the tensor inside
Definition: tensor_blob.h:185
TBlob(const mshadow::Tensor< Device, dim, DType > &src)
constructor from tensor
Definition: tensor_blob.h:115
const DLTensor & dltensor() const
return the corresponding DLTensor
Definition: tensor_blob.h:220
mshadow::Tensor< Device, 3, DType > FlatTo3D(int axis_begin, int axis_end, mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 3 dimension, collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim).
Definition: tensor_blob.h:288
mshadow::Tensor< Device, dim, DType > FlatToKD(mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to specified number of dimensions, collapse the highest dimensions or pad with hig...
Definition: tensor_blob.h:304
mshadow::Tensor< Device, 2, DType > FlatTo2D(mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor_blob.h:159
virtual void Check(void *head) const
Definition: tensor_blob.h:372
index_t Size(void) const
total number of elements in the tensor
Definition: tensor_blob.h:197
bool CheckContiguous(void) const
Definition: tensor_blob.h:137
TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id=-1)
constructor that construct TBlob from contiguous memory
Definition: tensor_blob.h:90
mshadow::Tensor< Device, 3, DType > FlatTo3D(int axis, mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 3 dimension, collapse the dimension before and after specified axis...
Definition: tensor_blob.h:272
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:124
TBlob(void *dptr, const TShape &shape, int dev_mask, int type_flag, int dev_id=-1)
constructor that construct TBlob from contiguous memory
Definition: tensor_blob.h:103
constexpr const int kCPU
Definition: tensor_blob.h:43
TBlob reshape(const TShape &shape) const
reshape to shape
Definition: tensor_blob.h:145
FieldEntry()
Definition: tensor_blob.h:368
ndarray interface
Definition: ndarray.h:82
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:209
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66
int dev_id() const
device index of the corresponding device
Definition: tensor_blob.h:213