mxnet
ndarray.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_NDARRAY_H_
28 #define MXNET_CPP_NDARRAY_H_
29 
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <vector>
34 #include <iostream>
35 #include "mxnet-cpp/base.h"
36 #include "mxnet-cpp/shape.h"
37 
38 namespace mxnet {
39 namespace cpp {
40 
41 enum DeviceType {
42  kCPU = 1,
43  kGPU = 2,
45 };
46 
50 class Context {
51  public:
57  Context(const DeviceType &type, int id) : type_(type), id_(id) {}
61  DeviceType GetDeviceType() const { return type_; }
65  int GetDeviceId() const { return id_; }
66 
72  static Context gpu(int device_id = 0) {
73  return Context(DeviceType::kGPU, device_id);
74  }
75 
81  static Context cpu(int device_id = 0) {
82  return Context(DeviceType::kCPU, device_id);
83  }
84 
85  private:
86  DeviceType type_;
87  int id_;
88 };
89 
93 struct NDBlob {
94  public:
98  NDBlob() : handle_(nullptr) {}
103  explicit NDBlob(NDArrayHandle handle) : handle_(handle) {}
107  ~NDBlob() { MXNDArrayFree(handle_); }
112 
113  private:
114  NDBlob(const NDBlob &);
115  NDBlob &operator=(const NDBlob &);
116 };
117 
121 class NDArray {
122  public:
126  NDArray();
130  explicit NDArray(const NDArrayHandle &handle);
138  NDArray(const std::vector<mx_uint> &shape, const Context &context,
139  bool delay_alloc = true, int dtype = 0);
147  NDArray(const Shape &shape, const Context &context,
148  bool delay_alloc = true, int dtype = 0);
149  NDArray(const mx_float *data, size_t size);
156  NDArray(const mx_float *data, const Shape &shape, const Context &context);
163  NDArray(const std::vector<mx_float> &data, const Shape &shape,
164  const Context &context);
165  explicit NDArray(const std::vector<mx_float> &data);
166  NDArray operator+(mx_float scalar);
167  NDArray operator-(mx_float scalar);
168  NDArray operator*(mx_float scalar);
169  NDArray operator/(mx_float scalar);
170  NDArray operator%(mx_float scalar);
171  NDArray operator+(const NDArray &);
172  NDArray operator-(const NDArray &);
173  NDArray operator*(const NDArray &);
174  NDArray operator/(const NDArray &);
175  NDArray operator%(const NDArray &);
181  NDArray &operator=(mx_float scalar);
188  NDArray &operator+=(mx_float scalar);
195  NDArray &operator-=(mx_float scalar);
202  NDArray &operator*=(mx_float scalar);
209  NDArray &operator/=(mx_float scalar);
216  NDArray &operator%=(mx_float scalar);
223  NDArray &operator+=(const NDArray &src);
230  NDArray &operator-=(const NDArray &src);
237  NDArray &operator*=(const NDArray &src);
244  NDArray &operator/=(const NDArray &src);
251  NDArray &operator%=(const NDArray &src);
252  NDArray ArgmaxChannel();
263  void SyncCopyFromCPU(const mx_float *data, size_t size);
273  void SyncCopyFromCPU(const std::vector<mx_float> &data);
284  void SyncCopyToCPU(mx_float *data, size_t size = 0);
295  void SyncCopyToCPU(std::vector<mx_float> *data, size_t size = 0);
301  NDArray CopyTo(NDArray * other) const;
307  NDArray Copy(const Context &) const;
314  size_t Offset(size_t h = 0, size_t w = 0) const;
322  size_t Offset(size_t c, size_t h, size_t w) const;
328  mx_float At(size_t index) const;
335  mx_float At(size_t h, size_t w) const;
343  mx_float At(size_t c, size_t h, size_t w) const;
350  NDArray Slice(mx_uint begin, mx_uint end) const;
356  NDArray Reshape(const Shape &new_shape) const;
361  void WaitToRead() const;
366  void WaitToWrite();
371  static void WaitAll();
378  static void SampleGaussian(mx_float mu, mx_float sigma, NDArray *out);
385  static void SampleUniform(mx_float begin, mx_float end, NDArray *out);
394  static void Load(const std::string &file_name,
395  std::vector<NDArray> *array_list = nullptr,
396  std::map<std::string, NDArray> *array_map = nullptr);
402  static std::map<std::string, NDArray> LoadToMap(const std::string &file_name);
408  static std::vector<NDArray> LoadToList(const std::string &file_name);
418  static void LoadFromBuffer(const void *buffer, size_t size,
419  std::vector<NDArray> *array_list = nullptr,
420  std::map<std::string, NDArray> *array_map = nullptr);
427  static std::map<std::string, NDArray> LoadFromBufferToMap(const void *buffer, size_t size);
434  static std::vector<NDArray> LoadFromBufferToList(const void *buffer, size_t size);
440  static void Save(const std::string &file_name,
441  const std::map<std::string, NDArray> &array_map);
447  static void Save(const std::string &file_name,
448  const std::vector<NDArray> &array_list);
452  size_t Size() const;
456  std::vector<mx_uint> GetShape() const;
460  int GetDType() const;
465  const mx_float *GetData() const;
466 
470  Context GetContext() const;
471 
475  NDArrayHandle GetHandle() const { return blob_ptr_->handle_; }
476 
477  private:
478  std::shared_ptr<NDBlob> blob_ptr_;
479 };
480 
481 std::ostream& operator<<(std::ostream& out, const NDArray &ndarray);
482 } // namespace cpp
483 } // namespace mxnet
484 
485 #endif // MXNET_CPP_NDARRAY_H_
NDBlob()
default constructor
Definition: ndarray.h:98
Symbol operator/(mx_float lhs, const Symbol &rhs)
namespace of mxnet
Definition: base.h:89
Definition: ndarray.h:44
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
MXNET_DLL int MXNDArrayFree(NDArrayHandle handle)
free the narray handle
~NDBlob()
destructor, free the NDArrayHandle
Definition: ndarray.h:107
static Context cpu(int device_id=0)
Return a CPU context.
Definition: ndarray.h:81
Symbol operator%(mx_float lhs, const Symbol &rhs)
Definition: ndarray.h:42
DeviceType
Definition: ndarray.h:41
NDArray interface.
Definition: ndarray.h:121
NDBlob(NDArrayHandle handle)
construct with a NDArrayHandle
Definition: ndarray.h:103
void SampleGaussian(real_t mu, real_t sigma, NDArray *out)
Sample gaussian distribution for each elements of out.
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:67
NDArrayHandle GetHandle() const
Definition: ndarray.h:475
NDArrayHandle handle_
the NDArrayHandle
Definition: ndarray.h:111
Symbol operator+(mx_float lhs, const Symbol &rhs)
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
void SampleUniform(real_t begin, real_t end, NDArray *out)
Sample uniform distribution for each elements of out.
DeviceType GetDeviceType() const
Definition: ndarray.h:61
std::ostream & operator<<(std::ostream &out, const NDArray &ndarray)
int GetDeviceId() const
Definition: ndarray.h:65
Definition: ndarray.h:43
float mx_float
manually define float
Definition: c_api.h:60
Symbol operator-(mx_float lhs, const Symbol &rhs)
struct to store NDArrayHandle
Definition: ndarray.h:93
definition of shape
static Context gpu(int device_id=0)
Return a GPU context.
Definition: ndarray.h:72
Context interface.
Definition: ndarray.h:50
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Reshapes the input array.
Definition: op.h:344
Symbol operator*(mx_float lhs, const Symbol &rhs)
Context(const DeviceType &type, int id)
Context constructor.
Definition: ndarray.h:57