mxnet
executor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_EXECUTOR_H_
27 #define MXNET_CPP_EXECUTOR_H_
28 
29 #include <vector>
30 #include <map>
31 #include <set>
32 #include <string>
33 #include "mxnet-cpp/base.h"
34 #include "mxnet-cpp/symbol.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class Optimizer;
40 
44 class Executor {
45  friend class Monitor;
46  public:
47  Executor(const Symbol &symbol, Context context,
48  const std::vector<NDArray> &arg_arrays,
49  const std::vector<NDArray> &grad_arrays,
50  const std::vector<OpReqType> &grad_reqs,
51  const std::vector<NDArray> &aux_arrays,
52  const std::map<std::string, Context> &group_to_ctx =
53  std::map<std::string, Context>(),
54  Executor *shared_exec = nullptr);
55  explicit Executor(const ExecutorHandle &h) { handle_ = h; }
60  void Forward(bool is_train) {
61  MXExecutorForward(handle_, is_train ? 1 : 0);
62  mx_uint out_size;
63  NDArrayHandle *out_array;
64  CHECK_EQ(MXExecutorOutputs(handle_, &out_size, &out_array), 0);
65  for (mx_uint i = 0; i < out_size; ++i) {
66  outputs[i] = NDArray(out_array[i]);
67  }
68  }
79  void Backward(const std::vector<NDArray> &head_grads =
80  std::vector<NDArray>()) {
81  std::vector<NDArrayHandle> head_grads_;
82  for (auto d : head_grads) {
83  head_grads_.push_back(d.GetHandle());
84  }
85  if (head_grads_.size() > 0) {
86  MXExecutorBackward(handle_, head_grads_.size(), head_grads_.data());
87  } else {
88  MXExecutorBackward(handle_, 0, nullptr);
89  }
90  }
91  // TODO(zhangchen-qinyinghua)
92  // To implement reshape function
93  void Reshape();
98  std::string DebugStr();
102  ~Executor() { MXExecutorFree(handle_); }
103  std::vector<NDArray> arg_arrays;
104  std::vector<NDArray> grad_arrays;
105  std::vector<NDArray> aux_arrays;
109  std::vector<NDArray> outputs;
110  std::map<std::string, NDArray> arg_dict() {
111  return GetDict(symbol_.ListArguments(), arg_arrays);
112  }
113  std::map<std::string, NDArray> grad_dict() {
114  return GetDict(symbol_.ListArguments(), grad_arrays);
115  }
116  std::map<std::string, NDArray> aux_dict() {
117  return GetDict(symbol_.ListAuxiliaryStates(), aux_arrays);
118  }
119 
120  private:
121  Executor(const Executor &e);
122  Executor &operator=(const Executor &e);
123  ExecutorHandle handle_;
124  Symbol symbol_;
125  std::map<std::string, NDArray> GetDict(const std::vector<std::string> &names,
126  const std::vector<NDArray> &arrays) {
127  std::map<std::string, NDArray> ret;
128  std::set<std::string> name_set;
129  for (const auto &s : names) {
130  CHECK(name_set.find(s) == name_set.end()) << "Duplicate names detected, "
131  << s;
132  name_set.insert(s);
133  }
134  CHECK_EQ(name_set.size(), arrays.size())
135  << "names size not equal to arrays size";
136  for (size_t i = 0; i < names.size(); ++i) {
137  ret[names[i]] = arrays[i];
138  }
139  return ret;
140  }
141 };
142 } // namespace cpp
143 } // namespace mxnet
144 #endif // MXNET_CPP_EXECUTOR_H_
definition of symbol
std::vector< NDArray > outputs
arrays store the outputs of forward
Definition: executor.h:109
void Forward(bool is_train)
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Definition: executor.h:60
std::vector< NDArray > grad_arrays
Definition: executor.h:104
namespace of mxnet
Definition: base.h:126
Executor interface.
Definition: executor.h:44
MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out)
Get executor&#39;s head NDArray.
Executor(const Symbol &symbol, Context context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
void * ExecutorHandle
handle to an Executor
Definition: c_api.h:76
std::vector< NDArray > arg_arrays
Definition: executor.h:103
std::vector< std::string > ListArguments() const
List the arguments names.
std::string DebugStr()
update the arguments with given learning rate and optimizer
NDArray interface.
Definition: ndarray.h:120
std::map< std::string, NDArray > aux_dict()
Definition: executor.h:116
std::vector< std::string > ListAuxiliaryStates() const
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:64
std::vector< NDArray > aux_arrays
Definition: executor.h:105
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:57
Monitor interface.
Definition: monitor.h:54
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads)
Excecutor run backward.
std::map< std::string, NDArray > grad_dict()
Definition: executor.h:113
void Backward(const std::vector< NDArray > &head_grads=std::vector< NDArray >())
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
Definition: executor.h:79
Executor(const ExecutorHandle &h)
Definition: executor.h:55
std::map< std::string, NDArray > arg_dict()
Definition: executor.h:110
Context interface.
Definition: ndarray.h:49
MXNET_DLL int MXExecutorFree(ExecutorHandle handle)
Delete the executor.
Symbol interface.
Definition: symbol.h:71
~Executor()
destructor, free the handle
Definition: executor.h:102
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train)
Executor forward method.