30 #ifndef MXNET_OPERATOR_UTIL_H_ 31 #define MXNET_OPERATOR_UTIL_H_ 34 #pragma warning(disable:4503) // disable warning: decorated name length exceeded. 37 #include <dmlc/registry.h> 38 #include <dmlc/parameter.h> 77 std::vector<std::pair<std::string, std::string> >
kwargs;
274 virtual TSelf& set_symbol_op_name(
char const* symbol_name) = 0;
282 virtual TSelf& set_enable_scalar(
291 virtual TSelf& set_enable_kwargs(
bool enable_kwargs) = 0;
298 virtual TSelf& set_resource_request(
299 const std::vector<ResourceRequest>& reqs) = 0;
330 virtual TSelf& set_function(
341 virtual TSelf& set_function(
353 virtual TSelf& set_function(
364 virtual TSelf& set_gradient(
int dev_mask,
373 virtual TSelf& set_gradient(
int dev_mask,
382 virtual TSelf& set_gradient(
int dev_mask,
391 virtual TSelf& set_gradient(
int dev_mask,
400 virtual TSelf& set_gradient(
int dev_mask,
408 virtual TSelf& describe(
const std::string &description) = 0;
415 virtual TSelf& add_arguments(
const std::vector<dmlc::ParamFieldInfo> &args) = 0;
435 return Get()->fmap_.at(name);
444 std::map<std::string, SimpleOpRegEntry*> fmap_;
455 #define ASSIGN_DISPATCH(out, req, exp) \ 461 case kWriteInplace: \ 468 LOG(FATAL) << "not reached"; \ 475 #define MXNET_SPECIAL_MAX_NDIM 5 498 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \ 499 static ::mxnet::op::SimpleOpRegEntry & \ 500 __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \ 501 ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name) 505 #endif // MXNET_OPERATOR_UTIL_H_ Definition: operator_util.h:258
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:162
Gradient of output value.
Definition: operator_util.h:67
TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape.
Definition: operator_util.h:100
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:256
Definition: operator_util.h:252
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:178
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:111
TShape(* BinaryShapeFunction)(const TShape &lhs, const TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:192
namespace of mxnet
Definition: base.h:118
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:225
registry for TBlob functions
Definition: operator_util.h:421
The resources that can be requested by Operator.
Definition: resource.h:38
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:126
nnvm::TShape TShape
Shape data structure used to record shape information.
Definition: base.h:128
std::string name
name of the operator
Definition: operator_util.h:267
in unary forward, allow inplace in with out
Definition: operator_util.h:240
execution time context. The information needed in runtime for actual execution.
Definition: base.h:257
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:75
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:434
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:417
registry entry to register simple operators via functions.
Definition: operator_util.h:262
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:242
super class of all gradient function argument
Definition: operator_util.h:54
TBlob data
The real data.
Definition: operator_util.h:56
do not allow inplace in arguments
Definition: operator_util.h:238
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:147
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:236
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:90
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:73
Ouput value of the function to the function.
Definition: operator_util.h:65
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:250
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:133
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops.
Definition: operator_util.h:206
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:77
Definition: operator_util.h:251
Definition: operator_util.h:257
in binary forward, allow inplace left operand with out
Definition: operator_util.h:244
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:246
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:79
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:265
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66
TShape(* UnaryShapeFunction)(const TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:122