mxnet
symbol.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_SYMBOL_H_
28 #define MXNET_CPP_SYMBOL_H_
29 
30 #include <map>
31 #include <string>
32 #include <vector>
33 #include "mxnet-cpp/base.h"
34 #include "mxnet-cpp/ndarray.h"
35 #include "mxnet-cpp/op_map.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
40 class Executor;
41 
45 struct SymBlob {
46  public:
50  SymBlob() : handle_(nullptr) {}
54  explicit SymBlob(SymbolHandle handle) : handle_(handle) {}
63 
64  private:
65  SymBlob(const SymBlob &);
66  SymBlob &operator=(const SymBlob &);
67 };
68 
72 class Symbol {
73  public:
74  Symbol() {}
79  explicit Symbol(SymbolHandle handle);
84  explicit Symbol(const char *name);
89  explicit Symbol(const std::string &name);
90  Symbol operator+(const Symbol &rhs) const;
91  Symbol operator-(const Symbol &rhs) const;
92  Symbol operator*(const Symbol &rhs) const;
93  Symbol operator/(const Symbol &rhs) const;
94  Symbol operator%(const Symbol &rhs) const;
95 
96  Symbol operator+(mx_float scalar) const;
97  Symbol operator-(mx_float scalar) const;
98  Symbol operator*(mx_float scalar) const;
99  Symbol operator/(mx_float scalar) const;
100  Symbol operator%(mx_float scalar) const;
101  Symbol Copy() const;
106  static Symbol Variable(const std::string &name = "");
107  Symbol operator[](int index);
108  Symbol operator[](const std::string &index);
113  static Symbol Group(const std::vector<Symbol> &symbols);
118  static Symbol Load(const std::string &file_name);
123  static Symbol LoadJSON(const std::string &json_str);
128  void Save(const std::string &file_name) const;
132  std::string ToJSON() const;
137  Symbol GetInternals() const;
141  SymbolHandle GetHandle() const { return (blob_ptr_) ? blob_ptr_->handle_: NULL; }
150  Symbol(const std::string &operator_name, const std::string &name,
151  std::vector<const char *> input_keys,
152  std::vector<SymbolHandle> input_values,
153  std::vector<const char *> config_keys,
154  std::vector<const char *> config_values);
163  void InferShape(
164  const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
165  std::vector<std::vector<mx_uint> > *in_shape,
166  std::vector<std::vector<mx_uint> > *aux_shape,
167  std::vector<std::vector<mx_uint> > *out_shape) const;
176  std::vector<std::string> ListArguments() const;
178  std::vector<std::string> ListOutputs() const;
180  std::vector<std::string> ListAuxiliaryStates() const;
182  std::string GetName() const;
197  void InferExecutorArrays(
198  const Context &context, std::vector<NDArray> *arg_arrays,
199  std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
200  std::vector<NDArray> *aux_arrays,
201  const std::map<std::string, NDArray> &args_map,
202  const std::map<std::string, NDArray> &arg_grad_store =
203  std::map<std::string, NDArray>(),
204  const std::map<std::string, OpReqType> &grad_req_type =
205  std::map<std::string, OpReqType>(),
206  const std::map<std::string, NDArray> &aux_map =
207  std::map<std::string, NDArray>()) const;
215  void InferArgsMap(const Context &context,
216  std::map<std::string, NDArray> *args_map,
217  const std::map<std::string, NDArray> &known_args) const;
236  Executor *SimpleBind(const Context &context,
237  const std::map<std::string, NDArray> &args_map,
238  const std::map<std::string, NDArray> &arg_grad_store =
239  std::map<std::string, NDArray>(),
240  const std::map<std::string, OpReqType> &grad_req_type =
241  std::map<std::string, OpReqType>(),
242  const std::map<std::string, NDArray> &aux_map =
243  std::map<std::string, NDArray>());
262  Executor *Bind(const Context &context, const std::vector<NDArray> &arg_arrays,
263  const std::vector<NDArray> &grad_arrays,
264  const std::vector<OpReqType> &grad_reqs,
265  const std::vector<NDArray> &aux_arrays,
266  const std::map<std::string, Context> &group_to_ctx =
267  std::map<std::string, Context>(),
268  Executor *shared_exec = nullptr);
269 
270  private:
271  std::shared_ptr<SymBlob> blob_ptr_;
272  static OpMap*& op_map();
273 };
274 Symbol operator+(mx_float lhs, const Symbol &rhs);
275 Symbol operator-(mx_float lhs, const Symbol &rhs);
276 Symbol operator*(mx_float lhs, const Symbol &rhs);
277 Symbol operator/(mx_float lhs, const Symbol &rhs);
278 Symbol operator%(mx_float lhs, const Symbol &rhs);
279 } // namespace cpp
280 } // namespace mxnet
281 #endif // MXNET_CPP_SYMBOL_H_
Symbol operator/(mx_float lhs, const Symbol &rhs)
Symbol()
Definition: symbol.h:74
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:43
definition of OpMap
namespace of mxnet
Definition: base.h:118
Executor interface.
Definition: executor.h:45
SymBlob(SymbolHandle handle)
construct with SymbolHandle to store
Definition: symbol.h:54
SymbolHandle handle_
the SymbolHandle to store
Definition: symbol.h:62
struct to store SymbolHandle
Definition: symbol.h:45
Symbol operator%(mx_float lhs, const Symbol &rhs)
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:75
SymBlob()
default constructor
Definition: symbol.h:50
~SymBlob()
destructor, free the SymbolHandle
Definition: symbol.h:58
Symbol operator+(mx_float lhs, const Symbol &rhs)
float mx_float
manually define float
Definition: c_api.h:60
Symbol operator-(mx_float lhs, const Symbol &rhs)
SymbolHandle GetHandle() const
Definition: symbol.h:141
Context interface.
Definition: ndarray.h:50
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
Symbol operator*(mx_float lhs, const Symbol &rhs)
Symbol interface.
Definition: symbol.h:72