26 #ifndef MXNET_CPP_EXECUTOR_H_ 27 #define MXNET_CPP_EXECUTOR_H_ 50 const std::vector<OpReqType> &grad_reqs,
52 const std::map<std::string, Context> &group_to_ctx =
53 std::map<std::string, Context>(),
65 for (
mx_uint i = 0; i < out_size; ++i) {
79 void Backward(
const std::vector<NDArray> &head_grads =
80 std::vector<NDArray>()) {
81 std::vector<NDArrayHandle> head_grads_;
82 for (
auto d : head_grads) {
83 head_grads_.push_back(d.GetHandle());
85 if (head_grads_.size() > 0) {
125 std::map<std::string, NDArray> GetDict(
const std::vector<std::string> &names,
126 const std::vector<NDArray> &arrays) {
127 std::map<std::string, NDArray> ret;
128 std::set<std::string> name_set;
129 for (
const auto &s : names) {
130 CHECK(name_set.find(s) == name_set.end()) <<
"Duplicate names detected, " 134 CHECK_EQ(name_set.size(), arrays.size())
135 <<
"names size not equal to arrays size";
136 for (
size_t i = 0; i < names.size(); ++i) {
137 ret[names[i]] = arrays[i];
144 #endif // MXNET_CPP_EXECUTOR_H_
std::vector< NDArray > outputs
arrays store the outputs of forward
Definition: executor.h:109
void Forward(bool is_train)
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Definition: executor.h:60
std::vector< NDArray > grad_arrays
Definition: executor.h:104
namespace of mxnet
Definition: base.h:126
Executor interface.
Definition: executor.h:44
MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, mx_uint *out_size, NDArrayHandle **out)
Get executor's head NDArray.
Executor(const Symbol &symbol, Context context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
void * ExecutorHandle
handle to an Executor
Definition: c_api.h:76
std::vector< NDArray > arg_arrays
Definition: executor.h:103
std::vector< std::string > ListArguments() const
List the arguments names.
std::string DebugStr()
update the arguments with given learning rate and optimizer
NDArray interface.
Definition: ndarray.h:120
std::map< std::string, NDArray > aux_dict()
Definition: executor.h:116
std::vector< std::string > ListAuxiliaryStates() const
void * NDArrayHandle
handle to NDArray
Definition: c_api.h:64
std::vector< NDArray > aux_arrays
Definition: executor.h:105
unsigned int mx_uint
manually define unsigned int
Definition: c_api.h:57
Monitor interface.
Definition: monitor.h:54
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, mx_uint len, NDArrayHandle *head_grads)
Excecutor run backward.
std::map< std::string, NDArray > grad_dict()
Definition: executor.h:113
void Backward(const std::vector< NDArray > &head_grads=std::vector< NDArray >())
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
Definition: executor.h:79
Executor(const ExecutorHandle &h)
Definition: executor.h:55
std::map< std::string, NDArray > arg_dict()
Definition: executor.h:110
Context interface.
Definition: ndarray.h:49
MXNET_DLL int MXExecutorFree(ExecutorHandle handle)
Delete the executor.
Symbol interface.
Definition: symbol.h:71
~Executor()
destructor, free the handle
Definition: executor.h:102
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train)
Executor forward method.