20 #ifndef MXNET_IMPERATIVE_H_    21 #define MXNET_IMPERATIVE_H_    26 #include <nnvm/symbolic.h>    28 #include <nnvm/graph.h>    33 #include <unordered_map>    52       grad_req(
kNullOp), fresh_out_grad(false) {}
    54     static void Clear(
const nnvm::NodePtr& node) {
    55       if (node == 
nullptr || node->info.empty()) 
return;
    62       return dmlc::get<AGInfo>(node->info);
    66       node->info.construct<
AGInfo>();
    71       return arr.entry_.node == 
nullptr || arr.entry_.node->info.empty();
    96       bool old = is_recording_;
   101   void RecordOp(nnvm::NodeAttrs&& attrs,
   102                 const std::vector<NDArray*>& inputs,
   103                 const std::vector<NDArray*>& 
outputs,
   105                 std::vector<bool>* p_save_inputs = 
nullptr,
   106                 std::vector<bool>* p_save_outputs = 
nullptr);
   109                     const nnvm::NodeAttrs& attrs,
   110                     const std::vector<NDArray*>& inputs,
   111                     const std::vector<NDArray*>& outputs);
   114                       const nnvm::NodeAttrs& attrs,
   115                       const std::vector<NDArray*>& inputs,
   116                       const std::vector<NDArray*>& outputs,
   117                       const std::vector<OpReqType>& req,
   122                      const std::vector<mx_uint>& grad_reqs,
   123                      const std::vector<NDArray*>& gradients);
   125   std::vector<NDArray*> 
Backward(
const std::vector<NDArray*>& outputs,
   126                                  const std::vector<NDArray*>& ograds,
   127                                  const std::vector<NDArray*>& variables,
   128                                  bool is_train, 
bool retain_graph,
   137     if (dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
   138       backward_bulk_size_ =  dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
   142   void GetBackwardDependency(
   143       const nnvm::NodePtr& node,
   144       uint32_t num_inputs, uint32_t num_outputs,
   145       std::vector<bool> *p_save_inputs,
   146       std::vector<bool> *p_save_outputs);
   148 #if DMLC_CXX11_THREAD_LOCAL   149   static thread_local 
bool is_train_;
   150   static thread_local 
bool is_recording_;
   152   static MX_THREAD_LOCAL 
bool is_train_;
   153   static MX_THREAD_LOCAL 
bool is_recording_;
   156   std::atomic<uint64_t> node_count_{0};
   158   std::atomic<uint64_t> variable_count_{0};
   160   int backward_bulk_size_{0};
   164 #endif  // MXNET_IMPERATIVE_H_ 
bool is_recording() const 
whether operator recording is on. 
Definition: imperative.h:91
 
static bool IsNone(const NDArray &arr)
Definition: imperative.h:70
 
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:65
 
bool is_training() const 
whether operator recording is on. 
Definition: imperative.h:81
 
no operation, do not write anything 
Definition: op_attr_types.h:47
 
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd. 
Definition: imperative.h:85
 
namespace of mxnet 
Definition: base.h:118
 
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables. 
 
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:54
 
Additional operator attributes beside the ones provided by NNVM. 
 
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
 
AGInfo()
Definition: imperative.h:51
 
DispatchMode
the dispatch mode of the operator 
Definition: op_attr_types.h:113
 
std::vector< NDArray > outputs
Definition: imperative.h:47
 
Definition: imperative.h:42
 
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd. 
Definition: imperative.h:95
 
bool fresh_out_grad
Definition: imperative.h:49
 
OpStatePtr state
Definition: imperative.h:46
 
std::vector< NDArray > out_grads
Definition: imperative.h:48
 
Data structures that can appear in graph attributes. 
 
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node. 
 
OpReqType grad_req
Definition: imperative.h:45
 
Context ctx
Definition: imperative.h:44
 
OpReqType
operation request type to Forward and Backward 
Definition: op_attr_types.h:45
 
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
 
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:74
 
runtime functions for NDArray 
Definition: imperative.h:39
 
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:61
 
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients. 
 
Context information about the execution environment. 
Definition: base.h:133
 
ndarray interface 
Definition: ndarray.h:82
 
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:129