mxnet
op_attr_types.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_OP_ATTR_TYPES_H_
27 #define MXNET_OP_ATTR_TYPES_H_
28 
29 #include <mshadow/tensor.h>
30 #include <nnvm/op_attr_types.h>
31 
32 #include <vector>
33 #include <functional>
34 
35 #include "./base.h"
36 #include "./ndarray.h"
37 #include "./engine.h"
38 #include "./resource.h"
39 
40 namespace mxnet {
41 
42 using nnvm::NodeAttrs;
43 
45 enum OpReqType {
58 };
59 
66 struct OpContext {
68  bool need_grad;
70  bool is_train;
76  std::vector<Resource> requested;
82  template<typename xpu>
83  inline mshadow::Stream<xpu>* get_stream() const {
84  return run_ctx.get_stream<xpu>();
85  }
86 #if MXNET_USE_CUDA
87 
92  return run_ctx.get_gpu_aux_stream();
93  }
94 #endif
95 };
96 
98 enum class ExecType {
100  kSync,
105  kAsync,
119 };
120 
122 enum class DispatchMode {
123  kUndefined = -1,
124  // dispatch on FCompute or FStatefulCompute
125  kFCompute,
126  // dispatch on FComputeEx or FStatefulComputeEx, if available
127  kFComputeEx,
128  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
130  // special dispatch mode for variables
131  kVariable,
132 };
133 
138 class OpStatePtr {
139  public:
140  /* \brief Create a OpStatePtr with state of type T.
141  * \param args Arguments passed to T's constructor.
142  */
143  template<typename T, typename... Args>
144  static OpStatePtr Create(Args&&... args) {
145  OpStatePtr ret;
146  auto state = new T(std::forward<Args>(args)...);
147  auto var = Engine::Get()->NewVariable();
148  ret.ptr_.reset(
149  new OpState(var, state),
150  [](OpState* p) {
151  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
152  delete reinterpret_cast<T*>(p->state);
153  delete p;
154  });
155 
156  return ret;
157  }
158  /* \brief Get engine variable associated with this state */
160  return ptr_->var;
161  }
162  /* \brief Get state of type T */
163  template<typename T>
164  T& get_state() const {
165  return *reinterpret_cast<T*>(ptr_->state);
166  }
167  /* \brief clear state */
168  void reset() {
169  ptr_.reset();
170  }
171  /* \brief checks whether the managed object is managed only by the current
172  OpStatePtr instance */
173  bool unique() const {
174  return ptr_.unique();
175  }
176  /* \brief Whether state is empty */
177  explicit operator bool() const {
178  return ptr_ ? true : false;
179  }
180 
181  private:
182  /* \brief state structure */
183  struct OpState {
184  engine::VarHandle var;
185  void* state;
186 
187  OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
188  OpState(const OpState& other) = delete;
189  OpState& operator=(const OpState& other) = delete;
190  };
191  /* \brief shared pointer to state */
192  std::shared_ptr<OpState> ptr_;
193 };
194 
207 using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
208  Context ctx,
209  const mxnet::ShapeVector& in_shape,
210  const std::vector<int>& in_type)>;
214 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
222 using FStatefulCompute = std::function<void (const OpStatePtr& state,
223  const OpContext& ctx,
224  const std::vector<TBlob>& inputs,
225  const std::vector<OpReqType>& req,
226  const std::vector<TBlob>& outputs)>;
234 using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
235  const OpContext& ctx,
236  const std::vector<NDArray>& inputs,
237  const std::vector<OpReqType>& req,
238  const std::vector<NDArray>& outputs)>;
245 using FResourceRequest = std::function<
246  std::vector<ResourceRequest> (const NodeAttrs& n)>;
255 using FResourceRequestEx = std::function<
256  std::vector<ResourceRequest> (const NodeAttrs& n,
257  const int dev_mask,
258  const DispatchMode dispatch_mode)>;
264 using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
265  const std::vector<NDArray>& inputs,
266  std::vector<NDArray>* outputs)>;
272 using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
273  const OpContext& ctx,
274  const std::vector<TBlob>& inputs,
275  const std::vector<OpReqType>& req,
276  const std::vector<TBlob>& outputs)>;
282 using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
283  const OpContext& ctx,
284  const std::vector<NDArray>& inputs,
285  const std::vector<OpReqType>& req,
286  const std::vector<NDArray>& outputs)>;
287 
294 using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
295  const int dev_mask,
296  DispatchMode* dispatch_mode,
297  std::vector<int>* in_attrs,
298  std::vector<int>* out_attrs)>;
299 
304 using FQuantizedOp = std::function<nnvm::NodePtr (const NodeAttrs& attrs)>;
305 
312 using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
313 
319 using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
320  size_t index)>;
321 
322 } // namespace mxnet
323 
324 #endif // MXNET_OP_ATTR_TYPES_H_
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it&#39;s content is mutab...
Definition: op_attr_types.h:226
void reset()
Definition: op_attr_types.h:168
Forward/Backward are synchronous calls.
Engine that schedules all the operations according to dependency.
no operation, do not write anything
Definition: op_attr_types.h:47
write gradient to provided space
Definition: op_attr_types.h:49
namespace of mxnet
Definition: base.h:89
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither.
Definition: op_attr_types.h:246
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:238
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:358
std::function< bool(const NodeAttrs &attrs)> FNeedRequantize
Register a function to determine if the output of a quantized operator needs to be requantized...
Definition: op_attr_types.h:312
Asynchronous function call.
bool is_train
whether it is training phase
Definition: op_attr_types.h:70
engine::VarHandle get_var() const
Definition: op_attr_types.h:159
Cross device copy operation, this is a special operator that indicates it will copy across devices...
execution time context. The information needed in runtime for actual execution.
Definition: base.h:337
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
base class of engine variables.
Definition: engine.h:44
Provides automatic coordination of an auxilary stream with a primary one. This object, upon construction, prepares an aux stream for use by syncing it with enqueued primary-stream work. Object destruction will sync again so future primary-stream work will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults simply: the primary stream will equal the aux stream and the syncs will be executed as nops. See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example.
Definition: base.h:302
T & get_state() const
Definition: op_attr_types.h:164
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:74
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:66
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:768
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:214
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
static OpStatePtr Create(Args &&...args)
Definition: op_attr_types.h:144
std::function< nnvm::NodePtr(const NodeAttrs &attrs)> FQuantizedOp
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:304
bool need_grad
whether there is a backward phase to compute gradients.
Definition: op_attr_types.h:68
SyncedGPUAuxStream get_gpu_aux_stream() const
get auxilary gpu stream auto-syncing object from Context
Definition: op_attr_types.h:91
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:266
Global resource allocation handling.
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const mxnet::ShapeVector &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state...
Definition: op_attr_types.h:210
perform an inplace write, This option only happen when Target shares memory with one of input argumen...
Definition: op_attr_types.h:55
A subgraph execution should happen in the main thread, instead of in the execution engine...
bool unique() const
Definition: op_attr_types.h:173
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Register an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:286
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:76
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:72
static Context CPU(int32_t dev_id=0)
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Register a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:276
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:73
SyncedGPUAuxStream get_gpu_aux_stream() const
get an RAII object that transparently handles the syncing of the auxiliary stream.
Definition: base.h:366
static Engine * Get()
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Register a storage and dispatch mode inference function based on storage types of the inputs and outp...
Definition: op_attr_types.h:298
std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)> FResourceRequestEx
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither. If an operator registers both ResourceRequestEx and ResourceRequest, ResourceRequest is ignored.
Definition: op_attr_types.h:258
add to the provided space
Definition: op_attr_types.h:57
std::function< bool(const NodeAttrs &attrs, size_t index)> FAvoidQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized...
Definition: op_attr_types.h:320
ExecType
the execution type of the operator
Definition: op_attr_types.h:98
Context information about the execution environment.
Definition: base.h:102
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:83
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:138