20 #ifndef MXNET_IMPERATIVE_H_ 21 #define MXNET_IMPERATIVE_H_ 26 #include <nnvm/symbolic.h> 28 #include <nnvm/graph.h> 33 #include <unordered_map> 44 DMLC_DECLARE_FIELD(inline_limit)
46 .describe(
"Maximum number of operators that can be inlined.");
47 DMLC_DECLARE_FIELD(forward_bulk_size)
48 .set_default(dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
49 .describe(
"Segment size of bulk execution during forward pass.");
50 DMLC_DECLARE_FIELD(backward_bulk_size)
51 .set_default(dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
52 .describe(
"Segment size of bulk execution during backward pass.");
69 grad_req(
kNullOp), fresh_out_grad(false) {}
71 static void Clear(
const nnvm::NodePtr& node) {
72 if (node ==
nullptr || node->info.empty())
return;
79 return dmlc::get<AGInfo>(node->info);
83 node->info.construct<
AGInfo>();
88 return arr.entry_.node ==
nullptr || arr.entry_.node->info.empty();
100 const std::vector<std::pair<std::string, std::string> >& kwargs);
102 return fwd_graph_.indexed_graph().input_nodes().size();
105 return fwd_graph_.outputs.size();
108 return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
114 return save_outputs_;
117 return fwd_graph_.indexed_graph().mutable_input_nodes();
119 nnvm::Graph GetForwardGraph(
const bool recording,
120 const std::vector<NDArray*>& inputs);
121 nnvm::Graph GetBackwardGraph(
const OpStatePtr& state,
122 const std::vector<OpReqType>& reqs,
123 const std::vector<NDArray*>& inputs);
124 std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
125 const std::vector<nnvm::NodeEntry>& ograds);
126 void Forward(
const std::shared_ptr<CachedOp>& op_ptr,
127 const std::vector<NDArray*>& inputs,
128 const std::vector<NDArray*>& outputs);
129 void Backward(
const bool retain_graph,
131 const std::vector<NDArray*>& inputs,
132 const std::vector<OpReqType>& reqs,
133 const std::vector<NDArray*>& outputs);
136 struct CachedOpState {
137 std::vector<NDArray> buff;
138 std::vector<OpStatePtr> states;
142 nnvm::Graph fwd_graph_;
143 nnvm::Graph grad_graph_;
144 nnvm::Graph full_graph_;
146 std::vector<nnvm::NodeEntry> ograd_entries_;
147 std::vector<bool> curr_grad_req_;
148 std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
149 std::vector<uint32_t> bwd_input_eid_;
150 std::vector<bool> save_inputs_, save_outputs_;
158 bool old = is_train_;
159 is_train_ = is_train;
164 return is_recording_;
168 bool old = is_recording_;
169 is_recording_ = is_recording;
173 void RecordOp(nnvm::NodeAttrs&& attrs,
174 const std::vector<NDArray*>& inputs,
175 const std::vector<NDArray*>& outputs,
177 std::vector<bool>* p_save_inputs =
nullptr,
178 std::vector<bool>* p_save_outputs =
nullptr);
181 const nnvm::NodeAttrs& attrs,
182 const std::vector<NDArray*>& inputs,
183 const std::vector<NDArray*>& outputs);
186 const nnvm::NodeAttrs& attrs,
187 const std::vector<NDArray*>& inputs,
188 const std::vector<NDArray*>& outputs,
189 const std::vector<OpReqType>& req,
193 void MarkVariables(
const std::vector<NDArray*>& variables,
194 const std::vector<mx_uint>& grad_reqs,
195 const std::vector<NDArray*>& gradients);
197 std::vector<NDArray*> Backward(
const std::vector<NDArray*>& outputs,
198 const std::vector<NDArray*>& ograds,
199 const std::vector<NDArray*>& variables,
200 bool is_train,
bool retain_graph,
209 if (dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
210 backward_bulk_size_ = dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
214 void GetBackwardDependency(
215 const nnvm::NodePtr& node,
216 uint32_t num_inputs, uint32_t num_outputs,
217 std::vector<bool> *p_save_inputs,
218 std::vector<bool> *p_save_outputs);
220 const bool retain_graph,
221 const nnvm::IndexedGraph& idx,
222 const std::vector<NDArray*> arrays,
223 size_t node_start,
size_t node_end,
224 std::vector<OpReqType>&& array_reqs,
225 std::vector<uint32_t>&& ref_count,
226 std::vector<OpStatePtr> *p_states,
229 #if DMLC_CXX11_THREAD_LOCAL 230 static thread_local
bool is_train_;
231 static thread_local
bool is_recording_;
233 static MX_THREAD_LOCAL
bool is_train_;
234 static MX_THREAD_LOCAL
bool is_recording_;
237 std::atomic<uint64_t> node_count_{0};
239 std::atomic<uint64_t> variable_count_{0};
241 int backward_bulk_size_{0};
247 #endif // MXNET_IMPERATIVE_H_
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:163
static bool IsNone(const NDArray &arr)
Definition: imperative.h:87
uint32_t num_outputs()
Definition: imperative.h:104
uint32_t backward_bulk_size
Definition: imperative.h:42
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:82
std::vector< DispatchMode > DispatchModeVector
The result holder of dispatch mode of each Node in the graph.
Definition: graph_attr_types.h:60
bool is_training() const
whether operator recording is on.
Definition: imperative.h:153
no operation, do not write anything
Definition: op_attr_types.h:47
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:157
namespace of mxnet
Definition: base.h:127
uint32_t inline_limit
Definition: imperative.h:40
CachedOp Parameters.
Definition: imperative.h:39
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:71
Additional operator attributes beside the ones provided by NNVM.
AGInfo()
Definition: imperative.h:68
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:105
std::vector< NDArray > outputs
Definition: imperative.h:64
Definition: imperative.h:59
uint32_t num_backward_inputs()
Definition: imperative.h:107
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:167
Definition: imperative.h:97
bool fresh_out_grad
Definition: imperative.h:66
OpStatePtr state
Definition: imperative.h:63
std::vector< NDArray > out_grads
Definition: imperative.h:65
Data structures that can appear in graph attributes.
std::shared_ptr< Imperative::CachedOp > CachedOpPtr
Definition: imperative.h:244
OpReqType grad_req
Definition: imperative.h:62
uint32_t forward_bulk_size
Definition: imperative.h:41
DMLC_DECLARE_PARAMETER(CachedOpParam)
Definition: imperative.h:43
Context ctx
Definition: imperative.h:61
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::vector< bool > & save_inputs()
Definition: imperative.h:110
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:91
runtime functions for NDArray
Definition: imperative.h:56
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:78
const std::unordered_set< uint32_t > & mutable_input_nodes()
Definition: imperative.h:116
uint32_t num_inputs()
Definition: imperative.h:101
Context information about the execution environment.
Definition: base.h:142
ndarray interface
Definition: ndarray.h:79
std::vector< bool > & save_outputs()
Definition: imperative.h:113
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:121