mxnet
Public Member Functions | List of all members
mxnet::Operator Class Referenceabstract

Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet. This interface relies on pre-allocated memory in TBlob, the caller need to set the memory region in TBlob correctly before calling Forward and Backward. More...

#include <operator.h>

Collaboration diagram for mxnet::Operator:
Collaboration graph

Public Member Functions

virtual ~Operator ()
 destructor More...
 
virtual void Forward (const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0
 perform a forward operation of Operator, save the output to TBlob. More...
 
virtual void Backward (const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states)
 Perform a Backward Operation, write gradient to the in_grad. More...
 
virtual ExecType exec_type () const final
 

Detailed Description

Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet. This interface relies on pre-allocated memory in TBlob, the caller need to set the memory region in TBlob correctly before calling Forward and Backward.

Operator is generated by OperatorProperty. To add new operator(aka. layers of neural nets) to mxnet, developer need to create a new OperatorProperty and its corresponding Operator.

See also
TBlob, TShape, OperatorProperty

Constructor & Destructor Documentation

virtual mxnet::Operator::~Operator ( )
inlinevirtual

destructor

Member Function Documentation

virtual void mxnet::Operator::Backward ( const OpContext ctx,
const std::vector< TBlob > &  out_grad,
const std::vector< TBlob > &  in_data,
const std::vector< TBlob > &  out_data,
const std::vector< OpReqType > &  req,
const std::vector< TBlob > &  in_grad,
const std::vector< TBlob > &  aux_states 
)
inlinevirtual

Perform a Backward Operation, write gradient to the in_grad.

Note
Convention: out_grad.size() == OperatorProperty.NumVisibleOutputs() out_data.size() == OperatorProperty.NumOutputs() out_data can contain additional invisible returns that remembers the state carried from the Forward pass. For example mask in the dropout. The gradients are passed from visible returns in this function.
Not all the TBlobs in the arguments will be available if you override the DeclareBackwardDependency of corresponding OperatorProperty class. Only the dependencies you declared will be available at corresponding position, the rest of the parameters are simply dummy where you will get a nullptr. You will be safe if you use the default DeclareBackwardDependency. But only declare what you need will give engine more chance for optimization.
Parameters
ctxruntime context available to this call
out_gradthe gradient value we get from of the Operator.
in_datathe array of input data.
out_datathe array of output data.
reqrequest types of the saving operation, can be all types.
in_gradthe array of gradient we need to write to.
aux_statesAuxiliary states of operator. Normally operator doesn't need
See also
OperatorProperty, OpReqType, OpContext
virtual ExecType mxnet::Operator::exec_type ( ) const
inlinefinalvirtual
Returns
[Deprecated] execution type of the operator
virtual void mxnet::Operator::Forward ( const OpContext ctx,
const std::vector< TBlob > &  in_data,
const std::vector< OpReqType > &  req,
const std::vector< TBlob > &  out_data,
const std::vector< TBlob > &  aux_states 
)
pure virtual

perform a forward operation of Operator, save the output to TBlob.

Parameters
ctxruntime context available to this call
in_dataarray of input data, it is const
reqthe request types of saving operation, can only be kWriteTo or kWriteInplace.
out_dataarray of output data, pointer is used to indicate that this is holder the space of TBlob in out_data must be pre-allocated with InferShape
aux_statesAuxiliary states of operator. Normally operator doesn't need, epecial case like Batch Norm requires.
See also
OpReqType, OpContext

The documentation for this class was generated from the following file: