27 #ifndef MXNET_CPP_EXECUTOR_H_ 28 #define MXNET_CPP_EXECUTOR_H_ 51 const std::vector<OpReqType> &grad_reqs,
53 const std::map<std::string, Context> &group_to_ctx =
54 std::map<std::string, Context>(),
66 for (
mx_uint i = 0; i < out_size; ++i) {
80 void Backward(
const std::vector<NDArray> &head_grads =
81 std::vector<NDArray>()) {
82 std::vector<NDArrayHandle> head_grads_;
83 for (
auto d : head_grads) {
84 head_grads_.push_back(d.GetHandle());
86 if (head_grads_.size() > 0) {
126 std::map<std::string, NDArray> GetDict(
const std::vector<std::string> &names,
127 const std::vector<NDArray> &arrays) {
128 std::map<std::string, NDArray> ret;
129 std::set<std::string> name_set;
130 for (
const auto &s : names) {
131 CHECK(name_set.find(s) == name_set.end()) <<
"Duplicate names detected, " 135 CHECK_EQ(name_set.size(), arrays.size())
136 <<
"names size not equal to arrays size";
137 for (
size_t i = 0; i < names.size(); ++i) {
138 ret[names[i]] = arrays[i];
145 #endif // MXNET_CPP_EXECUTOR_H_
std::vector< NDArray > outputs
arrays store the outputs of forward
Definition: executor.h:110
void Forward(bool is_train)
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Definition: executor.h:61
std::vector< NDArray > grad_arrays
Definition: executor.h:105
namespace of mxnet
Definition: base.h:118
Executor interface.
Definition: executor.h:45
MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out)
Get executor's head NDArray.
Executor(const Symbol &symbol, Context context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
void * ExecutorHandle
handle to an Executor
Definition: c_api.h:79
std::vector< NDArray > arg_arrays
Definition: executor.h:104
std::vector< std::string > ListArguments() const
List the arguments names.
std::string DebugStr()
update the arguments with given learning rate and optimizer
NDArray interface.
Definition: ndarray.h:121
std::map< std::string, NDArray > aux_dict()
Definition: executor.h:117
std::vector< std::string > ListAuxiliaryStates() const
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:67
std::vector< NDArray > aux_arrays
Definition: executor.h:106
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:58
Monitor interface.
Definition: monitor.h:55
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads)
Excecutor run backward.
std::map< std::string, NDArray > grad_dict()
Definition: executor.h:114
void Backward(const std::vector< NDArray > &head_grads=std::vector< NDArray >())
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
Definition: executor.h:80
Executor(const ExecutorHandle &h)
Definition: executor.h:56
std::map< std::string, NDArray > arg_dict()
Definition: executor.h:111
Context interface.
Definition: ndarray.h:50
MXNET_DLL int MXExecutorFree(ExecutorHandle handle)
Delete the executor.
Symbol interface.
Definition: symbol.h:72
~Executor()
destructor, free the handle
Definition: executor.h:103
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train)
Executor forward method.