Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet. This interface relies on pre-allocated memory in TBlob, the caller need to set the memory region in TBlob correctly before calling Forward and Backward.
More...
#include <operator.h>
|
virtual | ~Operator () |
| destructor More...
|
|
virtual void | Forward (const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0 |
| perform a forward operation of Operator, save the output to TBlob. More...
|
|
virtual void | Backward (const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states) |
| Perform a Backward Operation, write gradient to the in_grad. More...
|
|
virtual ExecType | exec_type () const final |
|
Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet. This interface relies on pre-allocated memory in TBlob, the caller need to set the memory region in TBlob correctly before calling Forward and Backward.
Operator is generated by OperatorProperty. To add new operator(aka. layers of neural nets) to mxnet, developer need to create a new OperatorProperty and its corresponding Operator.
- See also
- TBlob, TShape, OperatorProperty
virtual mxnet::Operator::~Operator |
( |
| ) |
|
|
inlinevirtual |
virtual void mxnet::Operator::Backward |
( |
const OpContext & |
ctx, |
|
|
const std::vector< TBlob > & |
out_grad, |
|
|
const std::vector< TBlob > & |
in_data, |
|
|
const std::vector< TBlob > & |
out_data, |
|
|
const std::vector< OpReqType > & |
req, |
|
|
const std::vector< TBlob > & |
in_grad, |
|
|
const std::vector< TBlob > & |
aux_states |
|
) |
| |
|
inlinevirtual |
Perform a Backward Operation, write gradient to the in_grad.
- Note
- Convention: out_grad.size() == OperatorProperty.NumVisibleOutputs() out_data.size() == OperatorProperty.NumOutputs() out_data can contain additional invisible returns that remembers the state carried from the Forward pass. For example mask in the dropout. The gradients are passed from visible returns in this function.
- Not all the TBlobs in the arguments will be available if you override the DeclareBackwardDependency of corresponding OperatorProperty class. Only the dependencies you declared will be available at corresponding position, the rest of the parameters are simply dummy where you will get a nullptr. You will be safe if you use the default DeclareBackwardDependency. But only declare what you need will give engine more chance for optimization.
- Parameters
-
ctx | runtime context available to this call |
out_grad | the gradient value we get from of the Operator. |
in_data | the array of input data. |
out_data | the array of output data. |
req | request types of the saving operation, can be all types. |
in_grad | the array of gradient we need to write to. |
aux_states | Auxiliary states of operator. Normally operator doesn't need |
- See also
- OperatorProperty, OpReqType, OpContext
virtual ExecType mxnet::Operator::exec_type |
( |
| ) |
const |
|
inlinefinalvirtual |
- Returns
- [Deprecated] execution type of the operator
virtual void mxnet::Operator::Forward |
( |
const OpContext & |
ctx, |
|
|
const std::vector< TBlob > & |
in_data, |
|
|
const std::vector< OpReqType > & |
req, |
|
|
const std::vector< TBlob > & |
out_data, |
|
|
const std::vector< TBlob > & |
aux_states |
|
) |
| |
|
pure virtual |
perform a forward operation of Operator, save the output to TBlob.
- Parameters
-
ctx | runtime context available to this call |
in_data | array of input data, it is const |
req | the request types of saving operation, can only be kWriteTo or kWriteInplace. |
out_data | array of output data, pointer is used to indicate that this is holder the space of TBlob in out_data must be pre-allocated with InferShape |
aux_states | Auxiliary states of operator. Normally operator doesn't need, epecial case like Batch Norm requires. |
- See also
- OpReqType, OpContext
The documentation for this class was generated from the following file: