26 #ifndef MXNET_OP_ATTR_TYPES_H_ 27 #define MXNET_OP_ATTR_TYPES_H_ 29 #include <mshadow/tensor.h> 30 #include <nnvm/op_attr_types.h> 42 using nnvm::NodeAttrs;
82 template<
typename xpu>
134 template<
typename T,
typename... Args>
137 auto state =
new T(std::forward<Args>(args)...);
140 new OpState(var, state),
143 delete reinterpret_cast<T*
>(p->state);
156 return *
reinterpret_cast<T*
>(ptr_->state);
165 return ptr_.unique();
168 explicit operator bool()
const {
169 return ptr_ ?
true :
false;
179 OpState(
const OpState& other) =
delete;
180 OpState& operator=(
const OpState& other) =
delete;
183 std::shared_ptr<OpState> ptr_;
200 const std::vector<TShape>& in_shape,
201 const std::vector<int>& in_type)>;
205 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
215 const std::vector<TBlob>& inputs,
216 const std::vector<OpReqType>& req,
217 const std::vector<TBlob>& outputs)>;
227 const std::vector<NDArray>& inputs,
228 const std::vector<OpReqType>& req,
229 const std::vector<NDArray>& outputs)>;
237 std::vector<ResourceRequest> (
const NodeAttrs& n)>;
245 std::vector<ResourceRequest> (
const NodeAttrs& n,
254 const std::vector<NDArray>& inputs,
255 std::vector<NDArray>* outputs)>;
261 using FCompute = std::function<void (
const nnvm::NodeAttrs& attrs,
263 const std::vector<TBlob>& inputs,
264 const std::vector<OpReqType>& req,
265 const std::vector<TBlob>& outputs)>;
271 using FComputeEx = std::function<void (
const nnvm::NodeAttrs& attrs,
273 const std::vector<NDArray>& inputs,
274 const std::vector<OpReqType>& req,
275 const std::vector<NDArray>& outputs)>;
286 std::vector<int>* in_attrs,
287 std::vector<int>* out_attrs)>;
293 using FQuantizedOp = std::function<nnvm::NodePtr (const NodeAttrs& attrs)>;
305 #endif // MXNET_OP_ATTR_TYPES_H_ std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it's content is mutab...
Definition: op_attr_types.h:217
void reset()
Definition: op_attr_types.h:159
Forward/Backward are synchronous calls.
Engine that schedules all the operations according to dependency.
no operation, do not write anything
Definition: op_attr_types.h:47
write gradient to provided space
Definition: op_attr_types.h:49
namespace of mxnet
Definition: base.h:118
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither.
Definition: op_attr_types.h:237
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:229
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:270
std::function< bool(const NodeAttrs &attrs)> FNeedRequantize
Register a function to determine if the output of a quantized operator needs to be requantized...
Definition: op_attr_types.h:301
Asynchronous function call.
bool is_train
whether it is training phase
Definition: op_attr_types.h:70
engine::VarHandle get_var() const
Definition: op_attr_types.h:150
Cross device copy operation, this is a special operator that indicates it will copy across devices...
execution time context. The information needed in runtime for actual execution.
Definition: base.h:257
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:113
T & get_state() const
Definition: op_attr_types.h:155
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:74
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const std::vector< TShape > &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state...
Definition: op_attr_types.h:201
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:66
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:205
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
static OpStatePtr Create(Args &&...args)
Definition: op_attr_types.h:135
std::function< nnvm::NodePtr(const NodeAttrs &attrs)> FQuantizedOp
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:293
bool need_grad
whether there is a backward phase to compute gradients.
Definition: op_attr_types.h:68
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:255
Global resource allocation handling.
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:47
perform an inplace write, This option only happen when Target shares memory with one of input argumen...
Definition: op_attr_types.h:55
A subgraph execution should happen in the main thread, instead of in the execution engine...
bool unique() const
Definition: op_attr_types.h:164
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Resiger an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:275
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:76
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:72
static Context CPU(int32_t dev_id=0)
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Resiger a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:265
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:56
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Resiger a storage and dispatch mode inference function based on storage types of the inputs and outpu...
Definition: op_attr_types.h:287
std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)> FResourceRequestEx
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither.
Definition: op_attr_types.h:247
add to the provided space
Definition: op_attr_types.h:57
ExecType
the execution type of the operator
Definition: op_attr_types.h:89
Context information about the execution environment.
Definition: base.h:133
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:83
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:129