30 #ifndef MXNET_OPERATOR_UTIL_H_    31 #define MXNET_OPERATOR_UTIL_H_    34 #pragma warning(disable:4503)  // disable warning: decorated name length exceeded.    37 #include <dmlc/registry.h>    38 #include <dmlc/parameter.h>    77   std::vector<std::pair<std::string, std::string> > 
kwargs;
   274   virtual TSelf& set_symbol_op_name(
char const* symbol_name) = 0;
   282   virtual TSelf& set_enable_scalar(
   291   virtual TSelf& set_enable_kwargs(
bool enable_kwargs) = 0;
   298   virtual TSelf& set_resource_request(
   299       const std::vector<ResourceRequest>& reqs) = 0;
   330   virtual TSelf& set_function(
   341   virtual TSelf& set_function(
   353   virtual TSelf& set_function(
   364   virtual TSelf& set_gradient(
int dev_mask,
   373   virtual TSelf& set_gradient(
int dev_mask,
   382   virtual TSelf& set_gradient(
int dev_mask,
   391   virtual TSelf& set_gradient(
int dev_mask,
   400   virtual TSelf& set_gradient(
int dev_mask,
   408   virtual TSelf& describe(
const std::string &description) = 0;
   415   virtual TSelf& add_arguments(
const std::vector<dmlc::ParamFieldInfo> &args) = 0;
   435     return Get()->fmap_.at(name);
   444   std::map<std::string, SimpleOpRegEntry*> fmap_;
   455 #define ASSIGN_DISPATCH(out, req, exp)  \   461       case kWriteInplace:               \   468         LOG(FATAL) << "not reached";    \   475 #define MXNET_SPECIAL_MAX_NDIM 5   498 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV)                             \   499   static ::mxnet::op::SimpleOpRegEntry &                                \   500   __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ =          \   501       ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)   505 #endif  // MXNET_OPERATOR_UTIL_H_ Definition: operator_util.h:258
 
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:162
 
Gradient of output value. 
Definition: operator_util.h:67
 
TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape. 
Definition: operator_util.h:100
 
SimpleOpRegOption
options in the registry to set symbolic registration 
Definition: operator_util.h:256
 
Definition: operator_util.h:252
 
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:178
 
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:111
 
TShape(* BinaryShapeFunction)(const TShape &lhs, const TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes. 
Definition: operator_util.h:192
 
namespace of mxnet 
Definition: base.h:127
 
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input. 
Definition: operator_util.h:225
 
registry for TBlob functions 
Definition: operator_util.h:421
 
The resources that can be requested by Operator. 
Definition: resource.h:38
 
mshadow::default_real_t real_t
data type that will be used to store ndarray 
Definition: base.h:135
 
nnvm::TShape TShape
Shape data structure used to record shape information. 
Definition: base.h:137
 
std::string name
name of the operator 
Definition: operator_util.h:267
 
in unary forward, allow inplace in with out 
Definition: operator_util.h:240
 
execution time context. The information needed in runtime for actual execution. 
Definition: base.h:253
 
real_t scalar
scalar argument, if enabled 
Definition: operator_util.h:75
 
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name. 
Definition: operator_util.h:434
 
virtual ~SimpleOpRegEntry()
virtual destructor 
Definition: operator_util.h:417
 
registry entry to register simple operators via functions. 
Definition: operator_util.h:262
 
in unary backward, allow inplace out_grad with in_grad 
Definition: operator_util.h:242
 
super class of all gradient function argument 
Definition: operator_util.h:54
 
TBlob data
The real data. 
Definition: operator_util.h:56
 
do not allow inplace in arguments 
Definition: operator_util.h:238
 
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:147
 
SimpleOpInplaceOption
options in the registry to set inplace of operator 
Definition: operator_util.h:236
 
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:90
 
OpReqType
operation request type to Forward and Backward 
Definition: op_attr_types.h:45
 
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:73
 
Ouput value of the function to the function. 
Definition: operator_util.h:65
 
SimpleOpScalarOption
options in the registry to set symbolic registration 
Definition: operator_util.h:250
 
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:133
 
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops. 
Definition: operator_util.h:206
 
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments 
Definition: operator_util.h:77
 
Definition: operator_util.h:251
 
Definition: operator_util.h:257
 
in binary forward, allow inplace left operand with out 
Definition: operator_util.h:244
 
in binary backward, allow inplace out_grad with lhs_grad 
Definition: operator_util.h:246
 
std::vector< Resource > resource
pointer to the resources requested 
Definition: operator_util.h:79
 
SimpleOpRegEntry TSelf
declare self type 
Definition: operator_util.h:265
 
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:59
 
TShape(* UnaryShapeFunction)(const TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source. 
Definition: operator_util.h:122