26 #ifndef MXNET_CPP_SHAPE_H_ 27 #define MXNET_CPP_SHAPE_H_ 47 num_heap_allocated_(0),
53 explicit Shape(
const std::vector<index_t> &v)
55 if (ndim_ <= kStackCache) {
57 num_heap_allocated_ = 0;
58 std::copy(v.begin(), v.end(), data_stack_);
60 data_heap_ =
new index_t[ndim_];
61 num_heap_allocated_ = ndim_;
62 std::copy(v.begin(), v.end(), data_heap_);
71 if (ndim_ <= kStackCache) {
73 num_heap_allocated_ = 0;
76 data_heap_ =
new index_t[ndim_];
77 num_heap_allocated_ = ndim_;
88 if (ndim_ <= kStackCache) {
90 num_heap_allocated_ = 0;
94 data_heap_ =
new index_t[ndim_];
95 num_heap_allocated_ = ndim_;
108 if (ndim_ <= kStackCache) {
110 num_heap_allocated_ = 0;
115 data_heap_ =
new index_t[ndim_];
116 num_heap_allocated_ = ndim_;
131 if (ndim_ <= kStackCache) {
133 num_heap_allocated_ = 0;
139 data_heap_ =
new index_t[ndim_];
140 num_heap_allocated_ = ndim_;
157 if (ndim_ <= kStackCache) {
159 num_heap_allocated_ = 0;
166 data_heap_ =
new index_t[ndim_];
167 num_heap_allocated_ = ndim_;
181 if (ndim_ <= kStackCache) {
183 num_heap_allocated_ = 0;
184 std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
186 data_heap_ =
new index_t[ndim_];
187 num_heap_allocated_ = ndim_;
188 std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
198 num_heap_allocated_(s.num_heap_allocated_),
199 data_heap_(s.data_heap_) {
200 if (ndim_ <= kStackCache) {
201 std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
218 template<
typename RandomAccessIterator>
220 RandomAccessIterator end) {
221 this->SetDim(end - begin);
222 std::copy(begin, end,
data());
230 this->SetDim(shape.ndim_);
232 std::copy(src, src + ndim_,
data());
241 this->
CopyFrom(shape.begin(), shape.end());
246 return ndim_ <= kStackCache ? data_stack_ : data_heap_;
250 return ndim_ <= kStackCache ? data_stack_ : data_heap_;
273 inline size_t Size(
void)
const {
276 for (
index_t i = 0; i < ndim_; ++i) {
286 if (ndim_ != s.ndim_)
return false;
287 if (ndim_ <= kStackCache) {
288 for (
index_t i = 0; i < ndim_; ++i) {
289 if (data_stack_[i] != s.data_stack_[i])
return false;
292 for (
index_t i = 0; i < ndim_; ++i) {
293 if (data_heap_[i] != s.data_heap_[i])
return false;
303 return !(*
this == s);
314 static const index_t kStackCache = 5;
320 index_t data_stack_[kStackCache];
327 inline void SetDim(
index_t dim) {
328 if (dim > kStackCache &&
329 dim > num_heap_allocated_) {
333 num_heap_allocated_ = dim;
348 if (i != 0) os <<
',';
349 os << static_cast<int>(shape[i]);
352 if (shape.
ndim() == 1) os <<
',';
367 if (ch ==
'(')
break;
369 is.setstate(std::ios::failbit);
374 std::vector<index_t> tmp;
380 }
while (isspace(ch));
392 if (ch ==
')')
break;
393 }
else if (ch ==
')') {
396 is.setstate(std::ios::failbit);
400 shape.
CopyFrom(tmp.begin(), tmp.end());
407 #endif // MXNET_CPP_SHAPE_H_ ~Shape()
destructor
Definition: shape.h:208
index_t & operator[](index_t i)
get corresponding index
Definition: shape.h:261
const index_t & operator[](index_t i) const
get corresponding index
Definition: shape.h:269
friend std::ostream & operator<<(std::ostream &os, const Shape &shape)
allow string printing of the shape
Definition: shape.h:345
Shape(index_t s1, index_t s2, index_t s3, index_t s4)
constructor four dimmension shape
Definition: shape.h:129
size_t Size(void) const
total number of elements in the tensor
Definition: shape.h:273
namespace of mxnet
Definition: base.h:126
const index_t * data() const
Definition: shape.h:245
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
index_t ndim(void) const
return number of dimension of the tensor inside
Definition: shape.h:253
Shape(index_t s1)
constructor one dimmension shape
Definition: shape.h:69
Shape(index_t s1, index_t s2, index_t s3)
constructor three dimmension shape
Definition: shape.h:106
Shape & operator=(const std::vector< index_t > &shape)
assignment from vector
Definition: shape.h:240
Shape(const Shape &s)
constructor from Shape
Definition: shape.h:179
Shape(const std::vector< index_t > &v)
constructor from a vector of index_t
Definition: shape.h:53
Shape(index_t s1, index_t s2)
constructor two dimmension shape
Definition: shape.h:86
Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
constructor five dimmension shape
Definition: shape.h:155
Shape()
constructor
Definition: shape.h:45
bool operator==(const Shape &s) const
Definition: shape.h:285
friend std::istream & operator>>(std::istream &is, Shape &shape)
read shape from the istream
Definition: shape.h:363
bool operator!=(const Shape &s) const
Definition: shape.h:302
void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end)
copy shape from content betwen two iterators
Definition: shape.h:219
index_t * data()
Definition: shape.h:249
unsigned index_t
Definition: base.h:36
Shape & operator=(const Shape &shape)
assignment from shape
Definition: shape.h:229