20 #ifndef MXNET_IMPERATIVE_H_ 21 #define MXNET_IMPERATIVE_H_ 26 #include <nnvm/symbolic.h> 28 #include <nnvm/graph.h> 31 #include <unordered_map> 50 grad_req(
kNullOp), fresh_out_grad(false) {}
52 static void Clear(
const nnvm::NodePtr& node) {
53 if (node ==
nullptr || node->info.empty())
return;
60 return dmlc::get<AGInfo>(node->info);
64 node->info.construct<
AGInfo>();
69 return arr.entry_.node ==
nullptr || arr.entry_.node->info.empty();
80 explicit CachedOp(
const nnvm::Symbol& sym);
82 return fwd_graph_.indexed_graph().input_nodes().size();
85 return fwd_graph_.outputs.size();
88 return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
97 return fwd_graph_.indexed_graph().mutable_input_nodes();
99 nnvm::Graph GetForwardGraph(
const bool recording,
100 const std::vector<NDArray*>& inputs);
102 const std::vector<OpReqType>& reqs,
103 const std::vector<NDArray*>& inputs);
104 std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
105 const std::vector<nnvm::NodeEntry>& ograds);
106 OpStatePtr Forward(
const std::vector<NDArray*>& inputs,
107 const std::vector<NDArray*>&
outputs);
108 void Backward(
const bool retain_graph,
110 const std::vector<NDArray*>& inputs,
111 const std::vector<OpReqType>& reqs,
112 const std::vector<NDArray*>& outputs);
115 struct CachedOpState {
116 std::vector<NDArray> buff;
117 std::vector<OpStatePtr> states;
120 nnvm::Graph fwd_graph_;
121 nnvm::Graph grad_graph_;
122 nnvm::Graph full_graph_;
123 std::vector<bool> curr_grad_req_;
124 std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
125 std::vector<uint32_t> bwd_input_eid_;
126 std::vector<bool> save_inputs_, save_outputs_;
134 bool old = is_train_;
135 is_train_ = is_train;
140 return is_recording_;
144 bool old = is_recording_;
149 void RecordOp(nnvm::NodeAttrs&& attrs,
150 const std::vector<NDArray*>& inputs,
151 const std::vector<NDArray*>&
outputs,
153 std::vector<bool>* p_save_inputs =
nullptr,
154 std::vector<bool>* p_save_outputs =
nullptr);
157 const nnvm::NodeAttrs& attrs,
158 const std::vector<NDArray*>& inputs,
159 const std::vector<NDArray*>& outputs);
162 const nnvm::NodeAttrs& attrs,
163 const std::vector<NDArray*>& inputs,
164 const std::vector<NDArray*>& outputs,
165 const std::vector<OpReqType>& req,
170 const std::vector<mx_uint>& grad_reqs,
171 const std::vector<NDArray*>& gradients);
173 std::vector<NDArray*>
Backward(
const std::vector<NDArray*>& outputs,
174 const std::vector<NDArray*>& ograds,
175 const std::vector<NDArray*>& variables,
176 bool is_train,
bool retain_graph,
186 void GetBackwardDependency(
187 const nnvm::NodePtr& node,
188 uint32_t num_inputs, uint32_t num_outputs,
189 std::vector<bool> *p_save_inputs,
190 std::vector<bool> *p_save_outputs);
192 const bool retain_graph,
193 const nnvm::IndexedGraph& idx,
194 const std::vector<NDArray*> arrays,
195 size_t node_start,
size_t node_end,
196 std::vector<OpReqType>&& array_reqs,
197 std::vector<uint32_t>&& ref_count,
198 std::vector<OpStatePtr> *p_states,
201 #if DMLC_CXX11_THREAD_LOCAL 202 static thread_local
bool is_train_;
203 static thread_local
bool is_recording_;
205 static MX_THREAD_LOCAL
bool is_train_;
206 static MX_THREAD_LOCAL
bool is_recording_;
209 std::atomic<uint64_t> node_count_{0};
211 std::atomic<uint64_t> variable_count_{0};
217 #endif // MXNET_IMPERATIVE_H_
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:139
static bool IsNone(const NDArray &arr)
Definition: imperative.h:68
uint32_t num_outputs()
Definition: imperative.h:84
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:63
std::vector< DispatchMode > DispatchModeVector
The result holder of dispatch mode of each Node in the graph.
Definition: graph_attr_types.h:60
bool is_training() const
whether operator recording is on.
Definition: imperative.h:129
no operation, do not write anything
Definition: op_attr_types.h:46
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:133
namespace of mxnet
Definition: base.h:126
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:52
Additional operator attributes beside the ones provided by NNVM.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:49
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:106
std::vector< NDArray > outputs
Definition: imperative.h:45
Definition: imperative.h:40
uint32_t num_backward_inputs()
Definition: imperative.h:87
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:143
Definition: imperative.h:78
bool fresh_out_grad
Definition: imperative.h:47
OpStatePtr state
Definition: imperative.h:44
std::vector< NDArray > out_grads
Definition: imperative.h:46
Data structures that can appear in graph attributes.
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
std::shared_ptr< Imperative::CachedOp > CachedOpPtr
Definition: imperative.h:214
OpReqType grad_req
Definition: imperative.h:43
Context ctx
Definition: imperative.h:42
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:44
std::vector< bool > & save_inputs()
Definition: imperative.h:90
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:72
runtime functions for NDArray
Definition: imperative.h:37
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:59
const std::unordered_set< uint32_t > & mutable_input_nodes()
Definition: imperative.h:96
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
uint32_t num_inputs()
Definition: imperative.h:81
Context information about the execution environment.
Definition: base.h:141
ndarray interface
Definition: ndarray.h:69
std::vector< bool > & save_outputs()
Definition: imperative.h:93
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:122