mxnet
operator.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_OPERATOR_H_
28 #define MXNET_CPP_OPERATOR_H_
29 
30 #include <map>
31 #include <string>
32 #include <vector>
33 #include "mxnet-cpp/base.h"
34 #include "mxnet-cpp/op_map.h"
35 #include "mxnet-cpp/symbol.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 class Mxnet;
43 class Operator {
44  public:
49  explicit Operator(const std::string &operator_name);
50  Operator &operator=(const Operator &rhs);
57  template <typename T>
58  Operator &SetParam(const std::string &name, const T &value) {
59  std::string value_str;
60  std::stringstream ss;
61  ss << value;
62  ss >> value_str;
63 
64  params_[name] = value_str;
65  return *this;
66  }
73  template <typename T>
74  Operator &SetParam(int pos, const T &value) {
75  std::string value_str;
76  std::stringstream ss;
77  ss << value;
78  ss >> value_str;
79 
80  params_[arg_names_[pos]] = value_str;
81  return *this;
82  }
89  Operator &SetInput(const std::string &name, Symbol symbol);
94  template<int N = 0>
95  void PushInput(const Symbol &symbol) {
96  input_symbols_.push_back(symbol.GetHandle());
97  }
102  Operator &operator()() { return *this; }
108  Operator &operator()(const Symbol &symbol) {
109  input_symbols_.push_back(symbol.GetHandle());
110  return *this;
111  }
117  Operator &operator()(const std::vector<Symbol> &symbols) {
118  for (auto &s : symbols) {
119  input_symbols_.push_back(s.GetHandle());
120  }
121  return *this;
122  }
128  Symbol CreateSymbol(const std::string &name = "");
129 
136  Operator &SetInput(const std::string &name, NDArray ndarray);
141  template<int N = 0>
142  Operator &PushInput(const NDArray &ndarray) {
143  input_ndarrays_.push_back(ndarray.GetHandle());
144  return *this;
145  }
149  template <class T, class... Args, int N = 0>
150  Operator &PushInput(const T &t, Args... args) {
151  SetParam(N, t);
152  PushInput<Args..., N+1>(args...);
153  return *this;
154  }
158  template <class T, int N = 0>
159  Operator &PushInput(const T &t) {
160  SetParam(N, t);
161  return *this;
162  }
168  Operator &operator()(const NDArray &ndarray) {
169  input_ndarrays_.push_back(ndarray.GetHandle());
170  return *this;
171  }
177  Operator &operator()(const std::vector<NDArray> &ndarrays) {
178  for (auto &s : ndarrays) {
179  input_ndarrays_.push_back(s.GetHandle());
180  }
181  return *this;
182  }
187  template <typename... Args>
188  Operator &operator()(Args... args) {
189  PushInput(args...);
190  return *this;
191  }
192  std::vector<NDArray> Invoke();
193  void Invoke(NDArray &output);
194  void Invoke(std::vector<NDArray> &outputs);
195 
196  private:
197  std::map<std::string, std::string> params_desc_;
198  bool variable_params_ = false;
199  std::map<std::string, std::string> params_;
200  std::vector<SymbolHandle> input_symbols_;
201  std::vector<NDArrayHandle> input_ndarrays_;
202  std::vector<std::string> input_keys_;
203  std::vector<std::string> arg_names_;
204  AtomicSymbolCreator handle_;
205  static OpMap*& op_map();
206 };
207 } // namespace cpp
208 } // namespace mxnet
209 
210 #endif // MXNET_CPP_OPERATOR_H_
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
definition of symbol
definition of OpMap
namespace of mxnet
Definition: base.h:118
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
void PushInput(const Symbol &symbol)
add an input symbol
Definition: operator.h:95
Operator & operator()(const Symbol &symbol)
add input symbols
Definition: operator.h:108
Operator(const std::string &operator_name)
Operator constructor.
Operator & operator()(Args...args)
add input ndarrays
Definition: operator.h:188
NDArray interface.
Definition: ndarray.h:121
Operator & operator()(const std::vector< NDArray > &ndarrays)
add a list of input ndarrays
Definition: operator.h:177
Operator & PushInput(const NDArray &ndarray)
add an input ndarray
Definition: operator.h:142
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
NDArrayHandle GetHandle() const
Definition: ndarray.h:466
Operator & operator()(const NDArray &ndarray)
add input ndarrays
Definition: operator.h:168
Operator & SetParam(int pos, const T &value)
set config parameters from positional inputs
Definition: operator.h:74
Operator & operator=(const Operator &rhs)
Operator & operator()(const std::vector< Symbol > &symbols)
add a list of input symbols
Definition: operator.h:117
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Operator & PushInput(const T &t, Args...args)
add positional inputs
Definition: operator.h:150
Operator & PushInput(const T &t)
add the last positional input
Definition: operator.h:159
Operator & operator()()
add input symbols
Definition: operator.h:102
void * AtomicSymbolCreator
handle to a function that takes param and creates symbol
Definition: c_api.h:71
SymbolHandle GetHandle() const
Definition: symbol.h:141
std::vector< NDArray > Invoke()
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72