mxnet
operator_util.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
29 #ifndef MXNET_OPERATOR_UTIL_H_
30 #define MXNET_OPERATOR_UTIL_H_
31 
32 #ifdef _MSC_VER
33 #pragma warning(disable:4503) // disable warning: decorated name length exceeded.
34 #endif
35 
36 #include <dmlc/registry.h>
37 #include <dmlc/parameter.h>
38 #include <map>
39 #include <vector>
40 #include <string>
41 #include <utility>
42 #include "./base.h"
43 #include "./operator.h"
44 
45 #if DMLC_USE_CXX11
46 #include <functional>
47 #endif
48 
49 namespace mxnet {
51 namespace op {
56 };
57 
62 
67 
72 struct EnvArguments {
76  std::vector<std::pair<std::string, std::string> > kwargs;
78  std::vector<Resource> resource;
79 };
80 
89 typedef void (*SourceFunction)(const EnvArguments& env,
90  TBlob* ret,
91  OpReqType req,
92  RunContext ctx);
93 
99 typedef TShape (*SourceShapeFunction)(const EnvArguments& env);
100 
110 typedef void (*UnaryFunction)(const TBlob& src,
111  const EnvArguments& env,
112  TBlob* ret,
113  OpReqType req,
114  RunContext ctx);
121 typedef TShape (*UnaryShapeFunction)(const TShape& src,
122  const EnvArguments& env);
123 
132 typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad,
133  const EnvArguments& env,
134  TBlob* in_grad,
135  OpReqType req,
136  RunContext ctx);
146 typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad,
147  const OutputValue& out_value,
148  const EnvArguments& env,
149  TBlob* in_grad,
150  OpReqType req,
151  RunContext ctx);
161 typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad,
162  const Input0& in_data0,
163  const EnvArguments& env,
164  TBlob* in_grad,
165  OpReqType req,
166  RunContext ctx);
177 typedef void (*BinaryFunction)(const TBlob& lhs,
178  const TBlob& rhs,
179  const EnvArguments& env,
180  TBlob* ret,
181  OpReqType req,
182  RunContext ctx);
183 
191 typedef TShape (*BinaryShapeFunction)(const TShape& lhs,
192  const TShape& rhs,
193  const EnvArguments& env);
205 typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad,
206  const EnvArguments& env,
207  TBlob* lhs_grad,
208  TBlob* rhs_grad,
209  OpReqType req_lhs_grad,
210  OpReqType req_rhs_grad,
211  RunContext ctx);
224 typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad,
225  const Input0& lhs,
226  const Input1& rhs,
227  const EnvArguments& env,
228  TBlob* lhs_grad,
229  TBlob* rhs_grad,
230  OpReqType req_lhs_grad,
231  OpReqType req_rhs_grad,
232  RunContext ctx);
233 
246 };
247 
252 };
253 
258 };
259 
262  public:
266  std::string name;
273  virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0;
281  virtual TSelf& set_enable_scalar(
282  bool enable_scalar,
283  SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0;
290  virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0;
297  virtual TSelf& set_resource_request(
298  const std::vector<ResourceRequest>& reqs) = 0;
305  virtual TSelf& set_resource_request(ResourceRequest req) = 0;
310  virtual TSelf& set_shape_function(SourceShapeFunction fshapeinfer) = 0;
316  virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0;
322  virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0;
329  virtual TSelf& set_function(
330  int dev_mask,
331  SourceFunction fsource,
332  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
340  virtual TSelf& set_function(
341  int dev_mask,
342  UnaryFunction funary,
343  SimpleOpInplaceOption inplace_in_out,
344  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
352  virtual TSelf& set_function(
353  int dev_mask,
354  BinaryFunction fbinary,
355  SimpleOpInplaceOption inplace_lhs_out,
356  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
363  virtual TSelf& set_gradient(int dev_mask,
364  UnaryGradFunctionT0 fgrad,
365  SimpleOpInplaceOption inplace_out_in_grad) = 0;
372  virtual TSelf& set_gradient(int dev_mask,
373  UnaryGradFunctionT1 fgrad,
374  SimpleOpInplaceOption inplace_out_in_grad) = 0;
381  virtual TSelf& set_gradient(int dev_mask,
382  UnaryGradFunctionT2 fgrad,
383  SimpleOpInplaceOption inplace_out_in_grad) = 0;
390  virtual TSelf& set_gradient(int dev_mask,
391  BinaryGradFunctionT0 fgrad,
392  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
399  virtual TSelf& set_gradient(int dev_mask,
400  BinaryGradFunctionT1 fgrad,
401  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
407  virtual TSelf& describe(const std::string &description) = 0;
414  virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo> &args) = 0;
416  virtual ~SimpleOpRegEntry() {}
417 };
418 
421  public:
427  SimpleOpRegEntry &__REGISTER_OR_FIND__(char const* name);
433  inline static const SimpleOpRegEntry *Find(const std::string &name) {
434  return Get()->fmap_.at(name);
435  }
437  static SimpleOpRegistry* Get();
438 
439  private:
440  // destructor
441  ~SimpleOpRegistry();
443  std::map<std::string, SimpleOpRegEntry*> fmap_;
444 };
445 
454 #define ASSIGN_DISPATCH(out, req, exp) \
455  { \
456  switch (req) { \
457  case kNullOp: \
458  break; \
459  case kWriteTo: \
460  case kWriteInplace: \
461  (out) = (exp); \
462  break; \
463  case kAddTo: \
464  (out) += (exp); \
465  break; \
466  default: \
467  LOG(FATAL) << "not reached"; \
468  } \
469  }
470 
474 #define MXNET_SPECIAL_MAX_NDIM 5
475 
476 
477 //--------------------------------------------------------------
478 // The following part are API Registration of Simple Operators
479 //--------------------------------------------------------------
497 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
498  static ::mxnet::op::SimpleOpRegEntry & \
499  __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
500  ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
501 
502 } // namespace op
503 } // namespace mxnet
504 #endif // MXNET_OPERATOR_UTIL_H_
Definition: operator_util.h:257
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:161
Gradient of output value.
Definition: operator_util.h:66
TShape(* SourceShapeFunction)(const EnvArguments &env)
Shape inference function to get the correct shape.
Definition: operator_util.h:99
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:255
Definition: operator_util.h:251
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:177
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:110
TShape(* BinaryShapeFunction)(const TShape &lhs, const TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:191
namespace of mxnet
Definition: base.h:126
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:224
registry for TBlob functions
Definition: operator_util.h:420
The resources that can be requested by Operator.
Definition: resource.h:36
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:134
nnvm::TShape TShape
Shape data structure used to record shape information.
Definition: base.h:136
std::string name
name of the operator
Definition: operator_util.h:266
in unary forward, allow inplace in with out
Definition: operator_util.h:239
execution time context. The information needed in runtime for actual execution.
Definition: base.h:238
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:74
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:433
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:416
registry entry to register simple operators via functions.
Definition: operator_util.h:261
Second input to the function.
Definition: operator_util.h:61
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:241
super class of all gradient function argument
Definition: operator_util.h:53
TBlob data
The real data.
Definition: operator_util.h:55
do not allow inplace in arguments
Definition: operator_util.h:237
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:146
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:235
void(* SourceFunction)(const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
source function that generate output based on env The result container is pre-allocated with the corr...
Definition: operator_util.h:89
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:44
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:72
Ouput value of the function to the function.
Definition: operator_util.h:64
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:249
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:132
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops.
Definition: operator_util.h:205
First input to the function.
Definition: operator_util.h:59
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:76
Definition: operator_util.h:250
Definition: operator_util.h:256
in binary forward, allow inplace left operand with out
Definition: operator_util.h:243
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:245
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:78
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:264
tensor blob class that can be used to hold tensor of any dimension, any device and any data type...
Definition: tensor_blob.h:58
TShape(* UnaryShapeFunction)(const TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:121