24 #ifndef MXNET_TUPLE_H_ 25 #define MXNET_TUPLE_H_ 28 #include <type_traits> 33 #include "nnvm/op_attr_types.h" 34 #include "nnvm/graph_attr_types.h" 35 #include "nnvm/graph.h" 36 #include "nnvm/pass.h" 53 template<
typename ValueType>
77 inline Tuple(std::initializer_list<ValueType> init) {
78 this->
assign(init.begin(), init.end());
84 inline Tuple(std::vector<ValueType> init) {
85 this->
assign(init.begin(), init.end());
101 template<
typename RandomAccessIterator>
103 RandomAccessIterator
end) {
112 template<
typename RandomAccessIterator>
114 RandomAccessIterator
end) {
115 this->
SetDim(end - begin);
117 std::copy(begin, end, this->
begin());
135 if (src.
ndim() == -1) {
157 this->
assign(init.begin(), init.end());
166 if (
ndim() == -1)
return true;
174 return !(*
this == s);
177 inline const ValueType *
begin()
const {
185 inline const ValueType*
end()
const {
202 CHECK(i >= 0 && i <
ndim()) <<
"index = " << i <<
" must be in range [0, " <<
ndim() <<
")";
211 CHECK(i >= 0 && i <
ndim()) <<
"index = " << i <<
" must be in range [0, " <<
ndim() <<
")";
218 inline void Save(dmlc::JSONWriter* writer)
const {
219 std::vector<ValueType> tmp(
begin(),
end());
226 inline void Load(dmlc::JSONReader* reader) {
227 std::vector<ValueType> tmp;
229 this->
assign(tmp.begin(), tmp.end());
237 friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
238 if (t.ndim() == -1) {
246 const ValueType*
begin = t.begin();
247 const ValueType*
end = t.end();
248 for (
const ValueType* it = begin; it !=
end; ++it) {
249 if (it != begin) os <<
',';
265 if (isdigit(ch) || ch ==
'-') {
273 if (ch ==
'(' || ch ==
'[')
break;
275 is.setstate(std::ios::failbit);
281 while (isspace(is.peek())) {
284 if (is.peek() ==
')' || is.peek() ==
']') {
291 std::vector<ValueType> tmp;
297 }
while (isspace(ch));
298 if (std::is_integral<ValueType>::value && ch ==
'L') {
307 if (ch ==
')' || ch ==
']') {
312 if (ch ==
')' || ch ==
']')
break;
313 }
else if (ch ==
')' || ch ==
']') {
316 is.setstate(std::ios::failbit);
320 t.
assign(tmp.begin(), tmp.end());
329 template<
typename DType = ValueType,
typename TStream>
330 inline void Save(TStream *strm)
const;
338 template<
typename DType = ValueType,
typename TStream>
339 inline bool Load(TStream *strm);
354 CHECK_GE(ndim, -1) <<
"ndim cannot be less than -1, received " <<
ndim;
355 if (ndim > kStackCache &&
360 }
else if (ndim <= 0 &&
data_heap_ !=
nullptr) {
372 CHECK_GE(ndim, -1) <<
"shape ndim must be >= -1, while received " <<
ndim;
378 CHECK_GE(dim_size, -1) <<
"shape dim size must be >= -1, while received " << dim_size;
379 return dim_size != -1;
409 std::fill_n(
begin(), ndim, value);
417 if (s.
ndim() == -1) {
427 inline TShape(std::initializer_list<dim_t> init) {
428 this->
assign(init.begin(), init.end());
445 template<
typename RandomAccessIterator,
446 typename std::enable_if<
447 std::is_same<typename std::iterator_traits<RandomAccessIterator>::iterator_category,
448 std::random_access_iterator_tag>::value,
int>::type = 0>
450 RandomAccessIterator
end) {
459 if (src.
ndim() == -1) {
480 for (
const dim_t* it = start; it != fin; ++it) {
481 CHECK(
dim_size_is_known(*it)) <<
"Shape dim size cannot be a negative value " << *it;
491 inline size_t ProdShape(
int dimstart,
int dimend)
const {
493 CHECK_GE(dimstart, 0) <<
"dimstart must be >= 0, while received " << dimstart;
494 CHECK_LE(dimend, this->
ndim()) <<
"dimend must be <= " << this->
ndim()
495 <<
", while received " << dimend;
497 const dim_t *d = this->data();
498 for (
int i = dimstart; i < dimend; ++i) {
499 CHECK(
dim_size_is_known(d[i])) <<
"Shape dim size must be known, while received " << d[i];
512 #ifdef MSHADOW_XINLINE 514 inline TShape(
const mshadow::Shape<dim> &s) {
515 this->
assign(s.shape_, s.shape_ + dim);
519 inline TShape(mshadow::Shape<dim> &&s) {
520 this->
assign(s.shape_, s.shape_ + dim);
530 this->
assign(shape.shape_, shape.shape_ + dim);
539 inline mshadow::Shape<dim>
get()
const {
540 CHECK_EQ(dim,
ndim())
541 <<
"dimension do not match target dimension " << dim <<
" vs " <<
ndim();
542 const dim_t *d = this->data();
543 mshadow::Shape<dim> s;
544 for (
int i = 0; i < dim; ++i) {
553 inline mshadow::Shape<2> FlatTo2D(
void)
const {
556 if (
ndim() == 0)
return mshadow::Shape2(1, 1);
557 const dim_t *d = this->data();
558 s.shape_[1] = d[
ndim() - 1];
560 for (
int i = 1; i <
ndim(); ++i) {
572 inline mshadow::Shape<3> FlatTo3D(
int axis_begin,
int axis_end)
const {
573 CHECK(axis_end >= axis_begin);
576 if (
ndim() == 0)
return mshadow::Shape3(1, 1, 1);
577 const dim_t *d = this->data();
582 for (
int i = 0; i < axis_begin; ++i) {
585 for (
int i = axis_begin; i <= axis_end; ++i) {
588 for (
int i = axis_end + 1; i <
ndim(); ++i) {
598 inline mshadow::Shape<3> FlatTo3D(
int axis)
const {
599 return FlatTo3D(axis, axis);
602 if (
ndim() != s.
ndim())
return false;
606 return !(*
this == s);
614 inline bool operator==(
const mshadow::Shape<dim> &s)
const {
615 if (
ndim_ != dim)
return false;
617 for (
size_t i = 0; i < dim; ++i) {
618 if (d[i] != s.shape_[i])
return false;
628 inline bool operator!=(
const mshadow::Shape<dim> &s)
const {
629 return !(*
this == s);
641 CHECK(idx >= 0 && idx < x.
ndim())
642 <<
"idx = " << idx <<
" exceeds shape dimension range [0, " << x.
ndim() <<
")";
650 for (
int i = 0; i < x.
ndim(); ++i) {
657 template<
typename SrcIter,
typename DstIter>
661 typedef typename std::iterator_traits<SrcIter>::value_type SrcDType;
662 typedef typename std::iterator_traits<DstIter>::value_type DstDType;
663 auto cast = [](
const SrcDType& dim) {
return static_cast<DstDType
>(dim); };
664 return std::transform(begin, end, dst_begin, cast);
668 template<
typename SrcIter>
670 size_t ndim = std::distance(begin, end);
677 template<
typename ValueType>
678 template<
typename DType,
typename TStream>
681 if (
typeid(DType) ==
typeid(ValueType)) {
682 strm->Write(
begin(),
sizeof(ValueType) *
ndim_);
684 std::vector<DType> buffer(
ndim_);
686 strm->Write(buffer.data(),
sizeof(DType) *
ndim_);
691 template<
typename ValueType>
692 template<
typename DType,
typename TStream>
696 size_t nread =
sizeof(DType) *
ndim_;
697 if (
typeid(DType) ==
typeid(ValueType)) {
698 if (strm->Read(
begin(), nread) != nread)
return false;
700 std::vector<DType> buffer(
ndim_);
701 if (strm->Read(buffer.data(), nread) != nread)
return false;
715 std::hash<int> hash_int;
716 size_t res = hash_int(val.
ndim());
717 for (
int i = 0; i < val.
ndim(); ++i) {
718 res = dmlc::HashCombine(res, val[i]);
729 std::hash<int> hash_int;
730 size_t res = hash_int(val.
ndim());
731 for (
int i = 0; i < val.
ndim(); ++i) {
732 res = dmlc::HashCombine(res, val[i]);
744 #if !defined(_MSC_VER) 748 return "tuple of <" + type_name<T>() +
">";
784 #endif // MXNET_TUPLE_H_ DMLC_DECLARE_TYPE_NAME(optional< mxnet::Tuple< int >>,"Shape or None")
Tuple< ValueType > & operator=(const Tuple< ValueType > &src)
assignment from another tuple.
Definition: tuple.h:134
bool operator==(const Tuple< ValueType > &s) const
Definition: tuple.h:164
size_t operator()(const mxnet::TShape &val) const
hash a TShape into unsigned int
Definition: tuple.h:728
Tuple< ValueType > & operator=(Tuple< ValueType > &&src)
assignment from rvalue of another tuple.
Definition: tuple.h:147
dim_t * data()
Definition: tuple.h:509
TShape(const Tuple< dim_t > &s)
copy constructor of TShape
Definition: tuple.h:416
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:349
namespace of mxnet
Definition: base.h:89
TShape & operator=(Tuple< dim_t > &&src)
move assignment function from tshape
Definition: tuple.h:471
Tuple(std::initializer_list< ValueType > init)
constructor from initializer list
Definition: tuple.h:77
size_t operator()(const mxnet::Tuple< T > &val) const
hash a Tuple into unsigned int
Definition: tuple.h:714
TShape()
default constructor
Definition: tuple.h:398
DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin)
helper function to cast type of container elements
Definition: tuple.h:658
Tuple(Tuple< ValueType > &&src)
move constructor from Tuple
Definition: tuple.h:92
void Load(dmlc::JSONReader *reader)
Load Tuple from JSON.
Definition: tuple.h:226
const dim_t * data() const
Definition: tuple.h:505
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:113
ValueType * end()
Definition: tuple.h:189
TShape(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator. This function is enforced with template arguments of ra...
Definition: tuple.h:449
ValueType & operator[](int i)
get corresponding index
Definition: tuple.h:201
const ValueType & operator[](int i) const
get corresponding index
Definition: tuple.h:210
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:768
size_t Size() const
Definition: tuple.h:476
Tuple(const Tuple< ValueType > &s)
copy constructor from another tuple
Definition: tuple.h:66
~Tuple()
destructor
Definition: tuple.h:59
Tuple(std::vector< ValueType > init)
constructor from vector
Definition: tuple.h:84
Definition: ndarray.h:1488
Tuple(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator
Definition: tuple.h:102
bool dim_size_is_known(const dim_t dim_size)
Definition: tuple.h:377
Tuple()=default
default constructor
int num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:347
static std::string value()
Definition: tuple.h:747
static const int kStackCache
Definition: tuple.h:343
const ValueType * end() const
Definition: tuple.h:185
void SetDim(int ndim)
Definition: tuple.h:353
void swap(Tuple< ValueType > &other)
Swap current object with other.
Definition: tuple.h:123
TShape(Tuple< dim_t > &&s)
move constructor.
Definition: tuple.h:434
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:54
TShape(std::initializer_list< dim_t > init)
constructor from initializer list
Definition: tuple.h:427
const ValueType * begin() const
Definition: tuple.h:177
int64_t dim_t
data type to store dim size
Definition: c_api.h:62
int ndim_
number of dimension of the tuple
Definition: tuple.h:345
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:395
nnvm::FInferNodeEntryAttr< mxnet::TShape > FInferShape
Shape inference function. Update the shapes given the input shape information. TShape.ndim() == -1 means the shape is still unknown.
Definition: tuple.h:780
TShape & operator=(const Tuple< dim_t > &src)
assignment function from tshape
Definition: tuple.h:458
TShape(const int ndim, const dim_t value)
Definition: tuple.h:406
size_t ProdShape(int dimstart, int dimend) const
Definition: tuple.h:491
int ndim() const
Definition: tuple.h:193
Tuple< ValueType > & operator=(std::initializer_list< ValueType > init)
assignment from initializer list
Definition: tuple.h:156
bool operator!=(const Tuple< ValueType > &s) const
Definition: tuple.h:173
bool shape_is_known(const TShape &x)
Definition: tuple.h:648
bool ndim_is_known(const int ndim)
Definition: tuple.h:371
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:351
ValueType * begin()
Definition: tuple.h:181
friend std::istream & operator>>(std::istream &is, Tuple< ValueType > &t)
read tuple from the istream
Definition: tuple.h:261
void Save(dmlc::JSONWriter *writer) const
Save Tuple to JSON.
Definition: tuple.h:218