29 #ifndef MXNET_OPERATOR_UTIL_H_ 30 #define MXNET_OPERATOR_UTIL_H_ 33 #pragma warning(disable:4503) // disable warning: decorated name length exceeded. 36 #include <dmlc/registry.h> 37 #include <dmlc/parameter.h> 76 std::vector<std::pair<std::string, std::string> >
kwargs;
273 virtual TSelf& set_symbol_op_name(
char const* symbol_name) = 0;
281 virtual TSelf& set_enable_scalar(
290 virtual TSelf& set_enable_kwargs(
bool enable_kwargs) = 0;
297 virtual TSelf& set_resource_request(
298 const std::vector<ResourceRequest>& reqs) = 0;
329 virtual TSelf& set_function(
340 virtual TSelf& set_function(
352 virtual TSelf& set_function(
363 virtual TSelf& set_gradient(
int dev_mask,
372 virtual TSelf& set_gradient(
int dev_mask,
381 virtual TSelf& set_gradient(
int dev_mask,
390 virtual TSelf& set_gradient(
int dev_mask,
399 virtual TSelf& set_gradient(
int dev_mask,
407 virtual TSelf& describe(
const std::string &description) = 0;
414 virtual TSelf& add_arguments(
const std::vector<dmlc::ParamFieldInfo> &args) = 0;
434 return Get()->fmap_.at(name);
443 std::map<std::string, SimpleOpRegEntry*> fmap_;
454 #define ASSIGN_DISPATCH(out, req, exp) \ 460 case kWriteInplace: \ 467 LOG(FATAL) << "not reached"; \ 474 #define MXNET_SPECIAL_MAX_NDIM 5 497 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \ 498 static ::mxnet::op::SimpleOpRegEntry & \ 499 __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \ 500 ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name) 504 #endif // MXNET_OPERATOR_UTIL_H_ Definition: operator_util.h:257
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:161
Gradient of output value.
Definition: operator_util.h:66
TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape.
Definition: operator_util.h:99
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:255
Definition: operator_util.h:251
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:177
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:110
TShape(* BinaryShapeFunction)(const TShape &lhs, const TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:191
namespace of mxnet
Definition: base.h:126
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:224
registry for TBlob functions
Definition: operator_util.h:420
The resources that can be requested by Operator.
Definition: resource.h:36
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:134
nnvm::TShape TShape
Shape data structure used to record shape information.
Definition: base.h:136
std::string name
name of the operator
Definition: operator_util.h:266
in unary forward, allow inplace in with out
Definition: operator_util.h:239
execution time context. The information needed in runtime for actual execution.
Definition: base.h:238
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:74
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:433
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:416
registry entry to register simple operators via functions.
Definition: operator_util.h:261
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:241
super class of all gradient function argument
Definition: operator_util.h:53
TBlob data
The real data.
Definition: operator_util.h:55
do not allow inplace in arguments
Definition: operator_util.h:237
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:146
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:235
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:89
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:72
Ouput value of the function to the function.
Definition: operator_util.h:64
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:249
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:132
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops.
Definition: operator_util.h:205
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:76
Definition: operator_util.h:250
Definition: operator_util.h:256
in binary forward, allow inplace left operand with out
Definition: operator_util.h:243
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:245
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:78
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:264
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:58
TShape(* UnaryShapeFunction)(const TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:121