23 #ifndef MXNET_TUPLE_H_ 24 #define MXNET_TUPLE_H_ 27 #include <type_traits> 56 template<
typename ValueType>
73 std::fill_n(
begin(), ndim, value);
91 inline Tuple(std::initializer_list<ValueType> init) {
92 this->
assign(init.begin(), init.end());
98 inline Tuple(std::vector<ValueType> init) {
99 this->
assign(init.begin(), init.end());
115 template<
typename RandomAccessIterator>
117 RandomAccessIterator
end) {
122 using namespace runtime;
123 ADT adt = Downcast<ADT, ObjectRef>(src);
125 for (
int i = 0; i <
ndim_; ++i) {
126 this->
begin()[i] = Downcast<Integer, ObjectRef>(adt[i])->value;
136 template<
typename RandomAccessIterator>
138 RandomAccessIterator
end) {
139 this->
SetDim(end - begin);
141 std::copy(begin, end, this->
begin());
159 if (src.
ndim() == -1) {
181 this->
assign(init.begin(), init.end());
190 if (
ndim() == -1)
return true;
198 return !(*
this == s);
201 inline const ValueType *
begin()
const {
209 inline const ValueType*
end()
const {
228 #pragma GCC diagnostic push 229 #pragma GCC diagnostic ignored "-Wstrict-overflow" 230 CHECK(i >= 0 && i <
ndim()) <<
"index = " << i <<
" must be in range [0, " <<
ndim() <<
")";
231 #pragma GCC diagnostic pop 242 #pragma GCC diagnostic push 243 #pragma GCC diagnostic ignored "-Wstrict-overflow" 244 CHECK(i >= 0 && i <
ndim()) <<
"index = " << i <<
" must be in range [0, " <<
ndim() <<
")";
245 #pragma GCC diagnostic pop 253 std::vector<ValueType> tmp(
begin(),
end());
261 std::vector<ValueType> tmp;
263 this->
assign(tmp.begin(), tmp.end());
271 friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
272 if (t.ndim() == -1) {
280 const ValueType*
begin = t.begin();
281 const ValueType*
end = t.end();
282 for (
const ValueType* it = begin; it !=
end; ++it) {
283 if (it != begin) os <<
',';
299 if (
isdigit(ch) || ch ==
'-') {
307 if (ch ==
'(' || ch ==
'[')
break;
312 if (tmp_val ==
"one") {
317 is.setstate(std::ios::failbit);
326 if (is.peek() ==
')' || is.peek() ==
']') {
333 std::vector<ValueType> tmp;
340 if (std::is_integral<ValueType>::value && ch ==
'L') {
349 if (ch ==
')' || ch ==
']') {
354 if (ch ==
')' || ch ==
']')
break;
355 }
else if (ch ==
')' || ch ==
']') {
358 is.setstate(std::ios::failbit);
362 t.
assign(tmp.begin(), tmp.end());
371 template<
typename DType = ValueType,
typename TStream>
372 inline void Save(TStream *strm)
const;
380 template<
typename DType = ValueType,
typename TStream>
381 inline bool Load(TStream *strm);
396 CHECK_GE(ndim, -1) <<
"ndim cannot be less than -1, received " <<
ndim;
397 if (ndim > kStackCache &&
402 }
else if (ndim <= 0 &&
data_heap_ !=
nullptr) {
414 CHECK_GE(ndim, -1) <<
"shape ndim must be >= -1, while received " <<
ndim;
420 CHECK_GE(dim_size, -1) <<
"shape dim size must be >= -1, while received " << dim_size;
421 return dim_size != -1;
451 std::fill_n(
begin(), ndim, value);
459 if (s.
ndim() == -1) {
469 inline TShape(std::initializer_list<dim_t> init) {
470 this->
assign(init.begin(), init.end());
487 template<
typename RandomAccessIterator,
488 typename std::enable_if<
489 std::is_same<typename std::iterator_traits<RandomAccessIterator>::iterator_category,
490 std::random_access_iterator_tag>::value,
int>::type = 0>
492 RandomAccessIterator
end) {
503 if (src.
ndim() == -1) {
524 for (
const dim_t* it = start; it != fin; ++it) {
525 CHECK(
dim_size_is_known(*it)) <<
"Shape dim size cannot be a negative value " << *it;
535 inline size_t ProdShape(
int dimstart,
int dimend)
const {
537 CHECK_GE(dimstart, 0) <<
"dimstart must be >= 0, while received " << dimstart;
538 CHECK_LE(dimend, this->
ndim()) <<
"dimend must be <= " << this->
ndim()
539 <<
", while received " << dimend;
541 const dim_t *d = this->data();
542 for (
int i = dimstart; i < dimend; ++i) {
543 CHECK(
dim_size_is_known(d[i])) <<
"Shape dim size must be known, while received " << d[i];
556 #ifdef MSHADOW_XINLINE 584 CHECK_EQ(dim,
ndim())
585 <<
"dimension do not match target dimension " << dim <<
" vs " <<
ndim();
586 const dim_t *d = this->data();
588 for (
int i = 0; i < dim; ++i) {
601 const dim_t *d = this->data();
604 for (
int i = 1; i <
ndim(); ++i) {
617 CHECK(axis_end >= axis_begin);
621 const dim_t *d = this->data();
626 for (
int i = 0; i < axis_begin; ++i) {
629 for (
int i = axis_begin; i <= axis_end; ++i) {
632 for (
int i = axis_end + 1; i <
ndim(); ++i) {
643 return FlatTo3D(axis, axis);
646 if (
ndim() != s.
ndim())
return false;
650 return !(*
this == s);
659 if (
ndim_ != dim)
return false;
661 for (
size_t i = 0; i < dim; ++i) {
662 if (d[i] != s.
shape_[i])
return false;
673 return !(*
this == s);
685 CHECK(idx >= 0 && idx < x.
ndim())
686 <<
"idx = " << idx <<
" exceeds shape dimension range [0, " << x.
ndim() <<
")";
694 for (
int i = 0; i < x.
ndim(); ++i) {
701 for (
const TShape& shape : shapes) {
708 template<
typename SrcIter,
typename DstIter>
712 typedef typename std::iterator_traits<SrcIter>::value_type SrcDType;
713 typedef typename std::iterator_traits<DstIter>::value_type DstDType;
714 auto cast = [](
const SrcDType& dim) {
return static_cast<DstDType
>(dim); };
715 return std::transform(begin, end, dst_begin, cast);
719 template<
typename SrcIter>
721 size_t ndim = std::distance(begin, end);
728 template<
typename ValueType>
729 template<
typename DType,
typename TStream>
732 if (
typeid(DType) ==
typeid(ValueType)) {
733 strm->Write(
begin(),
sizeof(ValueType) *
ndim_);
735 std::vector<DType> buffer(
ndim_);
737 strm->Write(buffer.data(),
sizeof(DType) *
ndim_);
742 template<
typename ValueType>
743 template<
typename DType,
typename TStream>
747 size_t nread =
sizeof(DType) *
ndim_;
748 if (
typeid(DType) ==
typeid(ValueType)) {
749 if (strm->Read(
begin(), nread) != nread)
return false;
751 std::vector<DType> buffer(
ndim_);
752 if (strm->Read(buffer.data(), nread) != nread)
return false;
766 std::hash<int> hash_int;
767 size_t res = hash_int(val.
ndim());
768 for (
int i = 0; i < val.
ndim(); ++i) {
780 std::hash<int> hash_int;
781 size_t res = hash_int(val.
ndim());
782 for (
int i = 0; i < val.
ndim(); ++i) {
795 #if !(defined(_MSC_VER) && _MSC_VER < 1900) 799 return "tuple of <" + type_name<T>() +
">";
835 #endif // MXNET_TUPLE_H_ size_t operator()(const mxnet::TShape &val) const
hash a TShape into unsigned int
Definition: tuple.h:779
#define DMLC_DECLARE_TYPE_NAME(Type, Name)
macro to quickly declare traits information
Definition: type_traits.h:133
bool operator!=(const Tuple< ValueType > &s) const
Definition: tuple.h:197
TShape(const ObjectRef &src)
Definition: tuple.h:496
helper class to construct a string that represents type name
Definition: type_traits.h:86
Tuple< ValueType > & operator=(const Tuple< ValueType > &src)
assignment from another tuple.
Definition: tuple.h:158
size_t ProdShape(int dimstart, int dimend) const
Definition: tuple.h:535
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:51
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:327
uint32_t ndim_
number of dimension of the tuple
Definition: tuple.h:321
Tuple< ValueType > & operator=(Tuple< ValueType > &&src)
assignment from rvalue of another tuple.
Definition: tuple.h:171
dim_t * data()
Definition: tuple.h:553
TShape(const Tuple< dim_t > &s)
copy constructor of TShape
Definition: tuple.h:458
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:391
namespace of mxnet
Definition: api_registry.h:33
TShape & operator=(Tuple< dim_t > &&src)
move assignment function from tshape
Definition: tuple.h:515
Tuple(std::initializer_list< ValueType > init)
constructor from initializer list
Definition: tuple.h:91
const ValueType * end() const
Definition: tuple.h:209
ValueType data_stack_[kStackCache]
in stack space used to store shape when it is small
Definition: tuple.h:325
TShape()
default constructor
Definition: tuple.h:440
int64_t dim_t
data type to store dim size
Definition: tuple.h:38
Definition: optional.h:251
DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin)
helper function to cast type of container elements
Definition: tuple.h:709
Tuple(Tuple< ValueType > &&src)
move constructor from Tuple
Definition: tuple.h:106
void Load(dmlc::JSONReader *reader)
Load Tuple from JSON.
Definition: tuple.h:260
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:137
Base class of all object reference.
Definition: object.h:499
const ValueType * end() const
Definition: tuple.h:172
ValueType * end()
Definition: tuple.h:213
TShape(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator. This function is enforced with template arguments of ra...
Definition: tuple.h:491
Data structures that can appear in graph attributes.
ValueType & operator[](int i)
get corresponding index
Definition: tuple.h:225
void SetDim(uint32_t ndim)
Definition: tuple.h:329
Tuple(const runtime::ObjectRef &src)
Definition: tuple.h:121
bool operator==(const Tuple< ValueType > &s) const
Definition: tuple.h:188
size_t HashCombine(size_t key, const T &value)
hash an object and combines the key with previous keys
Definition: common.h:37
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:819
Tuple(const Tuple< ValueType > &s)
copy constructor from another tuple
Definition: tuple.h:80
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
~Tuple()
destructor
Definition: tuple.h:62
Tuple(std::vector< ValueType > init)
constructor from vector
Definition: tuple.h:98
namespace for dmlc
Definition: array_view.h:12
Tuple(RandomAccessIterator begin, RandomAccessIterator end)
construct the Tuple from content of iterator
Definition: tuple.h:116
void Write(const ValueType &value)
Write value to json.
bool dim_size_is_known(const dim_t dim_size)
Definition: tuple.h:419
Tuple()=default
default constructor
int num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:389
static std::string value()
Definition: tuple.h:798
static const int kStackCache
Definition: tuple.h:385
void SetDim(int ndim)
Definition: tuple.h:395
const dim_t * data() const
Definition: tuple.h:549
std::function< bool(const NodeAttrs &attrs, std::vector< AttrType > *in_attrs, std::vector< AttrType > *out_attrs)> FInferNodeEntryAttr
Inference function of certain type.
Definition: op_attr_types.h:94
void swap(Tuple< ValueType > &other)
Swap current object with other.
Definition: tuple.h:147
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:75
Pass that can be applied to a graph.
TShape(Tuple< dim_t > &&s)
move constructor.
Definition: tuple.h:476
void Read(ValueType *out_value)
Read next ValueType.
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:216
size_t operator()(const mxnet::Tuple< T > &val) const
hash a Tuple into unsigned int
Definition: tuple.h:765
A dynamic sized array data structure that is optimized for storing small number of elements with same...
Definition: tuple.h:57
uint32_t ndim() const
Definition: tuple.h:180
TShape(std::initializer_list< dim_t > init)
constructor from initializer list
Definition: tuple.h:469
Configuation of nnvm as well as basic data structure.
A managed object in MXNet runtime.
void assign(RandomAccessIterator begin, RandomAccessIterator end)
Assign content to tuple from iterator.
Definition: tuple.h:106
int ndim_
number of dimension of the tuple
Definition: tuple.h:387
size_t Size() const
Definition: tuple.h:520
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:437
nnvm::FInferNodeEntryAttr< mxnet::TShape > FInferShape
Shape inference function. Update the shapes given the input shape information. TShape.ndim() == -1 means the shape is still unknown.
Definition: tuple.h:831
TShape & operator=(const Tuple< dim_t > &src)
assignment function from tshape
Definition: tuple.h:502
TShape(const int ndim, const dim_t value)
Definition: tuple.h:448
Base expr nodes in MXNet.
Tuple< ValueType > & operator=(std::initializer_list< ValueType > init)
assignment from initializer list
Definition: tuple.h:180
bool shape_is_known(const TShape &x)
Definition: tuple.h:692
bool ndim_is_known(const int ndim)
Definition: tuple.h:413
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:227
const ValueType & operator[](int i) const
get corresponding index
Definition: tuple.h:239
ValueType * data_heap_
space to store shape when dimension is big
Definition: tuple.h:393
ValueType * begin()
Definition: tuple.h:205
friend std::istream & operator>>(std::istream &is, Tuple< ValueType > &t)
read tuple from the istream
Definition: tuple.h:295
int64_t dim_t
data type to store dim size
Definition: c_api.h:61
const ValueType * begin() const
Definition: tuple.h:201
const ValueType * begin() const
Definition: tuple.h:164
uint32_t num_heap_allocated_
number of cells allocated in data_heap_
Definition: tuple.h:323
void Save(dmlc::JSONWriter *writer) const
Save Tuple to JSON.
Definition: tuple.h:252
int ndim() const
Definition: tuple.h:217
Tuple(const int ndim, const dim_t value)
Definition: tuple.h:70
bool isdigit(char c)
Inline implementation of isdigit(). Tests whether the given character is a decimal digit...
Definition: strtonum.h:46
Data structures that can appear in operator attributes.
Lightweight json to write any STL compositions.
Definition: json.h:190