25 #ifndef MXNET_EXECUTOR_H_    26 #define MXNET_EXECUTOR_H_    28 #include <dmlc/base.h>    40 #if DMLC_USE_CXX11 == 0    41 #error "CXX11 was required for symbolic module"    60   virtual void Forward(
bool is_train) = 0;
    69   virtual void PartialForward(
bool is_train, 
int step, 
int *step_left) = 0;
    79   virtual void Backward(
const std::vector<NDArray> &head_grads, 
bool is_train = 
true) = 0;
    84   virtual void Print(std::ostream &os)
 const {} 
    89  virtual const std::vector<NDArray> &
outputs() 
const = 0;
    94   virtual const std::unordered_map<std::string, NDArray>& 
in_arg_map() 
const = 0;
    99   virtual const std::unordered_map<std::string, NDArray>& 
arg_grad_map() 
const = 0;
   104   virtual const std::unordered_map<std::string, NDArray>& 
aux_state_map() 
const = 0;
   121                         const std::map<std::string, Context>& group2ctx,
   122                         const std::vector<NDArray> &in_args,
   123                         const std::vector<NDArray> &arg_grad_store,
   124                         const std::vector<OpReqType> &grad_req_type,
   125                         const std::vector<NDArray> &aux_states,
   130                               const std::map<std::string, Context>& group2ctx,
   131                               const std::vector<Context>& in_arg_ctxes,
   132                               const std::vector<Context>& arg_grad_ctxes,
   133                               const std::vector<Context>& aux_state_ctxes,
   134                               const std::unordered_map<std::string, TShape>& arg_shape_map,
   135                               const std::unordered_map<std::string, int>& arg_dtype_map,
   136                               const std::unordered_map<std::string, int>& arg_stype_map,
   137                               const std::vector<OpReqType>& grad_req_types,
   138                               const std::unordered_set<std::string>& param_names,
   139                               std::vector<NDArray>* in_args,
   140                               std::vector<NDArray>* arg_grads,
   141                               std::vector<NDArray>* aux_states,
   142                               std::unordered_map<std::string, NDArray>*
   143                                 shared_data_arrays = 
nullptr,
   155 #endif  // MXNET_EXECUTOR_H_ Executor of a computation graph. Executor can be created by Binding a symbol. 
Definition: executor.h:52
 
virtual ~Executor()
destructor 
Definition: executor.h:55
 
static Executor * SimpleBind(nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< Context > &in_arg_ctxes, const std::vector< Context > &arg_grad_ctxes, const std::vector< Context > &aux_state_ctxes, const std::unordered_map< std::string, TShape > &arg_shape_map, const std::unordered_map< std::string, int > &arg_dtype_map, const std::unordered_map< std::string, int > &arg_stype_map, const std::vector< OpReqType > &grad_req_types, const std::unordered_set< std::string > ¶m_names, std::vector< NDArray > *in_args, std::vector< NDArray > *arg_grads, std::vector< NDArray > *aux_states, std::unordered_map< std::string, NDArray > *shared_data_arrays=nullptr, Executor *shared_exec=nullptr)
 
std::function< void(const char *, void *)> MonitorCallback
the prototype of user-defined monitor callback 
Definition: executor.h:148
 
virtual void Backward(const std::vector< NDArray > &head_grads, bool is_train=true)=0
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
 
namespace of mxnet 
Definition: base.h:126
 
static Executor * Bind(nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< NDArray > &in_args, const std::vector< NDArray > &arg_grad_store, const std::vector< OpReqType > &grad_req_type, const std::vector< NDArray > &aux_states, Executor *shared_exec=NULL)
Create an operator by bind symbol with context and arguments. If user do not want to compute the grad...
 
virtual void Print(std::ostream &os) const 
print the execution plan info to output stream. 
Definition: executor.h:84
 
virtual const std::unordered_map< std::string, NDArray > & in_arg_map() const =0
get input argument map, key is arg name, value is arg's NDArray. 
 
virtual const std::vector< NDArray > & outputs() const =0
get array of outputs in the executor. 
 
virtual void SetMonitorCallback(const MonitorCallback &callback)
Install a callback to notify the completion of operation. 
Definition: executor.h:152
 
virtual void Forward(bool is_train)=0
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
 
virtual void PartialForward(bool is_train, int step, int *step_left)=0
Perform a Partial Forward operation of Operator. Only issue operation specified by step...
 
virtual const std::unordered_map< std::string, NDArray > & arg_grad_map() const =0
get input argument graident map, key is arg name, value is gradient's NDArray. 
 
virtual const std::unordered_map< std::string, NDArray > & aux_state_map() const =0
get aux state map, key is arg name, value is aux state's NDArray. 
 
Context information about the execution environment. 
Definition: base.h:141