mxnet
imperative.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_IMPERATIVE_H_
21 #define MXNET_IMPERATIVE_H_
22 
23 #include <mxnet/op_attr_types.h>
24 #include <mxnet/graph_attr_types.h>
25 #include <mxnet/c_api.h>
26 #include <nnvm/symbolic.h>
27 #include <nnvm/op.h>
28 #include <nnvm/graph.h>
29 #include <vector>
30 #include <atomic>
31 #include <utility>
32 #include <string>
33 #include <unordered_map>
34 
35 #include "./ndarray.h"
36 
37 namespace mxnet {
39 class Imperative {
40  public:
42  class AGInfo {
43  public:
47  std::vector<NDArray> outputs;
48  std::vector<NDArray> out_grads;
50 
51  AGInfo() :
52  grad_req(kNullOp), fresh_out_grad(false) {}
53 
54  static void Clear(const nnvm::NodePtr& node) {
55  if (node == nullptr || node->info.empty()) return;
56  AGInfo& info = Get(node);
57  if (info.grad_req != kNullOp) return;
58  node->info.clear();
59  }
60 
61  static AGInfo& Get(const nnvm::NodePtr& node) {
62  return dmlc::get<AGInfo>(node->info);
63  }
64 
65  static AGInfo& Create(const nnvm::NodePtr& node) {
66  node->info.construct<AGInfo>();
67  return Get(node);
68  }
69 
70  static bool IsNone(const NDArray& arr) {
71  return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
72  }
73 
74  static bool IsVariable(const nnvm::NodePtr& node) {
75  AGInfo& info = Get(node);
76  return info.grad_req != kNullOp && info.outputs.size() == 1
77  && info.out_grads.size() == 1;
78  }
79  };
81  bool is_training() const {
82  return is_train_;
83  }
85  bool set_is_training(bool is_train) {
86  bool old = is_train_;
87  is_train_ = is_train;
88  return old;
89  }
91  bool is_recording() const {
92  return is_recording_;
93  }
96  bool old = is_recording_;
97  is_recording_ = is_recording;
98  return old;
99  }
101  void RecordOp(nnvm::NodeAttrs&& attrs,
102  const std::vector<NDArray*>& inputs,
103  const std::vector<NDArray*>& outputs,
104  const OpStatePtr& state = OpStatePtr(),
105  std::vector<bool>* p_save_inputs = nullptr,
106  std::vector<bool>* p_save_outputs = nullptr);
108  OpStatePtr Invoke(const Context& default_ctx,
109  const nnvm::NodeAttrs& attrs,
110  const std::vector<NDArray*>& inputs,
111  const std::vector<NDArray*>& outputs);
114  const nnvm::NodeAttrs& attrs,
115  const std::vector<NDArray*>& inputs,
116  const std::vector<NDArray*>& outputs,
117  const std::vector<OpReqType>& req,
118  const DispatchMode dispatch_mode,
121  void MarkVariables(const std::vector<NDArray*>& variables,
122  const std::vector<mx_uint>& grad_reqs,
123  const std::vector<NDArray*>& gradients);
125  std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
126  const std::vector<NDArray*>& ograds,
127  const std::vector<NDArray*>& variables,
128  bool is_train, bool retain_graph,
129  bool create_graph);
131  static Imperative* Get();
132 
133  private:
134  friend class NDArray;
136  Imperative() {
137  if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
138  backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
139  }
140  }
142  void GetBackwardDependency(
143  const nnvm::NodePtr& node,
144  uint32_t num_inputs, uint32_t num_outputs,
145  std::vector<bool> *p_save_inputs,
146  std::vector<bool> *p_save_outputs);
148 #if DMLC_CXX11_THREAD_LOCAL
149  static thread_local bool is_train_;
150  static thread_local bool is_recording_;
151 #else
152  static MX_THREAD_LOCAL bool is_train_;
153  static MX_THREAD_LOCAL bool is_recording_;
154 #endif
155 
156  std::atomic<uint64_t> node_count_{0};
158  std::atomic<uint64_t> variable_count_{0};
160  int backward_bulk_size_{0};
161 };
162 
163 } // namespace mxnet
164 #endif // MXNET_IMPERATIVE_H_
C API of mxnet.
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:91
static bool IsNone(const NDArray &arr)
Definition: imperative.h:70
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:65
bool is_training() const
whether operator recording is on.
Definition: imperative.h:81
no operation, do not write anything
Definition: op_attr_types.h:47
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:85
namespace of mxnet
Definition: base.h:118
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:54
Additional operator attributes beside the ones provided by NNVM.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:51
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:113
std::vector< NDArray > outputs
Definition: imperative.h:47
Definition: imperative.h:42
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:95
bool fresh_out_grad
Definition: imperative.h:49
OpStatePtr state
Definition: imperative.h:46
std::vector< NDArray > out_grads
Definition: imperative.h:48
Data structures that can appear in graph attributes.
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
OpReqType grad_req
Definition: imperative.h:45
Context ctx
Definition: imperative.h:44
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:74
runtime functions for NDArray
Definition: imperative.h:39
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:61
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
Context information about the execution environment.
Definition: base.h:133
ndarray interface
Definition: ndarray.h:82
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:129