mxnet
executor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_EXECUTOR_H_
28 #define MXNET_CPP_EXECUTOR_H_
29 
30 #include <vector>
31 #include <map>
32 #include <set>
33 #include <string>
34 #include "mxnet-cpp/base.h"
35 #include "mxnet-cpp/symbol.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
40 class Optimizer;
41 
45 class Executor {
46  friend class Monitor;
47  public:
48  Executor(const Symbol &symbol, Context context,
49  const std::vector<NDArray> &arg_arrays,
50  const std::vector<NDArray> &grad_arrays,
51  const std::vector<OpReqType> &grad_reqs,
52  const std::vector<NDArray> &aux_arrays,
53  const std::map<std::string, Context> &group_to_ctx =
54  std::map<std::string, Context>(),
55  Executor *shared_exec = nullptr);
56  explicit Executor(const ExecutorHandle &h) { handle_ = h; }
61  void Forward(bool is_train) {
62  MXExecutorForward(handle_, is_train ? 1 : 0);
63  mx_uint out_size;
64  NDArrayHandle *out_array;
65  CHECK_EQ(MXExecutorOutputs(handle_, &out_size, &out_array), 0);
66  for (mx_uint i = 0; i < out_size; ++i) {
67  outputs[i] = NDArray(out_array[i]);
68  }
69  }
80  void Backward(const std::vector<NDArray> &head_grads =
81  std::vector<NDArray>()) {
82  std::vector<NDArrayHandle> head_grads_;
83  for (auto d : head_grads) {
84  head_grads_.push_back(d.GetHandle());
85  }
86  if (head_grads_.size() > 0) {
87  MXExecutorBackward(handle_, head_grads_.size(), head_grads_.data());
88  } else {
89  MXExecutorBackward(handle_, 0, nullptr);
90  }
91  }
92  // TODO(zhangchen-qinyinghua)
93  // To implement reshape function
94  void Reshape();
99  std::string DebugStr();
103  ~Executor() { MXExecutorFree(handle_); }
104  std::vector<NDArray> arg_arrays;
105  std::vector<NDArray> grad_arrays;
106  std::vector<NDArray> aux_arrays;
110  std::vector<NDArray> outputs;
111  std::map<std::string, NDArray> arg_dict() {
112  return GetDict(symbol_.ListArguments(), arg_arrays);
113  }
114  std::map<std::string, NDArray> grad_dict() {
115  return GetDict(symbol_.ListArguments(), grad_arrays);
116  }
117  std::map<std::string, NDArray> aux_dict() {
118  return GetDict(symbol_.ListAuxiliaryStates(), aux_arrays);
119  }
120 
121  private:
122  Executor(const Executor &e);
123  Executor &operator=(const Executor &e);
124  ExecutorHandle handle_;
125  Symbol symbol_;
126  std::map<std::string, NDArray> GetDict(const std::vector<std::string> &names,
127  const std::vector<NDArray> &arrays) {
128  std::map<std::string, NDArray> ret;
129  std::set<std::string> name_set;
130  for (const auto &s : names) {
131  CHECK(name_set.find(s) == name_set.end()) << "Duplicate names detected, "
132  << s;
133  name_set.insert(s);
134  }
135  CHECK_EQ(name_set.size(), arrays.size())
136  << "names size not equal to arrays size";
137  for (size_t i = 0; i < names.size(); ++i) {
138  ret[names[i]] = arrays[i];
139  }
140  return ret;
141  }
142 };
143 } // namespace cpp
144 } // namespace mxnet
145 #endif // MXNET_CPP_EXECUTOR_H_
definition of symbol
std::vector< NDArray > outputs
arrays store the outputs of forward
Definition: executor.h:110
void Forward(bool is_train)
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Definition: executor.h:61
std::vector< NDArray > grad_arrays
Definition: executor.h:105
namespace of mxnet
Definition: base.h:89
Executor interface.
Definition: executor.h:45
MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out)
Get executor&#39;s head NDArray.
Executor(const Symbol &symbol, Context context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
void * ExecutorHandle
handle to an Executor
Definition: c_api.h:79
std::vector< NDArray > arg_arrays
Definition: executor.h:104
std::vector< std::string > ListArguments() const
List the arguments names.
std::string DebugStr()
update the arguments with given learning rate and optimizer
NDArray interface.
Definition: ndarray.h:121
std::map< std::string, NDArray > aux_dict()
Definition: executor.h:117
std::vector< std::string > ListAuxiliaryStates() const
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:67
std::vector< NDArray > aux_arrays
Definition: executor.h:106
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
Monitor interface.
Definition: monitor.h:55
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads)
Excecutor run backward.
std::map< std::string, NDArray > grad_dict()
Definition: executor.h:114
void Backward(const std::vector< NDArray > &head_grads=std::vector< NDArray >())
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
Definition: executor.h:80
Executor(const ExecutorHandle &h)
Definition: executor.h:56
std::map< std::string, NDArray > arg_dict()
Definition: executor.h:111
Context interface.
Definition: ndarray.h:50
MXNET_DLL int MXExecutorFree(ExecutorHandle handle)
Delete the executor.
Symbol interface.
Definition: symbol.h:72
~Executor()
destructor, free the handle
Definition: executor.h:103
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train)
Executor forward method.