27 #ifndef MXNET_CPP_SHAPE_H_ 28 #define MXNET_CPP_SHAPE_H_ 48 num_heap_allocated_(0),
54 explicit Shape(
const std::vector<index_t> &v)
56 if (ndim_ <= kStackCache) {
58 num_heap_allocated_ = 0;
59 std::copy(v.begin(), v.end(), data_stack_);
61 data_heap_ =
new index_t[ndim_];
62 num_heap_allocated_ = ndim_;
63 std::copy(v.begin(), v.end(), data_heap_);
72 if (ndim_ <= kStackCache) {
74 num_heap_allocated_ = 0;
77 data_heap_ =
new index_t[ndim_];
78 num_heap_allocated_ = ndim_;
89 if (ndim_ <= kStackCache) {
91 num_heap_allocated_ = 0;
95 data_heap_ =
new index_t[ndim_];
96 num_heap_allocated_ = ndim_;
109 if (ndim_ <= kStackCache) {
111 num_heap_allocated_ = 0;
116 data_heap_ =
new index_t[ndim_];
117 num_heap_allocated_ = ndim_;
132 if (ndim_ <= kStackCache) {
134 num_heap_allocated_ = 0;
140 data_heap_ =
new index_t[ndim_];
141 num_heap_allocated_ = ndim_;
158 if (ndim_ <= kStackCache) {
160 num_heap_allocated_ = 0;
167 data_heap_ =
new index_t[ndim_];
168 num_heap_allocated_ = ndim_;
182 if (ndim_ <= kStackCache) {
184 num_heap_allocated_ = 0;
185 std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
187 data_heap_ =
new index_t[ndim_];
188 num_heap_allocated_ = ndim_;
189 std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
199 num_heap_allocated_(s.num_heap_allocated_),
200 data_heap_(s.data_heap_) {
201 if (ndim_ <= kStackCache) {
202 std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
219 template<
typename RandomAccessIterator>
221 RandomAccessIterator end) {
222 this->SetDim(end - begin);
223 std::copy(begin, end,
data());
231 this->SetDim(shape.ndim_);
233 std::copy(src, src + ndim_,
data());
242 this->
CopyFrom(shape.begin(), shape.end());
247 return ndim_ <= kStackCache ? data_stack_ : data_heap_;
251 return ndim_ <= kStackCache ? data_stack_ : data_heap_;
274 inline size_t Size(
void)
const {
277 for (
index_t i = 0; i < ndim_; ++i) {
287 if (ndim_ != s.ndim_)
return false;
288 if (ndim_ <= kStackCache) {
289 for (
index_t i = 0; i < ndim_; ++i) {
290 if (data_stack_[i] != s.data_stack_[i])
return false;
293 for (
index_t i = 0; i < ndim_; ++i) {
294 if (data_heap_[i] != s.data_heap_[i])
return false;
304 return !(*
this == s);
315 static const index_t kStackCache = 5;
321 index_t data_stack_[kStackCache];
328 inline void SetDim(
index_t dim) {
329 if (dim > kStackCache &&
330 dim > num_heap_allocated_) {
334 num_heap_allocated_ = dim;
349 if (i != 0) os <<
',';
350 os << static_cast<int>(shape[i]);
353 if (shape.
ndim() == 1) os <<
',';
368 if (ch ==
'(')
break;
370 is.setstate(std::ios::failbit);
375 std::vector<index_t> tmp;
381 }
while (isspace(ch));
393 if (ch ==
')')
break;
394 }
else if (ch ==
')') {
397 is.setstate(std::ios::failbit);
401 shape.
CopyFrom(tmp.begin(), tmp.end());
408 #endif // MXNET_CPP_SHAPE_H_ ~Shape()
destructor
Definition: shape.h:209
index_t & operator[](index_t i)
get corresponding index
Definition: shape.h:262
const index_t & operator[](index_t i) const
get corresponding index
Definition: shape.h:270
friend std::ostream & operator<<(std::ostream &os, const Shape &shape)
allow string printing of the shape
Definition: shape.h:346
Shape(index_t s1, index_t s2, index_t s3, index_t s4)
constructor four dimmension shape
Definition: shape.h:130
size_t Size(void) const
total number of elements in the tensor
Definition: shape.h:274
namespace of mxnet
Definition: base.h:89
const index_t * data() const
Definition: shape.h:246
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
index_t ndim(void) const
return number of dimension of the tensor inside
Definition: shape.h:254
Shape(index_t s1)
constructor one dimmension shape
Definition: shape.h:70
Shape(index_t s1, index_t s2, index_t s3)
constructor three dimmension shape
Definition: shape.h:107
Shape & operator=(const std::vector< index_t > &shape)
assignment from vector
Definition: shape.h:241
Shape(const Shape &s)
constructor from Shape
Definition: shape.h:180
Shape(const std::vector< index_t > &v)
constructor from a vector of index_t
Definition: shape.h:54
Shape(index_t s1, index_t s2)
constructor two dimmension shape
Definition: shape.h:87
Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
constructor five dimmension shape
Definition: shape.h:156
Shape()
constructor
Definition: shape.h:46
bool operator==(const Shape &s) const
Definition: shape.h:286
friend std::istream & operator>>(std::istream &is, Shape &shape)
read shape from the istream
Definition: shape.h:364
bool operator!=(const Shape &s) const
Definition: shape.h:303
void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end)
copy shape from content betwen two iterators
Definition: shape.h:220
index_t * data()
Definition: shape.h:250
unsigned index_t
Definition: base.h:37
Shape & operator=(const Shape &shape)
assignment from shape
Definition: shape.h:230