mxnet
operator_util.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
30 #ifndef MXNET_OPERATOR_UTIL_H_
31 #define MXNET_OPERATOR_UTIL_H_
32 
33 #ifdef _MSC_VER
34 #pragma warning(disable:4503) // disable warning: decorated name length exceeded.
35 #endif
36 
37 #include <dmlc/registry.h>
38 #include <dmlc/parameter.h>
39 #include <map>
40 #include <vector>
41 #include <string>
42 #include <utility>
43 #include "./base.h"
44 #include "./operator.h"
45 
46 #if DMLC_USE_CXX11
47 #include <functional>
48 #endif
49 
50 namespace mxnet {
52 namespace op {
57 };
58 
63 
68 
73 struct EnvArguments {
77  std::vector<std::pair<std::string, std::string> > kwargs;
79  std::vector<Resource> resource;
80 };
81 
90 typedef void (*SourceFunction)(const EnvArguments& env,
91  TBlob* ret,
92  OpReqType req,
93  RunContext ctx);
94 
101 
111 typedef void (*UnaryFunction)(const TBlob& src,
112  const EnvArguments& env,
113  TBlob* ret,
114  OpReqType req,
115  RunContext ctx);
123  const EnvArguments& env);
124 
133 typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad,
134  const EnvArguments& env,
135  TBlob* in_grad,
136  OpReqType req,
137  RunContext ctx);
147 typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad,
148  const OutputValue& out_value,
149  const EnvArguments& env,
150  TBlob* in_grad,
151  OpReqType req,
152  RunContext ctx);
162 typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad,
163  const Input0& in_data0,
164  const EnvArguments& env,
165  TBlob* in_grad,
166  OpReqType req,
167  RunContext ctx);
178 typedef void (*BinaryFunction)(const TBlob& lhs,
179  const TBlob& rhs,
180  const EnvArguments& env,
181  TBlob* ret,
182  OpReqType req,
183  RunContext ctx);
184 
193  const mxnet::TShape& rhs,
194  const EnvArguments& env);
206 typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad,
207  const EnvArguments& env,
208  TBlob* lhs_grad,
209  TBlob* rhs_grad,
210  OpReqType req_lhs_grad,
211  OpReqType req_rhs_grad,
212  RunContext ctx);
225 typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad,
226  const Input0& lhs,
227  const Input1& rhs,
228  const EnvArguments& env,
229  TBlob* lhs_grad,
230  TBlob* rhs_grad,
231  OpReqType req_lhs_grad,
232  OpReqType req_rhs_grad,
233  RunContext ctx);
234 
247 };
248 
253 };
254 
259 };
260 
263  public:
267  std::string name;
274  virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0;
282  virtual TSelf& set_enable_scalar(
283  bool enable_scalar,
284  SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0;
291  virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0;
298  virtual TSelf& set_resource_request(
299  const std::vector<ResourceRequest>& reqs) = 0;
306  virtual TSelf& set_resource_request(ResourceRequest req) = 0;
311  virtual TSelf& set_shape_function(SourceShapeFunction fshapeinfer) = 0;
317  virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0;
323  virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0;
330  virtual TSelf& set_function(
331  int dev_mask,
332  SourceFunction fsource,
333  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
341  virtual TSelf& set_function(
342  int dev_mask,
343  UnaryFunction funary,
344  SimpleOpInplaceOption inplace_in_out,
345  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
353  virtual TSelf& set_function(
354  int dev_mask,
355  BinaryFunction fbinary,
356  SimpleOpInplaceOption inplace_lhs_out,
357  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
364  virtual TSelf& set_gradient(int dev_mask,
365  UnaryGradFunctionT0 fgrad,
366  SimpleOpInplaceOption inplace_out_in_grad) = 0;
373  virtual TSelf& set_gradient(int dev_mask,
374  UnaryGradFunctionT1 fgrad,
375  SimpleOpInplaceOption inplace_out_in_grad) = 0;
382  virtual TSelf& set_gradient(int dev_mask,
383  UnaryGradFunctionT2 fgrad,
384  SimpleOpInplaceOption inplace_out_in_grad) = 0;
391  virtual TSelf& set_gradient(int dev_mask,
392  BinaryGradFunctionT0 fgrad,
393  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
400  virtual TSelf& set_gradient(int dev_mask,
401  BinaryGradFunctionT1 fgrad,
402  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
408  virtual TSelf& describe(const std::string &description) = 0;
415  virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo> &args) = 0;
417  virtual ~SimpleOpRegEntry() {}
418 };
419 
422  public:
428  SimpleOpRegEntry &__REGISTER_OR_FIND__(char const* name);
434  inline static const SimpleOpRegEntry *Find(const std::string &name) {
435  return Get()->fmap_.at(name);
436  }
438  static SimpleOpRegistry* Get();
439 
440  private:
441  // destructor
442  ~SimpleOpRegistry();
444  std::map<std::string, SimpleOpRegEntry*> fmap_;
445 };
446 
455 #define ASSIGN_DISPATCH(out, req, exp) \
456  { \
457  switch (req) { \
458  case kNullOp: \
459  break; \
460  case kWriteTo: \
461  case kWriteInplace: \
462  (out) = (exp); \
463  break; \
464  case kAddTo: \
465  (out) += (exp); \
466  break; \
467  default: \
468  LOG(FATAL) << "not reached"; \
469  } \
470  }
471 
475 #define MXNET_SPECIAL_MAX_NDIM 5
476 
477 
478 //--------------------------------------------------------------
479 // The following part are API Registration of Simple Operators
480 //--------------------------------------------------------------
498 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
499  static ::mxnet::op::SimpleOpRegEntry & \
500  __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
501  ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
502 
503 } // namespace op
504 } // namespace mxnet
505 #endif // MXNET_OPERATOR_UTIL_H_
Definition: operator_util.h:258
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:162
Gradient of output value.
Definition: operator_util.h:67
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:256
Definition: operator_util.h:252
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:178
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:111
namespace of mxnet
Definition: base.h:89
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:225
registry for TBlob functions
Definition: operator_util.h:421
The resources that can be requested by Operator.
Definition: resource.h:38
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:97
mxnet::TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape.
Definition: operator_util.h:100
std::string name
name of the operator
Definition: operator_util.h:267
in unary forward, allow inplace in with out
Definition: operator_util.h:240
execution time context. The information needed in runtime for actual execution.
Definition: base.h:337
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:75
mxnet::TShape(* BinaryShapeFunction)(const mxnet::TShape &lhs, const mxnet::TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:192
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:434
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:417
registry entry to register simple operators via functions.
Definition: operator_util.h:262
Second input to the function.
Definition: operator_util.h:62
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:242
super class of all gradient function argument
Definition: operator_util.h:54
TBlob data
The real data.
Definition: operator_util.h:56
do not allow inplace in arguments
Definition: operator_util.h:238
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:147
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:236
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:90
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:73
Ouput value of the function to the function.
Definition: operator_util.h:65
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:250
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:133
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:395
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops.
Definition: operator_util.h:206
First input to the function.
Definition: operator_util.h:60
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:77
Definition: operator_util.h:251
Definition: operator_util.h:257
in binary forward, allow inplace left operand with out
Definition: operator_util.h:244
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:246
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:79
mxnet::TShape(* UnaryShapeFunction)(const mxnet::TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:122
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:265
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:66