20 #ifndef MXNET_IMPERATIVE_H_ 21 #define MXNET_IMPERATIVE_H_ 26 #include <nnvm/symbolic.h> 28 #include <nnvm/graph.h> 33 #include <unordered_map> 52 grad_req(
kNullOp), fresh_out_grad(false) {}
54 static void Clear(
const nnvm::NodePtr& node) {
55 if (node ==
nullptr || node->info.empty())
return;
62 return dmlc::get<AGInfo>(node->info);
66 node->info.construct<
AGInfo>();
71 return arr.entry_.node ==
nullptr || arr.entry_.node->info.empty();
96 bool old = is_recording_;
101 void RecordOp(nnvm::NodeAttrs&& attrs,
102 const std::vector<NDArray*>& inputs,
103 const std::vector<NDArray*>&
outputs,
105 std::vector<bool>* p_save_inputs =
nullptr,
106 std::vector<bool>* p_save_outputs =
nullptr);
109 const nnvm::NodeAttrs& attrs,
110 const std::vector<NDArray*>& inputs,
111 const std::vector<NDArray*>& outputs);
114 const nnvm::NodeAttrs& attrs,
115 const std::vector<NDArray*>& inputs,
116 const std::vector<NDArray*>& outputs,
117 const std::vector<OpReqType>& req,
122 const std::vector<mx_uint>& grad_reqs,
123 const std::vector<NDArray*>& gradients);
125 std::vector<NDArray*>
Backward(
const std::vector<NDArray*>& outputs,
126 const std::vector<NDArray*>& ograds,
127 const std::vector<NDArray*>& variables,
128 bool is_train,
bool retain_graph,
137 if (dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
138 backward_bulk_size_ = dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
142 void GetBackwardDependency(
143 const nnvm::NodePtr& node,
144 uint32_t num_inputs, uint32_t num_outputs,
145 std::vector<bool> *p_save_inputs,
146 std::vector<bool> *p_save_outputs);
148 #if DMLC_CXX11_THREAD_LOCAL 149 static thread_local
bool is_train_;
150 static thread_local
bool is_recording_;
152 static MX_THREAD_LOCAL
bool is_train_;
153 static MX_THREAD_LOCAL
bool is_recording_;
156 std::atomic<uint64_t> node_count_{0};
158 std::atomic<uint64_t> variable_count_{0};
160 int backward_bulk_size_{0};
164 #endif // MXNET_IMPERATIVE_H_
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:91
static bool IsNone(const NDArray &arr)
Definition: imperative.h:70
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:65
bool is_training() const
whether operator recording is on.
Definition: imperative.h:81
no operation, do not write anything
Definition: op_attr_types.h:47
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:85
namespace of mxnet
Definition: base.h:118
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:54
Additional operator attributes beside the ones provided by NNVM.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:51
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:113
std::vector< NDArray > outputs
Definition: imperative.h:47
Definition: imperative.h:42
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:95
bool fresh_out_grad
Definition: imperative.h:49
OpStatePtr state
Definition: imperative.h:46
std::vector< NDArray > out_grads
Definition: imperative.h:48
Data structures that can appear in graph attributes.
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
OpReqType grad_req
Definition: imperative.h:45
Context ctx
Definition: imperative.h:44
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:74
runtime functions for NDArray
Definition: imperative.h:39
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:61
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
Context information about the execution environment.
Definition: base.h:133
ndarray interface
Definition: ndarray.h:82
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:129