mxnet
operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_OPERATOR_H_
27 #define MXNET_CPP_OPERATOR_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/op_map.h"
34 #include "mxnet-cpp/symbol.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 class Mxnet;
42 class Operator {
43  public:
48  explicit Operator(const std::string &operator_name);
49  Operator &operator=(const Operator &rhs);
56  template <typename T>
57  Operator &SetParam(const std::string &name, const T &value) {
58  std::string value_str;
59  std::stringstream ss;
60  ss << value;
61  ss >> value_str;
62 
63  params_[name] = value_str;
64  return *this;
65  }
72  template <typename T>
73  Operator &SetParam(int pos, const T &value) {
74  std::string value_str;
75  std::stringstream ss;
76  ss << value;
77  ss >> value_str;
78 
79  params_[arg_names_[pos]] = value_str;
80  return *this;
81  }
88  Operator &SetInput(const std::string &name, Symbol symbol);
93  template<int N = 0>
94  void PushInput(const Symbol &symbol) {
95  input_symbols_.push_back(symbol.GetHandle());
96  }
101  Operator &operator()() { return *this; }
107  Operator &operator()(const Symbol &symbol) {
108  input_symbols_.push_back(symbol.GetHandle());
109  return *this;
110  }
116  Operator &operator()(const std::vector<Symbol> &symbols) {
117  for (auto &s : symbols) {
118  input_symbols_.push_back(s.GetHandle());
119  }
120  return *this;
121  }
127  Symbol CreateSymbol(const std::string &name = "");
128 
135  Operator &SetInput(const std::string &name, NDArray ndarray);
140  template<int N = 0>
141  Operator &PushInput(const NDArray &ndarray) {
142  input_ndarrays_.push_back(ndarray.GetHandle());
143  return *this;
144  }
148  template <class T, class... Args, int N = 0>
149  Operator &PushInput(const T &t, Args... args) {
150  SetParam(N, t);
151  PushInput<Args..., N+1>(args...);
152  return *this;
153  }
157  template <class T, int N = 0>
158  Operator &PushInput(const T &t) {
159  SetParam(N, t);
160  return *this;
161  }
167  Operator &operator()(const NDArray &ndarray) {
168  input_ndarrays_.push_back(ndarray.GetHandle());
169  return *this;
170  }
176  Operator &operator()(const std::vector<NDArray> &ndarrays) {
177  for (auto &s : ndarrays) {
178  input_ndarrays_.push_back(s.GetHandle());
179  }
180  return *this;
181  }
186  template <typename... Args>
187  Operator &operator()(Args... args) {
188  PushInput(args...);
189  return *this;
190  }
191  std::vector<NDArray> Invoke();
192  void Invoke(NDArray &output);
193  void Invoke(std::vector<NDArray> &outputs);
194 
195  private:
196  std::map<std::string, std::string> params_desc_;
197  bool variable_params_ = false;
198  std::map<std::string, std::string> params_;
199  std::vector<SymbolHandle> input_symbols_;
200  std::vector<NDArrayHandle> input_ndarrays_;
201  std::vector<std::string> input_keys_;
202  std::vector<std::string> arg_names_;
203  AtomicSymbolCreator handle_;
204  static OpMap*& op_map();
205 };
206 } // namespace cpp
207 } // namespace mxnet
208 
209 #endif // MXNET_CPP_OPERATOR_H_
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:42
definition of symbol
definition of OpMap
namespace of mxnet
Definition: base.h:126
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
void PushInput(const Symbol &symbol)
add an input symbol
Definition: operator.h:94
Operator & operator()(const Symbol &symbol)
add input symbols
Definition: operator.h:107
Operator(const std::string &operator_name)
Operator constructor.
Operator & operator()(Args...args)
add input ndarrays
Definition: operator.h:187
NDArray interface.
Definition: ndarray.h:120
Operator & operator()(const std::vector< NDArray > &ndarrays)
add a list of input ndarrays
Definition: operator.h:176
Operator & PushInput(const NDArray &ndarray)
add an input ndarray
Definition: operator.h:141
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
NDArrayHandle GetHandle() const
Definition: ndarray.h:439
Operator & operator()(const NDArray &ndarray)
add input ndarrays
Definition: operator.h:167
Operator & SetParam(int pos, const T &value)
set config parameters from positional inputs
Definition: operator.h:73
Operator & operator=(const Operator &rhs)
Operator & operator()(const std::vector< Symbol > &symbols)
add a list of input symbols
Definition: operator.h:116
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:57
Operator & PushInput(const T &t, Args...args)
add positional inputs
Definition: operator.h:149
Operator & PushInput(const T &t)
add the last positional input
Definition: operator.h:158
Operator & operator()()
add input symbols
Definition: operator.h:101
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:68
SymbolHandle GetHandle() const
Definition: symbol.h:140
std::vector< NDArray > Invoke()
Operator interface.
Definition: operator.h:42
Symbol interface.
Definition: symbol.h:71