mxnet
ndarray.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_NDARRAY_H_
28 #define MXNET_CPP_NDARRAY_H_
29 
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <vector>
34 #include <iostream>
35 #include "mxnet-cpp/base.h"
36 #include "mxnet-cpp/shape.h"
37 
38 namespace mxnet {
39 namespace cpp {
40 
41 enum DeviceType {
42  kCPU = 1,
43  kGPU = 2,
45 };
46 
50 class Context {
51  public:
57  Context(const DeviceType &type, int id) : type_(type), id_(id) {}
61  DeviceType GetDeviceType() const { return type_; }
65  int GetDeviceId() const { return id_; }
66 
72  static Context gpu(int device_id = 0) {
73  return Context(DeviceType::kGPU, device_id);
74  }
75 
81  static Context cpu(int device_id = 0) {
82  return Context(DeviceType::kCPU, device_id);
83  }
84 
85  private:
86  DeviceType type_;
87  int id_;
88 };
89 
93 struct NDBlob {
94  public:
98  NDBlob() : handle_(nullptr) {}
103  explicit NDBlob(NDArrayHandle handle) : handle_(handle) {}
107  ~NDBlob() { MXNDArrayFree(handle_); }
112 
113  private:
114  NDBlob(const NDBlob &);
115  NDBlob &operator=(const NDBlob &);
116 };
117 
121 class NDArray {
122  public:
126  NDArray();
130  explicit NDArray(const NDArrayHandle &handle);
137  NDArray(const std::vector<mx_uint> &shape, const Context &context,
138  bool delay_alloc = true);
145  NDArray(const Shape &shape, const Context &context, bool delay_alloc = true);
146  NDArray(const mx_float *data, size_t size);
153  NDArray(const mx_float *data, const Shape &shape, const Context &context);
160  NDArray(const std::vector<mx_float> &data, const Shape &shape,
161  const Context &context);
162  explicit NDArray(const std::vector<mx_float> &data);
163  NDArray operator+(mx_float scalar);
164  NDArray operator-(mx_float scalar);
165  NDArray operator*(mx_float scalar);
166  NDArray operator/(mx_float scalar);
167  NDArray operator%(mx_float scalar);
168  NDArray operator+(const NDArray &);
169  NDArray operator-(const NDArray &);
170  NDArray operator*(const NDArray &);
171  NDArray operator/(const NDArray &);
172  NDArray operator%(const NDArray &);
178  NDArray &operator=(mx_float scalar);
185  NDArray &operator+=(mx_float scalar);
192  NDArray &operator-=(mx_float scalar);
199  NDArray &operator*=(mx_float scalar);
206  NDArray &operator/=(mx_float scalar);
213  NDArray &operator%=(mx_float scalar);
220  NDArray &operator+=(const NDArray &src);
227  NDArray &operator-=(const NDArray &src);
234  NDArray &operator*=(const NDArray &src);
241  NDArray &operator/=(const NDArray &src);
248  NDArray &operator%=(const NDArray &src);
249  NDArray ArgmaxChannel();
260  void SyncCopyFromCPU(const mx_float *data, size_t size);
270  void SyncCopyFromCPU(const std::vector<mx_float> &data);
281  void SyncCopyToCPU(mx_float *data, size_t size = 0);
292  void SyncCopyToCPU(std::vector<mx_float> *data, size_t size = 0);
298  NDArray CopyTo(NDArray * other) const;
304  NDArray Copy(const Context &) const;
311  size_t Offset(size_t h = 0, size_t w = 0) const;
319  size_t Offset(size_t c, size_t h, size_t w) const;
326  mx_float At(size_t h, size_t w) const;
334  mx_float At(size_t c, size_t h, size_t w) const;
341  NDArray Slice(mx_uint begin, mx_uint end) const;
347  NDArray Reshape(const Shape &new_shape) const;
352  void WaitToRead() const;
357  void WaitToWrite();
362  static void WaitAll();
369  static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out);
376  static void SampleUniform(mx_float begin, mx_float end, NDArray *out);
385  static void Load(const std::string &file_name,
386  std::vector<NDArray> *array_list = nullptr,
387  std::map<std::string, NDArray> *array_map = nullptr);
393  static std::map<std::string, NDArray> LoadToMap(const std::string &file_name);
399  static std::vector<NDArray> LoadToList(const std::string &file_name);
409  static void LoadFromBuffer(const void *buffer, size_t size,
410  std::vector<NDArray> *array_list = nullptr,
411  std::map<std::string, NDArray> *array_map = nullptr);
418  static std::map<std::string, NDArray> LoadFromBufferToMap(const void *buffer, size_t size);
425  static std::vector<NDArray> LoadFromBufferToList(const void *buffer, size_t size);
431  static void Save(const std::string &file_name,
432  const std::map<std::string, NDArray> &array_map);
438  static void Save(const std::string &file_name,
439  const std::vector<NDArray> &array_list);
443  size_t Size() const;
447  std::vector<mx_uint> GetShape() const;
451  int GetDType() const;
456  const mx_float *GetData() const;
457 
461  Context GetContext() const;
462 
466  NDArrayHandle GetHandle() const { return blob_ptr_->handle_; }
467 
468  private:
469  std::shared_ptr<NDBlob> blob_ptr_;
470 };
471 
472 std::ostream& operator<<(std::ostream& out, const NDArray &ndarray);
473 } // namespace cpp
474 } // namespace mxnet
475 
476 #endif // MXNET_CPP_NDARRAY_H_
NDBlob()
default constructor
Definition: ndarray.h:98
Symbol operator/(mx_float lhs, const Symbol &rhs)
namespace of mxnet
Definition: base.h:118
Definition: ndarray.h:44
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
MXNET_DLL int MXNDArrayFree(NDArrayHandle handle)
free the narray handle
~NDBlob()
destructor, free the NDArrayHandle
Definition: ndarray.h:107
static Context cpu(int device_id=0)
Return a CPU context.
Definition: ndarray.h:81
Symbol operator%(mx_float lhs, const Symbol &rhs)
Definition: ndarray.h:42
DeviceType
Definition: ndarray.h:41
NDArray interface.
Definition: ndarray.h:121
NDBlob(NDArrayHandle handle)
construct with a NDArrayHandle
Definition: ndarray.h:103
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:67
NDArrayHandle GetHandle() const
Definition: ndarray.h:466
NDArrayHandle handle_
the NDArrayHandle
Definition: ndarray.h:111
Symbol operator+(mx_float lhs, const Symbol &rhs)
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
DeviceType GetDeviceType() const
Definition: ndarray.h:61
std::ostream & operator<<(std::ostream &out, const NDArray &ndarray)
int GetDeviceId() const
Definition: ndarray.h:65
Definition: ndarray.h:43
float mx_float
manually define float
Definition: c_api.h:60
Symbol operator-(mx_float lhs, const Symbol &rhs)
struct to store NDArrayHandle
Definition: ndarray.h:93
definition of shape
static Context gpu(int device_id=0)
Return a GPU context.
Definition: ndarray.h:72
Context interface.
Definition: ndarray.h:50
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:302
Symbol operator*(mx_float lhs, const Symbol &rhs)
Context(const DeviceType &type, int id)
Context constructor.
Definition: ndarray.h:57