mxnet
symbol.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_SYMBOL_H_
27 #define MXNET_CPP_SYMBOL_H_
28 
29 #include <map>
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/ndarray.h"
34 #include "mxnet-cpp/op_map.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class Executor;
40 
44 struct SymBlob {
45  public:
49  SymBlob() : handle_(nullptr) {}
53  explicit SymBlob(SymbolHandle handle) : handle_(handle) {}
62 
63  private:
64  SymBlob(const SymBlob &);
65  SymBlob &operator=(const SymBlob &);
66 };
67 
71 class Symbol {
72  public:
73  Symbol() {}
78  explicit Symbol(SymbolHandle handle);
83  explicit Symbol(const char *name);
88  explicit Symbol(const std::string &name);
89  Symbol operator+(const Symbol &rhs) const;
90  Symbol operator-(const Symbol &rhs) const;
91  Symbol operator*(const Symbol &rhs) const;
92  Symbol operator/(const Symbol &rhs) const;
93  Symbol operator%(const Symbol &rhs) const;
94 
96  Symbol operator-(mx_float scalar) const;
97  Symbol operator*(mx_float scalar) const;
98  Symbol operator/(mx_float scalar) const;
99  Symbol operator%(mx_float scalar) const;
100  Symbol Copy() const;
105  static Symbol Variable(const std::string &name = "");
106  Symbol operator[](int index);
107  Symbol operator[](const std::string &index);
112  static Symbol Group(const std::vector<Symbol> &symbols);
117  static Symbol Load(const std::string &file_name);
122  static Symbol LoadJSON(const std::string &json_str);
127  void Save(const std::string &file_name) const;
131  std::string ToJSON() const;
136  Symbol GetInternals() const;
140  SymbolHandle GetHandle() const { return (blob_ptr_) ? blob_ptr_->handle_: nullptr; }
149  Symbol(const std::string &operator_name, const std::string &name,
150  std::vector<const char *> input_keys,
151  std::vector<SymbolHandle> input_values,
152  std::vector<const char *> config_keys,
153  std::vector<const char *> config_values);
162  void InferShape(
163  const std::map<std::string, std::vector<mx_uint> > &arg_shapes,
164  std::vector<std::vector<mx_uint> > *in_shape,
165  std::vector<std::vector<mx_uint> > *aux_shape,
166  std::vector<std::vector<mx_uint> > *out_shape) const;
175  std::vector<std::string> ListArguments() const;
177  std::vector<std::string> ListInputs() const;
179  std::vector<std::string> ListOutputs() const;
181  std::vector<std::string> ListAuxiliaryStates() const;
183  std::map<std::string, std::string> ListAttributes() const;
189  void SetAttribute(const std::string& key, const std::string& value);
194  void SetAttributes(const std::map<std::string, std::string>& attrs);
196  mx_uint GetNumOutputs() const;
198  mxnet::cpp::Symbol GetBackendSymbol(const std::string& backendName) const;
200  std::string GetName() const;
215  void InferExecutorArrays(
216  const Context &context, std::vector<NDArray> *arg_arrays,
217  std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
218  std::vector<NDArray> *aux_arrays,
219  const std::map<std::string, NDArray> &args_map,
220  const std::map<std::string, NDArray> &arg_grad_store =
221  std::map<std::string, NDArray>(),
222  const std::map<std::string, OpReqType> &grad_req_type =
223  std::map<std::string, OpReqType>(),
224  const std::map<std::string, NDArray> &aux_map =
225  std::map<std::string, NDArray>()) const;
233  void InferArgsMap(const Context &context,
234  std::map<std::string, NDArray> *args_map,
235  const std::map<std::string, NDArray> &known_args) const;
254  Executor *SimpleBind(const Context &context,
255  const std::map<std::string, NDArray> &args_map,
256  const std::map<std::string, NDArray> &arg_grad_store =
257  std::map<std::string, NDArray>(),
258  const std::map<std::string, OpReqType> &grad_req_type =
259  std::map<std::string, OpReqType>(),
260  const std::map<std::string, NDArray> &aux_map =
261  std::map<std::string, NDArray>());
280  Executor *Bind(const Context &context, const std::vector<NDArray> &arg_arrays,
281  const std::vector<NDArray> &grad_arrays,
282  const std::vector<OpReqType> &grad_reqs,
283  const std::vector<NDArray> &aux_arrays,
284  const std::map<std::string, Context> &group_to_ctx =
285  std::map<std::string, Context>(),
286  Executor *shared_exec = nullptr);
287 
288  private:
289  std::shared_ptr<SymBlob> blob_ptr_;
290  static OpMap*& op_map();
291 };
292 Symbol operator+(mx_float lhs, const Symbol &rhs);
293 Symbol operator-(mx_float lhs, const Symbol &rhs);
294 Symbol operator*(mx_float lhs, const Symbol &rhs);
295 Symbol operator/(mx_float lhs, const Symbol &rhs);
296 Symbol operator%(mx_float lhs, const Symbol &rhs);
297 } // namespace cpp
298 } // namespace mxnet
299 #endif // MXNET_CPP_SYMBOL_H_
Symbol operator/(mx_float lhs, const Symbol &rhs)
Symbol()
Definition: symbol.h:73
OpMap instance holds a map of all the symbol creators so we can get symbol creators by name...
Definition: op_map.h:42
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:74
float mx_float
manually define float
Definition: c_api.h:59
definition of OpMap
namespace of mxnet
Definition: api_registry.h:33
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
Executor interface.
Definition: executor.h:44
SymBlob(SymbolHandle handle)
construct with SymbolHandle to store
Definition: symbol.h:53
SymbolHandle handle_
the SymbolHandle to store
Definition: symbol.h:61
MXNET_DLL int MXSymbolFree(SymbolHandle symbol)
Free the symbol handle.
struct to store SymbolHandle
Definition: symbol.h:44
Symbol operator%(mx_float lhs, const Symbol &rhs)
SymbolHandle GetHandle() const
Definition: symbol.h:140
SymBlob()
default constructor
Definition: symbol.h:49
~SymBlob()
destructor, free the SymbolHandle
Definition: symbol.h:57
Symbol operator+(mx_float lhs, const Symbol &rhs)
Graph LoadJSON(const std::string &json_str)
Load a graph from JSON string, redirects to "LoadJSON" pass.
Definition: pass_functions.h:47
Symbol operator-(mx_float lhs, const Symbol &rhs)
Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key="")
Infer shapes in the graph given the information.
Definition: pass_functions.h:97
Context interface.
Definition: ndarray.h:49
Symbol operator*(mx_float lhs, const Symbol &rhs)
Symbol interface.
Definition: symbol.h:71
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:57