26 #ifndef MXNET_OPERATOR_H_    27 #define MXNET_OPERATOR_H_    29 #include <dmlc/base.h>    30 #include <dmlc/json.h>    31 #include <dmlc/logging.h>    32 #include <dmlc/registry.h>    33 #include <nnvm/node.h>    71                        const std::vector<TBlob> &in_data,
    72                        const std::vector<OpReqType> &req,
    73                        const std::vector<TBlob> &out_data,
    74                        const std::vector<TBlob> &aux_states) = 0;
   104                         const std::vector<TBlob> &out_grad,
   105                         const std::vector<TBlob> &in_data,
   106                         const std::vector<TBlob> &out_data,
   107                         const std::vector<OpReqType> &req,
   108                         const std::vector<TBlob> &in_grad,
   109                         const std::vector<TBlob> &aux_states) {
   110     LOG(FATAL) << 
"Backward is not implemented";
   138   virtual void Init(
const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
   143   virtual std::map<std::string, std::string> GetParams() 
const = 0;
   167     return this->ListOutputs().size();
   201   virtual bool InferShape(std::vector<TShape> *in_shape,
   202                           std::vector<TShape> *out_shape,
   203                           std::vector<TShape> *aux_shape) 
const = 0;
   222                           std::vector<int> *out_type,
   223                           std::vector<int> *aux_type)
 const {
   224     CHECK_LE(in_type->size(), this->ListArguments().size());
   225     int n_in = this->ListArguments().size();
   226     for (
unsigned i = 0; i < in_type->size(); ++i) {
   227       CHECK(in_type->at(i) == mshadow::default_type_flag ||
   228             in_type->at(i) == -1) << 
"Unsupported data type " << in_type->at(i);
   231     for (
int i = 0; i < n_in; ++i ) in_type->push_back(mshadow::default_type_flag);
   233     int n_out = this->ListOutputs().size();
   235     for (
int i = 0; i < n_out; ++i ) out_type->push_back(mshadow::default_type_flag);
   237     int n_aux = this->ListAuxiliaryStates().size();
   239     for (
int i = 0; i < n_aux; ++i ) aux_type->push_back(mshadow::default_type_flag);
   259                                      std::vector<int> *in_type)
 const {
   260     std::vector<int> out_type, aux_type;
   261     std::vector<TShape> out_shape, aux_shape;
   262     out_type.resize(this->ListOutputs().size());
   263     out_shape.resize(this->ListOutputs().size());
   264     aux_type.resize(this->ListAuxiliaryStates().size());
   265     aux_shape.resize(this->ListAuxiliaryStates().size());
   266     CHECK(InferType(in_type, &out_type, &aux_type));
   267     CHECK(InferShape(in_shape, &out_shape, &aux_shape));
   268     return CreateOperator(ctx);
   275   virtual std::string TypeString() 
const = 0;
   287       const std::vector<TShape> &in_shape)
 const {
   288     return std::vector<ResourceRequest>();
   298       const std::vector<TShape> &in_shape)
 const {
   299     return std::vector<ResourceRequest>();
   324       const std::vector<int> &out_grad,
   325       const std::vector<int> &in_data,
   326       const std::vector<int> &out_data)
 const {
   329     std::vector<int> ret = out_grad;
   330     ret.insert(ret.end(), in_data.begin(), in_data.end());
   331     ret.insert(ret.end(), out_data.begin(), out_data.end());
   356       const std::vector<int> &in_data,
   357       const std::vector<void*> &out_data)
 const {
   358     return std::vector<std::pair<int, void*> >();
   387       const std::vector<int> &out_grad,
   388       const std::vector<int> &in_data,
   389       const std::vector<int> &out_data,
   390       const std::vector<void*> &in_grad)
 const {
   391     return std::vector<std::pair<int, void*> >();
   407                                        const std::vector<T> &in_data,
   408                                        const std::vector<T> &out_data)
 const {
   410     std::vector<int> out_grad_index(out_grad.size());
   411     std::vector<int> in_data_index(in_data.size());
   412     std::vector<int> out_data_index(out_data.size());
   413     for (
size_t i = 0; i < out_grad_index.size(); ++i) {
   414       out_grad_index[i] = counter++;
   416     for (
size_t i = 0; i < in_data_index.size(); ++i) {
   417       in_data_index[i] = counter++;
   419     for (
size_t i = 0; i < out_data_index.size(); ++i) {
   420       out_data_index[i] = counter++;
   422     std::vector<T> all_data;
   423     all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
   424     all_data.insert(all_data.end(), in_data.begin(), in_data.end());
   425     all_data.insert(all_data.end(), out_data.begin(), out_data.end());
   427     std::vector<int> ret_index = this->DeclareBackwardDependency(
   428         out_grad_index, in_data_index, out_data_index);
   430     std::vector<T> ret(ret_index.size());
   431     for (
size_t i = 0; i < ret_index.size(); ++i) {
   432       ret[i] = all_data[ret_index[i]];
   454     : 
public dmlc::FunctionRegEntryBase<OperatorPropertyReg,
   455                                         OperatorPropertyFactory> {
   469     this->key_var_num_args = key;
   479     CHECK_EQ(this->name, type)
   480         << 
"Register Name and TypeString mismatch, name=\"" << this->name << 
"\","   481         << 
" but TypeString=\"" << type <<
"\"";
   503 #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType)          \   504   DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \   505   .set_body([]() { return new OperatorPropertyType(); })                \   506   .set_return_type("NDArray-or-Symbol") \   509 #endif  // DMLC_USE_CXX11   511 #endif  // MXNET_OPERATOR_H_ virtual std::vector< int > DeclareBackwardDependency(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data) const 
Declare the input requirement of Backward pass. 
Definition: operator.h:323
OperatorPropertyReg & check_name()
Check if TypeString of the type matches the registered name. 
Definition: operator.h:475
Forward/Backward are synchronize calls. 
virtual ExecType exec_type() const 
Definition: operator.h:443
namespace of mxnet 
Definition: base.h:118
virtual void Forward(const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0
perform a forward operation of Operator, save the output to TBlob. 
virtual void Backward(const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states)
Perform a Backward Operation, write gradient to the in_grad. 
Definition: operator.h:103
Additional operator attributes beside the ones provided by NNVM. 
virtual ~Operator()
destructor 
Definition: operator.h:58
std::function< OperatorProperty *()> OperatorPropertyFactory
typedef the factory function of operator property 
Definition: operator.h:449
std::vector< T > BackwardInputs(const std::vector< T > &out_grad, const std::vector< T > &in_data, const std::vector< T > &out_data) const 
Get Backward Input Dependency for generic types of data. Normally T can be pointer of Symbol::DataEnt...
Definition: operator.h:406
virtual std::vector< ResourceRequest > BackwardResource(const std::vector< TShape > &in_shape) const 
Declare additional resource required in backward pass. These additional resources will be presented i...
Definition: operator.h:297
OperatorPropertyReg & set_key_var_num_args(const std::string &key)
Set key_var_num_args When this is set, the API caller is required to pass in a argument with key=key_...
Definition: operator.h:468
virtual ~OperatorProperty()
virtual destructor 
Definition: operator.h:132
virtual std::vector< std::string > ListOutputs() const 
Get name of output values of Operator. 
Definition: operator.h:155
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:66
Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet...
Definition: operator.h:55
Global resource allocation handling. 
virtual std::vector< std::pair< int, void * > > ForwardInplaceOption(const std::vector< int > &in_data, const std::vector< void * > &out_data) const 
Get possible forward inplace options. This function enables optimization to reuse memory of inputs in...
Definition: operator.h:355
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:127
virtual Operator * CreateOperatorEx(Context ctx, std::vector< TShape > *in_shape, std::vector< int > *in_type) const 
Create a Operator on specific context and input shape/type. 
Definition: operator.h:258
virtual std::vector< std::pair< int, void * > > BackwardInplaceOption(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data, const std::vector< void * > &in_grad) const 
Get possible backward inplace options. This function enables optimization to reuse memory of inputs i...
Definition: operator.h:386
virtual bool InferType(std::vector< int > *in_type, std::vector< int > *out_type, std::vector< int > *aux_type) const 
infer the data types of outputs and unknown input arguments 
Definition: operator.h:221
virtual std::vector< std::string > ListAuxiliaryStates() const 
Get name of auxiliary states of Operator. 
Definition: operator.h:162
virtual std::vector< std::string > ListArguments() const 
Get input arguments of the Operator. 
Definition: operator.h:148
virtual ExecType exec_type() const final
Definition: operator.h:113
std::string key_var_num_args
The key num_args name. 
Definition: operator.h:486
virtual int NumVisibleOutputs() const 
get number of visible return values during Symbol creation. If NumVisibleOutputs() = k...
Definition: operator.h:181
Registry entry for OperatorProperty factory functions. 
Definition: operator.h:453
virtual std::vector< ResourceRequest > ForwardResource(const std::vector< TShape > &in_shape) const 
Declare additional resource required in forward pass. These additional resources will be presented in...
Definition: operator.h:286
ExecType
the execution type of the operator 
Definition: op_attr_types.h:87
virtual int NumOutputs() const 
Definition: operator.h:166
Context information about the execution environment. 
Definition: base.h:133
virtual std::string TypeString() const =0
return the type string of the Operator subclasses override this function.