mxnet
operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_OPERATOR_H_
26 #define MXNET_OPERATOR_H_
27 
28 #include <dmlc/base.h>
29 #include <dmlc/json.h>
30 #include <dmlc/logging.h>
31 #include <dmlc/registry.h>
32 #include <nnvm/node.h>
33 #include <vector>
34 #include <map>
35 #include <string>
36 #include <utility>
37 #include "./base.h"
38 #include "./resource.h"
39 #include "./op_attr_types.h"
40 
41 namespace mxnet {
54 class Operator {
55  public:
57  virtual ~Operator() {}
69  virtual void Forward(const OpContext &ctx,
70  const std::vector<TBlob> &in_data,
71  const std::vector<OpReqType> &req,
72  const std::vector<TBlob> &out_data,
73  const std::vector<TBlob> &aux_states) = 0;
102  virtual void Backward(const OpContext &ctx,
103  const std::vector<TBlob> &out_grad,
104  const std::vector<TBlob> &in_data,
105  const std::vector<TBlob> &out_data,
106  const std::vector<OpReqType> &req,
107  const std::vector<TBlob> &in_grad,
108  const std::vector<TBlob> &aux_states) {
109  LOG(FATAL) << "Backward is not implemented";
110  }
112  virtual ExecType exec_type() const final { // NOLINT(*) exec_type has been moved to OperatorProperty
113  return ExecType::kSync;
114  }
115 };
116 
117 #if DMLC_USE_CXX11
118 // OperatorProperty allows C++11, while Operator do not rely on it.
127  public:
131  virtual ~OperatorProperty() {}
137  virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
142  virtual std::map<std::string, std::string> GetParams() const = 0;
147  virtual std::vector<std::string> ListArguments() const {
148  return {"data"};
149  }
154  virtual std::vector<std::string> ListOutputs() const {
155  return {"output"};
156  }
161  virtual std::vector<std::string> ListAuxiliaryStates() const {
162  return {};
163  }
165  virtual int NumOutputs() const {
166  return this->ListOutputs().size();
167  }
180  virtual int NumVisibleOutputs() const {
181  return NumOutputs();
182  }
200  virtual bool InferShape(std::vector<TShape> *in_shape,
201  std::vector<TShape> *out_shape,
202  std::vector<TShape> *aux_shape) const = 0;
220  virtual bool InferType(std::vector<int> *in_type,
221  std::vector<int> *out_type,
222  std::vector<int> *aux_type) const {
223  CHECK_LE(in_type->size(), this->ListArguments().size());
224  int n_in = this->ListArguments().size();
225  for (unsigned i = 0; i < in_type->size(); ++i) {
226  CHECK(in_type->at(i) == mshadow::default_type_flag ||
227  in_type->at(i) == -1) << "Unsupported data type " << in_type->at(i);
228  }
229  in_type->clear();
230  for (int i = 0; i < n_in; ++i ) in_type->push_back(mshadow::default_type_flag);
231 
232  int n_out = this->ListOutputs().size();
233  out_type->clear();
234  for (int i = 0; i < n_out; ++i ) out_type->push_back(mshadow::default_type_flag);
235 
236  int n_aux = this->ListAuxiliaryStates().size();
237  aux_type->clear();
238  for (int i = 0; i < n_aux; ++i ) aux_type->push_back(mshadow::default_type_flag);
239  return true;
240  }
245  virtual OperatorProperty* Copy() const = 0;
249  virtual Operator* CreateOperator(Context ctx) const = 0;
257  virtual Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
258  std::vector<int> *in_type) const {
259  std::vector<int> out_type, aux_type;
260  std::vector<TShape> out_shape, aux_shape;
261  out_type.resize(this->ListOutputs().size());
262  out_shape.resize(this->ListOutputs().size());
263  aux_type.resize(this->ListAuxiliaryStates().size());
264  aux_shape.resize(this->ListAuxiliaryStates().size());
265  CHECK(InferType(in_type, &out_type, &aux_type));
266  CHECK(InferShape(in_shape, &out_shape, &aux_shape));
267  return CreateOperator(ctx);
268  }
274  virtual std::string TypeString() const = 0;
275  //--------------------------------------------------------
276  // All the below functions are optional to override.
277  //--------------------------------------------------------
285  virtual std::vector<ResourceRequest> ForwardResource(
286  const std::vector<TShape> &in_shape) const {
287  return std::vector<ResourceRequest>();
288  }
296  virtual std::vector<ResourceRequest> BackwardResource(
297  const std::vector<TShape> &in_shape) const {
298  return std::vector<ResourceRequest>();
299  }
322  virtual std::vector<int> DeclareBackwardDependency(
323  const std::vector<int> &out_grad,
324  const std::vector<int> &in_data,
325  const std::vector<int> &out_data) const {
326  // By default requires to see all the things.
327  // remember to override this function to get a better performance.
328  std::vector<int> ret = out_grad;
329  ret.insert(ret.end(), in_data.begin(), in_data.end());
330  ret.insert(ret.end(), out_data.begin(), out_data.end());
331  return ret;
332  }
354  virtual std::vector<std::pair<int, void*> > ForwardInplaceOption(
355  const std::vector<int> &in_data,
356  const std::vector<void*> &out_data) const {
357  return std::vector<std::pair<int, void*> >();
358  }
385  virtual std::vector<std::pair<int, void*> > BackwardInplaceOption(
386  const std::vector<int> &out_grad,
387  const std::vector<int> &in_data,
388  const std::vector<int> &out_data,
389  const std::vector<void*> &in_grad) const {
390  return std::vector<std::pair<int, void*> >();
391  }
404  template<typename T>
405  inline std::vector<T> BackwardInputs(const std::vector<T> &out_grad,
406  const std::vector<T> &in_data,
407  const std::vector<T> &out_data) const {
408  int counter = 0;
409  std::vector<int> out_grad_index(out_grad.size());
410  std::vector<int> in_data_index(in_data.size());
411  std::vector<int> out_data_index(out_data.size());
412  for (size_t i = 0; i < out_grad_index.size(); ++i) {
413  out_grad_index[i] = counter++;
414  }
415  for (size_t i = 0; i < in_data_index.size(); ++i) {
416  in_data_index[i] = counter++;
417  }
418  for (size_t i = 0; i < out_data_index.size(); ++i) {
419  out_data_index[i] = counter++;
420  }
421  std::vector<T> all_data;
422  all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
423  all_data.insert(all_data.end(), in_data.begin(), in_data.end());
424  all_data.insert(all_data.end(), out_data.begin(), out_data.end());
425 
426  std::vector<int> ret_index = this->DeclareBackwardDependency(
427  out_grad_index, in_data_index, out_data_index);
428 
429  std::vector<T> ret(ret_index.size());
430  for (size_t i = 0; i < ret_index.size(); ++i) {
431  ret[i] = all_data[ret_index[i]];
432  }
433  return ret;
434  }
440  static OperatorProperty *Create(const char* type_name);
442  virtual ExecType exec_type() const {
443  return ExecType::kSync;
444  }
445 };
446 
448 typedef std::function<OperatorProperty *()> OperatorPropertyFactory;
453  : public dmlc::FunctionRegEntryBase<OperatorPropertyReg,
454  OperatorPropertyFactory> {
467  inline OperatorPropertyReg& set_key_var_num_args(const std::string &key) { // NOLINT(*)
468  this->key_var_num_args = key;
469  return *this;
470  }
475  OperatorProperty *p = this->body();
476  std::string type = p->TypeString();
477  delete p;
478  CHECK_EQ(this->name, type)
479  << "Register Name and TypeString mismatch, name=\"" << this->name << "\","
480  << " but TypeString=\"" << type <<"\"";
481  return *this;
482  }
483 
485  std::string key_var_num_args;
486 };
487 
488 //---------------------------------------------------------------------------------
489 // The following part are API Registration of Operators
490 // See also MXNET_REGISTER_SIMPLE_OP in operator_util.h for registering simple ops.
491 //---------------------------------------------------------------------------------
502 #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
503  DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
504  .set_body([]() { return new OperatorPropertyType(); }) \
505  .set_return_type("NDArray-or-Symbol") \
506  .check_name()
507 
508 #endif // DMLC_USE_CXX11
509 } // namespace mxnet
510 #endif // MXNET_OPERATOR_H_
virtual std::vector< int > DeclareBackwardDependency(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data) const
Declare the input requirement of Backward pass.
Definition: operator.h:322
OperatorPropertyReg & check_name()
Check if TypeString of the type matches the registered name.
Definition: operator.h:474
Forward/Backward are synchronize calls.
virtual ExecType exec_type() const
Definition: operator.h:442
namespace of mxnet
Definition: base.h:126
virtual void Forward(const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0
perform a forward operation of Operator, save the output to TBlob.
virtual void Backward(const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states)
Perform a Backward Operation, write gradient to the in_grad.
Definition: operator.h:102
Additional operator attributes beside the ones provided by NNVM.
virtual ~Operator()
destructor
Definition: operator.h:57
std::function< OperatorProperty *()> OperatorPropertyFactory
typedef the factory function of operator property
Definition: operator.h:448
std::vector< T > BackwardInputs(const std::vector< T > &out_grad, const std::vector< T > &in_data, const std::vector< T > &out_data) const
Get Backward Input Dependency for generic types of data. Normally T can be pointer of Symbol::DataEnt...
Definition: operator.h:405
virtual std::vector< ResourceRequest > BackwardResource(const std::vector< TShape > &in_shape) const
Declare additional resource required in backward pass. These additional resources will be presented i...
Definition: operator.h:296
OperatorPropertyReg & set_key_var_num_args(const std::string &key)
Set key_var_num_args When this is set, the API caller is required to pass in a argument with key=key_...
Definition: operator.h:467
virtual ~OperatorProperty()
virtual destructor
Definition: operator.h:131
virtual std::vector< std::string > ListOutputs() const
Get name of output values of Operator.
Definition: operator.h:154
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:65
Operator interface. Operator defines basic operation unit of optimized computation graph in mxnet...
Definition: operator.h:54
Global resource allocation handling.
virtual std::vector< std::pair< int, void * > > ForwardInplaceOption(const std::vector< int > &in_data, const std::vector< void * > &out_data) const
Get possible forward inplace options. This function enables optimization to reuse memory of inputs in...
Definition: operator.h:354
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:126
virtual Operator * CreateOperatorEx(Context ctx, std::vector< TShape > *in_shape, std::vector< int > *in_type) const
Create a Operator on specific context and input shape/type.
Definition: operator.h:257
virtual std::vector< std::pair< int, void * > > BackwardInplaceOption(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data, const std::vector< void * > &in_grad) const
Get possible backward inplace options. This function enables optimization to reuse memory of inputs i...
Definition: operator.h:385
virtual bool InferType(std::vector< int > *in_type, std::vector< int > *out_type, std::vector< int > *aux_type) const
infer the data types of outputs and unknown input arguments
Definition: operator.h:220
virtual std::vector< std::string > ListAuxiliaryStates() const
Get name of auxiliary states of Operator.
Definition: operator.h:161
virtual std::vector< std::string > ListArguments() const
Get input arguments of the Operator.
Definition: operator.h:147
virtual ExecType exec_type() const final
Definition: operator.h:112
std::string key_var_num_args
The key num_args name.
Definition: operator.h:485
virtual int NumVisibleOutputs() const
get number of visible return values during Symbol creation. If NumVisibleOutputs() = k...
Definition: operator.h:180
Registry entry for OperatorProperty factory functions.
Definition: operator.h:452
virtual std::vector< ResourceRequest > ForwardResource(const std::vector< TShape > &in_shape) const
Declare additional resource required in forward pass. These additional resources will be presented in...
Definition: operator.h:285
ExecType
the execution type of the operator
Definition: op_attr_types.h:86
virtual int NumOutputs() const
Definition: operator.h:165
Context information about the execution environment.
Definition: base.h:141
virtual std::string TypeString() const =0
return the type string of the Operator subclasses override this function.