20 #ifndef MXNET_IMPERATIVE_H_ 21 #define MXNET_IMPERATIVE_H_ 26 #include <nnvm/symbolic.h> 28 #include <nnvm/graph.h> 31 #include <unordered_map> 50 grad_req(
kNullOp), fresh_out_grad(false) {}
52 static void Clear(
const nnvm::NodePtr& node) {
53 if (node ==
nullptr || node->info.empty())
return;
60 return dmlc::get<AGInfo>(node->info);
64 node->info.construct<
AGInfo>();
69 return arr.entry_.node ==
nullptr || arr.entry_.node->info.empty();
80 explicit CachedOp(
const nnvm::Symbol& sym);
82 return fwd_graph_.indexed_graph().input_nodes().size();
85 return fwd_graph_.outputs.size();
88 return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
97 return fwd_graph_.indexed_graph().mutable_input_nodes();
99 nnvm::Graph GetForwardGraph(
const bool recording,
100 const std::vector<NDArray*>& inputs);
102 const std::vector<OpReqType>& reqs,
103 const std::vector<NDArray*>& inputs);
104 std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
105 const std::vector<nnvm::NodeEntry>& ograds);
106 OpStatePtr Forward(
const std::vector<NDArray*>& inputs,
107 const std::vector<NDArray*>&
outputs);
108 void Backward(
const bool retain_graph,
110 const std::vector<NDArray*>& inputs,
111 const std::vector<OpReqType>& reqs,
112 const std::vector<NDArray*>& outputs);
115 struct CachedOpState {
116 std::vector<NDArray> buff;
117 std::vector<OpStatePtr> states;
120 nnvm::Graph fwd_graph_;
121 nnvm::Graph grad_graph_;
122 nnvm::Graph full_graph_;
123 std::vector<nnvm::NodeEntry> ograd_entries_;
124 std::vector<bool> curr_grad_req_;
125 std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
126 std::vector<uint32_t> bwd_input_eid_;
127 std::vector<bool> save_inputs_, save_outputs_;
135 bool old = is_train_;
136 is_train_ = is_train;
141 return is_recording_;
145 bool old = is_recording_;
150 void RecordOp(nnvm::NodeAttrs&& attrs,
151 const std::vector<NDArray*>& inputs,
152 const std::vector<NDArray*>&
outputs,
154 std::vector<bool>* p_save_inputs =
nullptr,
155 std::vector<bool>* p_save_outputs =
nullptr);
158 const nnvm::NodeAttrs& attrs,
159 const std::vector<NDArray*>& inputs,
160 const std::vector<NDArray*>& outputs);
163 const nnvm::NodeAttrs& attrs,
164 const std::vector<NDArray*>& inputs,
165 const std::vector<NDArray*>& outputs,
166 const std::vector<OpReqType>& req,
171 const std::vector<mx_uint>& grad_reqs,
172 const std::vector<NDArray*>& gradients);
174 std::vector<NDArray*>
Backward(
const std::vector<NDArray*>& outputs,
175 const std::vector<NDArray*>& ograds,
176 const std::vector<NDArray*>& variables,
177 bool is_train,
bool retain_graph,
187 void GetBackwardDependency(
188 const nnvm::NodePtr& node,
189 uint32_t num_inputs, uint32_t num_outputs,
190 std::vector<bool> *p_save_inputs,
191 std::vector<bool> *p_save_outputs);
193 const bool retain_graph,
194 const nnvm::IndexedGraph& idx,
195 const std::vector<NDArray*> arrays,
196 size_t node_start,
size_t node_end,
197 std::vector<OpReqType>&& array_reqs,
198 std::vector<uint32_t>&& ref_count,
199 std::vector<OpStatePtr> *p_states,
202 #if DMLC_CXX11_THREAD_LOCAL 203 static thread_local
bool is_train_;
204 static thread_local
bool is_recording_;
206 static MX_THREAD_LOCAL
bool is_train_;
207 static MX_THREAD_LOCAL
bool is_recording_;
210 std::atomic<uint64_t> node_count_{0};
212 std::atomic<uint64_t> variable_count_{0};
218 #endif // MXNET_IMPERATIVE_H_
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:140
static bool IsNone(const NDArray &arr)
Definition: imperative.h:68
uint32_t num_outputs()
Definition: imperative.h:84
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:63
std::vector< DispatchMode > DispatchModeVector
The result holder of dispatch mode of each Node in the graph.
Definition: graph_attr_types.h:60
bool is_training() const
whether operator recording is on.
Definition: imperative.h:130
no operation, do not write anything
Definition: op_attr_types.h:47
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:134
namespace of mxnet
Definition: base.h:127
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:52
Additional operator attributes beside the ones provided by NNVM.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:49
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:105
std::vector< NDArray > outputs
Definition: imperative.h:45
Definition: imperative.h:40
uint32_t num_backward_inputs()
Definition: imperative.h:87
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:144
Definition: imperative.h:78
bool fresh_out_grad
Definition: imperative.h:47
OpStatePtr state
Definition: imperative.h:44
std::vector< NDArray > out_grads
Definition: imperative.h:46
Data structures that can appear in graph attributes.
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
std::shared_ptr< Imperative::CachedOp > CachedOpPtr
Definition: imperative.h:215
OpReqType grad_req
Definition: imperative.h:43
Context ctx
Definition: imperative.h:42
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::vector< bool > & save_inputs()
Definition: imperative.h:90
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:72
runtime functions for NDArray
Definition: imperative.h:37
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:59
const std::unordered_set< uint32_t > & mutable_input_nodes()
Definition: imperative.h:96
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
uint32_t num_inputs()
Definition: imperative.h:81
Context information about the execution environment.
Definition: base.h:142
ndarray interface
Definition: ndarray.h:79
std::vector< bool > & save_outputs()
Definition: imperative.h:93
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:121