28 #ifndef MXNET_TENSOR_BLOB_H_ 29 #define MXNET_TENSOR_BLOB_H_ 31 #include <dmlc/logging.h> 32 #include <dmlc/json.h> 33 #include <dlpack/dlpack.h> 43 constexpr
const int kCPU = kDLCPU;
44 constexpr
const int kGPU = kDLGPU;
79 type_flag_(mshadow::DataType<
real_t>::kFlag) {
80 SetDLTensor(cpu::kDevMask, 0);
89 template<
typename DType>
91 : dptr_(dptr), shape_(shape),
92 type_flag_(mshadow::DataType<DType>::kFlag) {
93 SetDLTensor(dev_mask,
dev_id);
104 : dptr_(dptr), shape_(shape), type_flag_(type_flag) {
105 SetDLTensor(dev_mask,
dev_id);
112 : dptr_(dltensor.data),
113 shape_(
TShape(dltensor.shape, dltensor.shape + dltensor.
ndim)),
114 type_flag_(DLDataTypeTransform(dltensor.dtype)),
115 dltensor_(dltensor) {
117 if (dltensor.strides !=
nullptr) {
119 const int &
ndim = dltensor.ndim;
120 const int64_t *shape = dltensor.shape;
121 const int64_t *strides = dltensor.strides;
124 if (strides[ndim - 1] != 1) {
127 for (
int i = ndim - 2; i >= 0; --i) {
128 if (strides[i] != shape[i + 1] * strides[i + 1]) {
135 LOG(FATAL) <<
"Unsupported DLPack because MXNet only support compact tensor now";
147 template<
typename Device,
int dim,
typename DType>
148 TBlob(
const mshadow::Tensor<Device, dim, DType> &src) {
159 template<
typename Device,
int dim,
typename DType>
163 type_flag_ = mshadow::DataType<DType>::kFlag;
164 SetDLTensor(Device::kDevMask, -1);
179 CHECK_EQ(this->shape_.Size(), shape.Size()) <<
"Shape size mismatch " 180 << this->shape_.
Size() <<
" v.s. " << shape.Size();
191 template<
typename Device,
typename DType>
193 mshadow::Stream<Device> *stream = NULL)
const {
194 CHECK(Device::kDevMask == this->
dev_mask())
195 <<
"TBlob.get: device type do not match specified type";
196 CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
197 <<
"TBlob.get_with_shape: data type do not match specified type." 198 <<
"Expected: " << type_flag_ <<
" v.s. given " << mshadow::DataType<DType>::kFlag;
199 return mshadow::Tensor<Device, 2, DType>(
static_cast<DType*
>(
dptr_),
201 shape_[shape_.ndim() - 1],
211 template<
typename Device,
typename DType>
213 mshadow::Stream<Device> *stream = NULL)
const {
214 return this->get_with_shape<Device, 1, DType>(
215 mshadow::Shape1(shape_.Size()), stream);
219 return shape_.ndim();
231 return shape_.Size();
234 template<
typename DType>
236 CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
237 <<
"TBlob.get_with_shape: data type do not match specified type." 238 <<
"Expected: " << type_flag_ <<
" v.s. given " << mshadow::DataType<DType>::kFlag;
239 return static_cast<DType*
>(
dptr_);
243 return dltensor_.ctx.device_type;
247 return dltensor_.ctx.device_id;
266 template<
typename Device,
int dim,
typename DType>
267 inline mshadow::Tensor<Device, dim, DType>
get(mshadow::Stream<Device> *stream = NULL)
const {
268 CHECK(Device::kDevMask == this->
dev_mask())
269 <<
"TBlob.get: device type do not match specified type";
270 return mshadow::Tensor<Device, dim, DType>(dptr<DType>(),
271 shape_.get<dim>(), shape_[shape_.ndim() - 1], stream);
283 template<
typename Device,
int dim,
typename DType>
285 const mshadow::Shape<dim> &shape,
286 mshadow::Stream<Device> *stream = NULL)
const {
287 CHECK(Device::kDevMask == this->
dev_mask())
288 <<
"TBlob.get: device type do not match specified type";
289 CHECK_EQ(this->
CheckContiguous(),
true) <<
"TBlob.get_reshape: must be contiguous";
290 CHECK_EQ(this->shape_.Size(), shape.Size())
291 <<
"TBlob.get_with_shape: new and old shape do not match total elements";
292 return mshadow::Tensor<Device, dim, DType>(dptr<DType>(), shape,
293 shape[dim - 1], stream);
304 template<
typename Device,
typename DType>
306 int axis, mshadow::Stream<Device> *stream = NULL)
const {
307 return this->get_with_shape<Device, 3, DType>(
308 this->shape_.FlatTo3D(axis), stream);
320 template<
typename Device,
typename DType>
322 int axis_begin,
int axis_end,
323 mshadow::Stream<Device> *stream = NULL)
const {
324 return this->get_with_shape<Device, 3, DType>(
325 this->shape_.FlatTo3D(axis_begin, axis_end), stream);
336 template<
typename Device,
int dim,
typename DType>
337 inline mshadow::Tensor<Device, dim, DType>
FlatToKD(
338 mshadow::Stream<Device> *stream = NULL)
const {
339 mshadow::Shape<dim> shape;
342 for (
int i = 0; i < dim -
ndim(); ++i) {
346 for (
int i = 0; i <
ndim() - dim + 1; ++i) {
347 shape[0] *= shape_[i];
351 shape[i -
ndim() + dim] = shape_[i];
353 return this->get_with_shape<Device, dim, DType>(shape, stream);
357 static DLDataType DTypeTransform(
int type_flag) {
359 case mshadow::kFloat32:
return DLDataType{kDLFloat, 32, 1};
360 case mshadow::kFloat64:
return DLDataType{kDLFloat, 64, 1};
361 case mshadow::kFloat16:
return DLDataType{kDLFloat, 16, 1};
362 case mshadow::kUint8:
return DLDataType{kDLUInt, 8, 1};
363 case mshadow::kInt32:
return DLDataType{kDLInt, 32, 1};
364 case mshadow::kInt8:
return DLDataType{kDLInt, 8, 1};
365 case mshadow::kInt64:
return DLDataType{kDLInt, 64, 1};
367 LOG(FATAL) <<
"Unknown type_flag=" << type_flag;
372 static int DLDataTypeTransform(DLDataType dldata_type) {
373 if (dldata_type.lanes != 1) {
374 LOG(FATAL) <<
"Unsupported DLDataType whose lanes != 1";
376 switch (dldata_type.code) {
378 switch (dldata_type.bits) {
379 case 16:
return mshadow::kFloat16;
380 case 32:
return mshadow::kFloat32;
381 case 64:
return mshadow::kFloat64;
385 switch (dldata_type.bits) {
386 case 8:
return mshadow::kUint8;
390 switch (dldata_type.bits) {
391 case 8:
return mshadow::kInt8;
392 case 32:
return mshadow::kInt32;
393 case 64:
return mshadow::kInt64;
397 LOG(FATAL) <<
"Unknown DLDataType{" << dldata_type.code
398 <<
", " << dldata_type.bits
399 <<
", " << dldata_type.lanes <<
"}";
400 return mshadow::kFloat32;
404 dltensor_.data =
dptr_;
405 dltensor_.ctx = DLContext{
static_cast<DLDeviceType
>(
dev_mask), dev_id};
406 dltensor_.ndim = shape_.ndim();
407 dltensor_.dtype = DTypeTransform(type_flag_);
408 dltensor_.shape = shape_.data();
409 dltensor_.strides =
nullptr;
410 dltensor_.byte_offset = 0;
425 namespace parameter {
429 :
public FieldEntryBase<FieldEntry<mxnet::TShape>, mxnet::TShape> {
435 virtual void Check(
void *head)
const {
438 if (expect_ndim_ != 0 && v.ndim() != expect_ndim_) {
439 std::ostringstream os;
440 os <<
"value " << v <<
"for Parameter " << this->key_
441 <<
" has wrong dimensions, expected dimension=" << expect_ndim_;
442 throw dmlc::ParamError(os.str());
444 if (enforce_nonzero_) {
447 std::ostringstream os;
448 os <<
"value " << v <<
"for Parameter " << this->key_
449 <<
" is invalid, the input shape must be nonzero in all dimensions";
450 throw dmlc::ParamError(os.str());
456 this->enforce_nonzero_ =
true;
466 bool enforce_nonzero_;
474 #endif // MXNET_TENSOR_BLOB_H_ TBlob & operator=(const mshadow::Tensor< Device, dim, DType > &src)
assignment from tensor
Definition: tensor_blob.h:160
DMLC_DECLARE_TYPE_NAME(nnvm::Tuple< dmlc::optional< int >>,"Shape(tuple)")
constexpr const int kTVMNDArrayTypeCode
Definition: tensor_blob.h:49
TBlob(const DLTensor &dltensor)
constructor that construct TBlob from DLTensor
Definition: tensor_blob.h:111
TShape shape_
shape of the tensor
Definition: tensor_blob.h:72
FieldEntry< mxnet::TShape > & set_expect_ndim(mxnet::index_t ndim)
Definition: tensor_blob.h:459
namespace of mxnet
Definition: base.h:118
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:126
TBlob(void)
default constructor, default copy assign will work
Definition: tensor_blob.h:77
Definition: tensor_blob.h:428
int type_flag_
type flag of the tensor blob
Definition: tensor_blob.h:74
mshadow::Tensor< Device, dim, DType > get_with_shape(const mshadow::Shape< dim > &shape, mshadow::Stream< Device > *stream=NULL) const
fetch a tensor in given shape If size do not match the stored size, an error will be issued ...
Definition: tensor_blob.h:284
nnvm::TShape TShape
Shape data structure used to record shape information.
Definition: base.h:128
FieldEntry< mxnet::TShape > & enforce_nonzero()
Definition: tensor_blob.h:455
FieldEntryBase< FieldEntry< mxnet::TShape >, mxnet::TShape > Parent
Definition: tensor_blob.h:433
mshadow::Tensor< Device, 1, DType > FlatTo1D(mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 1 dimension, collapse all the dimensions together.
Definition: tensor_blob.h:212
constexpr const int kGPU
Definition: tensor_blob.h:44
Symbol max(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2756
index_t size(index_t idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor_blob.h:226
void * dptr_
pointer to the data
Definition: tensor_blob.h:70
Definition: ndarray.h:1457
DType * dptr() const
get pointer in dtype
Definition: tensor_blob.h:235
int ndim(void) const
return number of dimension of the tensor inside
Definition: tensor_blob.h:218
TBlob(const mshadow::Tensor< Device, dim, DType > &src)
constructor from tensor
Definition: tensor_blob.h:148
const DLTensor & dltensor() const
return the corresponding DLTensor
Definition: tensor_blob.h:253
mshadow::Tensor< Device, 3, DType > FlatTo3D(int axis_begin, int axis_end, mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 3 dimension, collapse the dimension: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim).
Definition: tensor_blob.h:321
mshadow::Tensor< Device, dim, DType > FlatToKD(mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to specified number of dimensions, collapse the highest dimensions or pad with hig...
Definition: tensor_blob.h:337
mshadow::Tensor< Device, 2, DType > FlatTo2D(mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor_blob.h:192
virtual void Check(void *head) const
Definition: tensor_blob.h:435
index_t Size(void) const
total number of elements in the tensor
Definition: tensor_blob.h:230
bool CheckContiguous(void) const
Definition: tensor_blob.h:170
TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id=-1)
constructor that construct TBlob from contiguous memory
Definition: tensor_blob.h:90
mshadow::Tensor< Device, 3, DType > FlatTo3D(int axis, mshadow::Stream< Device > *stream=NULL) const
flatten the tensor to 3 dimension, collapse the dimension before and after specified axis...
Definition: tensor_blob.h:305
mshadow::index_t index_t
index type usually use unsigned
Definition: base.h:124
TBlob(void *dptr, const TShape &shape, int dev_mask, int type_flag, int dev_id=-1)
constructor that construct TBlob from contiguous memory
Definition: tensor_blob.h:103
constexpr const int kCPU
Definition: tensor_blob.h:43
TBlob reshape(const TShape &shape) const
reshape to shape
Definition: tensor_blob.h:178
FieldEntry()
Definition: tensor_blob.h:431
ndarray interface
Definition: ndarray.h:82
int dev_mask() const
device mask of the corresponding device
Definition: tensor_blob.h:242
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66
int dev_id() const
device index of the corresponding device
Definition: tensor_blob.h:246