mxnet
shape.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_SHAPE_H_
27 #define MXNET_CPP_SHAPE_H_
28 
29 #include <istream>
30 #include <ostream>
31 #include <algorithm>
32 #include <vector>
33 #include "mxnet-cpp/base.h"
34 
35 namespace mxnet {
36 namespace cpp {
37 
42 struct Shape {
43  public:
46  : ndim_(0),
47  num_heap_allocated_(0),
48  data_heap_(NULL) {}
53  explicit Shape(const std::vector<index_t> &v)
54  : ndim_(v.size()) {
55  if (ndim_ <= kStackCache) {
56  data_heap_ = NULL;
57  num_heap_allocated_ = 0;
58  std::copy(v.begin(), v.end(), data_stack_);
59  } else {
60  data_heap_ = new index_t[ndim_];
61  num_heap_allocated_ = ndim_;
62  std::copy(v.begin(), v.end(), data_heap_);
63  }
64  }
69  explicit Shape(index_t s1)
70  : ndim_(1) {
71  if (ndim_ <= kStackCache) {
72  data_heap_ = NULL;
73  num_heap_allocated_ = 0;
74  data_stack_[0] = s1;
75  } else {
76  data_heap_ = new index_t[ndim_];
77  num_heap_allocated_ = ndim_;
78  data_heap_[0] = s1;
79  }
80  }
87  : ndim_(2) {
88  if (ndim_ <= kStackCache) {
89  data_heap_ = NULL;
90  num_heap_allocated_ = 0;
91  data_stack_[0] = s1;
92  data_stack_[1] = s2;
93  } else {
94  data_heap_ = new index_t[ndim_];
95  num_heap_allocated_ = ndim_;
96  data_heap_[0] = s1;
97  data_heap_[1] = s2;
98  }
99  }
107  : ndim_(3) {
108  if (ndim_ <= kStackCache) {
109  data_heap_ = NULL;
110  num_heap_allocated_ = 0;
111  data_stack_[0] = s1;
112  data_stack_[1] = s2;
113  data_stack_[2] = s3;
114  } else {
115  data_heap_ = new index_t[ndim_];
116  num_heap_allocated_ = ndim_;
117  data_heap_[0] = s1;
118  data_heap_[1] = s2;
119  data_heap_[2] = s3;
120  }
121  }
130  : ndim_(4) {
131  if (ndim_ <= kStackCache) {
132  data_heap_ = NULL;
133  num_heap_allocated_ = 0;
134  data_stack_[0] = s1;
135  data_stack_[1] = s2;
136  data_stack_[2] = s3;
137  data_stack_[3] = s4;
138  } else {
139  data_heap_ = new index_t[ndim_];
140  num_heap_allocated_ = ndim_;
141  data_heap_[0] = s1;
142  data_heap_[1] = s2;
143  data_heap_[2] = s3;
144  data_heap_[3] = s4;
145  }
146  }
156  : ndim_(5) {
157  if (ndim_ <= kStackCache) {
158  data_heap_ = NULL;
159  num_heap_allocated_ = 0;
160  data_stack_[0] = s1;
161  data_stack_[1] = s2;
162  data_stack_[2] = s3;
163  data_stack_[3] = s4;
164  data_stack_[4] = s5;
165  } else {
166  data_heap_ = new index_t[ndim_];
167  num_heap_allocated_ = ndim_;
168  data_heap_[0] = s1;
169  data_heap_[1] = s2;
170  data_heap_[2] = s3;
171  data_heap_[3] = s4;
172  data_heap_[5] = s5;
173  }
174  }
179  Shape(const Shape &s)
180  : ndim_(s.ndim_) {
181  if (ndim_ <= kStackCache) {
182  data_heap_ = NULL;
183  num_heap_allocated_ = 0;
184  std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
185  } else {
186  data_heap_ = new index_t[ndim_];
187  num_heap_allocated_ = ndim_;
188  std::copy(s.data_heap_, s.data_heap_ + ndim_, data_heap_);
189  }
190  }
191 #if MSHADOW_IN_CXX11
192 
196  Shape(Shape &&s)
197  : ndim_(s.ndim_),
198  num_heap_allocated_(s.num_heap_allocated_),
199  data_heap_(s.data_heap_) {
200  if (ndim_ <= kStackCache) {
201  std::copy(s.data_stack_, s.data_stack_ + ndim_, data_stack_);
202  }
203  // remove data heap space from s
204  s.data_heap_ = NULL;
205  }
206 #endif
207 
208  ~Shape() {
209  // data_heap_ can be NULL
210  delete[] data_heap_;
211  }
218  template<typename RandomAccessIterator>
219  inline void CopyFrom(RandomAccessIterator begin,
220  RandomAccessIterator end) {
221  this->SetDim(end - begin);
222  std::copy(begin, end, data());
223  }
229  inline Shape &operator=(const Shape &shape) {
230  this->SetDim(shape.ndim_);
231  const index_t *src = shape.data();
232  std::copy(src, src + ndim_, data());
233  return *this;
234  }
240  inline Shape &operator=(const std::vector<index_t> &shape) {
241  this->CopyFrom(shape.begin(), shape.end());
242  return *this;
243  }
245  inline const index_t *data() const {
246  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
247  }
249  inline index_t *data() {
250  return ndim_ <= kStackCache ? data_stack_ : data_heap_;
251  }
253  inline index_t ndim(void) const {
254  return ndim_;
255  }
262  return data()[i];
263  }
269  inline const index_t &operator[](index_t i) const {
270  return data()[i];
271  }
273  inline size_t Size(void) const {
274  size_t size = 1;
275  const index_t *d = this->data();
276  for (index_t i = 0; i < ndim_; ++i) {
277  size *= d[i];
278  }
279  return size;
280  }
285  inline bool operator==(const Shape &s) const {
286  if (ndim_ != s.ndim_) return false;
287  if (ndim_ <= kStackCache) {
288  for (index_t i = 0; i < ndim_; ++i) {
289  if (data_stack_[i] != s.data_stack_[i]) return false;
290  }
291  } else {
292  for (index_t i = 0; i < ndim_; ++i) {
293  if (data_heap_[i] != s.data_heap_[i]) return false;
294  }
295  }
296  return true;
297  }
302  inline bool operator!=(const Shape &s) const {
303  return !(*this == s);
304  }
305 
306  friend std::ostream &operator<<(std::ostream &os, const Shape &shape);
307  friend std::istream &operator>>(std::istream &is, Shape &shape);
308 
309  private:
310  // the shape will be stored in data_stack_
311  // when dimension is smaller than kStackCache
312  // when it is bigger, it will be stored in data_heap_;
314  static const index_t kStackCache = 5;
316  index_t ndim_;
318  index_t num_heap_allocated_;
320  index_t data_stack_[kStackCache];
322  index_t *data_heap_;
327  inline void SetDim(index_t dim) {
328  if (dim > kStackCache &&
329  dim > num_heap_allocated_) {
330  // data_heap_ can be NULL
331  delete[] data_heap_;
332  data_heap_ = new index_t[dim];
333  num_heap_allocated_ = dim;
334  }
335  ndim_ = dim;
336  }
337 };
338 
345 inline std::ostream &operator<<(std::ostream &os, const Shape &shape) {
346  os << '(';
347  for (index_t i = 0; i < shape.ndim(); ++i) {
348  if (i != 0) os << ',';
349  os << static_cast<int>(shape[i]); // Supports negative Shape 'special codes' for inferring
350  }
351  // python style tuple
352  if (shape.ndim() == 1) os << ',';
353  os << ')';
354  return os;
355 }
356 
363 inline std::istream &operator>>(std::istream &is, Shape &shape) {
364  // get (
365  while (true) {
366  char ch = is.get();
367  if (ch == '(') break;
368  if (!isspace(ch)) {
369  is.setstate(std::ios::failbit);
370  return is;
371  }
372  }
373  index_t idx;
374  std::vector<index_t> tmp;
375  while (is >> idx) {
376  tmp.push_back(idx);
377  char ch;
378  do {
379  ch = is.get();
380  } while (isspace(ch));
381  if (ch == ',') {
382  while (true) {
383  ch = is.peek();
384  if (isspace(ch)) {
385  is.get(); continue;
386  }
387  if (ch == ')') {
388  is.get(); break;
389  }
390  break;
391  }
392  if (ch == ')') break;
393  } else if (ch == ')') {
394  break;
395  } else {
396  is.setstate(std::ios::failbit);
397  return is;
398  }
399  }
400  shape.CopyFrom(tmp.begin(), tmp.end());
401  return is;
402 }
403 
404 } // namespace cpp
405 } // namespace mxnet
406 
407 #endif // MXNET_CPP_SHAPE_H_
~Shape()
destructor
Definition: shape.h:208
index_t & operator[](index_t i)
get corresponding index
Definition: shape.h:261
const index_t & operator[](index_t i) const
get corresponding index
Definition: shape.h:269
friend std::ostream & operator<<(std::ostream &os, const Shape &shape)
allow string printing of the shape
Definition: shape.h:345
Shape(index_t s1, index_t s2, index_t s3, index_t s4)
constructor four dimmension shape
Definition: shape.h:129
size_t Size(void) const
total number of elements in the tensor
Definition: shape.h:273
namespace of mxnet
Definition: base.h:126
const index_t * data() const
Definition: shape.h:245
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
index_t ndim(void) const
return number of dimension of the tensor inside
Definition: shape.h:253
Shape(index_t s1)
constructor one dimmension shape
Definition: shape.h:69
Shape(index_t s1, index_t s2, index_t s3)
constructor three dimmension shape
Definition: shape.h:106
Shape & operator=(const std::vector< index_t > &shape)
assignment from vector
Definition: shape.h:240
Shape(const Shape &s)
constructor from Shape
Definition: shape.h:179
Shape(const std::vector< index_t > &v)
constructor from a vector of index_t
Definition: shape.h:53
Shape(index_t s1, index_t s2)
constructor two dimmension shape
Definition: shape.h:86
Shape(index_t s1, index_t s2, index_t s3, index_t s4, index_t s5)
constructor five dimmension shape
Definition: shape.h:155
Shape()
constructor
Definition: shape.h:45
bool operator==(const Shape &s) const
Definition: shape.h:285
friend std::istream & operator>>(std::istream &is, Shape &shape)
read shape from the istream
Definition: shape.h:363
bool operator!=(const Shape &s) const
Definition: shape.h:302
void CopyFrom(RandomAccessIterator begin, RandomAccessIterator end)
copy shape from content betwen two iterators
Definition: shape.h:219
index_t * data()
Definition: shape.h:249
unsigned index_t
Definition: base.h:36
Shape & operator=(const Shape &shape)
assignment from shape
Definition: shape.h:229