mxnet
shape.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_SHAPE_H_
28 #define MXNET_CPP_SHAPE_H_
29 
30 #include <istream>
31 #include <ostream>
32 #include <algorithm>
33 #include <vector>
34 #include "mxnet-cpp/base.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
43 struct Shape {
44  public:
47  : ndim_(0),
48  num_heap_allocated_(0),
49  data_heap_(NULL) {}
54  explicit Shape(const std::vector<index_t> &v)
55  : ndim_(v.size()) {
56  if (ndim_ <= kStackCache) {
57  data_heap_ = NULL;
58  num_heap_allocated_ = 0;
59  std::copy(v.begin(), v.end(), data_stack_);
60  } else {
61  data_heap_ = new index_t[ndim_];
62  num_heap_allocated_ = ndim_;
63  std::copy(v.begin(), v.end(), data_heap_);
64  }
65  }
70  explicit Shape(index_t s1)
71  : ndim_(1) {
72  if (ndim_ <= kStackCache) {
73  data_heap_ = NULL;
74  num_heap_allocated_ = 0;
75  data_stack_[0] = s1;
76  } else {
77  data_heap_ = new index_t[ndim_];
78  num_heap_allocated_ = ndim_;
79  data_heap_[0] = s1;
80  }
81  }
88  : ndim_(2) {
89  if (ndim_ <= kStackCache) {
90  data_heap_ = NULL;
91  num_heap_allocated_ = 0;
92  data_stack_[0] = s1;
93  data_stack_[1] = s2;
94  } else {
95  data_heap_ = new index_t[ndim_];
96  num_heap_allocated_ = ndim_;
97  data_heap_[0] = s1;
98  data_heap_[1] = s2;
99  }
100  }
108  : ndim_(3) {
109  if (ndim_ <= kStackCache) {
110  data_heap_ = NULL;
111  num_heap_allocated_ = 0;
112  data_stack_[0] = s1;
113  data_stack_[1] = s2;
114  data_stack_[2] = s3;
115  } else {
116  data_heap_ = new index_t[ndim_];
117  num_heap_allocated_ = ndim_;
118  data_heap_[0] = s1;
119  data_heap_[1] = s2;
120  data_heap_[2] = s3;
121  }
122  }
131  : ndim_(4) {
132  if (ndim_ <= kStackCache) {
133  data_heap_ = NULL;
134  num_heap_allocated_ = 0;
135  data_stack_[0] = s1;
136  data_stack_[1] = s2;
137  data_stack_[2] = s3;
138  data_stack_[3] = s4;
139  } else {
140  data_heap_ = new index_t[ndim_];
141  num_heap_allocated_ = ndim_;
142  data_heap_[0] = s1;
143  data_heap_[1] = s2;
144  data_heap_[2] = s3;
145  data_heap_[3] = s4;
146  }
147  }
157  : ndim_(5) {
158  if (ndim_ <= kStackCache) {
159  data_heap_ = NULL;
160  num_heap_allocated_ = 0;
161  data_stack_[0] = s1;
162  data_stack_[1] = s2;
163  data_stack_[2] = s3;
164  data_stack_[3] = s4;
165  data_stack_[4] = s5;
166  } else {
167  data_heap_ = new index_t[ndim_];
168  num_heap_allocated_ = ndim_;
169  data_heap_[0] = s1;
170  data_heap_[1] = s2;
171  data_heap_[2] = s3;
172  data_heap_[3] = s4;
173  data_heap_[5] = s5;
174  }
175  }
180  Shape(const Shape &s)
181  : ndim_(s.ndim_) {
182  if (ndim_ <= kStackCache) {
183  data_heap_ = NULL;
184  num_heap_allocated_ = 0;
185  std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
186  } else {
187  data_heap_ = new index_t[ndim_];
188  num_heap_allocated_ = ndim_;
189  std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
190  }
191  }
192 #if MSHADOW_IN_CXX11
193 
197  Shape(Shape &&s)
198  : ndim_(s.ndim_),
199  num_heap_allocated_(s.num_heap_allocated_),
200  data_heap_(s.data_heap_) {
201  if (ndim_ <= kStackCache) {
202  std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
203  }
204  // remove data heap space from s
205  s.data_heap_ = NULL;
206  }
207 #endif
208 
209  ~Shape() {
210  // data_heap_ can be NULL
211  delete[] data_heap_;
212  }
219  template<typename RandomAccessIterator>
220  inline void CopyFrom(RandomAccessIterator begin,
221  RandomAccessIterator end) {
222  this->SetDim(end - begin);
223  std::copy(begin, end, data());
224  }
230  inline Shape &operator=(const Shape &shape) {
231  this->SetDim(shape.ndim_);
232  const index_t *src = shape.data();
233  std::copy(src, src + ndim_, data());
234  return *this;
235  }
241  inline Shape &operator=(const std::vector<index_t> &shape) {
242  this->CopyFrom(shape.begin(), shape.end());
243  return *this;
244  }
246  inline const index_t *data() const {
247  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
248  }
250  inline index_t *data() {
251  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
252  }
254  inline index_t ndim(void) const {
255  return ndim_;
256  }
263  return data()[i];
264  }
270  inline const index_t &operator[](index_t i) const {
271  return data()[i];
272  }
274  inline size_t Size(void) const {
275  size_t size = 1;
276  const index_t *d = this->data();
277  for (index_t i = 0; i < ndim_; ++i) {
278  size *= d[i];
279  }
280  return size;
281  }
286  inline bool operator==(const Shape &s) const {
287  if (ndim_ != s.ndim_) return false;
288  if (ndim_ <= kStackCache) {
289  for (index_t i = 0; i < ndim_; ++i) {
290  if (data_stack_[i] != s.data_stack_[i]) return false;
291  }
292  } else {
293  for (index_t i = 0; i < ndim_; ++i) {
294  if (data_heap_[i] != s.data_heap_[i]) return false;
295  }
296  }
297  return true;
298  }
303  inline bool operator!=(const Shape &s) const {
304  return !(*this == s);
305  }
306 
307  friend std::ostream &operator<<(std::ostream &os, const Shape &shape);
308  friend std::istream &operator>>(std::istream &is, Shape &shape);
309 
310  private:
311  // the shape will be stored in data_stack_
312  // when dimension is smaller than kStackCache
313  // when it is bigger, it will be stored in data_heap_;
315  static const index_t kStackCache = 5;
317  index_t ndim_;
319  index_t num_heap_allocated_;
321  index_t data_stack_[kStackCache];
323  index_t *data_heap_;
328  inline void SetDim(index_t dim) {
329  if (dim > kStackCache &&
330  dim > num_heap_allocated_) {
331  // data_heap_ can be NULL
332  delete[] data_heap_;
333  data_heap_ = new index_t[dim];
334  num_heap_allocated_ = dim;
335  }
336  ndim_ = dim;
337  }
338 };
339 
346 inline std::ostream &operator<<(std::ostream &os, const Shape &shape) {
347  os << '(';
348  for (index_t i = 0; i < shape.ndim(); ++i) {
349  if (i != 0) os << ',';
350  os << static_cast<int>(shape[i]); // Supports negative Shape 'special codes' for inferring
351  }
352  // python style tuple
353  if (shape.ndim() == 1) os << ',';
354  os << ')';
355  return os;
356 }
357 
364 inline std::istream &operator>>(std::istream &is, Shape &shape) {
365  // get (
366  while (true) {
367  char ch = is.get();
368  if (ch == '(') break;
369  if (!isspace(ch)) {
370  is.setstate(std::ios::failbit);
371  return is;
372  }
373  }
374  index_t idx;
375  std::vector<index_t> tmp;
376  while (is >> idx) {
377  tmp.push_back(idx);
378  char ch;
379  do {
380  ch = is.get();
381  } while (isspace(ch));
382  if (ch == ',') {
383  while (true) {
384  ch = is.peek();
385  if (isspace(ch)) {
386  is.get(); continue;
387  }
388  if (ch == ')') {
389  is.get(); break;
390  }
391  break;
392  }
393  if (ch == ')') break;
394  } else if (ch == ')') {
395  break;
396  } else {
397  is.setstate(std::ios::failbit);
398  return is;
399  }
400  }
401  shape.CopyFrom(tmp.begin(), tmp.end());
402  return is;
403 }
404 
405 } // namespace cpp
406 } // namespace mxnet
407 
408 #endif // MXNET_CPP_SHAPE_H_
~Shape()
destructor
Definition: shape.h:209
index_t & operator[](index_t i)
get corresponding index
Definition: shape.h:262
const index_t & operator[](index_t i) const
get corresponding index
Definition: shape.h:270
friend std::ostream & operator<<(std::ostream &os, const Shape &shape)
allow string printing of the shape
Definition: shape.h:346
Shape(index_t s1, index_t s2, index_t s3, index_t s4)
constructor four dimmension shape
Definition: shape.h:130
size_t Size(void) const
total number of elements in the tensor
Definition: shape.h:274
namespace of mxnet
Definition: base.h:127
const index_t * data() const
Definition: shape.h:246
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
index_t ndim(void) const
return number of dimension of the tensor inside
Definition: shape.h:254
Shape(index_t s1)
constructor one dimmension shape
Definition: shape.h:70
Shape(index_t s1, index_t s2, index_t s3)
constructor three dimmension shape
Definition: shape.h:107
Shape & operator=(const std::vector< index_t > &shape)
assignment from vector
Definition: shape.h:241
Shape(const Shape &s)
constructor from Shape
Definition: shape.h:180
Shape(const std::vector< index_t > &v)
constructor from a vector of index_t
Definition: shape.h:54
Shape(index_t s1, index_t s2)
constructor two dimmension shape
Definition: shape.h:87
Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
constructor five dimmension shape
Definition: shape.h:156
Shape()
constructor
Definition: shape.h:46
bool operator==(const Shape &s) const
Definition: shape.h:286
friend std::istream & operator>>(std::istream &is, Shape &shape)
read shape from the istream
Definition: shape.h:364
bool operator!=(const Shape &s) const
Definition: shape.h:303
void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end)
copy shape from content betwen two iterators
Definition: shape.h:220
index_t * data()
Definition: shape.h:250
unsigned index_t
Definition: base.h:37
Shape & operator=(const Shape &shape)
assignment from shape
Definition: shape.h:230